/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.vqsr;

import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.VariantContextBuilder;
import htsjdk.variant.variantcontext.writer.VariantContextWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang.ArrayUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.engine.FeatureContext;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.vqsr.TrainingSet;
import org.broadinstitute.hellbender.tools.walkers.vqsr.VariantDatum;
import org.broadinstitute.hellbender.tools.walkers.vqsr.VariantRecalibratorArgumentCollection;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.variant.GATKVariantContextUtils;

public class VariantDataManager {
    private List<VariantDatum> data = Collections.emptyList();
    private double[] meanVector;
    private double[] varianceVector;
    public List<String> annotationKeys;
    private final VariantRecalibratorArgumentCollection VRAC;
    protected static final Logger logger = LogManager.getLogger(VariantDataManager.class);
    protected final List<TrainingSet> trainingSets;
    private static final double SAFETY_OFFSET = 0.01;
    private static final double PRECISION = 0.01;

    public VariantDataManager(List<String> annotationKeys, VariantRecalibratorArgumentCollection VRAC) {
        List uniqueAnnotations = annotationKeys.stream().distinct().collect(Collectors.toList());
        if (annotationKeys.size() != uniqueAnnotations.size()) {
            logger.warn("Ignoring duplicate annotations for recalibration %s.", new Object[]{Utils.getDuplicatedItems(annotationKeys)});
        }
        this.annotationKeys = new ArrayList(uniqueAnnotations);
        this.VRAC = VRAC;
        this.meanVector = new double[this.annotationKeys.size()];
        this.varianceVector = new double[this.annotationKeys.size()];
        this.trainingSets = new ArrayList<TrainingSet>();
    }

    public void setData(List<VariantDatum> data) {
        this.data = data;
    }

    public void setNormalization(Map<String, Double> anMeans, Map<String, Double> anStdDevs) {
        for (int i = 0; i < this.annotationKeys.size(); ++i) {
            this.meanVector[i] = anMeans.get(this.annotationKeys.get(i));
            this.varianceVector[i] = anStdDevs.get(this.annotationKeys.get(i));
        }
    }

    public List<VariantDatum> getData() {
        return this.data;
    }

    public void normalizeData(boolean calculateMeans, List<Integer> theOrder) {
        boolean foundZeroVarianceAnnotation = false;
        for (int iii = 0; iii < this.meanVector.length; ++iii) {
            double theSTD;
            double theMean;
            if (calculateMeans) {
                theMean = this.mean(iii, true);
                theSTD = this.standardDeviation(theMean, iii, true);
                if (Double.isNaN(theMean)) {
                    throw new UserException.BadInput("Values for " + this.annotationKeys.get(iii) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations.");
                }
                foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || theSTD < 1.0E-5;
                this.meanVector[iii] = theMean;
                this.varianceVector[iii] = theSTD;
            } else {
                theMean = this.meanVector[iii];
                theSTD = this.varianceVector[iii];
            }
            logger.info(this.annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD));
            for (VariantDatum datum : this.data) {
                datum.annotations[iii] = datum.isNull[iii] ? 0.1 * Utils.getRandomGenerator().nextGaussian() : (datum.annotations[iii] - theMean) / theSTD;
            }
        }
        if (foundZeroVarianceAnnotation) {
            throw new UserException.BadInput("Found annotations with zero variance. They must be excluded before proceeding.");
        }
        for (VariantDatum datum : this.data) {
            boolean remove = false;
            for (double val : datum.annotations) {
                remove = remove || Math.abs(val) > this.VRAC.STD_THRESHOLD;
            }
            datum.failingSTDThreshold = remove;
        }
        if (theOrder == null) {
            theOrder = this.calculateSortOrder(this.meanVector);
        }
        this.annotationKeys = this.reorderList(this.annotationKeys, theOrder);
        this.varianceVector = ArrayUtils.toPrimitive((Double[])this.reorderArray(ArrayUtils.toObject((double[])this.varianceVector), theOrder));
        this.meanVector = ArrayUtils.toPrimitive((Double[])this.reorderArray(ArrayUtils.toObject((double[])this.meanVector), theOrder));
        for (VariantDatum datum : this.data) {
            datum.annotations = ArrayUtils.toPrimitive((Double[])this.reorderArray(ArrayUtils.toObject((double[])datum.annotations), theOrder));
            datum.isNull = ArrayUtils.toPrimitive((Boolean[])this.reorderArray(ArrayUtils.toObject((boolean[])datum.isNull), theOrder));
        }
        logger.info("Annotation order is: " + this.annotationKeys.toString());
    }

