/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.tsc.util;

import ai.libs.jaicore.ml.core.dataset.TimeSeriesInstance;
import ai.libs.jaicore.ml.core.dataset.attribute.IAttributeValue;
import ai.libs.jaicore.ml.core.dataset.attribute.timeseries.TimeSeriesAttributeValue;
import ai.libs.jaicore.ml.core.exception.TrainingException;
import ai.libs.jaicore.ml.tsc.dataset.TimeSeriesDataset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

public class WekaUtil {
    private static final String I_NAME = "Instances";

    private WekaUtil() {
    }

    private static INDArray hstackINDArrays(List<INDArray> matrices) {
        INDArray combinedMatrix;
        int i;
        if (!matrices.isEmpty()) {
            long[] shape = matrices.get(0).shape();
            for (i = 1; i < matrices.size(); ++i) {
                if (matrices.get(i).shape()[0] == shape[0]) continue;
                throw new IllegalArgumentException("First dimensionality of the given matrices must be equal!");
            }
        }
        if (!matrices.isEmpty()) {
            combinedMatrix = matrices.get(0).dup();
            for (i = 1; i < matrices.size(); ++i) {
                combinedMatrix = Nd4j.hstack((INDArray[])new INDArray[]{combinedMatrix, matrices.get(i)});
            }
        } else {
            combinedMatrix = Nd4j.create((int)0, (int)0);
        }
        return combinedMatrix;
    }

    public static Instance tsInstanceToWekaInstance(TimeSeriesInstance<?> instance) {
        IAttributeValue<?>[] attValues = instance.getAllAttributeValues();
        ArrayList<INDArray> indArrays = new ArrayList<INDArray>();
        for (IAttributeValue<?> attValue : attValues) {
            if (!(attValue instanceof TimeSeriesAttributeValue)) continue;
            indArrays.add((INDArray)((TimeSeriesAttributeValue)attValue).getValue());
        }
        INDArray combinedMatrix = WekaUtil.hstackINDArrays(indArrays);
        DenseInstance finalInstance = new DenseInstance(1.0, Nd4j.toFlattened((INDArray[])new INDArray[]{combinedMatrix}).toDoubleVector());
        finalInstance.setClassValue((double)ai.libs.jaicore.ml.WekaUtil.getIntValOfClassName((Instance)finalInstance, (String)instance.getTargetValue()));
        return finalInstance;
    }

    public static Instance simplifiedTSInstanceToWekaInstance(double[] instance) {
        return new DenseInstance(1.0, instance);
    }

    public static <L> void buildWekaClassifierFromTS(Classifier classifier, ai.libs.jaicore.ml.core.dataset.TimeSeriesDataset<L> timeSeriesDataset) throws TrainingException {
        Instances trainingInstances = WekaUtil.timeSeriesDatasetToWekaInstances(timeSeriesDataset);
        try {
            classifier.buildClassifier(trainingInstances);
        }
        catch (Exception e) {
            throw new TrainingException("Could not train classifier " + classifier.getClass().getName() + " due to a Weka exception.", e);
        }
    }

    public static void buildWekaClassifierFromSimplifiedTS(Classifier classifier, TimeSeriesDataset timeSeriesDataset) throws TrainingException {
        Instances trainingInstances = WekaUtil.simplifiedTimeSeriesDatasetToWekaInstances(timeSeriesDataset);
        try {
            classifier.buildClassifier(trainingInstances);
        }
        catch (Exception e) {
            throw new TrainingException(String.format("Could not train classifier %s due to a Weka exception.", classifier.getClass().getName()), e);
        }
    }

    public static INDArray wekaInstancesToINDArray(Instances instances, boolean keepClass) {
        if (instances == null || instances.isEmpty()) {
            throw new IllegalArgumentException("Instances must not be null or empty!");
        }
        int classSub = keepClass || instances.classIndex() < -1 ? 0 : 1;
        int numAttributes = instances.numAttributes() - classSub;
        int numInstances = instances.numInstances();
        INDArray result = Nd4j.create((int)numInstances, (int)numAttributes);
        for (int i = 0; i < numInstances; ++i) {
            double[] instValues = instances.get(i).toDoubleArray();
            for (int j = 0; j < numAttributes; ++j) {
                result.putScalar(new int[]{i, j}, instValues[j]);
            }
        }
        return result;
    }

