/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.copynumber.utils.segmentation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.DefaultRealMatrixChangingVisitor;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.RandomGeneratorFactory;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.copynumber.utils.optimization.PersistenceOptimizer;
import org.broadinstitute.hellbender.utils.IndexRange;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

public final class KernelSegmenter<DATA> {
    private static final Logger logger = LogManager.getLogger(KernelSegmenter.class);
    private static final int RANDOM_SEED = 1216;
    private static final double EPSILON = 1.0E-10;
    private final List<DATA> data;

    public KernelSegmenter(List<DATA> data) {
        this.data = Collections.unmodifiableList(new ArrayList(Utils.nonNull(data)));
    }

    public List<Integer> findChangepoints(int maxNumChangepoints, BiFunction<DATA, DATA, Double> kernel, int kernelApproximationDimension, List<Integer> windowSizes, double numChangepointsPenaltyLinearFactor, double numChangepointsPenaltyLogLinearFactor, ChangepointSortOrder changepointSortOrder) {
        ParamUtils.isPositiveOrZero(maxNumChangepoints, "Maximum number of changepoints must be non-negative.");
        ParamUtils.isPositive(kernelApproximationDimension, "Dimension of kernel approximation must be positive.");
        Utils.validateArg(!windowSizes.isEmpty(), "At least one window size must be provided.");
        Utils.validateArg(windowSizes.stream().allMatch(ws -> ws > 0), "Window sizes must all be positive.");
        Utils.validateArg(windowSizes.stream().distinct().count() == (long)windowSizes.size(), "Window sizes must all be unique.");
        ParamUtils.isPositiveOrZero(numChangepointsPenaltyLinearFactor, "Linear factor for the penalty on the number of changepoints per chromosome must be non-negative.");
        ParamUtils.isPositiveOrZero(numChangepointsPenaltyLogLinearFactor, "Log-linear factor for the penalty on the number of changepoints per chromosome must be non-negative.");
        if (maxNumChangepoints == 0) {
            logger.warn("No changepoints were requested, returning an empty list...");
            return Collections.emptyList();
        }
        if (this.data.isEmpty()) {
            logger.warn("No data points were provided, returning an empty list...");
            return Collections.emptyList();
        }
        logger.debug(String.format("Finding up to %d changepoints in %d data points...", maxNumChangepoints, this.data.size()));
        RandomGenerator rng = RandomGeneratorFactory.createRandomGenerator((Random)new Random(1216L));
        logger.debug("Calculating low-rank approximation to kernel matrix...");
        RealMatrix reducedObservationMatrix = KernelSegmenter.calculateReducedObservationMatrix(rng, this.data, kernel, kernelApproximationDimension);
        double[] kernelApproximationDiagonal = KernelSegmenter.calculateKernelApproximationDiagonal(reducedObservationMatrix);
        logger.debug(String.format("Finding changepoint candidates for all window sizes %s...", windowSizes.toString()));
        List<Integer> changepointCandidates = KernelSegmenter.findChangepointCandidates(this.data, reducedObservationMatrix, kernelApproximationDiagonal, maxNumChangepoints, windowSizes);
        logger.debug("Performing backward model selection on changepoint candidates...");
        return KernelSegmenter.selectChangepoints(changepointCandidates, maxNumChangepoints, numChangepointsPenaltyLinearFactor, numChangepointsPenaltyLogLinearFactor, reducedObservationMatrix, kernelApproximationDiagonal).stream().sorted((a, b) -> changepointSortOrder.equals((Object)ChangepointSortOrder.INDEX) ? Integer.compare(a, b) : 0).collect(Collectors.toList());
    }

