/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.dyadranking.activelearning;

import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.core.exception.TrainingException;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.activelearning.ARandomlyInitializingDyadRanker;
import ai.libs.jaicore.ml.dyadranking.activelearning.IDyadRankingPoolProvider;
import ai.libs.jaicore.ml.dyadranking.algorithm.PLNetDyadRanker;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.SparseDyadRankingInstance;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Random;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.clusterers.Clusterer;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

public class ConfidenceIntervalClusteringBasedActiveDyadRanker
extends ARandomlyInitializingDyadRanker {
    private static final Logger log = LoggerFactory.getLogger(ConfidenceIntervalClusteringBasedActiveDyadRanker.class);
    private Clusterer clusterer;

    public ConfidenceIntervalClusteringBasedActiveDyadRanker(PLNetDyadRanker ranker, IDyadRankingPoolProvider poolProvider, int seed, int numberRandomQueriesAtStart, int minibatchSize, Clusterer clusterer) {
        super(ranker, poolProvider, seed, numberRandomQueriesAtStart, minibatchSize);
        this.clusterer = clusterer;
    }

    @Override
    public void activelyTrainWithOneInstance() {
        PriorityQueue<List<Dyad>> clusterQueue = new PriorityQueue<List<Dyad>>(new ListComparator());
        HashSet<IDyadRankingInstance> minibatch = new HashSet<IDyadRankingInstance>();
        Map<Dyad, SummaryStatistics> dyadStats = this.getDyadStats();
        for (Vector inst : this.getInstanceFeatures()) {
            Attribute upperAttr = new Attribute("upper_bound");
            Attribute lowerAttr = new Attribute("lower_bound");
            ArrayList<Attribute> attributes = new ArrayList<Attribute>();
            attributes.add(upperAttr);
            attributes.add(lowerAttr);
            Instances intervalInstances = new Instances("confidence_intervalls", attributes, this.poolProvider.getDyadsByInstance(inst).size());
            for (Dyad dyad : this.poolProvider.getDyadsByInstance(inst)) {
                double skill = this.ranker.getSkillForDyad(dyad);
                dyadStats.get(dyad).addValue(skill);
                double[] attValues = new double[]{skill + dyadStats.get(dyad).getStandardDeviation(), skill - dyadStats.get(dyad).getStandardDeviation()};
                DenseInstance intervalInstance = new DenseInstance(1.0, attValues);
                intervalInstances.add((Instance)intervalInstance);
            }
            try {
                this.clusterer.buildClusterer(intervalInstances);
                ArrayList instanceClusters = new ArrayList();
                int numClusters = this.clusterer.numberOfClusters();
                for (int clusterIndex = 0; clusterIndex < numClusters; ++clusterIndex) {
                    instanceClusters.add(new ArrayList());
                }
                for (Dyad dyad : this.poolProvider.getDyadsByInstance(inst)) {
                    double skill = this.ranker.getSkillForDyad(dyad);
                    double[] attValues = new double[]{skill + dyadStats.get(dyad).getStandardDeviation(), skill - dyadStats.get(dyad).getStandardDeviation()};
                    DenseInstance intervalInstance = new DenseInstance(1.0, attValues);
                    int cluster = this.clusterer.clusterInstance((Instance)intervalInstance);
                    ((List)instanceClusters.get(cluster)).add(dyad);
                }
                for (int j = 0; j < instanceClusters.size(); ++j) {
                    clusterQueue.add((List<Dyad>)instanceClusters.get(j));
                }
            }
            catch (Exception e1) {
                log.error(e1.getMessage());
            }
        }
        Random random = this.getRandom();
        for (int minibatchIndex = 0; minibatchIndex < this.getMinibatchSize(); ++minibatchIndex) {
            List<Dyad> curDyads = clusterQueue.poll();
            if (curDyads.size() < 2) continue;
            double curMax = -1.0;
            int[] curPair = new int[]{0, 1};
            boolean changed = false;
            for (int j = 1; j < curDyads.size(); ++j) {
                for (int k = 0; k < j; ++k) {
                    Dyad dyad2;
                    Dyad dyad1 = curDyads.get(j);
                    double overlap = this.getConfidenceIntervalOverlapForDyads(dyad1, dyad2 = curDyads.get(k));
                    if (!(overlap > curMax)) continue;
                    curPair[0] = j;
                    curPair[1] = k;
                    curMax = overlap;
                    changed = true;
                }
            }
            if (!changed) {
                curPair[0] = random.nextInt(curDyads.size());
                curPair[1] = random.nextInt(curDyads.size());
                while (curPair[0] == curPair[1]) {
                    curPair[1] = random.nextInt(curDyads.size());
                }
            }
            LinkedList<Vector> alternatives = new LinkedList<Vector>();
            alternatives.add(curDyads.get(curPair[0]).getAlternative());
            alternatives.add(curDyads.get(curPair[1]).getAlternative());
            SparseDyadRankingInstance queryInstance = new SparseDyadRankingInstance(curDyads.get(curPair[0]).getInstance(), alternatives);
            IDyadRankingInstance trueRanking = this.poolProvider.query(queryInstance);
            minibatch.add(trueRanking);
        }
        try {
            this.updateRanker(minibatch);
        }
        catch (TrainingException e) {
            log.error(e.getMessage());
        }
    }

    private double getConfidenceIntervalOverlapForDyads(Dyad dyad1, Dyad dyad2) {
        double skill1 = this.ranker.getSkillForDyad(dyad1);
        double skill2 = this.ranker.getSkillForDyad(dyad2);
        Map<Dyad, SummaryStatistics> dyadStats = this.getDyadStats();
        double lower1 = skill1 - dyadStats.get(dyad1).getStandardDeviation();
        double upper1 = skill1 + dyadStats.get(dyad1).getStandardDeviation();
        double lower2 = skill2 - dyadStats.get(dyad2).getStandardDeviation();
        double upper2 = skill2 + dyadStats.get(dyad2).getStandardDeviation();
        if (lower1 > upper2 || upper1 < lower2) {
            return 0.0;
        }
        double upperlower = Math.max(lower1, lower2);
        double lowerupper = Math.min(upper1, upper2);
        return Math.abs(lowerupper - upperlower);
    }

    private class ListComparator
    implements Comparator<List<Dyad>> {
        private ListComparator() {
        }

        @Override
        public int compare(List<Dyad> o1, List<Dyad> o2) {
            if (o1.size() > o2.size()) {
                return -1;
            }
            if (o1.size() < o2.size()) {
                return 1;
            }
            return 0;
        }
    }
}

