/*
 * Decompiled with CFR 0.152.
 */
package elki.projection;

import elki.data.DoubleVector;
import elki.data.FeatureVector;
import elki.data.type.SimpleTypeInformation;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.data.type.VectorFieldTypeInformation;
import elki.database.Database;
import elki.database.datastore.DataStore;
import elki.database.datastore.DataStoreFactory;
import elki.database.datastore.WritableDataStore;
import elki.database.ids.DBIDArrayIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDs;
import elki.database.relation.MaterializedRelation;
import elki.database.relation.Relation;
import elki.logging.Logging;
import elki.logging.progress.AbstractProgress;
import elki.logging.progress.FiniteProgress;
import elki.logging.statistics.Duration;
import elki.logging.statistics.LongStatistic;
import elki.logging.statistics.Statistic;
import elki.math.MathUtil;
import elki.projection.AffinityMatrix;
import elki.projection.AffinityMatrixBuilder;
import elki.projection.NearestNeighborAffinityMatrixBuilder;
import elki.projection.TSNE;
import elki.utilities.Priority;
import elki.utilities.documentation.Reference;
import elki.utilities.documentation.Title;
import elki.utilities.exceptions.AbortException;
import elki.utilities.io.FormatUtil;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.constraints.ParameterConstraint;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.random.RandomFactory;
import java.util.ArrayList;
import java.util.Arrays;