    public double[] getMeanVector() {
        return this.meanVector;
    }

    public double[] getVarianceVector() {
        return this.varianceVector;
    }

    protected List<Integer> calculateSortOrder(double[] inputVector) {
        ArrayList<Integer> theOrder = new ArrayList<Integer>(inputVector.length);
        ArrayList<MyDoubleForSorting> toBeSorted = new ArrayList<MyDoubleForSorting>(inputVector.length);
        int count = 0;
        for (int iii = 0; iii < inputVector.length; ++iii) {
            toBeSorted.add(new MyDoubleForSorting(-1.0 * Math.abs(inputVector[iii] - this.mean(iii, false)), count++));
        }
        Collections.sort(toBeSorted);
        for (MyDoubleForSorting d : toBeSorted) {
            theOrder.add(d.originalIndex);
        }
        return theOrder;
    }

    private <T> T[] reorderArray(T[] data, List<Integer> order) {
        return this.reorderList(Arrays.asList(data), order).toArray(data);
    }

    private <T> List<T> reorderList(List<T> data, List<Integer> order) {
        ArrayList<T> returnList = new ArrayList<T>(data.size());
        for (int index : order) {
            returnList.add(data.get(index));
        }
        return returnList;
    }

    public double denormalizeDatum(double normalizedValue, int annI) {
        double mu = this.meanVector[annI];
        double sigma = this.varianceVector[annI];
        return normalizedValue * sigma + mu;
    }

    public void addTrainingSet(TrainingSet trainingSet) {
        this.trainingSets.add(trainingSet);
    }

    public List<String> getAnnotationKeys() {
        return this.annotationKeys;
    }

    public boolean checkHasTrainingSet() {
        for (TrainingSet trainingSet : this.trainingSets) {
            if (!trainingSet.isTraining) continue;
            return true;
        }
        return false;
    }

    public boolean checkHasTruthSet() {
        for (TrainingSet trainingSet : this.trainingSets) {
            if (!trainingSet.isTruth) continue;
            return true;
        }
        return false;
    }

    public List<VariantDatum> getTrainingData() {
        ArrayList<VariantDatum> trainingData = new ArrayList<VariantDatum>();
        for (VariantDatum datum : this.data) {
            if (datum.atTrainingSite && !datum.failingSTDThreshold) {
                trainingData.add(datum);
                continue;
            }
            if (!datum.failingSTDThreshold || !this.VRAC.debugStdevThresholding) continue;
            logger.warn("Datum at " + datum.loc + " with ref " + datum.referenceAllele + " and alt " + datum.alternateAllele + " failing std thresholding: " + Arrays.toString(datum.annotations));
        }
        logger.info("Training with " + trainingData.size() + " variants after standard deviation thresholding.");
        if (trainingData.size() < this.VRAC.MIN_NUM_BAD_VARIANTS) {
            logger.warn("WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable.");
        } else if (trainingData.size() > this.VRAC.MAX_NUM_TRAINING_DATA) {
            logger.warn("WARNING: Very large training set detected. Downsampling to " + this.VRAC.MAX_NUM_TRAINING_DATA + " training variants.");
            Collections.shuffle(trainingData, Utils.getRandomGenerator());
            return trainingData.subList(0, this.VRAC.MAX_NUM_TRAINING_DATA);
        }
        return trainingData;
    }

