/*
 * Decompiled with CFR 0.152.
 */
package hex.kmeans;

import hex.Model;
import hex.ModelBuilder;
import hex.kmeans.KMeansModel;
import hex.schemas.KMeansV2;
import hex.schemas.ModelBuilderSchema;
import java.util.ArrayList;
import java.util.Random;
import water.H2O;
import water.Job;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.RandomUtils;

public class KMeans
extends ModelBuilder<KMeansModel, KMeansModel.KMeansParameters, KMeansModel.KMeansOutput> {
    private int _ncats;
    private transient int _reinit_attempts;

    public Model.ModelCategory[] can_build() {
        return new Model.ModelCategory[]{Model.ModelCategory.Clustering};
    }

    public KMeans(KMeansModel.KMeansParameters parms) {
        super("K-means", (Model.Parameters)parms);
        this.init(false);
    }

    public ModelBuilderSchema schema() {
        return new KMeansV2();
    }

    public Job<KMeansModel> trainModel() {
        return this.start(new KMeansDriver(), ((KMeansModel.KMeansParameters)this._parms)._max_iters);
    }

    public void init(boolean expensive) {
        super.init(expensive);
        if (((KMeansModel.KMeansParameters)this._parms)._k < 1 || ((KMeansModel.KMeansParameters)this._parms)._k > 10000000) {
            this.error("_k", "k must be between 1 and 1e7");
        }
        if (((KMeansModel.KMeansParameters)this._parms)._max_iters < 1 || ((KMeansModel.KMeansParameters)this._parms)._max_iters > 1000000) {
            this.error("_max_iters", " max_iters must be between 1 and 1e6");
        }
        if (this._train == null) {
            return;
        }
        if (this._train.numRows() < (long)((KMeansModel.KMeansParameters)this._parms)._k) {
            this.error("_k", "Cannot make " + ((KMeansModel.KMeansParameters)this._parms)._k + " clusters out of " + this._train.numRows() + " rows.");
        }
        for (Vec v : this._train.vecs()) {
            if (!v.isEnum()) continue;
            ++this._ncats;
        }
        Vec[] vecs = this._train.vecs();
        int ncats = 0;
        int len = vecs.length;
        while (ncats != len) {
            while (ncats < len && vecs[ncats].isEnum()) {
                ++ncats;
            }
            while (len > 0 && !vecs[len - 1].isEnum()) {
                --len;
            }
            if (ncats >= len - 1) continue;
            this._train.swap(ncats, len - 1);
        }
        this._ncats = ncats;
    }

    private static double minSqr(double[][] clusters, double[] point, int ncats, ClusterDist cd) {
        return KMeans.closest((double[][])clusters, (double[])point, (int)ncats, (ClusterDist)cd, (int)clusters.length)._dist;
    }

    private static double minSqr(double[][] clusters, double[] point, int ncats, ClusterDist cd, int count) {
        return KMeans.closest((double[][])clusters, (double[])point, (int)ncats, (ClusterDist)cd, (int)count)._dist;
    }

    private static ClusterDist closest(double[][] clusters, double[] point, int ncats, ClusterDist cd) {
        return KMeans.closest(clusters, point, ncats, cd, clusters.length);
    }

    private static double distance(double[] cluster, double[] point, int ncats) {
        double d;
        int column;
        double sqr = 0.0;
        int pts = point.length;
        for (column = 0; column < ncats; ++column) {
            d = point[column];
            if (Double.isNaN(d)) {
                --pts;
                continue;
            }
            if (d == cluster[column]) continue;
            sqr += 1.0;
        }
        for (column = ncats; column < cluster.length; ++column) {
            d = point[column];
            if (Double.isNaN(d)) {
                --pts;
                continue;
            }
            double delta = d - cluster[column];
            sqr += delta * delta;
        }
        if (0 < pts && pts < point.length) {
            sqr *= (double)(point.length / pts);
        }
        return sqr;
    }

    private static ClusterDist closest(double[][] clusters, double[] point, int ncats, ClusterDist cd, int count) {
        int min = -1;
        double minSqr = Double.MAX_VALUE;
        for (int cluster = 0; cluster < count; ++cluster) {
            double sqr = KMeans.distance(clusters[cluster], point, ncats);
            if (!(sqr < minSqr)) continue;
            min = cluster;
            minSqr = sqr;
        }
        cd._cluster = min;
        cd._dist = minSqr;
        return cd;
    }

    static int closest(double[][] clusters, double[] point, int ncats) {
        int min = -1;
        double minSqr = Double.MAX_VALUE;
        for (int cluster = 0; cluster < clusters.length; ++cluster) {
            double sqr = KMeans.distance(clusters[cluster], point, ncats);
            if (!(sqr < minSqr)) continue;
            min = cluster;
            minSqr = sqr;
        }
        return min;
    }

    private double[][] recluster(double[][] points, Random rand) {
        double[][] res = new double[((KMeansModel.KMeansParameters)this._parms)._k][];
        res[0] = points[0];
        int count = 1;
        ClusterDist cd = new ClusterDist();
        switch (((KMeansModel.KMeansParameters)this._parms)._init) {
            case None: {
                break;
            }
            case PlusPlus: {
                block5: while (count < res.length) {
                    double sum = 0.0;
                    for (double[] point1 : points) {
                        sum += KMeans.minSqr(res, point1, this._ncats, cd, count);
                    }
                    for (double[] point : points) {
                        if (!(KMeans.minSqr(res, point, this._ncats, cd, count) >= rand.nextDouble() * sum)) continue;
                        res[count++] = point;
                        continue block5;
                    }
                }
                break;
            }
            case Furthest: {
                while (count < res.length) {
                    double max = 0.0;
                    int index = 0;
                    for (int i = 0; i < points.length; ++i) {
                        double sqr = KMeans.minSqr(res, points[i], this._ncats, cd, count);
                        if (!(sqr > max)) continue;
                        max = sqr;
                        index = i;
                    }
                    res[count++] = points[index];
                }
                break;
            }
            default: {
                throw H2O.fail();
            }
        }
        return res;
    }

    private void randomRow(Vec[] vecs, Random rand, double[] cluster, double[] means, double[] mults) {
        long row = Math.max(0L, (long)(rand.nextDouble() * (double)vecs[0].length()) - 1L);
        KMeans.data(cluster, vecs, row, means, mults);
    }

    private static boolean standardize(double sigma) {
        return sigma > 1.0E-6;
    }

    private static double[][] max_cats(double[][] clusters, long[][][] cats) {
        int K = cats.length;
        int ncats = cats[0].length;
        for (int clu = 0; clu < K; ++clu) {
            for (int col = 0; col < ncats; ++col) {
                clusters[clu][col] = ArrayUtils.maxIndex((long[])cats[clu][col]);
            }
        }
        return clusters;
    }

    private static double[][] destandardize(double[][] clusters, int ncats, double[] means, double[] mults) {
        int K = clusters.length;
        int N = clusters[0].length;
        double[][] value = new double[K][N];
        for (int clu = 0; clu < K; ++clu) {
            System.arraycopy(clusters[clu], 0, value[clu], 0, N);
            if (mults == null) continue;
            for (int col = ncats; col < N; ++col) {
                value[clu][col] = value[clu][col] / mults[col] + means[col];
            }
        }
        return value;
    }

    private static void data(double[] values, Vec[] vecs, long row, double[] means, double[] mults) {
        for (int i = 0; i < values.length; ++i) {
            double d = vecs[i].at(row);
            values[i] = KMeans.data(d, i, means, mults, vecs[i].cardinality());
        }
    }

    private static void data(double[] values, Chunk[] chks, int row, double[] means, double[] mults) {
        for (int i = 0; i < values.length; ++i) {
            double d = chks[i].at0(row);
            values[i] = KMeans.data(d, i, means, mults, chks[i].vec().cardinality());
        }
    }

    private static double data(double d, int i, double[] means, double[] mults, int cardinality) {
        if (cardinality == -1) {
            if (Double.isNaN(d)) {
                d = means[i];
            }
            if (mults != null) {
                d -= means[i];
                d *= mults[i];
            }
        } else if (Double.isNaN(d)) {
            d = Math.min(Math.round(means[i]), (long)(cardinality - 1));
        }
        return d;
    }

    private static final class ClusterDist {
        int _cluster;
        double _dist;

        private ClusterDist() {
        }
    }

    private static class Lloyds
    extends MRTask<Lloyds> {
        double[][] _clusters;
        double[] _means;
        double[] _mults;
        final int _ncats;
        final int _k;
        double[][] _cMeans;
        long[][][] _cats;
        double[] _cSqr;
        long[] _rows;
        long _worst_row;
        double _worst_err;

        Lloyds(double[][] clusters, double[] means, double[] mults, int ncats, int k) {
            this._clusters = clusters;
            this._means = means;
            this._mults = mults;
            this._ncats = ncats;
            this._k = k;
        }

        public void map(Chunk[] cs) {
            int N = cs.length;
            assert (this._clusters[0].length == N);
            this._cMeans = new double[this._k][N];
            this._cSqr = new double[this._k];
            this._rows = new long[this._k];
            this._cats = new long[this._k][this._ncats][];
            for (int clu = 0; clu < this._k; ++clu) {
                for (int col = 0; col < this._ncats; ++col) {
                    this._cats[clu][col] = new long[cs[col].vec().cardinality()];
                }
            }
            this._worst_err = 0.0;
            double[] values = new double[N];
            ClusterDist cd = new ClusterDist();
            for (int row = 0; row < cs[0]._len; ++row) {
                int col;
                KMeans.data(values, cs, row, this._means, this._mults);
                KMeans.closest(this._clusters, values, this._ncats, cd);
                int clu = cd._cluster;
                assert (clu != -1);
                int n = clu;
                this._cSqr[n] = this._cSqr[n] + cd._dist;
                for (col = 0; col < this._ncats; ++col) {
                    long[] lArray = this._cats[clu][col];
                    int n2 = (int)values[col];
                    lArray[n2] = lArray[n2] + 1L;
                }
                for (col = this._ncats; col < N; ++col) {
                    double[] dArray = this._cMeans[clu];
                    int n3 = col;
                    dArray[n3] = dArray[n3] + values[col];
                }
                int n4 = clu;
                this._rows[n4] = this._rows[n4] + 1L;
                if (!(cd._dist > this._worst_err)) continue;
                this._worst_err = cd._dist;
                this._worst_row = cs[0].start() + (long)row;
            }
            for (int clu = 0; clu < this._k; ++clu) {
                if (this._rows[clu] == 0L) continue;
                ArrayUtils.div((double[])this._cMeans[clu], (double)this._rows[clu]);
            }
            this._clusters = null;
            this._mults = null;
            this._means = null;
        }

        public void reduce(Lloyds mr) {
            for (int clu = 0; clu < this._k; ++clu) {
                long ra = this._rows[clu];
                long rb = mr._rows[clu];
                double[] ma = this._cMeans[clu];
                double[] mb = mr._cMeans[clu];
                for (int c = 0; c < ma.length; ++c) {
                    if (ra + rb <= 0L) continue;
                    ma[c] = (ma[c] * (double)ra + mb[c] * (double)rb) / (double)(ra + rb);
                }
            }
            ArrayUtils.add((long[][][])this._cats, (long[][][])mr._cats);
            ArrayUtils.add((double[])this._cSqr, (double[])mr._cSqr);
            ArrayUtils.add((long[])this._rows, (long[])mr._rows);
            if (this._worst_err < mr._worst_err) {
                this._worst_err = mr._worst_err;
                this._worst_row = mr._worst_row;
            }
        }
    }

    private static class Sampler
    extends MRTask<Sampler> {
        double[][] _clusters;
        double[] _means;
        double[] _mults;
        final int _ncats;
        final double _sqr;
        final double _probability;
        final long _seed;
        double[][] _sampled;

        Sampler(double[][] clusters, double[] means, double[] mults, int ncats, double sqr, double prob, long seed) {
            this._clusters = clusters;
            this._means = means;
            this._mults = mults;
            this._ncats = ncats;
            this._sqr = sqr;
            this._probability = prob;
            this._seed = seed;
        }

        public void map(Chunk[] cs) {
            double[] values = new double[cs.length];
            ArrayList<Object> list = new ArrayList<Object>();
            Random rand = RandomUtils.getRNG((long[])new long[]{this._seed + cs[0].start()});
            ClusterDist cd = new ClusterDist();
            for (int row = 0; row < cs[0]._len; ++row) {
                KMeans.data(values, cs, row, this._means, this._mults);
                double sqr = KMeans.minSqr(this._clusters, values, this._ncats, cd);
                if (!(this._probability * sqr > rand.nextDouble() * this._sqr)) continue;
                list.add(values.clone());
            }
            this._sampled = new double[list.size()][];
            list.toArray((T[])this._sampled);
            this._clusters = null;
            this._mults = null;
            this._means = null;
        }

        public void reduce(Sampler other) {
            this._sampled = ArrayUtils.append((double[][])this._sampled, (double[][])other._sampled);
        }
    }

    private static class SumSqr
    extends MRTask<SumSqr> {
        double[][] _clusters;
        double[] _means;
        double[] _mults;
        final int _ncats;
        double _sqr;

        SumSqr(double[][] clusters, double[] means, double[] mults, int ncats) {
            this._clusters = clusters;
            this._means = means;
            this._mults = mults;
            this._ncats = ncats;
        }

        public void map(Chunk[] cs) {
            double[] values = new double[cs.length];
            ClusterDist cd = new ClusterDist();
            for (int row = 0; row < cs[0]._len; ++row) {
                KMeans.data(values, cs, row, this._means, this._mults);
                this._sqr += KMeans.minSqr(this._clusters, values, this._ncats, cd);
            }
            this._mults = null;
            this._means = null;
            this._clusters = null;
        }

        public void reduce(SumSqr other) {
            this._sqr += other._sqr;
        }
    }

    private class KMeansDriver
    extends H2O.H2OCountedCompleter<KMeansDriver> {
        private KMeansDriver() {
        }

        protected void compute2() {
            KMeansModel model = null;
            try {
                double[][] clusters;
                ((KMeansModel.KMeansParameters)KMeans.this._parms).lock_frames((Job)KMeans.this);
                KMeans.this.init(true);
                model = new KMeansModel(KMeans.this.dest(), (KMeansModel.KMeansParameters)KMeans.this._parms, new KMeansModel.KMeansOutput(KMeans.this));
                model.delete_and_lock(KMeans.this._key);
                ((KMeansModel.KMeansOutput)model._output)._ncats = KMeans.this._ncats;
                Vec[] vecs = KMeans.this._train.vecs();
                int N = vecs.length;
                double[] means = new double[N];
                for (int i = 0; i < N; ++i) {
                    means[i] = vecs[i].mean();
                }
                double[] mults = null;
                if (((KMeansModel.KMeansParameters)KMeans.this._parms)._standardize) {
                    mults = new double[N];
                    for (int i = 0; i < N; ++i) {
                        double sigma = vecs[i].sigma();
                        mults[i] = KMeans.standardize(sigma) ? 1.0 / sigma : 1.0;
                    }
                }
                Random rand = RandomUtils.getRNG((long[])new long[]{((KMeansModel.KMeansParameters)KMeans.this._parms)._seed - 1L});
                if (((KMeansModel.KMeansParameters)KMeans.this._parms)._init == Initialization.None) {
                    ((KMeansModel.KMeansOutput)model._output)._clusters = new double[((KMeansModel.KMeansParameters)KMeans.this._parms)._k][KMeans.this._train.numCols()];
                    for (double[] cluster : clusters = ((KMeansModel.KMeansOutput)model._output)._clusters) {
                        KMeans.this.randomRow(vecs, rand, cluster, means, mults);
                    }
                } else {
                    clusters = new double[1][vecs.length];
                    KMeans.this.randomRow(vecs, rand, clusters[0], means, mults);
                    while (((KMeansModel.KMeansOutput)model._output)._iters < 5) {
                        SumSqr sqr = (SumSqr)new SumSqr(clusters, means, mults, KMeans.this._ncats).doAll(vecs);
                        Sampler sampler = (Sampler)new Sampler(clusters, means, mults, KMeans.this._ncats, sqr._sqr, ((KMeansModel.KMeansParameters)KMeans.this._parms)._k * 3, ((KMeansModel.KMeansParameters)KMeans.this._parms)._seed).doAll(vecs);
                        clusters = ArrayUtils.append((double[][])clusters, (double[][])sampler._sampled);
                        if (!KMeans.this.isRunning()) {
                            return;
                        }
                        ((KMeansModel.KMeansOutput)model._output)._clusters = KMeans.destandardize(clusters, KMeans.this._ncats, means, mults);
                        ((KMeansModel.KMeansOutput)model._output)._mse = sqr._sqr / (double)KMeans.this._train.numRows();
                        ++((KMeansModel.KMeansOutput)model._output)._iters;
                        model.update(KMeans.this._key);
                    }
                    clusters = KMeans.this.recluster(clusters, rand);
                }
                ((KMeansModel.KMeansOutput)model._output)._iters = 0;
                while (((KMeansModel.KMeansOutput)model._output)._iters < ((KMeansModel.KMeansParameters)KMeans.this._parms)._max_iters) {
                    block25: {
                        if (!KMeans.this.isRunning()) {
                            return;
                        }
                        Lloyds task = (Lloyds)new Lloyds(clusters, means, mults, KMeans.this._ncats, ((KMeansModel.KMeansParameters)KMeans.this._parms)._k).doAll(vecs);
                        KMeans.max_cats(task._cMeans, task._cats);
                        boolean badrow = false;
                        for (int clu = 0; clu < ((KMeansModel.KMeansParameters)KMeans.this._parms)._k; ++clu) {
                            if (task._rows[clu] != 0L) continue;
                            if (badrow) {
                                Log.warn((Object[])new Object[]{"KMeans: Re-running Lloyds to re-init another cluster"});
                                --((KMeansModel.KMeansOutput)model._output)._iters;
                                if (KMeans.this._reinit_attempts++ >= ((KMeansModel.KMeansParameters)KMeans.this._parms)._k) {
                                    KMeans.this._reinit_attempts = 0;
                                    break;
                                }
                                break block25;
                            }
                            long row = task._worst_row;
                            Log.warn((Object[])new Object[]{"KMeans: Re-initializing cluster " + clu + " to row " + row});
                            clusters[clu] = task._cMeans[clu];
                            KMeans.data(clusters[clu], vecs, row, means, mults);
                            task._rows[clu] = 1L;
                            badrow = true;
                        }
                        ((KMeansModel.KMeansOutput)model._output)._clusters = KMeans.destandardize(task._cMeans, KMeans.this._ncats, means, mults);
                        ((KMeansModel.KMeansOutput)model._output)._rows = task._rows;
                        ((KMeansModel.KMeansOutput)model._output)._mses = task._cSqr;
                        double ssq = 0.0;
                        for (int i = 0; i < ((KMeansModel.KMeansParameters)KMeans.this._parms)._k; ++i) {
                            ssq += ((KMeansModel.KMeansOutput)model._output)._mses[i];
                            int n = i;
                            ((KMeansModel.KMeansOutput)model._output)._mses[n] = ((KMeansModel.KMeansOutput)model._output)._mses[n] / (double)task._rows[i];
                        }
                        ((KMeansModel.KMeansOutput)model._output)._mse = ssq / (double)KMeans.this._train.numRows();
                        model.update(KMeans.this._key);
                        KMeans.this.update(1L);
                        double sum = 0.0;
                        for (int clu = 0; clu < ((KMeansModel.KMeansParameters)KMeans.this._parms)._k; ++clu) {
                            sum += KMeans.distance(clusters[clu], task._cMeans[clu], KMeans.this._ncats);
                        }
                        Log.info((Object[])new Object[]{"KMeans: Change in cluster centers=" + (sum /= (double)N)});
                        if (sum < 1.0E-6) {
                            break;
                        }
                        clusters = task._cMeans;
                        StringBuilder sb = new StringBuilder();
                        sb.append("KMeans: iter: ").append(((KMeansModel.KMeansOutput)model._output)._iters).append(", MSE=").append(((KMeansModel.KMeansOutput)model._output)._mse);
                        for (int i = 0; i < ((KMeansModel.KMeansParameters)KMeans.this._parms)._k; ++i) {
                            sb.append(", ").append(task._cSqr[i]).append("/").append(task._rows[i]);
                        }
                        Log.info((Object[])new Object[]{sb});
                    }
                    ++((KMeansModel.KMeansOutput)model._output)._iters;
                }
            }
            catch (Throwable t) {
                t.printStackTrace();
                KMeans.this.cancel2(t);
                throw t;
            }
            finally {
                if (model != null) {
                    model.unlock(KMeans.this._key);
                }
                ((KMeansModel.KMeansParameters)KMeans.this._parms).unlock_frames((Job)KMeans.this);
                KMeans.this.done();
            }
            this.tryComplete();
        }
    }

    public static enum Initialization {
        None,
        PlusPlus,
        Furthest;

    }
}

