/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.fs;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import org.tribuo.Dataset;
import org.tribuo.FeatureSelector;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.SelectedFeatureSet;
import org.tribuo.classification.Label;
import org.tribuo.classification.fs.FSMatrix;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.FeatureSelectorProvenance;
import org.tribuo.provenance.FeatureSetProvenance;
import org.tribuo.provenance.impl.FeatureSelectorProvenanceImpl;

public final class mRMR
implements FeatureSelector<Label> {
    private static final Logger logger = Logger.getLogger(mRMR.class.getName());
    @Config(mandatory=true, description="Number of bins to use when discretising continuous features.")
    private int numBins;
    @Config(description="Number of features to select, defaults to ranking all features.")
    private int k = -1;
    @Config(description="Number of computation threads to use.")
    private int numThreads = 1;

    private mRMR() {
    }

    public mRMR(int k, int numBins, int numThreads) {
        this.k = k;
        this.numBins = numBins;
        this.numThreads = numThreads;
        if (k != -1 && k < 1) {
            throw new IllegalArgumentException("k must be -1 to select all features, or a positive number, found " + k);
        }
        if (numBins < 2) {
            throw new IllegalArgumentException("numBins must be >= 2, found " + numBins);
        }
    }

    public void postConfig() {
        if (this.k != -1 && this.k < 1) {
            throw new PropertyException("", "k", "k must be -1 to select all features, or a positive number, found " + this.k);
        }
        if (this.numBins < 2) {
            throw new PropertyException("", "numBins", "numBins must be >= 2, found " + this.numBins);
        }
    }

    public boolean isOrdered() {
        return true;
    }

    public SelectedFeatureSet select(Dataset<Label> dataset) {
        int i;
        double[] miCache;
        FSMatrix data = FSMatrix.buildMatrix(dataset, this.numBins);
        ImmutableFeatureMap fmap = data.getFeatureMap();
        int max = this.k == -1 ? fmap.size() : Math.min(this.k, fmap.size());
        int numFeatures = fmap.size();
        boolean[] unselectedFeatures = new boolean[numFeatures];
        Arrays.fill(unselectedFeatures, true);
        int[] selectedFeatures = new int[max];
        double[] selectedScores = new double[max];
        double[] redundancyCache = new double[numFeatures];
        ForkJoinPool fjp = null;
        if (this.numThreads > 1) {
            fjp = new ForkJoinPool(this.numThreads);
            try {
                miCache = (double[])((ForkJoinTask)fjp.submit(() -> IntStream.range(0, numFeatures).parallel().mapToDouble(data::mi).toArray())).get();
            }
            catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            }
        } else {
            miCache = IntStream.range(0, numFeatures).mapToDouble(data::mi).toArray();
        }
        int curIdx = -1;
        double curVal = -1.0;
        for (i = 0; i < numFeatures; ++i) {
            if (!(miCache[i] > curVal)) continue;
            curIdx = i;
            curVal = miCache[i];
        }
        selectedFeatures[0] = curIdx;
        unselectedFeatures[curIdx] = false;
        selectedScores[0] = curVal;
        logger.log(Level.INFO, "Itr 0: selected feature " + fmap.get(curIdx).getName() + ", score = " + selectedScores[0]);
        for (i = 1; i < max; ++i) {
            int maxIdx;
            Pair maxPair;
            int j;
            if (this.numThreads > 1) {
                int prevIdx = selectedFeatures[i - 1];
                int curI = i;
                try {
                    double[] updates = (double[])((ForkJoinTask)fjp.submit(() -> IntStream.range(0, numFeatures).parallel().mapToDouble(j -> unselectedFeatures[j] ? data.mi(j, prevIdx) : 0.0).toArray())).get();
                    for (j = 0; j < redundancyCache.length; ++j) {
                        int n = j;
                        redundancyCache[n] = redundancyCache[n] + updates[j];
                    }
                    maxPair = (Pair)((ForkJoinTask)fjp.submit(() -> IntStream.range(0, numFeatures).parallel().filter(j -> unselectedFeatures[j]).mapToObj(j -> new Pair((Object)j, (Object)(miCache[j] - redundancyCache[j] / (double)curI))).max(Comparator.comparingDouble(Pair::getB)).get())).get();
                }
                catch (InterruptedException | ExecutionException e) {
                    throw new RuntimeException(e);
                }
            } else {
                int maxIndex = -1;
                double maxScore = Double.NEGATIVE_INFINITY;
                for (j = 0; j < numFeatures; ++j) {
                    if (!unselectedFeatures[j]) continue;
                    int prevIdx = selectedFeatures[i - 1];
                    int n = j;
                    redundancyCache[n] = redundancyCache[n] + data.mi(j, prevIdx);
                    double sum = miCache[j] - redundancyCache[j] / (double)i;
                    if (!(sum > maxScore)) continue;
                    maxScore = sum;
                    maxIndex = j;
                }
                maxPair = new Pair((Object)maxIndex, (Object)maxScore);
            }
            selectedFeatures[i] = maxIdx = ((Integer)maxPair.getA()).intValue();
            unselectedFeatures[maxIdx] = false;
            selectedScores[i] = (Double)maxPair.getB();
            logger.log(Level.INFO, "Itr " + i + ": selected feature " + fmap.get(maxIdx).getName() + ", score = " + maxPair.getB() + ", average score = " + selectedScores[i]);
        }
        if (fjp != null) {
            fjp.shutdown();
        }
        ArrayList<String> names = new ArrayList<String>();
        ArrayList<Double> scores = new ArrayList<Double>();
        for (int i2 = 0; i2 < max; ++i2) {
            names.add(fmap.get(selectedFeatures[i2]).getName());
            scores.add(selectedScores[i2]);
        }
        FeatureSetProvenance provenance = new FeatureSetProvenance(SelectedFeatureSet.class.getName(), (DatasetProvenance)dataset.getProvenance(), this.getProvenance());
        return new SelectedFeatureSet(names, scores, this.isOrdered(), provenance);
    }

    public FeatureSelectorProvenance getProvenance() {
        return new FeatureSelectorProvenanceImpl((FeatureSelector)this);
    }
}