    public List<VariantDatum> selectWorstVariants() {
        ArrayList<VariantDatum> trainingData = new ArrayList<VariantDatum>();
        for (VariantDatum datum : this.data) {
            if (datum == null || datum.failingSTDThreshold || Double.isInfinite(datum.lod) || !(datum.lod < this.VRAC.BAD_LOD_CUTOFF)) continue;
            datum.atAntiTrainingSite = true;
            trainingData.add(datum);
        }
        logger.info("Selected worst " + trainingData.size() + " scoring variants --> variants with LOD <= " + String.format("%.4f", this.VRAC.BAD_LOD_CUTOFF) + ".");
        return trainingData;
    }

    public List<VariantDatum> getEvaluationData() {
        ArrayList<VariantDatum> evaluationData = new ArrayList<VariantDatum>();
        for (VariantDatum datum : this.data) {
            if (datum == null || datum.failingSTDThreshold || datum.atTrainingSite || datum.atAntiTrainingSite) continue;
            evaluationData.add(datum);
        }
        return evaluationData;
    }

    public void dropAggregateData() {
        Iterator<VariantDatum> iter = this.data.iterator();
        while (iter.hasNext()) {
            VariantDatum datum = iter.next();
            if (!datum.isAggregate) continue;
            iter.remove();
        }
    }

    public List<VariantDatum> getRandomDataForPlotting(int numToAdd, List<VariantDatum> trainingData, List<VariantDatum> antiTrainingData, List<VariantDatum> evaluationData) {
        ArrayList<VariantDatum> returnData = new ArrayList<VariantDatum>();
        Collections.shuffle(trainingData, Utils.getRandomGenerator());
        Collections.shuffle(antiTrainingData, Utils.getRandomGenerator());
        Collections.shuffle(evaluationData, Utils.getRandomGenerator());
        returnData.addAll(trainingData.subList(0, Math.min(numToAdd, trainingData.size())));
        returnData.addAll(antiTrainingData.subList(0, Math.min(numToAdd, antiTrainingData.size())));
        returnData.addAll(evaluationData.subList(0, Math.min(numToAdd, evaluationData.size())));
        Collections.shuffle(returnData, Utils.getRandomGenerator());
        return returnData;
    }

    protected double mean(int index, boolean trainingData) {
        double sum = 0.0;
        int numNonNull = 0;
        for (VariantDatum datum : this.data) {
            if (trainingData != datum.atTrainingSite || datum.isNull[index]) continue;
            sum += datum.annotations[index];
            ++numNonNull;
        }
        return sum / (double)numNonNull;
    }

    protected double standardDeviation(double mean, int index, boolean trainingData) {
        double sum = 0.0;
        int numNonNull = 0;
        for (VariantDatum datum : this.data) {
            if (trainingData != datum.atTrainingSite || datum.isNull[index]) continue;
            sum += (datum.annotations[index] - mean) * (datum.annotations[index] - mean);
            ++numNonNull;
        }
        return Math.sqrt(sum / (double)numNonNull);
    }

    public void decodeAnnotations(VariantDatum datum, VariantContext vc, boolean jitter) {
        double[] annotations = new double[this.annotationKeys.size()];
        boolean[] isNull = new boolean[this.annotationKeys.size()];
        int iii = 0;
        for (String key : this.annotationKeys) {
            isNull[iii] = false;
            annotations[iii] = VariantDataManager.decodeAnnotation(key, vc, jitter, this.VRAC, datum);
            if (Double.isNaN(annotations[iii])) {
                isNull[iii] = true;
            }
            ++iii;
        }
        datum.annotations = annotations;
        datum.isNull = isNull;
    }

