/*
 * Decompiled with CFR 0.152.
 */
package hex.kmeans;

import hex.ClusteringModel;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsClustering;
import hex.ToEigenVec;
import hex.genmodel.GenModel;
import hex.genmodel.IClusteringModel;
import hex.kmeans.KMeans;
import hex.kmeans.KMeansMojoWriter;
import hex.util.EffectiveParametersUtils;
import hex.util.LinearAlgebraUtils;
import java.util.Arrays;
import water.DKV;
import water.Job;
import water.Key;
import water.MRTask;
import water.codegen.CodeGenerator;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.JCodeSB;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.udf.CFuncRef;
import water.util.ArrayUtils;
import water.util.JCodeGen;
import water.util.SBPrintStream;

public class KMeansModel
extends ClusteringModel<KMeansModel, KMeansParameters, KMeansOutput> {
    @Override
    public ToEigenVec getToEigenVec() {
        return LinearAlgebraUtils.toEigen;
    }

    public KMeansModel(Key selfKey, KMeansParameters parms, KMeansOutput output) {
        super(selfKey, parms, output);
    }

    @Override
    public void initActualParamValues() {
        super.initActualParamValues();
        EffectiveParametersUtils.initFoldAssignment(this._parms);
        EffectiveParametersUtils.initCategoricalEncoding(this._parms, Model.Parameters.CategoricalEncodingScheme.Enum);
    }

    @Override
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        assert (domain == null);
        return new ModelMetricsClustering.MetricBuilderClustering(((KMeansOutput)this._output).nfeatures(), ((KMeansOutput)this._output)._k[((KMeansOutput)this._output)._k.length - 1]);
    }

    @Override
    protected Model.PredictScoreResult predictScoreImpl(Frame orig, Frame adaptedFr, String destination_key, final Job j2, boolean computeMetrics, CFuncRef customMetricFunc) {
        if (!((KMeansParameters)this._parms)._pred_indicator) {
            return super.predictScoreImpl(orig, adaptedFr, destination_key, j2, computeMetrics, customMetricFunc);
        }
        final int len = ((KMeansOutput)this._output)._k[((KMeansOutput)this._output)._k.length - 1];
        String prefix = "cluster_";
        Frame adaptFrm = new Frame(adaptedFr);
        for (int c2 = 0; c2 < len; ++c2) {
            adaptFrm.add(prefix + Double.toString(c2 + 1), adaptFrm.anyVec().makeZero());
        }
        new MRTask(){

            @Override
            public void map(Chunk[] chks) {
                if (this.isCancelled() || j2 != null && j2.stop_requested()) {
                    return;
                }
                double[] tmp = new double[((KMeansOutput)KMeansModel.this._output)._names.length];
                double[] preds = new double[len];
                for (int row = 0; row < chks[0]._len; ++row) {
                    Arrays.fill(preds, 0.0);
                    double[] p2 = KMeansModel.this.score_indicator(chks, row, tmp, preds);
                    for (int c2 = 0; c2 < preds.length; ++c2) {
                        chks[((KMeansOutput)KMeansModel.this._output)._names.length + c2].set(row, p2[c2]);
                    }
                }
                if (j2 != null) {
                    j2.update(1L);
                }
            }
        }.doAll(adaptFrm);
        int x2 = ((KMeansOutput)this._output)._names.length;
        int y2 = adaptFrm.numCols();
        Frame f2 = adaptFrm.extractFrame(x2, y2);
        f2 = new Frame(Key.make(destination_key), f2.names(), f2.vecs());
        DKV.put(f2);
        ModelMetrics.MetricBuilder mb = this.makeMetricBuilder(null);
        return new Model.PredictScoreResult(this, mb, f2, f2);
    }

    public double[] score_indicator(Chunk[] chks, int row_in_chunk, double[] tmp, double[] preds) {
        assert (((KMeansParameters)this._parms)._pred_indicator);
        assert (tmp.length == ((KMeansOutput)this._output)._names.length && preds.length == ((KMeansOutput)this._output)._centers_raw.length);
        for (int i2 = 0; i2 < tmp.length; ++i2) {
            tmp[i2] = chks[i2].atd(row_in_chunk);
        }
        double[] clus = new double[1];
        this.score0(tmp, clus);
        assert (preds != null && ArrayUtils.l2norm2(preds) == 0.0) : "preds must be a vector of all zeros, got " + Arrays.toString(preds);
        assert (clus[0] >= 0.0 && clus[0] < (double)preds.length) : "Cluster number must be an integer in [0," + String.valueOf(preds.length) + ")";
        preds[(int)clus[0]] = 1.0;
        return preds;
    }

    public double[] score_ratio(Chunk[] chks, int row_in_chunk, double[] tmp) {
        assert (((KMeansParameters)this._parms)._pred_indicator);
        assert (tmp.length == ((KMeansOutput)this._output)._names.length);
        for (int i2 = 0; i2 < tmp.length; ++i2) {
            tmp[i2] = chks[i2].atd(row_in_chunk);
        }
        double[][] centers = ((KMeansParameters)this._parms)._standardize ? ((KMeansOutput)this._output)._centers_std_raw : ((KMeansOutput)this._output)._centers_raw;
        double[] preds = GenModel.KMeans_simplex(centers, tmp, ((KMeansOutput)this._output)._domains);
        assert (preds.length == ((KMeansOutput)this._output)._k[((KMeansOutput)this._output)._k.length - 1]);
        assert (Math.abs(ArrayUtils.sum(preds) - 1.0) < 1.0E-6) : "Sum of k-means distance ratios should equal 1";
        return preds;
    }

    @Override
    protected double[] score0(double[] data, double[] preds, double offset) {
        return this.score0(data, preds);
    }

    @Override
    protected double[] score0(double[] data, double[] preds) {
        double[][] centers = ((KMeansParameters)this._parms)._standardize ? ((KMeansOutput)this._output)._centers_std_raw : ((KMeansOutput)this._output)._centers_raw;
        GenModel.Kmeans_preprocessData(data, ((KMeansOutput)this._output)._normSub, ((KMeansOutput)this._output)._normMul, ((KMeansOutput)this._output)._mode);
        preds[0] = GenModel.KMeans_closest(centers, data, ((KMeansOutput)this._output)._domains);
        return preds;
    }

    @Override
    protected double data(Chunk[] chks, int row, int col) {
        return GenModel.Kmeans_preprocessData(chks[col].atd(row), col, ((KMeansOutput)this._output)._normSub, ((KMeansOutput)this._output)._normMul, ((KMeansOutput)this._output)._mode);
    }

    @Override
    protected Class<?>[] getPojoInterfaces() {
        return new Class[]{IClusteringModel.class};
    }

    @Override
    protected void toJavaPredictBody(SBPrintStream body, CodeGeneratorPipeline classCtx, CodeGeneratorPipeline fileCtx, boolean verboseCode) {
        final String mname = JCodeGen.toJavaId(this._key.toString());
        if (((KMeansParameters)this._parms)._standardize) {
            fileCtx.add(new CodeGenerator(){

                @Override
                public void generate(JCodeSB out) {
                    JCodeGen.toClassWithArray(out, null, mname + "_MEANS", ((KMeansOutput)KMeansModel.this._output)._normSub, "Column means of training data");
                    JCodeGen.toClassWithArray(out, null, mname + "_MULTS", ((KMeansOutput)KMeansModel.this._output)._normMul, "Reciprocal of column standard deviations of training data");
                    JCodeGen.toClassWithArray(out, null, mname + "_MODES", ((KMeansOutput)KMeansModel.this._output)._mode, "Mode for categorical columns");
                    JCodeGen.toClassWithArray(out, null, mname + "_CENTERS", ((KMeansOutput)KMeansModel.this._output)._centers_std_raw, "Normalized cluster centers[K][features]");
                }
            });
            body.ip("Kmeans_preprocessData(data,").pj(mname + "_MEANS", "VALUES,").pj(mname + "_MULTS", "VALUES,").pj(mname + "_MODES", "VALUES").p(");").nl();
            body.ip("preds[0] = KMeans_closest(").pj(mname + "_CENTERS", "VALUES").p(", data, DOMAINS); ").nl();
        } else {
            fileCtx.add(new CodeGenerator(){

                @Override
                public void generate(JCodeSB out) {
                    JCodeGen.toClassWithArray(out, null, mname + "_CENTERS", ((KMeansOutput)KMeansModel.this._output)._centers_raw, "Denormalized cluster centers[K][features]");
                }
            });
            body.ip("preds[0] = KMeans_closest(").pj(mname + "_CENTERS", "VALUES").p(",data, DOMAINS);").nl();
        }
    }

    @Override
    protected SBPrintStream toJavaTransform(SBPrintStream ccsb, CodeGeneratorPipeline fileCtx, boolean verboseCode) {
        ccsb.nl();
        ccsb.ip("// Pass in data in a double[], in a same way as to the score0 function.").nl();
        ccsb.ip("// Cluster distances will be stored into the distances[] array. Function").nl();
        ccsb.ip("// will return the closest cluster. This way the caller can avoid to call").nl();
        ccsb.ip("// score0(..) to retrieve the cluster where the data point belongs.").nl();
        ccsb.ip("public final int distances( double[] data, double[] distances ) {").nl();
        this.toJavaDistancesBody(ccsb.ii(1));
        ccsb.ip("return cluster;").nl();
        ccsb.di(1).ip("}").nl();
        ccsb.nl();
        ccsb.ip("// Returns number of cluster used by this model.").nl();
        ccsb.ip("public final int getNumClusters() {").nl();
        this.toJavaGetNumClustersBody(ccsb.ii(1));
        ccsb.ip("return nclusters;").nl();
        ccsb.di(1).ip("}").nl();
        CodeGeneratorPipeline classCtx = new CodeGeneratorPipeline();
        classCtx.generate(ccsb.ii(1));
        ccsb.di(1);
        return ccsb;
    }

    private void toJavaDistancesBody(SBPrintStream body) {
        String mname = JCodeGen.toJavaId(this._key.toString());
        if (((KMeansParameters)this._parms)._standardize) {
            body.ip("Kmeans_preprocessData(data,").pj(mname + "_MEANS", "VALUES,").pj(mname + "_MULTS", "VALUES,").pj(mname + "_MODES", "VALUES").p(");").nl();
            body.ip("int cluster = KMeans_distances(").pj(mname + "_CENTERS", "VALUES").p(", data, DOMAINS, distances); ").nl();
        } else {
            body.ip("int cluster = KMeans_distances(").pj(mname + "_CENTERS", "VALUES").p(",data, DOMAINS, distances);").nl();
        }
    }

    private void toJavaGetNumClustersBody(SBPrintStream body) {
        String mname = JCodeGen.toJavaId(this._key.toString());
        body.ip("int nclusters = ").pj(mname + "_CENTERS", "VALUES").p(".length;").nl();
    }

    @Override
    protected boolean toJavaCheckTooBig() {
        return ((KMeansParameters)this._parms)._standardize ? (double)(((KMeansOutput)this._output)._centers_std_raw.length * ((KMeansOutput)this._output)._centers_std_raw[0].length) > 1000000.0 : (double)(((KMeansOutput)this._output)._centers_raw.length * ((KMeansOutput)this._output)._centers_raw[0].length) > 1000000.0;
    }

    @Override
    public KMeansMojoWriter getMojo() {
        return new KMeansMojoWriter(this);
    }

    public static class KMeansOutput
    extends ClusteringModel.ClusteringOutput {
        public int _iterations;
        public double[] _withinss;
        public double _tot_withinss;
        public double[] _history_withinss = new double[]{Double.NaN};
        public double _totss;
        public double _betweenss;
        public int _categorical_column_count;
        public long[] _training_time_ms = new long[]{System.currentTimeMillis()};
        public double[] _reassigned_count = new double[]{Double.NaN};
        public int[] _k = new int[]{0};

        public KMeansOutput(KMeans b2) {
            super(b2);
        }
    }

    public static class KMeansParameters
    extends ClusteringModel.ClusteringParameters {
        public int _max_iterations = 10;
        public boolean _standardize = true;
        public KMeans.Initialization _init = KMeans.Initialization.Furthest;
        public Key<Frame> _user_points;
        public boolean _pred_indicator = false;
        public boolean _estimate_k = false;
        public int[] _cluster_size_constraints = null;

        @Override
        public String algoName() {
            return "KMeans";
        }

        @Override
        public String fullName() {
            return "K-means";
        }

        @Override
        public String javaName() {
            return KMeansModel.class.getName();
        }

        @Override
        public long progressUnits() {
            return this._estimate_k ? (long)this._k : (long)this._max_iterations;
        }
    }
}

