/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.tsc.classifier.ensemble;

import java.util.Arrays;
import java.util.Random;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.Evaluation;
import weka.classifiers.meta.Vote;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;

public class MajorityConfidenceVote
extends Vote {
    private static final long serialVersionUID = -7128109840679632228L;
    private int numFolds;
    private double[] classifierWeights;
    private int seed;

    public MajorityConfidenceVote(int numFolds, int seed) {
        this.numFolds = numFolds;
    }

    public void buildClassifier(Instances data) throws Exception {
        int i;
        this.classifierWeights = new double[this.m_Classifiers.length];
        Instances newData = new Instances(data);
        newData.deleteWithMissingClass();
        this.m_structure = new Instances(newData, 0);
        this.getCapabilities().testWithFail(data);
        for (i = 0; i < this.m_Classifiers.length; ++i) {
            if (Thread.currentThread().isInterrupted()) {
                throw new InterruptedException();
            }
            for (int n = 0; n < this.numFolds; ++n) {
                Instances train = data.trainCV(this.numFolds, n, new Random(this.seed));
                Instances test = data.testCV(this.numFolds, n);
                this.getClassifier(i).buildClassifier(train);
                Evaluation eval = new Evaluation(train);
                eval.evaluateModel(this.getClassifier(i), test, new Object[0]);
                int n2 = i;
                this.classifierWeights[n2] = this.classifierWeights[n2] + eval.pctCorrect() / 100.0;
            }
            this.classifierWeights[i] = Math.pow(this.classifierWeights[i], 2.0);
            int n = i;
            this.classifierWeights[n] = this.classifierWeights[n] / (double)this.numFolds;
            this.getClassifier(i).buildClassifier(newData);
        }
        if (Arrays.stream(this.classifierWeights).allMatch(d -> d < 1.0E-6)) {
            for (i = 0; i < this.classifierWeights.length; ++i) {
                this.classifierWeights[i] = 1.0 / (double)this.classifierWeights.length;
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        int j;
        double[] dist;
        int i;
        double[] probs = new double[instance.numClasses()];
        for (int i2 = 0; i2 < probs.length; ++i2) {
            probs[i2] = 1.0;
        }
        int numPredictions = 0;
        for (i = 0; i < this.m_Classifiers.length; ++i) {
            if (Thread.currentThread().isInterrupted()) {
                throw new InterruptedException();
            }
            dist = this.getClassifier(i).distributionForInstance(instance);
            if (!(Utils.sum((double[])dist) > 0.0)) continue;
            for (j = 0; j < dist.length; ++j) {
                int n = j;
                probs[n] = probs[n] + this.classifierWeights[i] * dist[j];
            }
            ++numPredictions;
        }
        for (i = 0; i < this.m_preBuiltClassifiers.size(); ++i) {
            if (Thread.currentThread().isInterrupted()) {
                throw new InterruptedException();
            }
            dist = ((Classifier)this.m_preBuiltClassifiers.get(i)).distributionForInstance(instance);
            if (!(Utils.sum((double[])dist) > 0.0)) continue;
            for (j = 0; j < dist.length; ++j) {
                int n = j;
                probs[n] = probs[n] * dist[j];
            }
            ++numPredictions;
        }
        if (numPredictions == 0) {
            return new double[instance.numClasses()];
        }
        if (Utils.sum((double[])probs) > 0.0) {
            Utils.normalize((double[])probs);
        }
        return probs;
    }

    public double classifyInstance(Instance instance) throws Exception {
        int index;
        double[] dist = this.distributionForInstance(instance);
        double result = instance.classAttribute().isNominal() ? (dist[index = Utils.maxIndex((double[])dist)] == 0.0 ? Utils.missingValue() : (double)index) : (instance.classAttribute().isNumeric() ? dist[0] : Utils.missingValue());
        return result;
    }
}