    private static <DATA> RealMatrix calculateReducedObservationMatrix(RandomGenerator rng, final List<DATA> data, final BiFunction<DATA, DATA, Double> kernel, int kernelApproximationDimension) {
        if (kernelApproximationDimension > data.size()) {
            logger.warn(String.format("Specified dimension of the kernel approximation (%d) exceeds the number of data points (%d) to segment; using all data points to calculate kernel matrix.", kernelApproximationDimension, data.size()));
        }
        int numSubsample = Math.min(kernelApproximationDimension, data.size());
        logger.debug(String.format("Subsampling %d points from data to find kernel approximation...", numSubsample));
        final List dataSubsample = numSubsample == data.size() ? data : IntStream.range(0, numSubsample).mapToObj(i -> data.get(rng.nextInt(data.size()))).collect(Collectors.toList());
        logger.debug(String.format("Calculating kernel matrix of subsampled data (%d x %d)...", numSubsample, numSubsample));
        Array2DRowRealMatrix subKernelMatrix = new Array2DRowRealMatrix(numSubsample, numSubsample);
        for (int i2 = 0; i2 < numSubsample; ++i2) {
            for (int j = 0; j < i2; ++j) {
                double value = kernel.apply(dataSubsample.get(i2), dataSubsample.get(j));
                subKernelMatrix.setEntry(i2, j, value);
                subKernelMatrix.setEntry(j, i2, value);
            }
            subKernelMatrix.setEntry(i2, i2, kernel.apply(dataSubsample.get(i2), dataSubsample.get(i2)).doubleValue());
        }
        logger.debug(String.format("Performing SVD of kernel matrix of subsampled data (%d x %d)...", numSubsample, numSubsample));
        final SingularValueDecomposition svd = new SingularValueDecomposition((RealMatrix)subKernelMatrix);
        logger.debug(String.format("Calculating reduced observation matrix (%d x %d)...", data.size(), numSubsample));
        final double[] invSqrtSingularValues = Arrays.stream(svd.getSingularValues()).map(Math::sqrt).map(x -> 1.0 / (x + 1.0E-10)).toArray();
        Array2DRowRealMatrix subKernelUMatrix = new Array2DRowRealMatrix(numSubsample, numSubsample);
        subKernelUMatrix.walkInOptimizedOrder((RealMatrixChangingVisitor)new DefaultRealMatrixChangingVisitor(){

            public double visit(int i, int j, double value) {
                return svd.getU().getEntry(i, j) * invSqrtSingularValues[j];
            }
        });
        Array2DRowRealMatrix reducedKernelMatrix = new Array2DRowRealMatrix(data.size(), numSubsample);
        reducedKernelMatrix.walkInOptimizedOrder((RealMatrixChangingVisitor)new DefaultRealMatrixChangingVisitor(){

            public double visit(int i, int j, double value) {
                return (Double)kernel.apply(data.get(i), dataSubsample.get(j));
            }
        });
        return reducedKernelMatrix.multiply((RealMatrix)subKernelUMatrix);
    }

    private static double[] calculateKernelApproximationDiagonal(RealMatrix reducedObservationMatrix) {
        return new IndexRange(0, reducedObservationMatrix.getRowDimension()).mapToDouble(i -> MathUtils.square(reducedObservationMatrix.getRowVector(i).getNorm()));
    }

    private static <DATA> List<Integer> findChangepointCandidates(List<DATA> data, RealMatrix reducedObservationMatrix, double[] kernelApproximationDiagonal, int maxNumChangepoints, List<Integer> windowSizes) {
        ArrayList<Integer> changepointCandidates = new ArrayList<Integer>(windowSizes.size() * maxNumChangepoints);
        for (int windowSize : windowSizes) {
            logger.debug(String.format("Calculating local changepoints costs for window size %d...", windowSize));
            if (2 * windowSize > data.size()) {
                logger.warn(String.format("Number of points needed to calculate local changepoint costs (2 * window size = %d) exceeds number of data points (%d).  Local changepoint costs will not be calculated for this window size.", 2 * windowSize, data.size()));
                continue;
            }
            double[] windowCosts = KernelSegmenter.calculateWindowCosts(reducedObservationMatrix, kernelApproximationDiagonal, windowSize);
            logger.debug(String.format("Finding local minima of local changepoint costs for window size %d...", windowSize));
            ArrayList<Integer> windowCostLocalMinima = new ArrayList<Integer>(new PersistenceOptimizer(windowCosts).getMinimaIndices());
            windowCostLocalMinima.remove((Object)0);
            windowCostLocalMinima.remove((Object)(data.size() - 1));
            changepointCandidates.addAll(windowCostLocalMinima.subList(0, Math.min(maxNumChangepoints, windowCostLocalMinima.size())));
        }
        if (changepointCandidates.isEmpty()) {
            logger.warn("No changepoint candidates were found.  The specified window sizes may be inappropriate, or there may be insufficient data points.");
        }
        return changepointCandidates;
    }

