/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel;

import hex.ModelCategory;
import hex.genmodel.CategoricalEncoding;
import hex.genmodel.IGenModel;
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.LinkedList;
import java.util.Random;
import water.genmodel.IGeneratedModel;

public abstract class GenModel
implements IGenModel,
IGeneratedModel,
Serializable {
    public final String[] _names;
    public final String[][] _domains;
    public final String _responseColumn;
    public String _offsetColumn;

    public GenModel(String[] names, String[][] domains, String responseColumn) {
        this._names = names;
        this._domains = domains;
        this._responseColumn = responseColumn;
    }

    @Deprecated
    public GenModel(String[] names, String[][] domains) {
        this(names, domains, null);
    }

    public boolean requiresOffset() {
        return false;
    }

    @Override
    public boolean isSupervised() {
        return false;
    }

    @Override
    public int nfeatures() {
        return this._names.length;
    }

    public int nCatFeatures() {
        int nCat = 0;
        String[][] domainValues = this.getDomainValues();
        for (int i = 0; i < this.nfeatures(); ++i) {
            if (domainValues[i] == null) continue;
            ++nCat;
        }
        return nCat;
    }

    @Override
    public String[] features() {
        return Arrays.copyOf(this._names, this.nfeatures());
    }

    @Override
    public int nclasses() {
        return 0;
    }

    @Override
    public abstract ModelCategory getModelCategory();

    public String[] getOutputNames() {
        String[] outputNames;
        ModelCategory category = this.getModelCategory();
        switch (category) {
            case AutoEncoder: {
                int index;
                LinkedList<String> onames = new LinkedList<String>();
                String[] cnames = this.getNames();
                int numCats = this.nCatFeatures();
                String[][] domainValues = this.getDomainValues();
                for (index = 0; index <= numCats - 1; ++index) {
                    String[] tdomains = domainValues[index];
                    int tdomainLen = tdomains.length - 1;
                    for (int index2 = 0; index2 <= tdomainLen; ++index2) {
                        onames.add("reconstr_" + cnames[index] + "." + tdomains[index2]);
                    }
                    onames.add("reconstr_" + cnames[index] + ".missing(NA)");
                }
                for (index = numCats; index < cnames.length; ++index) {
                    onames.add("reconstr_" + cnames[index]);
                }
                outputNames = onames.toArray(new String[0]);
                break;
            }
            case Binomial: 
            case Multinomial: 
            case Ordinal: {
                String[] responseDomainValues = this.getDomainValues(this.getResponseIdx());
                outputNames = new String[1 + responseDomainValues.length];
                outputNames[0] = "predict";
                System.arraycopy(responseDomainValues, 0, outputNames, 1, outputNames.length - 1);
                for (int i = 1; i < outputNames.length; ++i) {
                    try {
                        Integer.valueOf(outputNames[i]);
                        outputNames[i] = "p" + outputNames[i];
                        continue;
                    }
                    catch (Exception exception) {
                        // empty catch block
                    }
                }
                break;
            }
            case Clustering: {
                outputNames = new String[]{"cluster"};
                break;
            }
            case Regression: {
                outputNames = new String[]{"predict"};
                break;
            }
            default: {
                throw new UnsupportedOperationException("Getting output column names for model category '" + (Object)((Object)category) + "' is not supported.");
            }
        }
        return outputNames;
    }

    @Override
    public EnumSet<ModelCategory> getModelCategories() {
        return EnumSet.of(this.getModelCategory());
    }

    @Override
    public abstract String getUUID();

    @Override
    public int getNumCols() {
        return this.nfeatures();
    }

    @Override
    public String[] getNames() {
        return this._names;
    }

    public int getOrigNumCols() {
        String[] origNames = this.getOrigNames();
        if (origNames == null || origNames.length == 0) {
            return 0;
        }
        boolean hasResponse = false;
        if (this.isSupervised()) {
            String responseName = this.getResponseName();
            hasResponse = origNames[origNames.length - 1].equals(responseName);
        }
        return hasResponse ? origNames.length - 1 : origNames.length;
    }

    @Override
    public String[] getOrigNames() {
        return null;
    }

    @Override
    public String getResponseName() {
        int r = this.getResponseIdx();
        return r < this._names.length ? this._names[r] : this._responseColumn;
    }

    @Override
    public int getResponseIdx() {
        if (!this.isSupervised()) {
            throw new UnsupportedOperationException("Cannot provide response index for unsupervised models.");
        }
        return this._domains.length - 1;
    }

    @Override
    public String getOffsetName() {
        return this._offsetColumn;
    }

    @Override
    public int getNumClasses(int colIdx) {
        String[] domval = this.getDomainValues(colIdx);
        return domval != null ? domval.length : -1;
    }

    @Override
    public int getNumResponseClasses() {
        if (!this.isClassifier()) {
            throw new UnsupportedOperationException("Cannot provide number of response classes for non-classifiers.");
        }
        return this.nclasses();
    }

    @Override
    public CategoricalEncoding getCategoricalEncoding() {
        return CategoricalEncoding.AUTO;
    }

    @Override
    public boolean isClassifier() {
        ModelCategory cat = this.getModelCategory();
        return cat == ModelCategory.Binomial || cat == ModelCategory.Multinomial || cat == ModelCategory.Ordinal;
    }

    @Override
    public boolean isAutoEncoder() {
        return this.getModelCategory() == ModelCategory.AutoEncoder;
    }

    @Override
    public String[] getDomainValues(String name) {
        int colIdx = this.getColIdx(name);
        return colIdx != -1 ? this.getDomainValues(colIdx) : null;
    }

    @Override
    public String[] getDomainValues(int i) {
        return this.getDomainValues()[i];
    }

    @Override
    public String[][] getDomainValues() {
        return this._domains;
    }

    @Override
    public String[][] getOrigDomainValues() {
        return null;
    }

    @Override
    public double[] getOrigProjectionArray() {
        return null;
    }

    @Override
    public int getColIdx(String name) {
        String[] names = this.getNames();
        for (int i = 0; i < names.length; ++i) {
            if (!names[i].equals(name)) continue;
            return i;
        }
        return -1;
    }

    @Override
    public int mapEnum(int colIdx, String enumValue) {
        String[] domain = this.getDomainValues(colIdx);
        if (domain != null) {
            for (int i = 0; i < domain.length; ++i) {
                if (!enumValue.equals(domain[i])) continue;
                return i;
            }
        }
        return -1;
    }

    @Override
    public int getPredsSize() {
        return this.isClassifier() ? 1 + this.getNumResponseClasses() : 2;
    }

    public int getPredsSize(ModelCategory mc) {
        return mc == ModelCategory.DimReduction ? this.nclasses() : this.getPredsSize();
    }

    public static String createAuxKey(String k) {
        return k + ".aux";
    }

    public abstract double[] score0(double[] var1, double[] var2);

    public double[] score0(double[] row, double offset, double[] preds) {
        throw new UnsupportedOperationException("`offset` column is not supported");
    }

    public boolean calibrateClassProbabilities(double[] preds) {
        return false;
    }

    public static double[] correctProbabilities(double[] scored, double[] priorClassDist, double[] modelClassDist) {
        double probsum = 0.0;
        for (int c = 1; c < scored.length; ++c) {
            double original_fraction = priorClassDist[c - 1];
            double oversampled_fraction = modelClassDist[c - 1];
            assert (!Double.isNaN(scored[c])) : "Predicted NaN class probability";
            if (original_fraction != 0.0 && oversampled_fraction != 0.0) {
                int n = c;
                scored[n] = scored[n] * (original_fraction / oversampled_fraction);
            }
            probsum += scored[c];
        }
        if (probsum > 0.0) {
            int i = 1;
            while (i < scored.length) {
                int n = i++;
                scored[n] = scored[n] / probsum;
            }
        }
        return scored;
    }

    public static int getPrediction(double[] preds, double[] priorClassDist, double[] data, double threshold) {
        if (preds.length == 3) {
            return GenModel.getPredictionBinomial(preds, threshold);
        }
        return GenModel.getPredictionMultinomial(preds, priorClassDist, data);
    }

    public static int getPredictionBinomial(double[] preds, double threshold) {
        return preds[2] >= threshold ? 1 : 0;
    }

    public static int getPredictionMultinomial(double[] preds, double[] priorClassDist, double[] data) {
        ArrayList<Integer> ties = new ArrayList<Integer>();
        ties.add(0);
        int best = 1;
        int tieCnt = 0;
        for (int c = 2; c < preds.length; ++c) {
            if (preds[best] < preds[c]) {
                best = c;
                tieCnt = 0;
                continue;
            }
            if (preds[best] != preds[c]) continue;
            ++tieCnt;
            ties.add(c - 1);
        }
        if (tieCnt == 0) {
            return best - 1;
        }
        long hash = 0L;
        if (data != null) {
            for (double d : data) {
                hash ^= Double.doubleToRawLongBits(d) >> 6;
            }
        }
        if (priorClassDist != null) {
            assert (preds.length == priorClassDist.length + 1);
            double sum = 0.0;
            for (Integer i : ties) {
                sum += priorClassDist[i];
            }
            Random rng = new Random(hash);
            double tie = rng.nextDouble();
            double partialSum = 0.0;
            for (Integer i : ties) {
                if (!(tie <= (partialSum += priorClassDist[i] / sum))) continue;
                return i;
            }
        }
        double res = preds[best];
        int idx = (int)hash % (tieCnt + 1);
        for (best = 1; best < preds.length; ++best) {
            if (res != preds[best] || --idx >= 0) continue;
            return best - 1;
        }
        throw new RuntimeException("Should Not Reach Here");
    }

    public static boolean bitSetContains(byte[] bits, int nbits, int bitoff, double dnum) {
        assert (!Double.isNaN(dnum));
        int idx = (int)dnum;
        assert ((idx -= bitoff) >= 0 && idx < nbits) : "Must have " + bitoff + " <= idx <= " + (bitoff + nbits - 1) + ": " + idx;
        return (bits[idx >> 3] & 1 << (idx & 7)) != 0;
    }

    public static boolean bitSetIsInRange(int nbits, int bitoff, double dnum) {
        assert (!Double.isNaN(dnum));
        int idx = (int)dnum;
        return (idx -= bitoff) >= 0 && idx < nbits;
    }

    public static void Kmeans_preprocessData(double[] data, double[] means, double[] mults, int[] modes) {
        for (int i = 0; i < data.length; ++i) {
            data[i] = GenModel.Kmeans_preprocessData(data[i], i, means, mults, modes);
        }
    }

    public static double Kmeans_preprocessData(double d, int i, double[] means, double[] mults, int[] modes) {
        if (modes[i] == -1) {
            if (Double.isNaN(d)) {
                d = means[i];
            }
            if (mults != null) {
                d -= means[i];
                d *= mults[i];
            }
        } else if (Double.isNaN(d)) {
            d = modes[i];
        }
        return d;
    }

    public static int KMeans_closest(double[][] centers, double[] point, String[][] domains) {
        int min = -1;
        double minSqr = Double.MAX_VALUE;
        for (int cluster = 0; cluster < centers.length; ++cluster) {
            double sqr = GenModel.KMeans_distance(centers[cluster], point, domains);
            if (!(sqr < minSqr)) continue;
            min = cluster;
            minSqr = sqr;
        }
        return min;
    }

    public static int KMeans_distances(double[][] centers, double[] point, String[][] domains, double[] distances) {
        int min = -1;
        double minSqr = Double.MAX_VALUE;
        for (int cluster = 0; cluster < centers.length; ++cluster) {
            distances[cluster] = GenModel.KMeans_distance(centers[cluster], point, domains);
            if (!(distances[cluster] < minSqr)) continue;
            min = cluster;
            minSqr = distances[cluster];
        }
        return min;
    }

    public static double[] KMeans_simplex(double[][] centers, double[] point, String[][] domains) {
        double[] dist = new double[centers.length];
        double sum = 0.0;
        double inv_sum = 0.0;
        for (int cluster = 0; cluster < centers.length; ++cluster) {
            dist[cluster] = GenModel.KMeans_distance(centers[cluster], point, domains);
            sum += dist[cluster];
            inv_sum += 1.0 / dist[cluster];
        }
        double[] ratios = new double[centers.length];
        if (sum == 0.0) {
            Random rng = new Random();
            int idx = rng.nextInt(centers.length);
            ratios[idx] = 1.0;
        } else {
            int cluster;
            int idx = -1;
            for (cluster = 0; cluster < centers.length; ++cluster) {
                if (dist[cluster] != 0.0) continue;
                idx = cluster;
                break;
            }
            if (idx == -1) {
                for (cluster = 0; cluster < centers.length; ++cluster) {
                    ratios[cluster] = 1.0 / (dist[cluster] * inv_sum);
                }
            } else {
                ratios[idx] = 1.0;
            }
        }
        return ratios;
    }

    public static double KMeans_distance(double[] center, float[] point, int[] modes, double[] colSum, double[] colSumSq) {
        double sqr = 0.0;
        int pts = point.length;
        for (int column = 0; column < center.length; ++column) {
            float d = point[column];
            if (Float.isNaN(d)) {
                --pts;
                continue;
            }
            if (modes[column] != -1) {
                if ((double)d != center[column]) {
                    sqr += 1.0;
                }
                if (d == (float)modes[column]) continue;
                int n = column;
                colSum[n] = colSum[n] + 1.0;
                continue;
            }
            double delta = (double)d - center[column];
            sqr += delta * delta;
            int n = column;
            colSum[n] = colSum[n] + (double)d;
            int n2 = column;
            colSumSq[n2] = colSumSq[n2] + (double)(d * d);
        }
        if (0 < pts && pts < point.length) {
            double scale = (double)point.length / (double)pts;
            sqr *= scale;
        }
        return sqr;
    }

    public static double KMeans_distance(double[] center, double[] point, String[][] domains) {
        double sqr = 0.0;
        int pts = point.length;
        for (int column = 0; column < center.length; ++column) {
            double d = point[column];
            if (Double.isNaN(d)) {
                --pts;
                continue;
            }
            if (domains[column] != null) {
                if (d == center[column]) continue;
                sqr += 1.0;
                continue;
            }
            double delta = d - center[column];
            sqr += delta * delta;
        }
        if (0 < pts && pts < point.length) {
            sqr *= (double)point.length / (double)pts;
        }
        return sqr;
    }

    public static double log_rescale(double[] preds) {
        double maxval = Double.NEGATIVE_INFINITY;
        for (int k = 1; k < preds.length; ++k) {
            maxval = Math.max(maxval, preds[k]);
        }
        assert (!Double.isInfinite(maxval)) : "Something is wrong with GBM trees since returned prediction is " + Arrays.toString(preds);
        double dsum = 0.0;
        for (int k = 1; k < preds.length; ++k) {
            preds[k] = Math.exp(preds[k] - maxval);
            dsum += preds[k];
        }
        return dsum;
    }

    public static void GBM_rescale(double[] preds) {
        double sum = GenModel.log_rescale(preds);
        int k = 1;
        while (k < preds.length) {
            int n = k++;
            preds[n] = preds[n] / sum;
        }
    }

    public static double GLM_identityInv(double x) {
        return x;
    }

    public static double GLM_logitInv(double x) {
        return 1.0 / (Math.exp(-x) + 1.0);
    }

    public static double GLM_logInv(double x) {
        return Math.exp(x);
    }

    public static double GLM_inverseInv(double x) {
        double xx = x < 0.0 ? Math.min(-1.0E-5, x) : Math.max(1.0E-5, x);
        return 1.0 / xx;
    }

    public static double GLM_ologitInv(double x) {
        return GenModel.GLM_logitInv(x);
    }

    public static double GLM_tweedieInv(double x, double tweedie_link_power) {
        return tweedie_link_power == 0.0 ? Math.max(2.0E-16, Math.exp(x)) : Math.pow(x, 1.0 / tweedie_link_power);
    }

    public String getHeader() {
        return null;
    }

    public static void setInput(double[] from, float[] to, int _nums, int _cats, int[] _catOffsets, double[] _normMul, double[] _normSub, boolean useAllFactorLevels, boolean replaceMissingWithZero) {
        int i;
        double[] nums = new double[_nums];
        int[] cats = new int[_cats];
        GenModel.setCats(from, nums, cats, _cats, _catOffsets, _normMul, _normSub, useAllFactorLevels);
        assert (to.length == _nums + _catOffsets[_cats]);
        Arrays.fill(to, 0.0f);
        for (i = 0; i < _cats; ++i) {
            if (cats[i] < 0) continue;
            to[cats[i]] = 1.0f;
        }
        for (i = 0; i < _nums; ++i) {
            to[_catOffsets[_cats] + i] = Double.isNaN(nums[i]) ? (replaceMissingWithZero ? 0.0f : Float.NaN) : (float)nums[i];
        }
    }

    public static void setInput(double[] from, double[] to, double[] nums, int[] cats, int _nums, int _cats, int[] _catOffsets, double[] _normMul, double[] _normSub, boolean useAllFactorLevels, boolean replaceMissingWithZero) {
        int i;
        GenModel.setCats(from, nums, cats, _cats, _catOffsets, _normMul, _normSub, useAllFactorLevels);
        assert (to.length == _nums + _catOffsets[_cats]);
        Arrays.fill(to, 0.0);
        for (i = 0; i < _cats; ++i) {
            if (cats[i] < 0) continue;
            to[cats[i]] = 1.0;
        }
        for (i = 0; i < _nums; ++i) {
            to[_catOffsets[_cats] + i] = Double.isNaN(nums[i]) ? (replaceMissingWithZero ? 0.0 : Double.NaN) : nums[i];
        }
    }

    public static void setCats(double[] from, double[] nums, int[] cats, int _cats, int[] _catOffsets, double[] _normMul, double[] _normSub, boolean useAllFactorLevels) {
        GenModel.setCats(from, cats, _cats, _catOffsets, useAllFactorLevels);
        for (int i = _cats; i < from.length; ++i) {
            double d = from[i];
            if (_normMul != null && _normMul.length > 0) {
                d = (d - _normSub[i - _cats]) * _normMul[i - _cats];
            }
            nums[i - _cats] = d;
        }
    }

    public static void setCats(double[] from, int[] to, int cats, int[] catOffsets, boolean useAllFactorLevels) {
        for (int i = 0; i < cats; ++i) {
            if (Double.isNaN(from[i])) {
                to[i] = catOffsets[i + 1] - 1;
                continue;
            }
            int c = (int)from[i];
            to[i] = useAllFactorLevels ? c + catOffsets[i] : (c != 0 ? c - 1 + catOffsets[i] : -1);
            if (to[i] < catOffsets[i + 1]) continue;
            to[i] = catOffsets[i + 1] - 1;
        }
    }

    public static float[] convertDouble2Float(double[] input) {
        int arraySize = input.length;
        float[] output = new float[arraySize];
        for (int index = 0; index < arraySize; ++index) {
            output[index] = (float)input[index];
        }
        return output;
    }

    public static void img2pixels(BufferedImage img, int w, int h, int channels, float[] pixels, int start, float[] mean) throws IOException {
        BufferedImage scaledImg = new BufferedImage(w, h, img.getType());
        Graphics2D g2d = scaledImg.createGraphics();
        g2d.drawImage(img, 0, 0, w, h, null);
        g2d.dispose();
        int r_idx = start;
        int g_idx = r_idx + w * h;
        int b_idx = g_idx + w * h;
        for (int i = 0; i < h; ++i) {
            for (int j = 0; j < w; ++j) {
                Color mycolor = new Color(scaledImg.getRGB(j, i));
                int red = mycolor.getRed();
                int green = mycolor.getGreen();
                int blue = mycolor.getBlue();
                if (channels == 1) {
                    pixels[r_idx] = (red + green + blue) / 3;
                    if (mean != null) {
                        int n = r_idx;
                        pixels[n] = pixels[n] - mean[r_idx];
                    }
                } else {
                    pixels[r_idx] = red;
                    pixels[g_idx] = green;
                    pixels[b_idx] = blue;
                    if (mean != null) {
                        int n = r_idx;
                        pixels[n] = pixels[n] - mean[r_idx - start];
                        int n2 = g_idx;
                        pixels[n2] = pixels[n2] - mean[g_idx - start];
                        int n3 = b_idx;
                        pixels[n3] = pixels[n3] - mean[b_idx - start];
                    }
                }
                ++r_idx;
                ++g_idx;
                ++b_idx;
            }
        }
    }
}