@Title(value="t-SNE using Barnes-Hut-Approximation")
@Reference(authors="L. J. P. van der Maaten", title="Accelerating t-SNE using Tree-Based Algorithms", booktitle="Journal of Machine Learning Research 15", url="http://dl.acm.org/citation.cfm?id=2697068", bibkey="DBLP:journals/jmlr/Maaten14")
@Priority(value=199)
public class BarnesHutTSNE<O>
extends TSNE<O> {
    private static final Logging LOG = Logging.getLogger(BarnesHutTSNE.class);
    protected static final double PERPLEXITY_ERROR = 1.0E-4;
    protected static final int PERPLEXITY_MAXITER = 25;
    private static final double QUADTREE_MIN_RESOLUION = 1.0E-10;
    protected double sqtheta;

    public BarnesHutTSNE(AffinityMatrixBuilder<? super O> affinity, int dim, double finalMomentum, double learningRate, int maxIterations, RandomFactory random, boolean keep, double theta) {
        super(affinity, dim, finalMomentum, learningRate * 4.0, maxIterations, random, keep);
        this.sqtheta = theta * theta;
    }

    public Relation<DoubleVector> run(Database database, Relation<O> relation) {
        AffinityMatrix neighbors = this.affinity.computeAffinityMatrix(relation, 4.0);
        double[][] solution = BarnesHutTSNE.randomInitialSolution(neighbors.size(), this.dim, this.random.getSingleThreadedRandom());
        this.projectedDistances = 0L;
        this.optimizetSNE(neighbors, solution);
        LOG.statistics((Statistic)new LongStatistic(this.getClass().getName() + ".projected-distances", this.projectedDistances));
        this.removePreviousRelation(relation);
        DBIDs ids = relation.getDBIDs();
        WritableDataStore proj = DataStoreFactory.FACTORY.makeStorage(ids, 30, DoubleVector.class);
        VectorFieldTypeInformation otype = new VectorFieldTypeInformation((FeatureVector.Factory)DoubleVector.FACTORY, this.dim);
        DBIDArrayIter it = neighbors.iterDBIDs();
        while (it.valid()) {
            proj.put((DBIDRef)it, (Object)DoubleVector.wrap((double[])solution[it.getOffset()]));
            it.advance();
        }
        return new MaterializedRelation("Barnes-Hut t-SNE", (SimpleTypeInformation)otype, ids, (DataStore)proj);
    }

    @Override
    protected void optimizetSNE(AffinityMatrix pij, double[][] sol) {
        int size = pij.size();
        if ((long)size * 3L * (long)this.dim > 0x7FFFFFFAL) {
            throw new AbortException("Memory exceeds Java array size limit.");
        }
        double[] meta = new double[size * 3 * this.dim];
        int dim3 = this.dim * 3;
        for (int off = 2 * this.dim; off < meta.length; off += dim3) {
            Arrays.fill(meta, off, off + this.dim, 1.0);
        }
        FiniteProgress prog = LOG.isVerbose() ? new FiniteProgress("Iterative Optimization", this.iterations, LOG) : null;
        Duration timer = LOG.isStatistics() ? LOG.newDuration(this.getClass().getName() + ".runtime.optimization").begin() : null;
        for (int i = 0; i < this.iterations; ++i) {
            this.computeGradient(pij, sol, meta);
            this.updateSolution(sol, meta, i);
            if (i == 50) {
                pij.scale(0.25);
            }
            LOG.incrementProcessed((AbstractProgress)prog);
        }
        LOG.ensureCompleted(prog);
        if (timer != null) {
            LOG.statistics((Statistic)timer.end());
        }
    }

    private void computeGradient(AffinityMatrix pij, double[][] solution, double[] grad) {
        int dim3 = 3 * this.dim;
        for (int off = 0; off < grad.length; off += dim3) {
            Arrays.fill(grad, off, off + this.dim, 0.0);
        }
        QuadTree tree = QuadTree.build(this.dim, solution);
        double z = 0.0;
        int i = 0;
        int off = 0;
        while (i < solution.length) {
            z -= this.computeRepulsiveForces(grad, off, solution[i], tree);
            ++i;
            off += dim3;
        }
        double s = 1.0 / z;
        for (int off2 = 0; off2 < grad.length; off2 += dim3) {
            for (int j = 0; j < this.dim; ++j) {
                int n = off2 + j;
                grad[n] = grad[n] * s;
            }
        }
        this.computeAttractiveForces(grad, pij, solution);
    }

    private void computeAttractiveForces(double[] attr, AffinityMatrix pij, double[][] sol) {
        int dim3 = 3 * this.dim;
        int i = 0;
        for (int off = 0; off < attr.length; off += dim3) {
            double[] sol_i = sol[i];
            int offj = pij.iter(i);
            while (pij.iterValid(i, offj)) {
                double[] sol_j = sol[pij.iterDim(i, offj)];
                double pij_ij = pij.iterValue(i, offj);
                double a = pij_ij / (1.0 + this.sqDist(sol_i, sol_j));
                for (int k = 0; k < this.dim; ++k) {
                    int n = off + k;
                    attr[n] = attr[n] + a * (sol_i[k] - sol_j[k]);
                }
                offj = pij.iterAdvance(i, offj);
            }
            ++i;
        }
    }

    private double computeRepulsiveForces(double[] rep_i, int off, double[] sol_i, QuadTree node) {
        double[] center = node.center;
        double dist = this.sqDist(sol_i, center);
        if (node.weight == 1 || node.squareSize / dist < this.sqtheta) {
            double u = 1.0 / (1.0 + dist);
            double d = (double)node.weight * u;
            double a = d * u;
            for (int k = 0; k < this.dim; ++k) {
                int n = off + k;
                rep_i[n] = rep_i[n] + a * (sol_i[k] - center[k]);
            }
            return d;
        }
        double z = 0.0;
        if (node.points != null) {
            for (double[] point : node.points) {
                double pdist = this.sqDist(sol_i, point);
                double pz = 1.0 / (1.0 + pdist);
                double a = pz * pz;
                for (int k = 0; k < this.dim; ++k) {
                    int n = off + k;
                    rep_i[n] = rep_i[n] + a * (sol_i[k] - point[k]);
                }
                z += pz;
            }
        }
        if (node.children != null) {
            for (QuadTree child : node.children) {
                z += this.computeRepulsiveForces(rep_i, off, sol_i, child);
            }
        }
        return z;
    }

    @Override
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array((TypeInformation[])new TypeInformation[]{this.affinity.getInputTypeRestriction()});
    }

    public static class Par<O>
    extends TSNE.Par<O> {
        public static final OptionID THETA_ID = new OptionID("tsne.theta", "Approximation quality parameter");
        public double theta;

        @Override
        public void configure(Parameterization config) {
            super.configure(config);
            ((DoubleParameter)((DoubleParameter)new DoubleParameter(THETA_ID).setDefaultValue((Object)0.5)).addConstraint((ParameterConstraint)CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE)).grab(config, x -> {
                this.theta = x;
            });
        }

        @Override
        protected Class<?> getDefaultAffinity() {
            return NearestNeighborAffinityMatrixBuilder.class;
        }

        @Override
        public BarnesHutTSNE<O> make() {
            return new BarnesHutTSNE(this.affinity, this.dim, this.finalMomentum, this.learningRate, this.iterations, this.random, this.keep, this.theta);
        }
    }

    protected static class QuadTree {
        public double[] center;
        public double[][] points;
        public double squareSize;
        public int weight;
        public QuadTree[] children;

        private QuadTree(double[][] data, QuadTree[] children, double[] mid, int weight, double squareSize) {
            this.center = mid;
            this.points = data;
            this.weight = weight;
            this.squareSize = squareSize;
            this.children = children;
        }

        public static QuadTree build(int dim, double[][] data) {
            return QuadTree.build(dim, (double[][])data.clone(), 0, data.length);
        }

        private static QuadTree build(int dim, double[][] data, int begin, int end) {
            double[] minmax = QuadTree.computeExtend(dim, data, begin, end);
            double squareSize = QuadTree.computeSquareSize(minmax);
            double[] mid = QuadTree.computeCenterofMass(dim, data, begin, end);
            int size = end - begin;
            if (squareSize <= 1.0E-10) {
                data = (double[][])Arrays.copyOfRange(data, begin, end);
                return new QuadTree(data, null, mid, size, squareSize);
            }
            ArrayList<double[]> singletons = new ArrayList<double[]>();
            ArrayList<QuadTree> children = new ArrayList<QuadTree>();
            QuadTree.splitRecursively(data, begin, end, 0, dim, minmax, singletons, children);
            double[][] sing = singletons.size() > 0 ? (double[][])singletons.toArray((T[])new double[singletons.size()][]) : null;
            QuadTree[] chil = children.size() > 0 ? children.toArray(new QuadTree[children.size()]) : null;
            return new QuadTree(sing, chil, mid, size, squareSize);
        }

        private static void splitRecursively(double[][] data, int begin, int end, int initdim, int dims, double[] minmax, ArrayList<double[]> singletons, ArrayList<QuadTree> children) {
            double max;
            int dim2;
            double min;
            int len = end - begin;
            if (len <= 1) {
                if (len == 1) {
                    singletons.add(data[begin]);
                }
                return;
            }
            double mid = Double.NaN;
            int cur = initdim;
            while (!((min = minmax[dim2 = cur << 1]) < (mid = 0.5 * (min + (max = minmax[dim2 + 1]))))) {
                if (++cur != dims) continue;
                LOG.warning((CharSequence)"Should not be reached", new Throwable());
                assert (initdim != 0) : "All dimensions constant?";
                LOG.warning((CharSequence)"Unexpected all-constant split.");
                double[] center = QuadTree.computeCenterofMass(dims, data, begin, end);
                data = (double[][])Arrays.copyOfRange(data, begin, end);
                children.add(new QuadTree(data, null, center, len, 0.0));
                return;
            }
            int l = begin;
            int r = end - 1;
            while (l <= r) {
                while (l <= r && data[l][cur] <= mid) {
                    ++l;
                }
                while (l <= r && data[r][cur] >= mid) {
                    --r;
                }
                if (l >= r) continue;
                assert (data[l][cur] > mid);
                assert (data[r][cur] < mid);
                double[] tmp = data[r];
                data[r] = data[l];
                data[l] = tmp;
                ++l;
                --r;
            }
            assert (l == end || data[l][cur] >= mid);
            assert (l == begin || data[l - 1][cur] <= mid);
            if (++cur < dims) {
                if (begin < l) {
                    QuadTree.splitRecursively(data, begin, l, cur, dims, minmax, singletons, children);
                }
                if (l < end) {
                    QuadTree.splitRecursively(data, l, end, cur, dims, minmax, singletons, children);
                }
                return;
            }
            if (begin < l) {
                children.add(QuadTree.build(dims, data, begin, l));
            }
            if (l < end) {
                children.add(QuadTree.build(dims, data, l, end));
            }
        }

        private static double[] computeCenterofMass(int dim, double[][] data, int begin, int end) {
            int d;
            int size = end - begin;
            if (size == 1) {
                return data[begin];
            }
            double[] center = new double[dim];
            for (int i = begin; i < end; ++i) {
                double[] row = data[i];
                for (d = 0; d < dim; ++d) {
                    int n = d;
                    center[n] = center[n] + row[d];
                }
            }
            double norm = 1.0 / (double)size;
            d = 0;
            while (d < dim) {
                int n = d++;
                center[n] = center[n] * norm;
            }
            return center;
        }

        private static double[] computeExtend(int dim, double[][] data, int begin, int end) {
            double[] minmax = new double[dim << 1];
            int d = 0;
            while (d < minmax.length) {
                minmax[d++] = Double.POSITIVE_INFINITY;
                minmax[d++] = Double.NEGATIVE_INFINITY;
            }
            for (int i = begin; i < end; ++i) {
                double[] row = data[i];
                int d2 = 0;
                for (int d3 = 0; d3 < dim; ++d3) {
                    double v = row[d3];
                    minmax[d2] = MathUtil.min((double)minmax[d2], (double)v);
                    minmax[++d2] = MathUtil.max((double)minmax[d2], (double)v);
                    ++d2;
                }
            }
            return minmax;
        }

        private static double computeSquareSize(double[] minmax) {
            double max = 0.0;
            int e = minmax.length - 1;
            for (int d = 0; d < e; d += 2) {
                double width = minmax[d + 1] - minmax[d];
                max += width * width;
            }
            return max;
        }

        public String toString() {
            return "QuadTree[center=" + FormatUtil.format((double[])this.center) + ", weight=" + this.weight + ", points=" + this.points.length + ", children=" + this.children.length + ", sqSize=" + this.squareSize + "]";
        }
    }
}

