/*
 * Decompiled with CFR 0.152.
 */
package elki.projection;

import elki.database.ids.ArrayDBIDs;
import elki.database.ids.DBIDArrayIter;
import elki.database.ids.DBIDRange;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DoubleDBIDListIter;
import elki.database.ids.KNNList;
import elki.database.query.LinearScanQuery;
import elki.database.query.QueryBuilder;
import elki.database.query.knn.KNNSearcher;
import elki.database.relation.Relation;
import elki.distance.Distance;
import elki.logging.Logging;
import elki.logging.progress.AbstractProgress;
import elki.logging.progress.FiniteProgress;
import elki.logging.statistics.DoubleStatistic;
import elki.logging.statistics.Duration;
import elki.logging.statistics.Statistic;
import elki.math.MathUtil;
import elki.math.MeanVariance;
import elki.projection.AffinityMatrix;
import elki.projection.PerplexityAffinityMatrixBuilder;
import elki.projection.SparseAffinityMatrix;
import elki.utilities.datastructures.arraylike.DoubleArray;
import elki.utilities.datastructures.arraylike.IntegerArray;
import elki.utilities.documentation.Reference;
import elki.utilities.exceptions.AbortException;
import net.jafama.FastMath;

@Reference(authors="L. J. P. van der Maaten", title="Accelerating t-SNE using Tree-Based Algorithms", booktitle="Journal of Machine Learning Research 15", url="http://dl.acm.org/citation.cfm?id=2697068", bibkey="DBLP:journals/jmlr/Maaten14")
public class NearestNeighborAffinityMatrixBuilder<O>
extends PerplexityAffinityMatrixBuilder<O> {
    private static final Logging LOG = Logging.getLogger(NearestNeighborAffinityMatrixBuilder.class);
    protected int numberOfNeighbours;

    public NearestNeighborAffinityMatrixBuilder(Distance<? super O> distance, double perplexity) {
        super(distance, perplexity);
        this.numberOfNeighbours = (int)FastMath.ceil((double)(3.0 * perplexity));
    }

    public NearestNeighborAffinityMatrixBuilder(Distance<? super O> distance, double perplexity, int neighbors) {
        super(distance, perplexity);
        this.numberOfNeighbours = neighbors;
    }

    @Override
    public <T extends O> AffinityMatrix computeAffinityMatrix(Relation<T> relation, double initialScale) {
        KNNSearcher knnq = new QueryBuilder(relation, this.distance).kNNByDBID(this.numberOfNeighbours + 1);
        if (knnq instanceof LinearScanQuery && this.numberOfNeighbours * this.numberOfNeighbours < relation.size()) {
            LOG.warning((CharSequence)"To accelerate Barnes-Hut tSNE, please use an index.");
        }
        if (!(relation.getDBIDs() instanceof DBIDRange)) {
            throw new AbortException("Distance matrixes are currently only supported for DBID ranges (as used by static databases) for performance reasons (Patches welcome).");
        }
        DBIDRange rids = (DBIDRange)relation.getDBIDs();
        int size = rids.size();
        double[][] pij = new double[size][];
        int[][] indices = new int[size][];
        boolean square = !this.distance.isSquared();
        this.computePij(rids, (KNNSearcher<DBIDRef>)knnq, square, this.numberOfNeighbours, pij, indices, initialScale);
        SparseAffinityMatrix mat = new SparseAffinityMatrix(pij, indices, (ArrayDBIDs)rids);
        return mat;
    }

    protected void computePij(DBIDRange ids, KNNSearcher<DBIDRef> knnq, boolean square, int numberOfNeighbours, double[][] pij, int[][] indices, double initialScale) {
        Duration timer = LOG.newDuration(this.getClass().getName() + ".runtime.neighborspijmatrix").begin();
        double logPerp = FastMath.log((double)this.perplexity);
        DoubleArray dists = new DoubleArray(numberOfNeighbours + 10);
        IntegerArray inds = new IntegerArray(numberOfNeighbours + 10);
        FiniteProgress prog = LOG.isVerbose() ? new FiniteProgress("Finding neighbors and optimizing perplexity", ids.size(), LOG) : null;
        MeanVariance mv = LOG.isStatistics() ? new MeanVariance() : null;
        DBIDArrayIter ix = ids.iter();
        while (ix.valid()) {
            dists.clear();
            inds.clear();
            KNNList neighbours = knnq.getKNN((Object)ix, numberOfNeighbours + 1);
            this.convertNeighbors(ids, (DBIDRef)ix, square, neighbours, dists, inds);
            double[] dArray = new double[dists.size()];
            pij[ix.getOffset()] = dArray;
            double beta = NearestNeighborAffinityMatrixBuilder.computeSigma(ix.getOffset(), dists, this.perplexity, logPerp, dArray);
            if (mv != null) {
                mv.put(beta > 0.0 ? Math.sqrt(0.5 / beta) : 0.0);
            }
            indices[ix.getOffset()] = inds.toArray();
            LOG.incrementProcessed((AbstractProgress)prog);
            ix.advance();
        }
        LOG.ensureCompleted(prog);
        double sum = 0.0;
        for (int i = 0; i < pij.length; ++i) {
            double[] pij_i = pij[i];
            for (int j = 0; j < pij_i.length; ++j) {
                sum += pij_i[j];
            }
        }
        double scale = initialScale / (2.0 * sum);
        for (int i = 0; i < pij.length; ++i) {
            double[] pij_i = pij[i];
            for (int offi = 0; offi < pij_i.length; ++offi) {
                int j = indices[i][offi];
                assert (i != j);
                int offj = NearestNeighborAffinityMatrixBuilder.containsIndex(indices[j], i);
                if (offj >= 0) {
                    assert (indices[j][offj] == i);
                    if (i >= j) continue;
                    double val = pij_i[offi] + pij[j][offj];
                    double d = MathUtil.max((double)(val * scale), (double)1.0E-12);
                    pij[j][offj] = d;
                    pij_i[offi] = d;
                    continue;
                }
                pij_i[offi] = MathUtil.max((double)(pij_i[offi] * scale), (double)1.0E-12);
            }
        }
        LOG.statistics((Statistic)timer.end());
        if (mv != null && LOG.isStatistics()) {
            LOG.statistics((Statistic)new DoubleStatistic(NearestNeighborAffinityMatrixBuilder.class.getName() + ".sigma.average", mv.getMean()));
            LOG.statistics((Statistic)new DoubleStatistic(NearestNeighborAffinityMatrixBuilder.class.getName() + ".sigma.stddev", mv.getSampleStddev()));
        }
    }

    protected void convertNeighbors(DBIDRange ids, DBIDRef ix, boolean square, KNNList neighbours, DoubleArray dist, IntegerArray ind) {
        DoubleDBIDListIter iter = neighbours.iter();
        while (iter.valid()) {
            if (!DBIDUtil.equal((DBIDRef)iter, (DBIDRef)ix)) {
                double d = iter.doubleValue();
                dist.add(square ? d * d : d);
                ind.add(ids.getOffset((DBIDRef)iter));
            }
            iter.advance();
        }
    }

    protected static double computeSigma(int i, DoubleArray pij_row, double perplexity, double log_perp, double[] pij_i) {
        double max = pij_row.get((int)FastMath.ceil((double)perplexity)) / Math.E;
        double beta = 1.0 / max;
        double diff = NearestNeighborAffinityMatrixBuilder.computeH(pij_row, pij_i, -beta) - log_perp;
        double betaMin = 0.0;
        double betaMax = Double.POSITIVE_INFINITY;
        for (int tries = 0; tries < 50 && Math.abs(diff) > 1.0E-5; ++tries) {
            if (diff > 0.0) {
                betaMin = beta;
                beta += betaMax == Double.POSITIVE_INFINITY ? beta : (betaMax - beta) * 0.5;
            } else {
                betaMax = beta;
                beta -= (beta - betaMin) * 0.5;
            }
            diff = NearestNeighborAffinityMatrixBuilder.computeH(pij_row, pij_i, -beta) - log_perp;
        }
        return beta;
    }

    protected static double computeH(DoubleArray dist_i, double[] pij_row, double mbeta) {
        int len = dist_i.size();
        assert (pij_row.length == len);
        double sumP = 0.0;
        for (int j = 0; j < len; ++j) {
            pij_row[j] = FastMath.exp((double)(dist_i.get(j) * mbeta));
            sumP += pij_row[j];
        }
        if (!(sumP > 0.0)) {
            return Double.NEGATIVE_INFINITY;
        }
        double s = 1.0 / sumP;
        double sum = 0.0;
        for (int j = 0; j < len; ++j) {
            int n = j;
            double d = pij_row[n] * s;
            pij_row[n] = d;
            sum += dist_i.get(j) * d;
        }
        return FastMath.log((double)sumP) - mbeta * sum;
    }

    protected static int containsIndex(int[] is, int i) {
        for (int j = 0; j < is.length; ++j) {
            if (i != is[j]) continue;
            return j;
        }
        return -1;
    }

    public static class Par<O>
    extends PerplexityAffinityMatrixBuilder.Par<O> {
        @Override
        public NearestNeighborAffinityMatrixBuilder<O> make() {
            return new NearestNeighborAffinityMatrixBuilder(this.distance, this.perplexity);
        }
    }
}

