/*
 * Decompiled with CFR 0.152.
 */
package hex.kmeans;

import hex.ClusteringModel;
import hex.ClusteringModelBuilder;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsClustering;
import hex.ToEigenVec;
import hex.genmodel.GenModel;
import hex.kmeans.KMeans;
import hex.util.LinearAlgebraUtils;
import java.util.Arrays;
import water.DKV;
import water.Job;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.codegen.CodeGenerator;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.JCodeSB;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.JCodeGen;
import water.util.SBPrintStream;

public class KMeansModel
extends ClusteringModel<KMeansModel, KMeansParameters, KMeansOutput> {
    public ToEigenVec getToEigenVec() {
        return LinearAlgebraUtils.toEigen;
    }

    public KMeansModel(Key selfKey, KMeansParameters parms, KMeansOutput output) {
        super(selfKey, (ClusteringModel.ClusteringParameters)parms, (ClusteringModel.ClusteringOutput)output);
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        assert (domain == null);
        return new ModelMetricsClustering.MetricBuilderClustering(((KMeansOutput)this._output).nfeatures(), ((KMeansOutput)this._output)._k[((KMeansOutput)this._output)._k.length - 1]);
    }

    protected Frame predictScoreImpl(Frame orig, Frame adaptedFr, String destination_key, final Job j, boolean computeMetrics) {
        if (!((KMeansParameters)this._parms)._pred_indicator) {
            return super.predictScoreImpl(orig, adaptedFr, destination_key, j, computeMetrics);
        }
        final int len = ((KMeansOutput)this._output)._k[((KMeansOutput)this._output)._k.length - 1];
        String prefix = "cluster_";
        Frame adaptFrm = new Frame(adaptedFr);
        for (int c = 0; c < len; ++c) {
            adaptFrm.add(prefix + Double.toString(c + 1), adaptFrm.anyVec().makeZero());
        }
        new MRTask(){

            public void map(Chunk[] chks) {
                if (this.isCancelled() || j != null && j.stop_requested()) {
                    return;
                }
                double[] tmp = new double[((KMeansOutput)KMeansModel.this._output)._names.length];
                double[] preds = new double[len];
                for (int row = 0; row < chks[0]._len; ++row) {
                    Arrays.fill(preds, 0.0);
                    double[] p = KMeansModel.this.score_indicator(chks, row, tmp, preds);
                    for (int c = 0; c < preds.length; ++c) {
                        chks[((KMeansOutput)KMeansModel.this._output)._names.length + c].set(row, p[c]);
                    }
                }
                if (j != null) {
                    j.update(1L);
                }
            }
        }.doAll(adaptFrm);
        int x = ((KMeansOutput)this._output)._names.length;
        int y = adaptFrm.numCols();
        Frame f = adaptFrm.extractFrame(x, y);
        f = new Frame(Key.make((String)destination_key), f.names(), f.vecs());
        DKV.put((Keyed)f);
        this.makeMetricBuilder(null).makeModelMetrics((Model)this, orig, null, null);
        return f;
    }

    public double[] score_indicator(Chunk[] chks, int row_in_chunk, double[] tmp, double[] preds) {
        assert (((KMeansParameters)this._parms)._pred_indicator);
        assert (tmp.length == ((KMeansOutput)this._output)._names.length && preds.length == ((KMeansOutput)this._output)._centers_raw.length);
        for (int i = 0; i < tmp.length; ++i) {
            tmp[i] = chks[i].atd(row_in_chunk);
        }
        double[] clus = new double[1];
        this.score0(tmp, clus);
        assert (preds != null && ArrayUtils.l2norm2((double[])preds) == 0.0) : "preds must be a vector of all zeros, got " + Arrays.toString(preds);
        assert (clus[0] >= 0.0 && clus[0] < (double)preds.length) : "Cluster number must be an integer in [0," + String.valueOf(preds.length) + ")";
        preds[(int)clus[0]] = 1.0;
        return preds;
    }

    public double[] score_ratio(Chunk[] chks, int row_in_chunk, double[] tmp) {
        assert (((KMeansParameters)this._parms)._pred_indicator);
        assert (tmp.length == ((KMeansOutput)this._output)._names.length);
        for (int i = 0; i < tmp.length; ++i) {
            tmp[i] = chks[i].atd(row_in_chunk);
        }
        double[][] centers = ((KMeansParameters)this._parms)._standardize ? ((KMeansOutput)this._output)._centers_std_raw : ((KMeansOutput)this._output)._centers_raw;
        double[] preds = GenModel.KMeans_simplex((double[][])centers, (double[])tmp, (String[][])((KMeansOutput)this._output)._domains);
        assert (preds.length == ((KMeansOutput)this._output)._k[((KMeansOutput)this._output)._k.length - 1]);
        assert (Math.abs(ArrayUtils.sum((double[])preds) - 1.0) < 1.0E-6) : "Sum of k-means distance ratios should equal 1";
        return preds;
    }

    protected double[] score0(double[] data, double[] preds, double weight, double offset) {
        if (weight == 0.0) {
            return data;
        }
        assert (weight == 1.0);
        return this.score0(data, preds);
    }

    protected double[] score0(double[] data, double[] preds) {
        double[][] centers = ((KMeansParameters)this._parms)._standardize ? ((KMeansOutput)this._output)._centers_std_raw : ((KMeansOutput)this._output)._centers_raw;
        GenModel.Kmeans_preprocessData((double[])data, (double[])((KMeansOutput)this._output)._normSub, (double[])((KMeansOutput)this._output)._normMul, (int[])((KMeansOutput)this._output)._mode);
        preds[0] = GenModel.KMeans_closest((double[][])centers, (double[])data, (String[][])((KMeansOutput)this._output)._domains);
        return preds;
    }

    protected double data(Chunk[] chks, int row, int col) {
        return GenModel.Kmeans_preprocessData((double)chks[col].atd(row), (int)col, (double[])((KMeansOutput)this._output)._normSub, (double[])((KMeansOutput)this._output)._normMul, (int[])((KMeansOutput)this._output)._mode);
    }

    protected void toJavaPredictBody(SBPrintStream body, CodeGeneratorPipeline classCtx, CodeGeneratorPipeline fileCtx, boolean verboseCode) {
        final String mname = JCodeGen.toJavaId((String)this._key.toString());
        if (((KMeansParameters)this._parms)._standardize) {
            fileCtx.add((Object)new CodeGenerator(){

                public void generate(JCodeSB out) {
                    JCodeGen.toClassWithArray((JCodeSB)out, null, (String)(mname + "_MEANS"), (double[])((KMeansOutput)KMeansModel.this._output)._normSub, (String)"Column means of training data");
                    JCodeGen.toClassWithArray((JCodeSB)out, null, (String)(mname + "_MULTS"), (double[])((KMeansOutput)KMeansModel.this._output)._normMul, (String)"Reciprocal of column standard deviations of training data");
                    JCodeGen.toClassWithArray((JCodeSB)out, null, (String)(mname + "_MODES"), (int[])((KMeansOutput)KMeansModel.this._output)._mode, (String)"Mode for categorical columns");
                    JCodeGen.toClassWithArray((JCodeSB)out, null, (String)(mname + "_CENTERS"), (double[][])((KMeansOutput)KMeansModel.this._output)._centers_std_raw, (String)"Normalized cluster centers[K][features]");
                }
            });
            body.ip("Kmeans_preprocessData(data,").pj(mname + "_MEANS", "VALUES,").pj(mname + "_MULTS", "VALUES,").pj(mname + "_MODES", "VALUES").p(");").nl();
            body.ip("preds[0] = KMeans_closest(").pj(mname + "_CENTERS", "VALUES").p(", data, DOMAINS); ").nl();
        } else {
            fileCtx.add((Object)new CodeGenerator(){

                public void generate(JCodeSB out) {
                    JCodeGen.toClassWithArray((JCodeSB)out, null, (String)(mname + "_CENTERS"), (double[][])((KMeansOutput)KMeansModel.this._output)._centers_raw, (String)"Denormalized cluster centers[K][features]");
                }
            });
            body.ip("preds[0] = KMeans_closest(").pj(mname + "_CENTERS", "VALUES").p(",data, DOMAINS);").nl();
        }
    }

    protected boolean toJavaCheckTooBig() {
        return ((KMeansParameters)this._parms)._standardize ? (double)(((KMeansOutput)this._output)._centers_std_raw.length * ((KMeansOutput)this._output)._centers_std_raw[0].length) > 1000000.0 : (double)(((KMeansOutput)this._output)._centers_raw.length * ((KMeansOutput)this._output)._centers_raw[0].length) > 1000000.0;
    }

    public static class KMeansOutput
    extends ClusteringModel.ClusteringOutput {
        public int _iterations;
        public double[] _withinss;
        public double _tot_withinss;
        public double[] _history_withinss = new double[]{Double.NaN};
        public double _totss;
        public double _betweenss;
        public int _categorical_column_count;
        public long[] _training_time_ms = new long[]{System.currentTimeMillis()};
        public double[] _reassigned_count = new double[]{Double.NaN};
        public int[] _k = new int[]{0};

        public KMeansOutput(KMeans b) {
            super((ClusteringModelBuilder)b);
        }
    }

    public static class KMeansParameters
    extends ClusteringModel.ClusteringParameters {
        public int _max_iterations = 10;
        public boolean _standardize = true;
        public KMeans.Initialization _init = KMeans.Initialization.Furthest;
        public Key<Frame> _user_points;
        public boolean _pred_indicator = false;
        public boolean _estimate_k = false;

        public String algoName() {
            return "KMeans";
        }

        public String fullName() {
            return "K-means";
        }

        public String javaName() {
            return KMeansModel.class.getName();
        }

        public long progressUnits() {
            return this._estimate_k ? (long)this._k : (long)this._max_iterations;
        }
    }
}