    private static List<Integer> selectChangepoints(List<Integer> changepointCandidates, int maxNumChangepoints, double numChangepointsPenaltyLinearFactor, double numChangepointsPenaltyLogLinearFactor, RealMatrix reducedObservationMatrix, double[] kernelApproximationDiagonal) {
        ArrayList<Integer> changepoints = new ArrayList<Integer>(changepointCandidates.size());
        int numData = reducedObservationMatrix.getRowDimension();
        List changepointPenalties = IntStream.range(0, maxNumChangepoints + 1).mapToObj(numChangepoints -> KernelSegmenter.calculateChangepointPenalty(numChangepoints, numChangepointsPenaltyLinearFactor, numChangepointsPenaltyLogLinearFactor, numData)).collect(Collectors.toList());
        List candidateStarts = changepointCandidates.stream().sorted().distinct().map(i -> Math.min(i + 1, numData - 1)).collect(Collectors.toList());
        candidateStarts.add(0, 0);
        List candidateEnds = changepointCandidates.stream().sorted().distinct().collect(Collectors.toList());
        candidateEnds.add(numData - 1);
        int numSegments = candidateStarts.size();
        List segments = IntStream.range(0, numSegments).mapToObj(i -> new Segment((Integer)candidateStarts.get(i), (Integer)candidateEnds.get(i), reducedObservationMatrix, kernelApproximationDiagonal)).collect(Collectors.toList());
        ArrayList<Double> totalSegmentationCosts = new ArrayList<Double>(Collections.singletonList(segments.stream().mapToDouble(s -> ((Segment)s).cost).sum()));
        List costsForSegmentPairs = IntStream.range(0, numSegments - 1).mapToObj(i -> ((Segment)segments.get(i)).cost + ((Segment)segments.get(i + 1)).cost).collect(Collectors.toList());
        List costsForMergedSegmentPairs = IntStream.range(0, numSegments - 1).mapToObj(i -> new Segment((Integer)candidateStarts.get(i), (Integer)candidateEnds.get(i + 1), reducedObservationMatrix, kernelApproximationDiagonal).cost).collect(Collectors.toList());
        List costsForMergingSegmentPairs = IntStream.range(0, numSegments - 1).mapToObj(i -> (Double)costsForSegmentPairs.get(i) - (Double)costsForMergedSegmentPairs.get(i)).collect(Collectors.toList());
        for (int i2 = 0; i2 < numSegments - 1; ++i2) {
            int indexOfLeftSegmentToMerge = costsForMergingSegmentPairs.indexOf(Collections.max(costsForMergingSegmentPairs));
            double newCost = (Double)costsForMergedSegmentPairs.get(indexOfLeftSegmentToMerge);
            int newStart = ((Segment)segments.get(indexOfLeftSegmentToMerge)).start;
            int mergepoint = ((Segment)segments.get(indexOfLeftSegmentToMerge)).end;
            int newEnd = ((Segment)segments.get(indexOfLeftSegmentToMerge + 1)).end;
            segments.remove(indexOfLeftSegmentToMerge);
            segments.remove(indexOfLeftSegmentToMerge);
            segments.add(indexOfLeftSegmentToMerge, new Segment(newStart, newEnd, newCost));
            costsForSegmentPairs.remove(indexOfLeftSegmentToMerge);
            costsForMergedSegmentPairs.remove(indexOfLeftSegmentToMerge);
            costsForMergingSegmentPairs.remove(indexOfLeftSegmentToMerge);
            if (indexOfLeftSegmentToMerge > 0) {
                costsForSegmentPairs.set(indexOfLeftSegmentToMerge - 1, ((Segment)segments.get(indexOfLeftSegmentToMerge - 1)).cost + ((Segment)segments.get(indexOfLeftSegmentToMerge)).cost);
                costsForMergedSegmentPairs.set(indexOfLeftSegmentToMerge - 1, new Segment(((Segment)segments.get(indexOfLeftSegmentToMerge - 1)).start, newEnd, reducedObservationMatrix, kernelApproximationDiagonal).cost);
                costsForMergingSegmentPairs.set(indexOfLeftSegmentToMerge - 1, (Double)costsForSegmentPairs.get(indexOfLeftSegmentToMerge - 1) - (Double)costsForMergedSegmentPairs.get(indexOfLeftSegmentToMerge - 1));
            }
            if (indexOfLeftSegmentToMerge < segments.size() - 1) {
                costsForSegmentPairs.set(indexOfLeftSegmentToMerge, ((Segment)segments.get(indexOfLeftSegmentToMerge)).cost + ((Segment)segments.get(indexOfLeftSegmentToMerge + 1)).cost);
                costsForMergedSegmentPairs.set(indexOfLeftSegmentToMerge, new Segment(newStart, ((Segment)segments.get(indexOfLeftSegmentToMerge + 1)).end, reducedObservationMatrix, kernelApproximationDiagonal).cost);
                costsForMergingSegmentPairs.set(indexOfLeftSegmentToMerge, (Double)costsForSegmentPairs.get(indexOfLeftSegmentToMerge) - (Double)costsForMergedSegmentPairs.get(indexOfLeftSegmentToMerge));
            }
            totalSegmentationCosts.add(0, segments.stream().mapToDouble(s -> ((Segment)s).cost).sum());
            changepoints.add(0, mergepoint);
        }
        int effectiveMaxNumChangepoints = Math.min(maxNumChangepoints, changepoints.size());
        List totalSegmentationCostsPlusPenalties = IntStream.range(0, effectiveMaxNumChangepoints + 1).mapToObj(i -> (Double)totalSegmentationCosts.get(i) + (Double)changepointPenalties.get(i)).collect(Collectors.toList());
        int numChangepointsOptimal = totalSegmentationCostsPlusPenalties.indexOf(Collections.min(totalSegmentationCostsPlusPenalties));
        logger.info(String.format("Found %d changepoints after applying the changepoint penalty.", numChangepointsOptimal));
        return changepoints.subList(0, numChangepointsOptimal);
    }