    private static double logitTransform(double x, double xmin, double xmax) {
        return Math.log((x - xmin) / (xmax - x));
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private static double decodeAnnotation(String annotationKey, VariantContext vc, boolean jitter, VariantRecalibratorArgumentCollection vrac, VariantDatum datum) {
        double LOG_OF_TWO = 0.6931472;
        try {
            double value;
            if (vrac.useASannotations && annotationKey.startsWith("AS_")) {
                List valueList = vc.getAttributeAsList(annotationKey);
                if (!vc.hasAllele(datum.alternateAllele)) throw new IllegalStateException("VariantDatum allele " + datum.alternateAllele + " is not contained in the input VariantContext.");
                int altIndex = vc.getAlleleIndex(datum.alternateAllele) - 1;
                value = Double.parseDouble((String)valueList.get(altIndex));
            } else {
                value = vc.getAttributeAsDouble(annotationKey, Double.NaN);
            }
            if (Double.isInfinite(value)) {
                value = Double.NaN;
            }
            if (jitter && annotationKey.equalsIgnoreCase("HaplotypeScore") && MathUtils.compareDoubles(value, 0.0, 0.01) == 0) {
                value += 0.01 * Utils.getRandomGenerator().nextGaussian();
            }
            if (jitter && (annotationKey.equalsIgnoreCase("FS") || annotationKey.equalsIgnoreCase("AS_FilterStatus")) && MathUtils.compareDoubles(value, 0.0, 0.01) == 0) {
                value += 0.01 * Utils.getRandomGenerator().nextGaussian();
            }
            if (jitter && annotationKey.equalsIgnoreCase("InbreedingCoeff") && MathUtils.compareDoubles(value, 0.0, 0.01) == 0) {
                value += 0.01 * Utils.getRandomGenerator().nextGaussian();
            }
            if (jitter && (annotationKey.equalsIgnoreCase("SOR") || annotationKey.equalsIgnoreCase("AS_SOR")) && MathUtils.compareDoubles(value, 0.6931472, 0.01) == 0) {
                value += 0.01 * Utils.getRandomGenerator().nextGaussian();
            }
            if (jitter && annotationKey.equalsIgnoreCase("MQ")) {
                if (vrac.MQ_CAP > 0) {
                    if (MathUtils.compareDoubles(value = VariantDataManager.logitTransform(value, -0.01, (double)vrac.MQ_CAP + 0.01), VariantDataManager.logitTransform(vrac.MQ_CAP, -0.01, (double)vrac.MQ_CAP + 0.01), 0.01) == 0) {
                        value += vrac.MQ_JITTER * Utils.getRandomGenerator().nextGaussian();
                    }
                } else if (MathUtils.compareDoubles(value, vrac.MQ_CAP, 0.01) == 0) {
                    value += vrac.MQ_JITTER * Utils.getRandomGenerator().nextGaussian();
                }
            }
            if (!jitter) return value;
            if (!annotationKey.equalsIgnoreCase("AS_MQ")) return value;
            value += vrac.MQ_JITTER * Utils.getRandomGenerator().nextGaussian();
            return value;
        }
        catch (NumberFormatException e) {
            return Double.NaN;
        }
    }

    public void parseTrainingSets(FeatureContext featureContext, VariantContext evalVC, VariantDatum datum, boolean TRUST_ALL_POLYMORPHIC) {
        datum.isKnown = false;
        datum.atTruthSite = false;
        datum.atTrainingSite = false;
        datum.atAntiTrainingSite = false;
        datum.prior = 2.0;
        for (TrainingSet trainingSet : this.trainingSets) {
            List<VariantContext> vcs = featureContext.getValues(trainingSet.variantSource, featureContext.getInterval().getStart());
            for (VariantContext trainVC : vcs) {
                if (this.VRAC.useASannotations && !this.doAllelesMatch(trainVC, datum)) continue;
                if (this.isValidVariant(evalVC, trainVC, TRUST_ALL_POLYMORPHIC)) {
                    datum.isKnown = datum.isKnown || trainingSet.isKnown;
                    datum.atTruthSite = datum.atTruthSite || trainingSet.isTruth;
                    datum.atTrainingSite = datum.atTrainingSite || trainingSet.isTraining;
                    datum.prior = Math.max(datum.prior, trainingSet.prior);
                }
                if (trainVC == null) continue;
                datum.atAntiTrainingSite = datum.atAntiTrainingSite || trainingSet.isAntiTraining;
            }
        }
    }

    private boolean isValidVariant(VariantContext evalVC, VariantContext trainVC, boolean TRUST_ALL_POLYMORPHIC) {
        return trainVC != null && trainVC.isNotFiltered() && trainVC.isVariant() && VariantDataManager.checkVariationClass(evalVC, trainVC) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphicInSamples());
    }

