/*
 * Decompiled with CFR 0.152.
 */
package elki.datasource.filter.transform;

import elki.data.DoubleVector;
import elki.data.NumberVector;
import elki.data.type.SimpleTypeInformation;
import elki.data.type.TypeInformation;
import elki.data.type.VectorFieldTypeInformation;
import elki.datasource.bundle.MultipleObjectsBundle;
import elki.datasource.filter.ObjectFilter;
import elki.distance.PrimitiveDistance;
import elki.distance.minkowski.SquaredEuclideanDistance;
import elki.logging.Logging;
import elki.logging.progress.FiniteProgress;
import elki.logging.progress.StepProgress;
import elki.math.linearalgebra.SingularValueDecomposition;
import elki.utilities.Alias;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import java.util.List;

@Alias(value={"mds"})
public class ClassicMultidimensionalScalingTransform<I, O extends NumberVector>
implements ObjectFilter {
    private static final Logging LOG = Logging.getLogger(ClassicMultidimensionalScalingTransform.class);
    PrimitiveDistance<? super I> dist = null;
    int tdim;
    NumberVector.Factory<O> factory;

    public ClassicMultidimensionalScalingTransform(int tdim, PrimitiveDistance<? super I> dist, NumberVector.Factory<O> factory) {
        this.tdim = tdim;
        this.dist = dist;
        this.factory = factory;
    }

    public MultipleObjectsBundle filter(MultipleObjectsBundle objects) {
        int size = objects.dataLength();
        if (size == 0) {
            return objects;
        }
        MultipleObjectsBundle bundle = new MultipleObjectsBundle();
        for (int r = 0; r < objects.metaLength(); ++r) {
            SimpleTypeInformation type = objects.meta(r);
            List column = objects.getColumn(r);
            if (!this.dist.getInputTypeRestriction().isAssignableFromType((TypeInformation)type)) {
                bundle.appendColumn(type, column);
                continue;
            }
            List castColumn = column;
            bundle.appendColumn((SimpleTypeInformation)new VectorFieldTypeInformation(this.factory, this.tdim), castColumn);
            StepProgress prog = LOG.isVerbose() ? new StepProgress("Classic MDS", 2) : null;
            LOG.beginStep(prog, 1, "Computing distance matrix");
            double[][] mat = ClassicMultidimensionalScalingTransform.computeSquaredDistanceMatrix(castColumn, this.dist);
            ClassicMultidimensionalScalingTransform.doubleCenterSymmetric(mat);
            LOG.beginStep(prog, 2, "Computing singular value decomposition");
            SingularValueDecomposition svd = new SingularValueDecomposition(mat);
            double[][] u = svd.getU();
            double[] lambda = svd.getSingularValues();
            if (!this.dist.isSquared()) {
                for (int i = 0; i < this.tdim; ++i) {
                    lambda[i] = Math.sqrt(Math.abs(lambda[i]));
                }
            }
            double[] buf = new double[this.tdim];
            for (int i = 0; i < size; ++i) {
                double[] row = u[i];
                for (int x = 0; x < buf.length; ++x) {
                    buf[x] = lambda[x] * row[x];
                }
                column.set(i, this.factory.newNumberVector(buf));
            }
            LOG.setCompleted(prog);
        }
        return bundle;
    }

    protected static <I> double[][] computeSquaredDistanceMatrix(List<I> col, PrimitiveDistance<? super I> dist) {
        int size = col.size();
        double[][] imat = new double[size][size];
        boolean squared = dist.isSquared();
        FiniteProgress dprog = LOG.isVerbose() ? new FiniteProgress("Computing distance matrix", size * (size - 1) >>> 1, LOG) : null;
        for (int x = 0; x < size; ++x) {
            I ox = col.get(x);
            for (int y = x + 1; y < size; ++y) {
                I oy = col.get(y);
                double distance = dist.distance(ox, oy);
                double d = distance *= squared ? -0.5 : -0.5 * distance;
                imat[y][x] = d;
                imat[x][y] = d;
            }
            if (dprog == null) continue;
            dprog.setProcessed(dprog.getProcessed() + size - x - 1, LOG);
        }
        LOG.ensureCompleted(dprog);
        return imat;
    }

    public static void doubleCenterSymmetric(double[][] m) {
        int x;
        int size = m.length;
        double[] means = new double[size];
        for (int x2 = 0; x2 < m.length; ++x2) {
            double[] rowx = m[x2];
            double rmean = means[x2] - means[x2] / (double)(x2 + 1);
            int y = x2 + 1;
            while (y < rowx.length) {
                double nv = rowx[y];
                double dx = nv - rmean;
                double dy = nv - means[y];
                rmean += dx / (double)(y + 1);
                int n = y++;
                means[n] = means[n] + dy / (double)(x2 + 1);
            }
            means[x2] = rmean;
        }
        double mean = means[0];
        for (x = 1; x < size; ++x) {
            double dm = means[x] - mean;
            mean += dm / (double)(x + 1);
        }
        for (x = 0; x < size; ++x) {
            m[x][x] = -2.0 * means[x] + mean;
            for (int y = x + 1; y < size; ++y) {
                double nv;
                m[x][y] = nv = m[x][y] - means[x] - means[y] + mean;
                m[y][x] = nv;
            }
        }
    }

    public static class Par<I, O extends NumberVector>
    implements Parameterizer {
        public static final OptionID DIM_ID = new OptionID("mds.dim", "Output dimensionality.");
        public static final OptionID DISTANCE_ID = new OptionID("mds.distance", "Distance function to use.");
        public static final OptionID VECTOR_TYPE_ID = new OptionID("mds.vector-type", "The type of vectors to create.");
        int tdim;
        PrimitiveDistance<? super I> dist = null;
        NumberVector.Factory<O> factory;

        public void configure(Parameterization config) {
            new IntParameter(DIM_ID).grab(config, x -> {
                this.tdim = x;
            });
            new ObjectParameter(DISTANCE_ID, PrimitiveDistance.class, SquaredEuclideanDistance.class).grab(config, x -> {
                this.dist = x;
            });
            new ObjectParameter(VECTOR_TYPE_ID, NumberVector.Factory.class, DoubleVector.Factory.class).grab(config, x -> {
                this.factory = x;
            });
        }

        public ClassicMultidimensionalScalingTransform<I, O> make() {
            return new ClassicMultidimensionalScalingTransform<I, O>(this.tdim, this.dist, this.factory);
        }
    }
}

