/*
 * Decompiled with CFR 0.152.
 */
package hivemall.smile.utils;

import hivemall.annotations.VisibleForTesting;
import hivemall.smile.classification.DecisionTree;
import hivemall.smile.utils.VariableOrder;
import hivemall.utils.collections.arrays.SparseIntArray;
import hivemall.utils.collections.lists.DoubleArrayList;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.math.MathUtils;
import hivemall.utils.random.PRNG;
import hivemall.utils.random.RandomNumberGeneratorFactory;
import java.util.Arrays;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import matrix4j.matrix.ColumnMajorMatrix;
import matrix4j.matrix.Matrix;
import matrix4j.matrix.MatrixUtils;
import matrix4j.vector.VectorProcedure;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.roaringbitmap.RoaringBitmap;
import smile.data.Attribute;
import smile.math.Math;
import smile.sort.QuickSort;

public final class SmileExtUtils {
    public static final byte NUMERIC = 1;
    public static final byte NOMINAL = 2;

    private SmileExtUtils() {
    }

    @Nonnull
    public static RoaringBitmap resolveAttributes(@Nullable String opt) throws UDFArgumentException {
        RoaringBitmap attr = new RoaringBitmap();
        if (opt == null) {
            return attr;
        }
        String[] opts = opt.split(",");
        int size = opts.length;
        for (int i = 0; i < size; ++i) {
            String type = opts[i];
            if ("Q".equals(type)) continue;
            if ("C".equals(type)) {
                attr.add(i);
                continue;
            }
            throw new UDFArgumentException("Unsupported attribute type: " + type);
        }
        return attr;
    }

    @Nonnull
    public static RoaringBitmap parseNominalAttributeIndicies(@Nullable String opt) throws UDFArgumentException {
        RoaringBitmap attr = new RoaringBitmap();
        if (opt == null) {
            return attr;
        }
        for (String s : opt.split(",")) {
            if (!NumberUtils.isDigits(s)) {
                throw new UDFArgumentException("Expected integer but got " + s);
            }
            int index = NumberUtils.parseInt(s);
            attr.add(index);
        }
        return attr;
    }

    @Nonnull
    @VisibleForTesting
    public static RoaringBitmap convertAttributeTypes(@Nonnull Attribute[] original) {
        int size = original.length;
        RoaringBitmap nominalAttrs = new RoaringBitmap();
        block4: for (int i = 0; i < size; ++i) {
            Attribute o = original[i];
            switch (o.type) {
                case NOMINAL: {
                    nominalAttrs.add(i);
                    continue block4;
                }
                case NUMERIC: {
                    continue block4;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported type: " + (Object)((Object)o.type));
                }
            }
        }
        return nominalAttrs;
    }

    @Nonnull
    public static VariableOrder sort(@Nonnull RoaringBitmap nominalAttrs, @Nonnull Matrix x, final @Nonnull int[] samples) {
        int n = x.numRows();
        int p = x.numColumns();
        SparseIntArray[] index = new SparseIntArray[p];
        if (x.isSparse()) {
            int initSize = n / 10;
            final DoubleArrayList dlist = new DoubleArrayList(initSize);
            final IntArrayList ilist = new IntArrayList(initSize);
            VectorProcedure proc = new VectorProcedure(){

                @Override
                public void apply(int i, double v) {
                    if (samples[i] == 0) {
                        return;
                    }
                    dlist.add(v);
                    ilist.add(i);
                }
            };
            ColumnMajorMatrix x2 = x.toColumnMajorMatrix();
            for (int j = 0; j < p; ++j) {
                if (nominalAttrs.contains(j)) continue;
                x2.eachNonNullInColumn(j, proc);
                if (ilist.isEmpty()) continue;
                int[] rowPtrs = ilist.toArray();
                QuickSort.sort(dlist.array(), rowPtrs, rowPtrs.length);
                index[j] = new SparseIntArray(rowPtrs);
                dlist.clear();
                ilist.clear();
            }
        } else {
            DoubleArrayList dlist = new DoubleArrayList(n);
            IntArrayList ilist = new IntArrayList(n);
            for (int j = 0; j < p; ++j) {
                if (nominalAttrs.contains(j)) continue;
                for (int i = 0; i < n; ++i) {
                    if (samples[i] == 0) continue;
                    double x_ij = x.get(i, j);
                    dlist.add(x_ij);
                    ilist.add(i);
                }
                if (ilist.isEmpty()) continue;
                int[] rowPtrs = ilist.toArray();
                QuickSort.sort(dlist.array(), rowPtrs, rowPtrs.length);
                index[j] = new SparseIntArray(rowPtrs);
                dlist.clear();
                ilist.clear();
            }
        }
        return new VariableOrder(index);
    }

    @Nonnull
    public static int[] classLabels(@Nonnull int[] y) throws HiveException {
        int[] labels = Math.unique(y);
        Arrays.sort(labels);
        if (labels.length < 2) {
            throw new HiveException("Only one class.");
        }
        for (int i = 0; i < labels.length; ++i) {
            if (labels[i] < 0) {
                throw new HiveException("Negative class label: " + labels[i]);
            }
            if (i <= 0 || labels[i] - labels[i - 1] <= 1) continue;
            throw new HiveException("Missing class: " + (labels[i - 1] + 1));
        }
        return labels;
    }