    public static <L> Instances timeSeriesDatasetToWekaInstances(ai.libs.jaicore.ml.core.dataset.TimeSeriesDataset<L> dataSet) {
        ArrayList<INDArray> matrices = new ArrayList<INDArray>();
        for (int i = 0; i < dataSet.getNumberOfVariables(); ++i) {
            matrices.add(dataSet.getValues(i));
        }
        ArrayList<Attribute> attributes = new ArrayList<Attribute>();
        for (int m = 0; m < matrices.size(); ++m) {
            INDArray matrix = (INDArray)matrices.get(m);
            int i = 0;
            while ((long)i < matrix.shape()[1]) {
                Attribute newAtt = new Attribute(String.format("val_%d_%d", m, i));
                attributes.add(newAtt);
                ++i;
            }
        }
        INDArray targets = dataSet.getTargets();
        attributes.add(new Attribute("class", IntStream.rangeClosed((int)targets.minNumber().longValue(), (int)targets.maxNumber().longValue()).boxed().map(String::valueOf).collect(Collectors.toList())));
        Instances result = new Instances(I_NAME, attributes, (int)dataSet.getNumberOfInstances());
        result.setClassIndex(result.numAttributes() - 1);
        INDArray combinedMatrix = WekaUtil.hstackINDArrays(matrices);
        int i = 0;
        while ((long)i < dataSet.getNumberOfInstances()) {
            DenseInstance inst = new DenseInstance(1.0, Nd4j.hstack((INDArray[])new INDArray[]{Nd4j.toFlattened((INDArray[])new INDArray[]{combinedMatrix.getRow((long)i)}), Nd4j.create((double[])new double[]{targets.getDouble((long)i)})}).toDoubleVector());
            inst.setDataset(result);
            result.add((Instance)inst);
            ++i;
        }
        return result;
    }

    public static Instances simplifiedTimeSeriesDatasetToWekaInstances(TimeSeriesDataset dataSet) {
        int[] targets = dataSet.getTargets();
        List<Integer> targetList = Arrays.asList(ArrayUtils.toObject((int[])targets));
        int min = Collections.min(targetList);
        int max = Collections.max(targetList);
        List<String> classValues = IntStream.rangeClosed(min, max).boxed().map(String::valueOf).collect(Collectors.toList());
        return WekaUtil.simplifiedTimeSeriesDatasetToWekaInstances(dataSet, classValues);
    }

    public static Instances simplifiedTimeSeriesDatasetToWekaInstances(TimeSeriesDataset dataSet, List<String> classValues) {
        int i;
        ArrayList<double[][]> matrices = new ArrayList<double[][]>();
        for (int i2 = 0; i2 < dataSet.getNumberOfVariables(); ++i2) {
            matrices.add(dataSet.getValues(i2));
        }
        ArrayList<Attribute> attributes = new ArrayList<Attribute>();
        for (int m = 0; m < matrices.size(); ++m) {
            double[][] matrix = (double[][])matrices.get(m);
            if (matrix == null) continue;
            for (i = 0; i < matrix[0].length; ++i) {
                Attribute newAtt = new Attribute(String.format("val_%d_%d", m, i));
                attributes.add(newAtt);
            }
        }
        int[] targets = dataSet.getTargets();
        attributes.add(new Attribute("class", classValues));
        Instances result = new Instances(I_NAME, attributes, dataSet.getNumberOfInstances());
        result.setClassIndex(result.numAttributes() - 1);
        for (i = 0; i < dataSet.getNumberOfInstances(); ++i) {
            double[] concatenatedRow = ((double[][])matrices.get(0))[i];
            for (int j = 1; j < matrices.size(); ++j) {
                concatenatedRow = ArrayUtils.addAll((double[])concatenatedRow, (double[])((double[][])matrices.get(j))[i]);
            }
            concatenatedRow = ArrayUtils.addAll((double[])concatenatedRow, (double[])new double[]{targets[i]});
            DenseInstance inst = new DenseInstance(1.0, concatenatedRow);
            inst.setDataset(result);
            result.add((Instance)inst);
        }
        return result;
    }

    public static Instances indArrayToWekaInstances(INDArray matrix) {
        if (matrix == null || matrix.length() == 0L) {
            throw new IllegalArgumentException("Matrix must not be null or empty!");
        }
        if (matrix.shape().length != 2) {
            throw new IllegalArgumentException(String.format("Parameter matrix must be a matrix with 2 axis (instances x attributes). Actual shape: (%s)", Arrays.toString(matrix.shape())));
        }
        int numInstances = (int)matrix.shape()[0];
        int numAttributes = (int)matrix.shape()[1];
        ArrayList<Attribute> attributes = new ArrayList<Attribute>();
        for (int i = 0; i < numAttributes; ++i) {
            Attribute newAtt = new Attribute("val" + i);
            attributes.add(newAtt);
        }
        Instances result = new Instances(I_NAME, attributes, numInstances);
        for (int i = 0; i < numInstances; ++i) {
            DenseInstance inst = new DenseInstance(1.0, Nd4j.toFlattened((INDArray[])new INDArray[]{matrix.getRow((long)i)}).toDoubleVector());
            inst.setDataset(result);
            result.add((Instance)inst);
        }
        return result;
    }

    public static Instances matrixToWekaInstances(double[][] matrix) {
        ArrayList<Attribute> attributes = new ArrayList<Attribute>();
        for (int i = 0; i < matrix[0].length; ++i) {
            Attribute newAtt = new Attribute("val" + i);
            attributes.add(newAtt);
        }
        Instances wekaInstances = new Instances(I_NAME, attributes, matrix.length);
        for (int i = 0; i < matrix[0].length; ++i) {
            DenseInstance inst = new DenseInstance(1.0, matrix[i]);
            inst.setDataset(wekaInstances);
            wekaInstances.add((Instance)inst);
        }
        return wekaInstances;
    }
}