    private static double calculateChangepointPenalty(int numChangepoints, double numChangepointsPenaltyLinearFactor, double numChangepointsPenaltyLogLinearFactor, int numData) {
        return numChangepointsPenaltyLinearFactor * (double)numChangepoints + numChangepointsPenaltyLogLinearFactor * (double)numChangepoints * Math.log((double)numData / ((double)numChangepoints + 1.0E-10));
    }

    private static Cost calculateSegmentCost(int start, int end, RealMatrix reducedObservationMatrix, double[] kernelApproximationDiagonal) {
        int N = reducedObservationMatrix.getRowDimension();
        int p = reducedObservationMatrix.getColumnDimension();
        double D = kernelApproximationDiagonal[start];
        double[] W = Arrays.copyOf(reducedObservationMatrix.getRow(start), p);
        double V = Arrays.stream(W).map(w -> w * w).sum();
        List indices = start <= end ? IntStream.range(start + 1, end + 1).boxed().collect(Collectors.toList()) : IntStream.concat(IntStream.range(start + 1, N), IntStream.range(0, end + 1)).boxed().collect(Collectors.toList());
        Iterator iterator = indices.iterator();
        while (iterator.hasNext()) {
            int tauPrime = (Integer)iterator.next();
            D += kernelApproximationDiagonal[tauPrime];
            double ZdotW = 0.0;
            for (int j = 0; j < p; ++j) {
                ZdotW += reducedObservationMatrix.getEntry(tauPrime, j) * W[j];
                int n = j;
                W[n] = W[n] + reducedObservationMatrix.getEntry(tauPrime, j);
            }
            V += 2.0 * ZdotW + kernelApproximationDiagonal[tauPrime];
        }
        double C = D - V / (double)(indices.size() + 1);
        return new Cost(D, W, V, C);
    }