    @Nonnull
    public static DecisionTree.SplitRule resolveSplitRule(@Nullable String ruleName) {
        if ("gini".equalsIgnoreCase(ruleName)) {
            return DecisionTree.SplitRule.GINI;
        }
        if ("entropy".equalsIgnoreCase(ruleName)) {
            return DecisionTree.SplitRule.ENTROPY;
        }
        if ("classification_error".equalsIgnoreCase(ruleName)) {
            return DecisionTree.SplitRule.CLASSIFICATION_ERROR;
        }
        return DecisionTree.SplitRule.GINI;
    }

    public static int computeNumInputVars(float numVars, @Nonnull Matrix x) {
        int numInputVars;
        if (numVars <= 0.0f) {
            int dims = x.numColumns();
            numInputVars = (int)java.lang.Math.ceil(java.lang.Math.sqrt(dims));
        } else {
            numInputVars = numVars > 0.0f && numVars <= 1.0f ? (int)(numVars * (float)x.numColumns()) : (int)numVars;
        }
        return numInputVars;
    }

    public static long generateSeed() {
        return Thread.currentThread().getId() * System.nanoTime();
    }

    public static void shuffle(@Nonnull int[] x, @Nonnull PRNG rnd) {
        for (int i = x.length; i > 1; --i) {
            int j = rnd.nextInt(i);
            SmileExtUtils.swap(x, i - 1, j);
        }
    }

    @Nonnull
    public static Matrix shuffle(@Nonnull Matrix x, @Nonnull int[] y, long seed) {
        int numRows = x.numRows();
        if (numRows != y.length) {
            throw new IllegalArgumentException("x.length (" + numRows + ") != y.length (" + y.length + ')');
        }
        if (seed == -1L) {
            seed = SmileExtUtils.generateSeed();
        }
        PRNG rnd = RandomNumberGeneratorFactory.createPRNG(seed);
        if (x.swappable()) {
            for (int i = numRows; i > 1; --i) {
                int j = rnd.nextInt(i);
                int k = i - 1;
                x.swap(k, j);
                SmileExtUtils.swap(y, k, j);
            }
            return x;
        }
        int[] indices = MathUtils.permutation(numRows);
        for (int i = numRows; i > 1; --i) {
            int j = rnd.nextInt(i);
            int k = i - 1;
            SmileExtUtils.swap(indices, k, j);
            SmileExtUtils.swap(y, k, j);
        }
        return MatrixUtils.shuffle(x, indices);
    }

    @Nonnull
    public static Matrix shuffle(@Nonnull Matrix x, @Nonnull double[] y, @Nonnull long seed) {
        int numRows = x.numRows();
        if (numRows != y.length) {
            throw new IllegalArgumentException("x.length (" + numRows + ") != y.length (" + y.length + ')');
        }
        if (seed == -1L) {
            seed = SmileExtUtils.generateSeed();
        }
        PRNG rnd = RandomNumberGeneratorFactory.createPRNG(seed);
        if (x.swappable()) {
            for (int i = numRows; i > 1; --i) {
                int j = rnd.nextInt(i);
                int k = i - 1;
                x.swap(k, j);
                SmileExtUtils.swap(y, k, j);
            }
            return x;
        }
        int[] indices = MathUtils.permutation(numRows);
        for (int i = numRows; i > 1; --i) {
            int j = rnd.nextInt(i);
            int k = i - 1;
            SmileExtUtils.swap(indices, k, j);
            SmileExtUtils.swap(y, k, j);
        }
        return MatrixUtils.shuffle(x, indices);
    }

    private static void swap(int[] x, int i, int j) {
        int s = x[i];
        x[i] = x[j];
        x[j] = s;
    }

    private static void swap(double[] x, int i, int j) {
        double s = x[i];
        x[i] = x[j];
        x[j] = s;
    }

    public static boolean containsNumericType(@Nonnull Matrix x, RoaringBitmap attributes) {
        int numCategoricalCols;
        int numColumns = x.numColumns();
        return numColumns != (numCategoricalCols = attributes.getCardinality());
    }

    @Nonnull
    public static String resolveFeatureName(int index, @Nullable String[] names) {
        if (names == null) {
            return "feature#" + index;
        }
        if (index >= names.length) {
            return "feature#" + index;
        }
        return names[index];
    }

    @Nonnull
    public static String resolveName(int index, @Nullable String[] names) {
        if (names == null) {
            return String.valueOf(index);
        }
        if (index >= names.length) {
            return String.valueOf(index);
        }
        return names[index];
    }

    public static double[] getColorBrew(@Nonnegative int n) {
        Preconditions.checkArgument(n >= 1);
        double hue_step = 360.0 / (double)n;
        double[] colors = new double[n];
        for (int i = 0; i < n; ++i) {
            colors[i] = (double)i * hue_step / 360.0;
        }
        return colors;
    }
}

