/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.fs;

import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.util.HashMap;
import org.tribuo.CategoricalIDInfo;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.RealIDInfo;
import org.tribuo.VariableIDInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.fs.FSMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.transform.Transformer;
import org.tribuo.transform.transformations.BinningTransformation;
import org.tribuo.util.infotheory.InformationTheory;
import org.tribuo.util.infotheory.impl.CachedPair;
import org.tribuo.util.infotheory.impl.CachedTriple;
import org.tribuo.util.infotheory.impl.PairDistribution;
import org.tribuo.util.infotheory.impl.TripleDistribution;

final class DenseFSMatrix
implements FSMatrix {
    private final int[] labels;
    private final int[][] features;
    private final ImmutableFeatureMap fmap;
    private final int numBins;
    private final int numLabels;

    private DenseFSMatrix(int[] labels, int[][] features, ImmutableFeatureMap fmap, int numBins, int numLabels) {
        this.labels = labels;
        this.features = features;
        this.fmap = fmap;
        this.numBins = numBins;
        this.numLabels = numLabels;
    }

    @Override
    public int getNumFeatures() {
        return this.features.length;
    }

    @Override
    public int getNumSamples() {
        return this.labels.length;
    }

    @Override
    public ImmutableFeatureMap getFeatureMap() {
        return this.fmap;
    }

    @Override
    public double mi(int featureIndex) {
        HashMap<CachedPair, MutableLong> map = new HashMap<CachedPair, MutableLong>();
        for (int i = 0; i < this.labels.length; ++i) {
            CachedPair p = new CachedPair((Object)this.features[featureIndex][i], (Object)this.labels[i]);
            MutableLong l = map.computeIfAbsent(p, k -> new MutableLong());
            l.increment();
        }
        return InformationTheory.mi((PairDistribution)PairDistribution.constructFromMap(map, (int)this.numBins, (int)this.numLabels));
    }

    @Override
    public double mi(int firstIndex, int secondIndex) {
        HashMap<CachedPair, MutableLong> map = new HashMap<CachedPair, MutableLong>();
        for (int i = 0; i < this.labels.length; ++i) {
            CachedPair p = new CachedPair((Object)this.features[firstIndex][i], (Object)this.features[secondIndex][i]);
            MutableLong l = map.computeIfAbsent(p, k -> new MutableLong());
            l.increment();
        }
        return InformationTheory.mi((PairDistribution)PairDistribution.constructFromMap(map, (int)this.numBins, (int)this.numBins));
    }

    @Override
    public double jmi(int featureIndex, int jointIndex) {
        HashMap<CachedTriple, MutableLong> map = new HashMap<CachedTriple, MutableLong>();
        for (int i = 0; i < this.labels.length; ++i) {
            CachedTriple p = new CachedTriple((Object)this.features[featureIndex][i], (Object)this.features[jointIndex][i], (Object)this.labels[i]);
            MutableLong l = map.computeIfAbsent(p, k -> new MutableLong());
            l.increment();
        }
        return InformationTheory.jointMI((TripleDistribution)TripleDistribution.constructFromMap(map));
    }

    @Override
    public double jmi(int firstIndex, int jointIndex, int targetIndex) {
        HashMap<CachedTriple, MutableLong> map = new HashMap<CachedTriple, MutableLong>();
        for (int i = 0; i < this.labels.length; ++i) {
            CachedTriple p = new CachedTriple((Object)this.features[firstIndex][i], (Object)this.features[jointIndex][i], (Object)this.features[targetIndex][i]);
            MutableLong l = map.computeIfAbsent(p, k -> new MutableLong());
            l.increment();
        }
        return InformationTheory.jointMI((TripleDistribution)TripleDistribution.constructFromMap(map));
    }

    @Override
    public double cmi(int featureIndex, int conditionIndex) {
        HashMap<CachedTriple, MutableLong> map = new HashMap<CachedTriple, MutableLong>();
        for (int i = 0; i < this.labels.length; ++i) {
            CachedTriple p = new CachedTriple((Object)this.features[featureIndex][i], (Object)this.labels[i], (Object)this.features[conditionIndex][i]);
            MutableLong l = map.computeIfAbsent(p, k -> new MutableLong());
            l.increment();
        }
        return InformationTheory.conditionalMI((TripleDistribution)TripleDistribution.constructFromMap(map));
    }

    @Override
    public double cmi(int firstIndex, int secondIndex, int conditionIndex) {
        HashMap<CachedTriple, MutableLong> map = new HashMap<CachedTriple, MutableLong>();
        for (int i = 0; i < this.labels.length; ++i) {
            CachedTriple p = new CachedTriple((Object)this.features[firstIndex][i], (Object)this.features[secondIndex][i], (Object)this.features[conditionIndex][i]);
            MutableLong l = map.computeIfAbsent(p, k -> new MutableLong());
            l.increment();
        }
        return InformationTheory.conditionalMI((TripleDistribution)TripleDistribution.constructFromMap(map));
    }

    static DenseFSMatrix equalWidthBins(Dataset<Label> dataset, int numBins) {
        int i;
        ImmutableFeatureMap fmap = dataset.getFeatureIDMap();
        ImmutableOutputInfo lmap = dataset.getOutputIDInfo();
        int numFeatures = fmap.size();
        int numExamples = dataset.size();
        int numLabels = dataset.getOutputInfo().size();
        int[][] features = new int[numFeatures][numExamples];
        int[] labels = new int[numExamples];
        Transformer[] transformers = new Transformer[numFeatures];
        for (i = 0; i < numFeatures; ++i) {
            VariableIDInfo info = fmap.get(i);
            transformers[i] = DenseFSMatrix.makeBinningTransformer(info, numExamples, numBins);
        }
        for (i = 0; i < numExamples; ++i) {
            Example ex = dataset.getExample(i);
            DenseVector vec = DenseVector.createDenseVector((Example)ex, (ImmutableFeatureMap)fmap, (boolean)false);
            for (int j = 0; j < numFeatures; ++j) {
                int bin;
                features[j][i] = bin = (int)transformers[j].transform(vec.get(j));
            }
            labels[i] = lmap.getID((Output)((Label)ex.getOutput()));
        }
        return new DenseFSMatrix(labels, features, fmap, numBins, numLabels);
    }

    private static Transformer makeBinningTransformer(VariableIDInfo info, int numExamples, int numBins) {
        int count = info.getCount();
        double min = Double.POSITIVE_INFINITY;
        double max = Double.NEGATIVE_INFINITY;
        if (info instanceof CategoricalIDInfo) {
            CategoricalIDInfo catInfo = (CategoricalIDInfo)info;
            double[] values = catInfo.getValues();
            for (int i = 0; i < values.length; ++i) {
                double cur = values[i];
                min = Math.min(min, cur);
                max = Math.max(max, cur);
            }
        } else if (info instanceof RealIDInfo) {
            RealIDInfo realInfo = (RealIDInfo)info;
            min = realInfo.getMin();
            max = realInfo.getMax();
        } else {
            throw new IllegalStateException("Unknown variable info subclass " + info.getClass());
        }
        if (numExamples != count) {
            min = Math.min(min, 0.0);
            max = Math.max(max, 0.0);
        }
        double range = Math.abs(max - min);
        double increment = range / (double)numBins;
        double[] bins = new double[numBins];
        double[] values = new double[numBins];
        for (int i = 0; i < bins.length; ++i) {
            bins[i] = min + (double)(i + 1) * increment;
            values[i] = i + 1;
        }
        return new BinningTransformation.BinningTransformer(BinningTransformation.BinningType.EQUAL_WIDTH, bins, values);
    }
}

