/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.fs;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import org.tribuo.Dataset;
import org.tribuo.FeatureSelector;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.SelectedFeatureSet;
import org.tribuo.classification.Label;
import org.tribuo.classification.fs.FSMatrix;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.FeatureSelectorProvenance;
import org.tribuo.provenance.FeatureSetProvenance;
import org.tribuo.provenance.impl.FeatureSelectorProvenanceImpl;

public final class CMIM
implements FeatureSelector<Label> {
    private static final Logger logger = Logger.getLogger(CMIM.class.getName());
    @Config(mandatory=true, description="Number of bins to use when discretising continuous features.")
    private int numBins;
    @Config(description="Number of features to select, defaults to ranking all features.")
    private int k = -1;
    @Config(description="Number of computation threads to use.")
    private int numThreads = 1;

    private CMIM() {
    }

    public CMIM(int k, int numBins, int numThreads) {
        this.k = k;
        this.numBins = numBins;
        this.numThreads = numThreads;
        if (k != -1 && k < 1) {
            throw new IllegalArgumentException("k must be -1 to select all features, or a positive number, found " + k);
        }
        if (numBins < 2) {
            throw new IllegalArgumentException("numBins must be >= 2, found " + numBins);
        }
    }

    public void postConfig() {
        if (this.k != -1 && this.k < 1) {
            throw new PropertyException("", "k", "k must be -1 to select all features, or a positive number, found " + this.k);
        }
        if (this.numBins < 2) {
            throw new PropertyException("", "numBins", "numBins must be >= 2, found " + this.numBins);
        }
    }

    public boolean isOrdered() {
        return true;
    }

    public SelectedFeatureSet select(Dataset<Label> dataset) {
        int i;
        double[] miCache;
        FSMatrix data = FSMatrix.buildMatrix(dataset, this.numBins);
        ImmutableFeatureMap fmap = data.getFeatureMap();
        int max = this.k == -1 ? fmap.size() : Math.min(this.k, fmap.size());
        int numFeatures = fmap.size();
        boolean[] unselectedFeatures = new boolean[numFeatures];
        Arrays.fill(unselectedFeatures, true);
        int[] selectedFeatures = new int[max];
        double[] selectedScores = new double[max];
        int[] idxCache = new int[numFeatures];
        if (this.numThreads > 1) {
            ForkJoinPool fjp = new ForkJoinPool(this.numThreads);
            try {
                miCache = (double[])((ForkJoinTask)fjp.submit(() -> IntStream.range(0, numFeatures).parallel().mapToDouble(data::mi).toArray())).get();
            }
            catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            }
            fjp.shutdown();
        } else {
            miCache = IntStream.range(0, numFeatures).mapToDouble(data::mi).toArray();
        }
        int curIdx = -1;
        double curVal = -1.0;
        for (i = 0; i < numFeatures; ++i) {
            if (!(miCache[i] > curVal)) continue;
            curIdx = i;
            curVal = miCache[i];
        }
        selectedFeatures[0] = curIdx;
        selectedScores[0] = curVal;
        unselectedFeatures[curIdx] = false;
        logger.log(Level.INFO, "Itr 0: selected feature " + fmap.get(curIdx).getName() + ", score = " + selectedScores[0]);
        for (i = 1; i < max; ++i) {
            double curMaxVal = -1.0;
            int curMaxIdx = -1;
            for (int j = 0; j < numFeatures; ++j) {
                if (!unselectedFeatures[j]) continue;
                while (miCache[j] > curMaxVal && idxCache[j] < i) {
                    double newVal = data.cmi(j, selectedFeatures[idxCache[j]]);
                    if (newVal < miCache[j]) {
                        miCache[j] = newVal;
                    }
                    int n = j;
                    idxCache[n] = idxCache[n] + 1;
                }
                if (!(miCache[j] > curMaxVal)) continue;
                curMaxVal = miCache[j];
                curMaxIdx = j;
            }
            selectedFeatures[i] = curMaxIdx;
            selectedScores[i] = curMaxVal;
            unselectedFeatures[curMaxIdx] = false;
            logger.log(Level.INFO, "Itr " + i + ": selected feature " + fmap.get(curMaxIdx).getName() + ", score = " + curMaxVal);
        }
        ArrayList<String> names = new ArrayList<String>();
        ArrayList<Double> scores = new ArrayList<Double>();
        for (int i2 = 0; i2 < max; ++i2) {
            names.add(fmap.get(selectedFeatures[i2]).getName());
            scores.add(selectedScores[i2]);
        }
        FeatureSetProvenance provenance = new FeatureSetProvenance(SelectedFeatureSet.class.getName(), (DatasetProvenance)dataset.getProvenance(), this.getProvenance());
        return new SelectedFeatureSet(names, scores, this.isOrdered(), provenance);
    }

    public FeatureSelectorProvenance getProvenance() {
        return new FeatureSelectorProvenanceImpl((FeatureSelector)this);
    }
}

