/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform.visualization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.visualization.VisualizationTransform;
import jsat.distributions.Normal;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.vectorcollection.VPTree;
import jsat.linear.vectorcollection.VPTreeMV;
import jsat.math.FastMath;
import jsat.math.FunctionBase;
import jsat.math.optimization.stochastic.Adam;
import jsat.math.rootfinding.Zeroin;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDouble;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class TSNE
implements VisualizationTransform {
    private double alpha = 4.0;
    private double exageratedPortion = 0.25;
    private DistanceMetric dm = new EuclideanDistance();
    private int T = 1000;
    private double perplexity = 30.0;
    private double theta = 0.5;
    private int s = 2;

    public void setAlpha(double alpha) {
        if (alpha <= 0.0 || Double.isNaN(alpha) || Double.isInfinite(alpha)) {
            throw new IllegalArgumentException("alpha must be positive, not " + alpha);
        }
        this.alpha = alpha;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setPerplexity(double perplexity) {
        if (perplexity <= 0.0 || Double.isNaN(perplexity) || Double.isInfinite(perplexity)) {
            throw new IllegalArgumentException("perplexity must be positive, not " + perplexity);
        }
        this.perplexity = perplexity;
    }

    public double getPerplexity() {
        return this.perplexity;
    }

    public void setIterations(int T) {
        if (T <= 1) {
            throw new IllegalArgumentException("number of iterations must be positive, not " + T);
        }
        this.T = T;
    }

    public int getIterations() {
        return this.T;
    }

    @Override
    public <Type extends DataSet> Type transform(DataSet<Type> d) {
        return this.transform(d, new FakeExecutor());
    }

    @Override
    public <Type extends DataSet> Type transform(DataSet<Type> d, ExecutorService ex) {
        Random rand = RandomUtil.getRandom();
        final int N = d.getSampleSize();
        final int knn = (int)Math.min(Math.floor(3.0 * this.perplexity), (double)(N - 1));
        final double[][] nearMePij = new double[N][knn];
        final int[][] nearMe = new int[N][knn];
        TSNE.computeP(d, ex, rand, knn, nearMe, nearMePij, this.dm, this.perplexity);
        Normal normalDIst = new Normal(0.0, 1.0E-4);
        final double[] y = normalDIst.sample(N * this.s, rand);
        final double[] y_grad = new double[y.length];
        DenseVector y_vec = DenseVector.toDenseVec(y);
        DenseVector y_grad_vec = DenseVector.toDenseVec(y_grad);
        Adam gradUpdater = new Adam();
        gradUpdater.setup(y.length);
        for (int iter = 0; iter < this.T; ++iter) {
            final int ITER = iter;
            Arrays.fill(y_grad, 0.0);
            final Quadtree qt = new Quadtree(y);
            final AtomicDouble Z = new AtomicDouble(0.0);
            final CountDownLatch latch_g0 = new CountDownLatch(SystemInfo.LogicalCores);
            int id = 0;
            while (id < SystemInfo.LogicalCores) {
                final int ID = id++;
                ex.submit(new Runnable(){

                    @Override
                    public void run() {
                        double[] workSpace = new double[TSNE.this.s];
                        double local_Z = 0.0;
                        for (int i = ID; i < N; i += SystemInfo.LogicalCores) {
                            Arrays.fill(workSpace, 0.0);
                            local_Z += TSNE.this.computeF_rep(qt.root, i, y, workSpace);
                            for (int k = 0; k < TSNE.this.s; ++k) {
                                TSNE.inc_z_ij(workSpace[k], i, k, y_grad, TSNE.this.s);
                            }
                        }
                        Z.addAndGet(local_Z);
                        latch_g0.countDown();
                    }
                });
            }
            try {
                latch_g0.await();
            }
            catch (InterruptedException ex1) {
                Logger.getLogger(TSNE.class.getName()).log(Level.SEVERE, null, ex1);
            }
            double zNorm = 4.0 / (Z.get() + 1.0E-13);
            int i = 0;
            while (i < y.length) {
                int n = i++;
                y_grad[n] = y_grad[n] * zNorm;
            }
            final CountDownLatch latch_g1 = new CountDownLatch(SystemInfo.LogicalCores);
            int id2 = 0;
            while (id2 < SystemInfo.LogicalCores) {
                final int ID = id2++;
                ex.submit(new Runnable(){

                    @Override
                    public void run() {
                        int start = ParallelUtils.getStartBlock(N, ID, SystemInfo.LogicalCores);
                        int end = ParallelUtils.getEndBlock(N, ID, SystemInfo.LogicalCores);
                        for (int i = start; i < end; ++i) {
                            for (int j_indx = 0; j_indx < knn; ++j_indx) {
                                int j = nearMe[i][j_indx];
                                if (i == j) continue;
                                double pij = nearMePij[i][j_indx];
                                if ((double)ITER < (double)TSNE.this.T * TSNE.this.exageratedPortion) {
                                    pij *= TSNE.this.alpha;
                                }
                                double cnst = pij * TSNE.q_ijZ(i, j, y, TSNE.this.s) * 4.0;
                                for (int k = 0; k < TSNE.this.s; ++k) {
                                    double diff = TSNE.z_ij(i, k, y, TSNE.this.s) - TSNE.z_ij(j, k, y, TSNE.this.s);
                                    TSNE.inc_z_ij(cnst * diff, i, k, y_grad, TSNE.this.s);
                                }
                            }
                        }
                        latch_g1.countDown();
                    }
                });
            }
            try {
                latch_g1.await();
            }
            catch (InterruptedException ex1) {
                Logger.getLogger(TSNE.class.getName()).log(Level.SEVERE, null, ex1);
            }
            double eta = 200.0;
            gradUpdater.update(y_vec, y_grad_vec, eta);
        }
        DataSet<Type> transformed = d.shallowClone();
        final IdentityHashMap<DataPoint, Integer> indexMap = new IdentityHashMap<DataPoint, Integer>(N);
        for (int i = 0; i < N; ++i) {
            indexMap.put(d.getDataPoint(i), i);
        }
        transformed.applyTransform(new DataTransform(){

            @Override
            public DataPoint transform(DataPoint dp) {
                int i = (Integer)indexMap.get(dp);
                DenseVector dv = new DenseVector(TSNE.this.s);
                for (int k = 0; k < TSNE.this.s; ++k) {
                    dv.set(k, y[i * 2 + k]);
                }
                return new DataPoint(dv, dp.getCategoricalValues(), dp.getCategoricalData(), dp.getWeight());
            }

            @Override
            public void fit(DataSet data) {
            }

            @Override
            public DataTransform clone() {
                return this;
            }
        });
        return (Type)transformed;
    }

    protected static void computeP(DataSet d, ExecutorService ex, Random rand, final int knn, final int[][] nearMe, final double[][] nearMePij, final DistanceMetric dm, final double perplexity) {
        final List<Vec> vecs = d.getDataVectors();
        final List<Double> accelCache = dm.getAccelerationCache(vecs, ex);
        final int N = vecs.size();
        final VPTreeMV<Vec> vp = new VPTreeMV<Vec>(vecs, dm, VPTree.VPSelection.Random, rand, 2, 1, ex);
        final ArrayList neighbors = new ArrayList(N);
        for (int i = 0; i < N; ++i) {
            neighbors.add(null);
        }
        final IdentityHashMap<Vec, Integer> vecIndex = new IdentityHashMap<Vec, Integer>(N);
        for (int i = 0; i < N; ++i) {
            vecIndex.put(vecs.get(i), i);
        }
        final CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
        int id = 0;
        while (id < SystemInfo.LogicalCores) {
            final int ID = id++;
            ex.submit(new Runnable(){

                @Override
                public void run() {
                    for (int i = ID; i < N; i += SystemInfo.LogicalCores) {
                        Vec x_i = (Vec)vecs.get(i);
                        List closest = vp.search(x_i, knn + 1);
                        neighbors.set(i, closest);
                        for (int j = 1; j < closest.size(); ++j) {
                            nearMe[i][j - 1] = (Integer)vecIndex.get(closest.get(j).getVector());
                        }
                    }
                    latch.countDown();
                }
            });
        }
        try {
            latch.await();
        }
        catch (InterruptedException ex1) {
            Logger.getLogger(TSNE.class.getName()).log(Level.SEVERE, null, ex1);
        }
        final double[] sigma = new double[N];
        final AtomicDouble minSigma = new AtomicDouble(Double.POSITIVE_INFINITY);
        final AtomicDouble maxSigma = new AtomicDouble(0.0);
        for (int i = 0; i < N; ++i) {
            List n_i = (List)neighbors.get(i);
            double min = (Double)((VecPaired)n_i.get(1)).getPair();
            double max = (Double)((VecPaired)n_i.get(Math.min(knn, n_i.size() - 1))).getPair();
            minSigma.set(Math.min(minSigma.get(), Math.max(min, 1.0E-9)));
            maxSigma.set(Math.max(maxSigma.get(), max));
        }
        final CountDownLatch latch0 = new CountDownLatch(SystemInfo.LogicalCores);
        int id2 = 0;
        while (id2 < SystemInfo.LogicalCores) {
            final int ID = id2++;
            ex.submit(new Runnable(){

                @Override
                public void run() {
                    for (int i = ID; i < N; i += SystemInfo.LogicalCores) {
                        final int I = i;
                        boolean tryAgain = false;
                        do {
                            tryAgain = false;
                            try {
                                double sigma_i;
                                sigma[i] = sigma_i = Zeroin.root(0.01, 100, minSigma.get(), maxSigma.get(), 0, new FunctionBase(){

                                    @Override
                                    public double f(Vec x) {
                                        return TSNE.perp(I, nearMe, x.get(0), neighbors, vecs, accelCache, dm) - perplexity;
                                    }
                                }, new double[0]);
                            }
                            catch (ArithmeticException exception) {
                                if (maxSigma.get() >= 8.988465674311579E307) {
                                    sigma[i] = 1.0E100;
                                    continue;
                                }
                                tryAgain = true;
                                minSigma.set(Math.max(minSigma.get() / 2.0, 1.0E-6));
                                maxSigma.set(Math.min(maxSigma.get() * 2.0, 8.988465674311579E307));
                            }
                        } while (tryAgain);
                    }
                    latch0.countDown();
                }
            });
        }
        try {
            latch0.await();
        }
        catch (InterruptedException ex1) {
            Logger.getLogger(TSNE.class.getName()).log(Level.SEVERE, null, ex1);
        }
        final CountDownLatch latch1 = new CountDownLatch(SystemInfo.LogicalCores);
        int id3 = 0;
        while (id3 < SystemInfo.LogicalCores) {
            final int ID = id3++;
            ex.submit(new Runnable(){

                @Override
                public void run() {
                    for (int i = ID; i < N; i += SystemInfo.LogicalCores) {
                        for (int j_indx = 0; j_indx < knn; ++j_indx) {
                            int j = nearMe[i][j_indx];
                            nearMePij[i][j_indx] = TSNE.p_ij(i, j, sigma[i], sigma[j], neighbors, vecs, accelCache, dm);
                        }
                    }
                    latch1.countDown();
                }
            });
        }
        try {
            latch1.await();
        }
        catch (InterruptedException ex1) {
            Logger.getLogger(TSNE.class.getName()).log(Level.SEVERE, null, ex1);
        }
    }

    private double computeF_rep(Quadtree.Node node, int i, double[] z, double[] workSpace) {
        if (node == null || node.N_cell == 0 || node.indx == i) {
            return 0.0;
        }
        double x = z[i * 2];
        double y = z[i * 2 + 1];
        double r_cell = Math.max(node.maxX - node.minX, node.maxY - node.minY);
        r_cell *= r_cell;
        double mass_x = node.x_mass / (double)node.N_cell;
        double mass_y = node.y_mass / (double)node.N_cell;
        double dot = (mass_x - x) * (mass_x - x) + (mass_y - y) * (mass_y - y);
        if (node.NW == null || r_cell < this.theta * dot) {
            if (node.indx == i) {
                return 0.0;
            }
            double Z = 1.0 / (1.0 + dot);
            double q_cell_Z_sqrd = (double)(-node.N_cell) * (Z * Z);
            workSpace[0] = workSpace[0] + q_cell_Z_sqrd * (x - mass_x);
            workSpace[1] = workSpace[1] + q_cell_Z_sqrd * (y - mass_y);
            return Z * (double)node.N_cell;
        }
        double Z_sum = 0.0;
        for (Quadtree.Node child : node) {
            Z_sum += this.computeF_rep(child, i, z, workSpace);
        }
        return Z_sum;
    }

    private static void inc_z_ij(double val, int i, int j, double[] z, int s) {
        int n = i * s + j;
        z[n] = z[n] + val;
    }

    private static double z_ij(int i, int j, double[] z, int s) {
        return z[i * s + j];
    }

    private static double q_ijZ(int i, int j, double[] z, int s) {
        double denom = 1.0;
        for (int k = 0; k < s; ++k) {
            double diff = TSNE.z_ij(i, k, z, s) - TSNE.z_ij(j, k, z, s);
            denom += diff * diff;
        }
        return 1.0 / denom;
    }

    private static double p_j_i(int j, int i, double sigma, List<List<? extends VecPaired<Vec, Double>>> neighbors, List<Vec> vecs, List<Double> accelCache, DistanceMetric dm) {
        if (i == j) {
            return 0.0;
        }
        Vec x_j = neighbors.get(j).get(0).getVector();
        double sigmaSqrdInv = 1.0 / (2.0 * (sigma * sigma));
        double numer = 0.0;
        double denom = 0.0;
        boolean jIsNearBy = false;
        List<? extends VecPaired<Vec, Double>> neighbors_i = neighbors.get(i);
        for (int k = 1; k < neighbors_i.size(); ++k) {
            VecPaired<Vec, Double> neighbor_ik = neighbors_i.get(k);
            double d_ik = neighbor_ik.getPair();
            denom += FastMath.exp(-(d_ik * d_ik) * sigmaSqrdInv);
            if (neighbor_ik.getVector() != x_j) continue;
            jIsNearBy = true;
            numer = FastMath.exp(-(d_ik * d_ik) * sigmaSqrdInv);
        }
        if (!jIsNearBy) {
            double d_ij = dm.dist(i, j, vecs, accelCache);
            numer = FastMath.exp(-(d_ij * d_ij) * sigmaSqrdInv);
        }
        return numer / (denom + 1.0E-9);
    }

    private static double p_ij(int i, int j, double sigma_i, double sigma_j, List<List<? extends VecPaired<Vec, Double>>> neighbors, List<Vec> vecs, List<Double> accelCache, DistanceMetric dm) {
        return (TSNE.p_j_i(j, i, sigma_i, neighbors, vecs, accelCache, dm) + TSNE.p_j_i(i, j, sigma_j, neighbors, vecs, accelCache, dm)) / (double)(2 * neighbors.size());
    }

    private static double perp(int i, int[][] nearMe, double sigma, List<List<? extends VecPaired<Vec, Double>>> neighbors, List<Vec> vecs, List<Double> accelCache, DistanceMetric dm) {
        double hp = 0.0;
        for (int j_indx = 0; j_indx < nearMe[i].length; ++j_indx) {
            double p_ji = TSNE.p_j_i(nearMe[i][j_indx], i, sigma, neighbors, vecs, accelCache, dm);
            if (!(p_ji > 0.0)) continue;
            hp += p_ji * FastMath.log2(p_ji);
        }
        return FastMath.pow2(hp *= -1.0);
    }

    @Override
    public int getTargetDimension() {
        return 2;
    }

    @Override
    public boolean setTargetDimension(int target) {
        return target == 2;
    }

    private class Quadtree {
        public Node root = new Node();

        public Quadtree(double[] z) {
            int i;
            this.root.minY = Double.POSITIVE_INFINITY;
            this.root.minX = Double.POSITIVE_INFINITY;
            this.root.maxY = Double.NEGATIVE_INFINITY;
            this.root.maxX = Double.NEGATIVE_INFINITY;
            for (i = 0; i < z.length / 2; ++i) {
                double x = z[i * 2];
                double y = z[i * 2 + 1];
                this.root.minX = Math.min(this.root.minX, x);
                this.root.maxX = Math.max(this.root.maxX, x);
                this.root.minY = Math.min(this.root.minY, y);
                this.root.maxY = Math.max(this.root.maxY, y);
            }
            this.root.maxX = Math.nextUp(this.root.maxX);
            this.root.maxY = Math.nextUp(this.root.maxY);
            for (i = 0; i < z.length / 2; ++i) {
                this.root.insert(1, i, z);
            }
        }

        private class Node
        implements Iterable<Node> {
            public int indx = -1;
            public double x_mass = 0.0;
            public double y_mass = 0.0;
            public int N_cell = 0;
            public double minX;
            public double maxX;
            public double minY;
            public double maxY;
            public Node NW = null;
            public Node NE = null;
            public Node SE = null;
            public Node SW = null;

            public Node() {
            }

            public Node(double minX, double maxX, double minY, double maxY) {
                this();
                this.minX = minX;
                this.maxX = maxX;
                this.minY = minY;
                this.maxY = maxY;
            }

            public boolean contains(int i, double[] z) {
                double x = z[i * 2];
                double y = z[i * 2 + 1];
                return this.minX <= x && x < this.maxX && this.minY <= y && y < this.maxY;
            }

            public void insert(int weight, int i, double[] z) {
                this.x_mass += z[i * 2];
                this.y_mass += z[i * 2 + 1];
                this.N_cell += weight;
                if (this.NW == null && this.indx < 0) {
                    this.indx = i;
                } else {
                    if (this.indx >= 0 && Math.abs(z[this.indx * 2] - z[i * 2]) < 1.0E-13 && Math.abs(z[this.indx * 2 + 1] - z[i * 2 + 1]) < 1.0E-13) {
                        return;
                    }
                    if (this.NW == null) {
                        double w2 = (this.maxX - this.minX) / 2.0;
                        double h2 = (this.maxY - this.minY) / 2.0;
                        this.NW = new Node(this.minX, this.minX + w2, this.minY + h2, this.maxY);
                        this.NE = new Node(this.minX + w2, this.maxX, this.minY + h2, this.maxY);
                        this.SW = new Node(this.minX, this.minX + w2, this.minY, this.minY + h2);
                        this.SE = new Node(this.minX + w2, this.maxX, this.minY, this.minY + h2);
                        for (Node child : this) {
                            if (!child.contains(this.indx, z)) continue;
                            child.insert(this.N_cell, this.indx, z);
                            break;
                        }
                        this.indx = -1;
                    }
                    for (Node child : this) {
                        if (!child.contains(i, z)) continue;
                        child.insert(weight, i, z);
                        break;
                    }
                }
            }

            public double diagLen() {
                double w = this.maxX - this.minX;
                double h = this.maxY - this.minY;
                return Math.sqrt(w * w + h * h);
            }

            @Override
            public Iterator<Node> iterator() {
                if (this.NW == null) {
                    return Collections.emptyIterator();
                }
                return Arrays.asList(this.NW, this.NE, this.SW, this.SE).iterator();
            }
        }
    }
}