    private boolean doAllelesMatch(VariantContext trainVC, VariantDatum datum) {
        if (datum.alternateAllele == null) {
            return true;
        }
        try {
            return GATKVariantContextUtils.isAlleleInList(datum.referenceAllele, datum.alternateAllele, trainVC.getReference(), trainVC.getAlternateAlleles());
        }
        catch (IllegalStateException e) {
            throw new IllegalStateException("Reference allele mismatch at position " + trainVC.getContig() + ":" + trainVC.getStart() + " : ", e);
        }
    }

    protected static boolean checkVariationClass(VariantContext evalVC, VariantContext trainVC) {
        switch (trainVC.getType()) {
            case SNP: 
            case MNP: {
                return VariantDataManager.checkVariationClass(evalVC, VariantRecalibratorArgumentCollection.Mode.SNP);
            }
            case INDEL: 
            case MIXED: 
            case SYMBOLIC: {
                return VariantDataManager.checkVariationClass(evalVC, VariantRecalibratorArgumentCollection.Mode.INDEL);
            }
        }
        return false;
    }

    protected static boolean checkVariationClass(VariantContext evalVC, VariantRecalibratorArgumentCollection.Mode mode) {
        switch (mode) {
            case SNP: {
                return evalVC.isSNP() || evalVC.isMNP();
            }
            case INDEL: {
                return evalVC.isStructuralIndel() || evalVC.isIndel() || evalVC.isMixed() || evalVC.isSymbolic();
            }
            case BOTH: {
                return true;
            }
        }
        throw new IllegalStateException("Encountered unknown recal mode: " + (Object)((Object)mode));
    }

    protected static boolean checkVariationClass(VariantContext evalVC, Allele allele, VariantRecalibratorArgumentCollection.Mode mode) {
        switch (mode) {
            case SNP: {
                return evalVC.getReference().length() == allele.length();
            }
            case INDEL: {
                return evalVC.getReference().length() != allele.length() || allele.isSymbolic();
            }
            case BOTH: {
                return true;
            }
        }
        throw new IllegalStateException("Encountered unknown recal mode: " + (Object)((Object)mode));
    }

    public void writeOutRecalibrationTable(VariantContextWriter recalWriter, SAMSequenceDictionary seqDictionary) {
        Collections.sort(this.data, VariantDatum.getComparator(seqDictionary));
        List<Allele> alleles = Arrays.asList(Allele.create((String)"N", (boolean)true), Allele.create((String)"<VQSR>", (boolean)false));
        for (VariantDatum datum : this.data) {
            if (this.VRAC.useASannotations) {
                alleles = Arrays.asList(datum.referenceAllele, datum.alternateAllele);
            }
            VariantContextBuilder builder = new VariantContextBuilder("VQSR", datum.loc.getContig(), (long)datum.loc.getStart(), (long)datum.loc.getEnd(), alleles);
            builder.attribute("END", (Object)datum.loc.getEnd());
            builder.attribute("VQSLOD", (Object)String.format("%.4f", datum.lod));
            builder.attribute("culprit", (Object)(datum.worstAnnotation != -1 ? this.annotationKeys.get(datum.worstAnnotation) : "NULL"));
            if (datum.atTrainingSite) {
                builder.attribute("POSITIVE_TRAIN_SITE", (Object)true);
            }
            if (datum.atAntiTrainingSite) {
                builder.attribute("NEGATIVE_TRAIN_SITE", (Object)true);
            }
            recalWriter.add(builder.make());
        }
    }

    private class MyDoubleForSorting
    implements Comparable<MyDoubleForSorting> {
        final Double myData;
        final int originalIndex;

        public MyDoubleForSorting(double myData, int originalIndex) {
            this.myData = myData;
            this.originalIndex = originalIndex;
        }

        @Override
        public int compareTo(MyDoubleForSorting other) {
            return this.myData.compareTo(other.myData);
        }
    }
}

