/*
 * Decompiled with CFR 0.152.
 */
package tech.tablesaw.api.ml.classification;

import com.google.common.base.Preconditions;
import java.util.Arrays;
import java.util.Collection;
import java.util.TreeSet;
import smile.classification.KNN;
import tech.tablesaw.api.BooleanColumn;
import tech.tablesaw.api.CategoryColumn;
import tech.tablesaw.api.IntColumn;
import tech.tablesaw.api.NumericColumn;
import tech.tablesaw.api.ShortColumn;
import tech.tablesaw.api.ml.classification.AbstractClassifier;
import tech.tablesaw.api.ml.classification.CategoryConfusionMatrix;
import tech.tablesaw.api.ml.classification.ConfusionMatrix;
import tech.tablesaw.api.ml.classification.StandardConfusionMatrix;
import tech.tablesaw.util.DoubleArrays;

public class Knn
extends AbstractClassifier {
    private final KNN<double[]> classifierModel;

    private Knn(KNN<double[]> classifierModel) {
        this.classifierModel = classifierModel;
    }

    public static Knn learn(int k, ShortColumn labels, NumericColumn ... predictors) {
        KNN classifierModel = KNN.learn((double[][])DoubleArrays.to2dArray(predictors), (int[])labels.toIntArray(), (int)k);
        return new Knn((KNN<double[]>)classifierModel);
    }

    public static Knn learn(int k, IntColumn labels, NumericColumn ... predictors) {
        KNN classifierModel = KNN.learn((double[][])DoubleArrays.to2dArray(predictors), (int[])labels.data().toIntArray(), (int)k);
        return new Knn((KNN<double[]>)classifierModel);
    }

    public static Knn learn(int k, BooleanColumn labels, NumericColumn ... predictors) {
        KNN classifierModel = KNN.learn((double[][])DoubleArrays.to2dArray(predictors), (int[])labels.toIntArray(), (int)k);
        return new Knn((KNN<double[]>)classifierModel);
    }

    public static Knn learn(int k, CategoryColumn labels, NumericColumn ... predictors) {
        KNN classifierModel = KNN.learn((double[][])DoubleArrays.to2dArray(predictors), (int[])labels.data().toIntArray(), (int)k);
        return new Knn((KNN<double[]>)classifierModel);
    }

    public int predict(double[] data) {
        return this.classifierModel.predict((Object)data);
    }

    public ConfusionMatrix predictMatrix(ShortColumn labels, NumericColumn ... predictors) {
        Preconditions.checkArgument((predictors.length > 0 ? 1 : 0) != 0);
        TreeSet<Object> labelSet = new TreeSet<Object>((Collection<Object>)labels.asSet());
        StandardConfusionMatrix confusion = new StandardConfusionMatrix(labelSet);
        this.populateMatrix(labels.toIntArray(), confusion, predictors);
        return confusion;
    }

    public ConfusionMatrix predictMatrix(IntColumn labels, NumericColumn ... predictors) {
        Preconditions.checkArgument((predictors.length > 0 ? 1 : 0) != 0);
        TreeSet<Object> labelSet = new TreeSet<Object>((Collection<Object>)labels.asSet());
        StandardConfusionMatrix confusion = new StandardConfusionMatrix(labelSet);
        this.populateMatrix(labels.data().toIntArray(), confusion, predictors);
        return confusion;
    }

    public ConfusionMatrix predictMatrix(BooleanColumn labels, NumericColumn ... predictors) {
        Preconditions.checkArgument((predictors.length > 0 ? 1 : 0) != 0);
        TreeSet<Object> labelSet = new TreeSet<Object>((Collection<Object>)labels.asSet());
        StandardConfusionMatrix confusion = new StandardConfusionMatrix(labelSet);
        this.populateMatrix(labels.toIntArray(), confusion, predictors);
        return confusion;
    }

    public ConfusionMatrix predictMatrix(CategoryColumn labels, NumericColumn ... predictors) {
        Preconditions.checkArgument((predictors.length > 0 ? 1 : 0) != 0);
        TreeSet<String> labelSet = new TreeSet<String>(labels.asSet());
        CategoryConfusionMatrix confusion = new CategoryConfusionMatrix(labels, labelSet);
        this.populateMatrix(labels.data().toIntArray(), confusion, predictors);
        return confusion;
    }

    public int[] predict(NumericColumn ... predictors) {
        Preconditions.checkArgument((predictors.length > 0 ? 1 : 0) != 0);
        int[] predictedLabels = new int[predictors[0].size()];
        for (int row = 0; row < predictors[0].size(); ++row) {
            double[] data = new double[predictors.length];
            for (int col = 0; col < predictors.length; ++col) {
                data[row] = predictors[col].getFloat(row);
            }
            predictedLabels[row] = this.classifierModel.predict((Object)data);
        }
        return predictedLabels;
    }

    @Override
    int predictFromModel(double[] data) {
        if (data[0] == 5.0) {
            System.out.println(Arrays.toString(data));
        }
        return this.classifierModel.predict((Object)data);
    }
}

