/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka;

import ai.libs.jaicore.basic.Maps;
import ai.libs.jaicore.ml.weka.RankingByPairwiseComparisonConfig;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Add;
import weka.filters.unsupervised.attribute.Remove;

public class RankingByPairwiseComparison {
    private RankingByPairwiseComparisonConfig config;
    private List<Integer> labelIndices;
    private Set<String> labelSet = new HashSet<String>();
    private List<PairWiseClassifier> pwClassifiers = new LinkedList<PairWiseClassifier>();

    public RankingByPairwiseComparison(RankingByPairwiseComparisonConfig config) {
        this.config = config;
    }

    private Instances applyFiltersToDataset(Instances dataset) throws Exception {
        Remove removeFilter = new Remove();
        removeFilter.setAttributeIndicesArray(this.labelIndices.stream().mapToInt(x -> x).toArray());
        removeFilter.setInvertSelection(false);
        removeFilter.setInputFormat(dataset);
        Instances filteredDataset = Filter.useFilter((Instances)dataset, (Filter)removeFilter);
        Add addTarget = new Add();
        addTarget.setAttributeIndex("last");
        addTarget.setNominalLabels("true,false");
        addTarget.setAttributeName("a>b");
        addTarget.setInputFormat(filteredDataset);
        filteredDataset = Filter.useFilter((Instances)filteredDataset, (Filter)addTarget);
        filteredDataset.setClassIndex(filteredDataset.numAttributes() - 1);
        return filteredDataset;
    }

    private static List<Integer> getLabelIndices(int labels, Instances dataset) {
        LinkedList<Integer> labelIndices = new LinkedList<Integer>();
        if (labels < 0) {
            for (int i = dataset.numAttributes() - 1; i >= dataset.numAttributes() + labels; --i) {
                labelIndices.add(i);
            }
        } else {
            for (int i = 0; i < labels; ++i) {
                labelIndices.add(i);
            }
        }
        return labelIndices;
    }

    public void fit(Instances dataset, int labels) throws Exception {
        this.labelIndices = RankingByPairwiseComparison.getLabelIndices(labels, dataset);
        this.labelIndices.stream().map(x -> dataset.attribute(x.intValue()).name()).forEach(this.labelSet::add);
        Instances plainPWDataset = this.applyFiltersToDataset(dataset);
        try {
            for (int i = 0; i < this.labelIndices.size() - 1; ++i) {
                for (int j = i + 1; j < this.labelIndices.size(); ++j) {
                    PairWiseClassifier pwc = new PairWiseClassifier();
                    pwc.a = dataset.attribute(this.labelIndices.get(i).intValue()).name();
                    pwc.b = dataset.attribute(this.labelIndices.get(j).intValue()).name();
                    pwc.c = AbstractClassifier.forName((String)this.config.getBaseLearner(), null);
                    Instances pwDataset = new Instances(plainPWDataset);
                    for (int k = 0; k < pwDataset.size(); ++k) {
                        String value = dataset.get(k).value(this.labelIndices.get(i).intValue()) > dataset.get(k).value(this.labelIndices.get(j).intValue()) ? "true" : "false";
                        pwDataset.get(k).setValue(pwDataset.numAttributes() - 1, value);
                    }
                    pwDataset.setClassIndex(pwDataset.numAttributes() - 1);
                    pwc.c.buildClassifier(pwDataset);
                    this.pwClassifiers.add(pwc);
                }
            }
        }
        catch (Exception e) {
            throw new TrainingException("Could not build ranker", (Throwable)e);
        }
    }

    public List<String> predict(Instance xTest) throws PredictionException {
        try {
            Instances datasetCopy = new Instances(xTest.dataset(), 0);
            datasetCopy.add(xTest);
            datasetCopy = this.applyFiltersToDataset(datasetCopy);
            HashMap vote = new HashMap();
            this.labelSet.stream().forEach(x -> vote.put(x, 0.0));
            block9: for (PairWiseClassifier pwc : this.pwClassifiers) {
                double[] dist = pwc.c.distributionForInstance(datasetCopy.get(0));
                switch (this.config.getVotingStrategy()) {
                    case "classify": {
                        if (dist[0] > dist[1]) {
                            Maps.increaseCounterInDoubleMap(vote, (Object)pwc.a);
                            continue block9;
                        }
                        Maps.increaseCounterInDoubleMap(vote, (Object)pwc.b);
                        continue block9;
                    }
                }
                Maps.increaseCounterInDoubleMap(vote, (Object)pwc.a, (double)dist[0]);
                Maps.increaseCounterInDoubleMap(vote, (Object)pwc.b, (double)dist[1]);
            }
            LinkedList<String> ranking = new LinkedList<String>(vote.keySet());
            ranking.sort((arg0, arg1) -> ((Double)vote.get(arg1)).compareTo((Double)vote.get(arg0)));
            return ranking;
        }
        catch (Exception e) {
            throw new PredictionException("Could not create a prediction.", (Throwable)e);
        }
    }

    class PairWiseClassifier {
        private String a;
        private String b;
        private Classifier c;

        PairWiseClassifier() {
        }
    }
}

