/*
 * Decompiled with CFR 0.152.
 */
package hex.aggregator;

import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ToEigenVec;
import hex.aggregator.AggregatorModel;
import hex.util.LinearAlgebraUtils;
import java.util.Arrays;
import water.DKV;
import water.Iced;
import water.IcedUtils;
import water.Job;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;

public class Aggregator
extends ModelBuilder<AggregatorModel, AggregatorModel.AggregatorParameters, AggregatorModel.AggregatorOutput> {
    public ToEigenVec getToEigenVec() {
        return LinearAlgebraUtils.toEigen;
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    protected AggregatorDriver trainModelImpl() {
        return new AggregatorDriver();
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Clustering};
    }

    public Aggregator(AggregatorModel.AggregatorParameters parms) {
        super((Model.Parameters)parms);
        this.init(false);
    }

    public Aggregator(boolean startup_once) {
        super((Model.Parameters)new AggregatorModel.AggregatorParameters(), startup_once);
    }

    public void init(boolean expensive) {
        if (expensive && ((AggregatorModel.AggregatorParameters)this._parms)._categorical_encoding == Model.Parameters.CategoricalEncodingScheme.AUTO) {
            ((AggregatorModel.AggregatorParameters)this._parms)._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.Eigen;
        }
        super.init(expensive);
        if (expensive) {
            byte[] types;
            for (byte b : types = this._train.types()) {
                if (b == 3) continue;
                this.error("_categorical_encoding", "Categorical features must be turned into numeric features. Specify categorical_encoding=\"Eigen\", \"OneHotExplicit\" or \"Binary\"");
            }
        }
        if (this.error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder((ModelBuilder)this);
        }
    }

    private static class RenumberTask
    extends MRTask<RenumberTask> {
        final long[][] _map;

        public RenumberTask(AggregateTask.GIDMapping mapping) {
            this._map = mapping.unsortedList();
        }

        public void map(Chunk c) {
            for (int i = 0; i < c._len; ++i) {
                long old = c.at8(i);
                int pos = ArrayUtils.find((long[])this._map[0], (long)old);
                if (pos < 0) continue;
                long newVal = this._map[1][pos];
                c.set(i, newVal);
            }
        }
    }

    private static class AggregateTask
    extends MRTask<AggregateTask> {
        final double _delta;
        final Key _dataInfoKey;
        final Key _jobKey;
        Exemplar[] _exemplars;
        GIDMapping _mapping;

        public AggregateTask(Key<DataInfo> dataInfoKey, double radius, Key<Job> jobKey) {
            this._delta = radius * radius;
            this._dataInfoKey = dataInfoKey;
            this._jobKey = jobKey;
        }

        public void map(Chunk[] chks) {
            this._mapping = new GIDMapping();
            Exemplar[] es = new Exemplar[4];
            Chunk[] dataChks = Arrays.copyOf(chks, chks.length - 1);
            Chunk assignmentChk = chks[chks.length - 1];
            DataInfo di = (DataInfo)this._dataInfoKey.get();
            assert (di != null);
            DataInfo.Row row = di.newDenseRow();
            int nCols = row.nNums;
            for (int r = 0; r < chks[0]._len; ++r) {
                long rowIndex = chks[0].start() + (long)r;
                row = di.extractDenseRow(dataChks, r, row);
                double[] data = Arrays.copyOf(row.numVals, nCols);
                if (r == 0) {
                    Exemplar ex = new Exemplar(data, rowIndex);
                    es = Exemplar.addExemplar(es, ex);
                    assignmentChk.set(r, ex.gid);
                    continue;
                }
                double distanceToNearestExemplar = Double.MAX_VALUE;
                int closestExemplarIndex = 0;
                int index = 0;
                long gid = -1L;
                for (Exemplar e : es) {
                    if (null == e) break;
                    double distToExemplar = e.squaredEuclideanDistance(data, distanceToNearestExemplar);
                    if (distToExemplar < distanceToNearestExemplar) {
                        distanceToNearestExemplar = distToExemplar;
                        closestExemplarIndex = index;
                        gid = e.gid;
                    }
                    if (distanceToNearestExemplar < this._delta) break;
                    ++index;
                }
                if (distanceToNearestExemplar < this._delta) {
                    ++es[closestExemplarIndex]._cnt;
                    assignmentChk.set(r, gid);
                    continue;
                }
                Exemplar ex = new Exemplar(data, rowIndex);
                es = Exemplar.addExemplar(es, ex);
                assignmentChk.set(r, rowIndex);
            }
            this._exemplars = Exemplar.trim(es);
            assert (this._exemplars.length <= chks[0].len());
            long sum = 0L;
            for (Exemplar e : this._exemplars) {
                sum += e._cnt;
            }
            assert (sum <= (long)chks[0].len());
            ((Job)this._jobKey.get()).update(1L, "Aggregating.");
        }

        public void reduce(AggregateTask mrt) {
            for (int i = 0; i < mrt._mapping.len; ++i) {
                this._mapping.set(mrt._mapping.pairSet[i].first, mrt._mapping.pairSet[i].second);
            }
            Exemplar[] exemplars = mrt._exemplars;
            long localCounts = 0L;
            for (Exemplar e : this._exemplars) {
                localCounts += e._cnt;
            }
            long remoteCounts = 0L;
            for (Exemplar e : mrt._exemplars) {
                remoteCounts += e._cnt;
            }
            for (int r = 0; r < mrt._exemplars.length; ++r) {
                double distanceToNearestExemplar = Double.MAX_VALUE;
                int closestExemplarIndex = 0;
                int index = 0;
                for (Exemplar le : this._exemplars) {
                    if (null == le) break;
                    double distToExemplar = le.squaredEuclideanDistance(mrt._exemplars[r].data, distanceToNearestExemplar);
                    if (distToExemplar < distanceToNearestExemplar) {
                        distanceToNearestExemplar = distToExemplar;
                        closestExemplarIndex = index;
                    }
                    if (distanceToNearestExemplar < this._delta) break;
                    ++index;
                }
                if (distanceToNearestExemplar < this._delta) {
                    this._exemplars[closestExemplarIndex]._cnt += mrt._exemplars[r]._cnt;
                    this._mapping.set(exemplars[r].gid, this._exemplars[closestExemplarIndex].gid);
                    continue;
                }
                this._exemplars = Exemplar.addExemplar(this._exemplars, (Exemplar)IcedUtils.deepCopy((Iced)mrt._exemplars[r]));
            }
            mrt._exemplars = null;
            this._exemplars = Exemplar.trim(this._exemplars);
            assert ((long)this._exemplars.length <= localCounts + remoteCounts);
            long sum = 0L;
            for (Exemplar e : this._exemplars) {
                sum += e._cnt;
            }
            assert (sum == localCounts + remoteCounts);
            ((Job)this._jobKey.get()).update(1L, "Aggregating.");
        }

        private static class GIDMapping
        extends Iced<GIDMapping> {
            MyPair[] pairSet = new MyPair[this.capacity];
            int len = 0;
            int capacity = 32;

            void set(long from, long to) {
                for (int i = 0; i < this.len; ++i) {
                    MyPair p = this.pairSet[i];
                    if (p.second != from) continue;
                    p.second = to;
                }
                MyPair p = new MyPair(from, to);
                if (this.len == this.capacity) {
                    this.capacity *= 2;
                    this.pairSet = Arrays.copyOf(this.pairSet, this.capacity);
                }
                this.pairSet[this.len++] = p;
            }

            long[][] unsortedList() {
                long[][] li = new long[2][this.len];
                MyPair[] pl = this.pairSet;
                for (int i = 0; i < this.len; ++i) {
                    li[0][i] = pl[i].first;
                    li[1][i] = pl[i].second;
                }
                return li;
            }
        }

        static class MyPair
        extends Iced<MyPair>
        implements Comparable<MyPair> {
            long first;
            long second;

            public MyPair(long f, long s) {
                this.first = f;
                this.second = s;
            }

            public MyPair() {
            }

            @Override
            public int compareTo(MyPair o) {
                if (this.first < o.first) {
                    return -1;
                }
                if (this.first == o.first) {
                    return 0;
                }
                return 1;
            }
        }
    }

    class AggregatorDriver
    extends ModelBuilder.Driver {
        AggregatorDriver() {
            super((ModelBuilder)Aggregator.this);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void computeImpl() {
            DataInfo di;
            block7: {
                AggregatorModel model = null;
                di = null;
                try {
                    Aggregator.this.init(true);
                    if (Aggregator.this.error_count() > 0) {
                        throw new IllegalArgumentException("Found validation errors: " + Aggregator.this.validationErrors());
                    }
                    model = new AggregatorModel(Aggregator.this.dest(), (AggregatorModel.AggregatorParameters)Aggregator.this._parms, new AggregatorModel.AggregatorOutput(Aggregator.this));
                    model.delete_and_lock(Aggregator.this._job);
                    Frame orig = Aggregator.this.train();
                    Aggregator.this._job.update(1L, "Preprocessing data.");
                    di = new DataInfo(orig, null, true, ((AggregatorModel.AggregatorParameters)Aggregator.this._parms)._transform, false, false, false);
                    DKV.put((Keyed)di);
                    double radius = ((AggregatorModel.AggregatorParameters)Aggregator.this._parms)._radius_scale * 0.1 / Math.pow(Math.log(orig.numRows()), 1.0 / (double)orig.numCols());
                    Vec[] vecs = Arrays.copyOf(orig.vecs(), orig.vecs().length + 1);
                    Vec vec = orig.anyVec().makeZero();
                    vecs[vecs.length - 1] = vec;
                    Vec assignment = vec;
                    Aggregator.this._job.update(1L, "Aggregating.");
                    AggregateTask aggTask = (AggregateTask)new AggregateTask((Key<DataInfo>)di._key, radius, (Key<Job>)Aggregator.this._job._key).doAll(vecs);
                    Aggregator.this._job.update(1L, "Aggregating exemplar assignments.");
                    new RenumberTask(aggTask._mapping).doAll(new Vec[]{assignment});
                    model._exemplars = aggTask._exemplars;
                    model._counts = new long[aggTask._exemplars.length];
                    for (int i = 0; i < aggTask._exemplars.length; ++i) {
                        model._counts[i] = aggTask._exemplars[i]._cnt;
                    }
                    model._exemplar_assignment_vec_key = assignment._key;
                    ((AggregatorModel.AggregatorOutput)model._output)._output_frame = Key.make((String)("aggregated_" + ((AggregatorModel.AggregatorParameters)Aggregator.this._parms)._train.toString() + "_by_" + model._key));
                    Aggregator.this._job.update(1L, "Creating output frame.");
                    model.createFrameOfExemplars(di._adaptedFrame, ((AggregatorModel.AggregatorOutput)model._output)._output_frame);
                    Aggregator.this._job.update(1L, "Done.");
                    model.update(Aggregator.this._job);
                    if (model == null) break block7;
                }
                catch (Throwable throwable) {
                    if (model != null) {
                        model.unlock(Aggregator.this._job);
                        Scope.untrack((Key[])new Key[]{model._exemplar_assignment_vec_key});
                        Scope.untrack((Key[])((Frame)((AggregatorModel.AggregatorOutput)model._output)._output_frame.get()).keys());
                    }
                    if (di != null) {
                        di.remove();
                    }
                    throw throwable;
                }
                model.unlock(Aggregator.this._job);
                Scope.untrack((Key[])new Key[]{model._exemplar_assignment_vec_key});
                Scope.untrack((Key[])((Frame)((AggregatorModel.AggregatorOutput)model._output)._output_frame.get()).keys());
            }
            if (di != null) {
                di.remove();
            }
        }
    }

    public static class Exemplar
    extends Iced<Exemplar> {
        final double[] data;
        final long gid;
        long _cnt;

        Exemplar(double[] d, long id) {
            this.data = d;
            this.gid = id;
            this._cnt = 1L;
        }

        public static Exemplar[] addExemplar(Exemplar[] es, Exemplar e) {
            int idx;
            Exemplar[] res = es;
            for (idx = es.length - 1; idx >= 0 && null == es[idx]; --idx) {
            }
            if (idx == es.length - 1) {
                res = Arrays.copyOf(es, es.length << 1);
                res[es.length] = e;
                return res;
            }
            res[idx + 1] = e;
            return res;
        }

        public static Exemplar[] trim(Exemplar[] es) {
            int idx = es.length - 1;
            while (null == es[idx]) {
                --idx;
            }
            return Arrays.copyOf(es, idx + 1);
        }

        private double squaredEuclideanDistance(double[] e2, double thresh) {
            double sum = 0.0;
            int n = 0;
            boolean missing = false;
            double[] e1 = this.data;
            double ncols = e1.length;
            int j = 0;
            while ((double)j < ncols) {
                double d1 = e1[j];
                double d2 = e2[j];
                if (!Exemplar.isMissing(d1) && !Exemplar.isMissing(d2)) {
                    double dist = d1 - d2;
                    sum += dist * dist;
                    ++n;
                } else {
                    missing = true;
                }
                if (!missing && sum > thresh) break;
                ++j;
            }
            return sum *= ncols / (double)n;
        }

        private static boolean isMissing(double x) {
            return Double.isNaN(x);
        }
    }
}