    private static double[] calculateWindowCosts(RealMatrix reducedObservationMatrix, double[] kernelApproximationDiagonal, int windowSize) {
        int N = reducedObservationMatrix.getRowDimension();
        int p = reducedObservationMatrix.getColumnDimension();
        int center = 0;
        int start = (center - windowSize + 1 + N) % N;
        int end = (center + windowSize) % N;
        Cost leftCost = KernelSegmenter.calculateSegmentCost(start, center, reducedObservationMatrix, kernelApproximationDiagonal);
        Cost rightCost = KernelSegmenter.calculateSegmentCost(center + 1, end, reducedObservationMatrix, kernelApproximationDiagonal);
        Cost totalCost = KernelSegmenter.calculateSegmentCost(start, end, reducedObservationMatrix, kernelApproximationDiagonal);
        double leftD = leftCost.D;
        double[] leftW = Arrays.copyOf(leftCost.W, p);
        double leftV = leftCost.V;
        double leftC = leftCost.C;
        double rightD = rightCost.D;
        double[] rightW = Arrays.copyOf(rightCost.W, p);
        double rightV = rightCost.V;
        double rightC = rightCost.C;
        double totalD = totalCost.D;
        double[] totalW = Arrays.copyOf(totalCost.W, p);
        double totalV = totalCost.V;
        double totalC = totalCost.C;
        double[] windowCosts = new double[N];
        windowCosts[center] = leftC + rightC - totalC;
        double windowSizeReciprocal = 1.0 / (double)windowSize;
        for (center = 0; center < N; ++center) {
            int j;
            int centerNext = (center + 1) % N;
            int endNext = (end + 1) % N;
            leftD -= kernelApproximationDiagonal[start];
            double ZdotW = 0.0;
            for (j = 0; j < p; ++j) {
                ZdotW += reducedObservationMatrix.getEntry(start, j) * leftW[j];
                int n = j;
                leftW[n] = leftW[n] - reducedObservationMatrix.getEntry(start, j);
            }
            leftV += -2.0 * ZdotW + kernelApproximationDiagonal[start];
            leftD += kernelApproximationDiagonal[centerNext];
            ZdotW = 0.0;
            for (j = 0; j < p; ++j) {
                ZdotW += reducedObservationMatrix.getEntry(centerNext, j) * leftW[j];
                int n = j;
                leftW[n] = leftW[n] + reducedObservationMatrix.getEntry(centerNext, j);
            }
            leftC = leftD - (leftV += 2.0 * ZdotW + kernelApproximationDiagonal[centerNext]) * windowSizeReciprocal;
            rightD -= kernelApproximationDiagonal[centerNext];
            ZdotW = 0.0;
            for (j = 0; j < p; ++j) {
                ZdotW += reducedObservationMatrix.getEntry(centerNext, j) * rightW[j];
                int n = j;
                rightW[n] = rightW[n] - reducedObservationMatrix.getEntry(centerNext, j);
            }
            rightV += -2.0 * ZdotW + kernelApproximationDiagonal[centerNext];
            rightD += kernelApproximationDiagonal[endNext];
            ZdotW = 0.0;
            for (j = 0; j < p; ++j) {
                ZdotW += reducedObservationMatrix.getEntry(endNext, j) * rightW[j];
                int n = j;
                rightW[n] = rightW[n] + reducedObservationMatrix.getEntry(endNext, j);
            }
            rightC = rightD - (rightV += 2.0 * ZdotW + kernelApproximationDiagonal[endNext]) * windowSizeReciprocal;
            totalD -= kernelApproximationDiagonal[start];
            ZdotW = 0.0;
            for (j = 0; j < p; ++j) {
                ZdotW += reducedObservationMatrix.getEntry(start, j) * totalW[j];
                int n = j;
                totalW[n] = totalW[n] - reducedObservationMatrix.getEntry(start, j);
            }
            totalV += -2.0 * ZdotW + kernelApproximationDiagonal[start];
            totalD += kernelApproximationDiagonal[endNext];
            ZdotW = 0.0;
            for (j = 0; j < p; ++j) {
                ZdotW += reducedObservationMatrix.getEntry(endNext, j) * totalW[j];
                int n = j;
                totalW[n] = totalW[n] + reducedObservationMatrix.getEntry(endNext, j);
            }
            totalC = totalD - 0.5 * (totalV += 2.0 * ZdotW + kernelApproximationDiagonal[endNext]) * windowSizeReciprocal;
            windowCosts[centerNext] = leftC + rightC - totalC;
            start = (start + 1) % N;
            end = endNext;
        }
        return windowCosts;
    }

    private static final class Cost {
        private final double D;
        private final double[] W;
        private final double V;
        private final double C;

        private Cost(double D, double[] W, double V, double C) {
            this.D = D;
            this.W = W;
            this.V = V;
            this.C = C;
        }
    }

    private static final class Segment {
        private final int start;
        private final int end;
        private final double cost;

        private Segment(int start, int end, double cost) {
            this.start = start;
            this.end = end;
            this.cost = cost;
        }

        private Segment(int start, int end, RealMatrix reducedObservationMatrix, double[] kernelApproximationDiagonal) {
            this(start, end, KernelSegmenter.calculateSegmentCost(start, end, reducedObservationMatrix, kernelApproximationDiagonal).C);
        }
    }

    public static enum ChangepointSortOrder {
        BACKWARD_SELECTION,
        INDEX;

    }
}

