/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.genotyper;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Ints;
import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.Genotype;
import htsjdk.variant.variantcontext.GenotypeBuilder;
import htsjdk.variant.variantcontext.GenotypeLikelihoods;
import htsjdk.variant.variantcontext.GenotypesContext;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.VariantContextBuilder;
import htsjdk.variant.variantcontext.VariantContextUtils;
import htsjdk.variant.vcf.VCFFormatHeaderLine;
import htsjdk.variant.vcf.VCFHeader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.ReferenceConfidenceVariantContextMerger;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeAlleleCounts;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeAssignmentMethod;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeLikelihoodCalculator;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeLikelihoodCalculators;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.genotyper.AlleleListPermutation;
import org.broadinstitute.hellbender.utils.genotyper.GenotypePriorCalculator;
import org.broadinstitute.hellbender.utils.genotyper.IndexedAlleleList;
import org.broadinstitute.hellbender.utils.variant.GATKVariantContextUtils;
import org.broadinstitute.hellbender.utils.variant.VariantContextGetters;

public final class AlleleSubsettingUtils {
    private static final int PL_INDEX_OF_HOM_REF = 0;
    public static final int NUM_OF_STRANDS = 2;
    private static final GenotypeLikelihoodCalculators GL_CALCS = new GenotypeLikelihoodCalculators();

    private AlleleSubsettingUtils() {
    }

    public static GenotypesContext subsetAlleles(GenotypesContext originalGs, int defaultPloidy, List<Allele> originalAlleles, List<Allele> allelesToKeep, GenotypePriorCalculator gpc, GenotypeAssignmentMethod assignmentMethod, int depth) {
        Utils.nonNull(originalGs, "original GenotypesContext must not be null");
        Utils.nonNull(allelesToKeep, "allelesToKeep is null");
        Utils.nonEmpty(allelesToKeep, "must keep at least one allele");
        Utils.validateArg(allelesToKeep.get(0).isReference(), "First allele must be the reference allele");
        GenotypesContext newGTs = GenotypesContext.create((int)originalGs.size());
        AlleleListPermutation<Allele> allelePermutation = new IndexedAlleleList<Allele>((Collection<Allele>)originalAlleles).permutation(new IndexedAlleleList<Allele>((Collection<Allele>)allelesToKeep));
        TreeMap<Integer, int[]> subsettedLikelihoodIndicesByPloidy = new TreeMap<Integer, int[]>();
        for (Genotype g : originalGs) {
            int ploidy;
            int n2 = ploidy = g.getPloidy() > 0 ? g.getPloidy() : defaultPloidy;
            if (!subsettedLikelihoodIndicesByPloidy.containsKey(ploidy)) {
                subsettedLikelihoodIndicesByPloidy.put(ploidy, AlleleSubsettingUtils.subsettedPLIndices(ploidy, originalAlleles, allelesToKeep));
            }
            int[] subsettedLikelihoodIndices = (int[])subsettedLikelihoodIndicesByPloidy.get(ploidy);
            int expectedNumLikelihoods = GenotypeLikelihoods.numLikelihoods((int)originalAlleles.size(), (int)ploidy);
            double[] newLikelihoods = null;
            double newLog10GQ = -1.0;
            if (g.hasLikelihoods()) {
                double[] originalLikelihoods = g.getLikelihoods().getAsVector();
                double[] dArray = newLikelihoods = originalLikelihoods.length == expectedNumLikelihoods ? MathUtils.scaleLogSpaceArrayForNumericalStability(Arrays.stream(subsettedLikelihoodIndices).mapToDouble(idx -> originalLikelihoods[idx]).toArray()) : null;
                if (newLikelihoods != null) {
                    int PLindex = MathUtils.maxElementIndex(newLikelihoods);
                    newLog10GQ = GenotypeLikelihoods.getGQLog10FromLikelihoods((int)PLindex, (double[])newLikelihoods);
                }
            }
            boolean useNewLikelihoods = newLikelihoods != null && (depth != 0 || GATKVariantContextUtils.isInformative(newLikelihoods));
            GenotypeBuilder gb = new GenotypeBuilder(g);
            if (useNewLikelihoods) {
                HashMap attributes = new HashMap(g.getExtendedAttributes());
                gb.PL(newLikelihoods).log10PError(newLog10GQ);
                attributes.remove("PP");
                gb.noAttributes().attributes(attributes);
            } else {
                gb.noPL().noGQ();
            }
            GATKVariantContextUtils.makeGenotypeCall(g.getPloidy(), gb, assignmentMethod, newLikelihoods, allelesToKeep, g.getAlleles(), gpc);
            if (g.hasExtendedAttribute("SAC")) {
                int[] newSACs = AlleleSubsettingUtils.subsetSACAlleles(g, originalAlleles, allelesToKeep);
                gb.attribute("SAC", (Object)newSACs);
            }
            if (g.hasAD()) {
                int[] oldAD = g.getAD();
                int[] newAD = IntStream.range(0, allelesToKeep.size()).map(n -> oldAD[allelePermutation.fromIndex(n)]).toArray();
                int nonRefIndex = allelesToKeep.indexOf(Allele.NON_REF_ALLELE);
                if (nonRefIndex != -1 && nonRefIndex < newAD.length) {
                    newAD[nonRefIndex] = 0;
                }
                gb.AD(newAD);
            }
            newGTs.add(gb.make());
        }
        return newGTs;
    }

    public static GenotypesContext subsetSomaticAlleles(VCFHeader outputHeader, GenotypesContext originalGs, List<Allele> allelesToKeep, int[] relevantIndices) {
        GenotypesContext newGTs = GenotypesContext.create((int)originalGs.size());
        for (Genotype g : originalGs) {
            GenotypeBuilder gb = new GenotypeBuilder(g);
            gb.noAttributes();
            ArrayList<Allele> keepGTAlleles = new ArrayList<Allele>(g.getAlleles());
            for (Allele a : keepGTAlleles) {
                if (allelesToKeep.contains(a)) continue;
                keepGTAlleles.set(keepGTAlleles.indexOf(a), Allele.NO_CALL);
            }
            gb.alleles(keepGTAlleles);
            gb.AD(AlleleSubsettingUtils.generateAD(g.getAD(), relevantIndices));
            Set keys = g.getExtendedAttributes().keySet();
            for (String key : keys) {
                VCFFormatHeaderLine headerLine = outputHeader.getFormatHeaderLine(key);
                gb.attribute(key, ReferenceConfidenceVariantContextMerger.generateAnnotationValueVector(headerLine.getCountType(), VariantContextGetters.attributeToList(g.getAnyAttribute(key)), relevantIndices));
            }
            newGTs.add(gb.make());
        }
        return newGTs;
    }

    public static void addInfoFieldAnnotations(VariantContext vc, VariantContextBuilder builder, boolean keepOriginalChrCounts) {
        Utils.nonNull(vc);
        Utils.nonNull(builder);
        Utils.nonNull(builder.getAlleles());
        List alleles = builder.getAlleles();
        if (alleles.size() < 2) {
            throw new IllegalArgumentException("the variant context builder must contain at least 2 alleles");
        }
        boolean keepOriginal = vc.getAlleles().size() == alleles.size();
        List alleleIndices = builder.getAlleles().stream().map(arg_0 -> ((VariantContext)vc).getAlleleIndex(arg_0)).collect(Collectors.toList());
        if (keepOriginalChrCounts) {
            if (vc.hasAttribute("AC")) {
                builder.attribute("AC_Orig", keepOriginal ? vc.getAttribute("AC") : alleleIndices.stream().filter(i -> i > 0).map(j -> vc.getAttributeAsList("AC").get(j - 1)).collect(Collectors.toList()).get(0));
            }
            if (vc.hasAttribute("AF")) {
                builder.attribute("AF_Orig", keepOriginal ? vc.getAttribute("AF") : alleleIndices.stream().filter(i -> i > 0).map(j -> vc.getAttributeAsList("AF").get(j - 1)).collect(Collectors.toList()).get(0));
            }
            if (vc.hasAttribute("AN")) {
                builder.attribute("AN_Orig", vc.getAttribute("AN"));
            }
        }
        VariantContextUtils.calculateChromosomeCounts((VariantContextBuilder)builder, (boolean)true);
    }

    private static int[] subsetSACAlleles(Genotype g, List<Allele> originalAlleles, List<Allele> allelesToUse) {
        if (originalAlleles.size() == allelesToUse.size()) {
            return AlleleSubsettingUtils.getSACs(g);
        }
        return AlleleSubsettingUtils.makeNewSACs(g, originalAlleles, allelesToUse);
    }

    private static int[] makeNewSACs(Genotype g, List<Allele> originalAlleles, List<Allele> allelesToUse) {
        int[] oldSACs = AlleleSubsettingUtils.getSACs(g);
        int[] newSACs = new int[2 * allelesToUse.size()];
        int newIndex = 0;
        for (int alleleIndex = 0; alleleIndex < originalAlleles.size(); ++alleleIndex) {
            if (!allelesToUse.contains(originalAlleles.get(alleleIndex))) continue;
            newSACs[2 * newIndex] = oldSACs[2 * alleleIndex];
            newSACs[2 * newIndex + 1] = oldSACs[2 * alleleIndex + 1];
            ++newIndex;
        }
        return newSACs;
    }

    private static int[] getSACs(Genotype g) {
        if (!g.hasExtendedAttribute("SAC")) {
            throw new IllegalArgumentException("Genotype must have SAC");
        }
        Class<?> clazz = g.getExtendedAttributes().get("SAC").getClass();
        if (clazz.equals(String.class)) {
            String SACsString = (String)g.getExtendedAttributes().get("SAC");
            String[] stringSACs = SACsString.split(",");
            int[] intSACs = new int[stringSACs.length];
            int i = 0;
            for (String sac : stringSACs) {
                intSACs[i++] = Integer.parseInt(sac);
            }
            return intSACs;
        }
        if (clazz.equals(int[].class)) {
            return (int[])g.getExtendedAttributes().get("SAC");
        }
        throw new GATKException("Unexpected SAC type");
    }

    public static List<Allele> calculateMostLikelyAlleles(VariantContext vc, int defaultPloidy, int numAltAllelesToKeep) {
        Utils.nonNull(vc, "vc is null");
        Utils.validateArg(defaultPloidy > 0, () -> "default ploidy must be > 0 but defaultPloidy=" + defaultPloidy);
        Utils.validateArg(numAltAllelesToKeep > 0, () -> "numAltAllelesToKeep must be > 0, but numAltAllelesToKeep=" + numAltAllelesToKeep);
        boolean hasSymbolicNonRef = vc.hasAllele(Allele.NON_REF_ALLELE);
        int numberOfAllelesThatArentProperAlts = hasSymbolicNonRef ? 2 : 1;
        int numberOfProperAltAlleles = vc.getNAlleles() - numberOfAllelesThatArentProperAlts;
        if (numAltAllelesToKeep >= numberOfProperAltAlleles) {
            return vc.getAlleles();
        }
        double[] likelihoodSums = AlleleSubsettingUtils.calculateLikelihoodSums(vc, defaultPloidy);
        return AlleleSubsettingUtils.filterToMaxNumberOfAltAllelesBasedOnScores(numAltAllelesToKeep, vc.getAlleles(), likelihoodSums);
    }

    public static List<Allele> filterToMaxNumberOfAltAllelesBasedOnScores(int numAltAllelesToKeep, List<Allele> alleles, double[] likelihoodSums) {
        int nonRefAltAlleleIndex = alleles.indexOf(Allele.NON_REF_ALLELE);
        int numAlleles = alleles.size();
        Set properAltIndexesToKeep = IntStream.range(1, numAlleles).filter(n -> n != nonRefAltAlleleIndex).boxed().sorted(Comparator.comparingDouble(n -> likelihoodSums[n]).reversed()).limit(numAltAllelesToKeep).collect(Collectors.toSet());
        return IntStream.range(0, numAlleles).filter(i -> i == 0 || i == nonRefAltAlleleIndex || properAltIndexesToKeep.contains(i)).mapToObj(alleles::get).collect(Collectors.toList());
    }

    @VisibleForTesting
    static double[] calculateLikelihoodSums(VariantContext vc, int defaultPloidy) {
        double[] likelihoodSums = new double[vc.getNAlleles()];
        for (Genotype genotype : vc.getGenotypes().iterateInSampleNameOrder()) {
            GenotypeLikelihoods gls = genotype.getLikelihoods();
            if (gls == null) continue;
            double[] glsVector = gls.getAsVector();
            int indexOfMostLikelyGenotype = MathUtils.maxElementIndex(glsVector);
            double GLDiffBetweenRefAndBest = glsVector[indexOfMostLikelyGenotype] - glsVector[0];
            int ploidy = genotype.getPloidy() > 0 ? genotype.getPloidy() : defaultPloidy;
            int[] alleleCounts = new GenotypeLikelihoodCalculators().getInstance(ploidy, vc.getNAlleles()).genotypeAlleleCountsAt(indexOfMostLikelyGenotype).alleleCountsByIndex(vc.getNAlleles() - 1);
            for (int allele = 1; allele < alleleCounts.length; ++allele) {
                if (alleleCounts[allele] <= 0) continue;
                int n = allele;
                likelihoodSums[n] = likelihoodSums[n] + GLDiffBetweenRefAndBest;
            }
        }
        return likelihoodSums;
    }

    public static int[] subsettedPLIndices(int ploidy, List<Allele> originalAlleles, List<Allele> newAlleles) {
        int[] result = new int[GenotypeLikelihoods.numLikelihoods((int)newAlleles.size(), (int)ploidy)];
        AlleleListPermutation<Allele> allelePermutation = new IndexedAlleleList<Allele>((Collection<Allele>)originalAlleles).permutation(new IndexedAlleleList<Allele>((Collection<Allele>)newAlleles));
        GenotypeLikelihoodCalculator glCalc = GL_CALCS.getInstance(ploidy, originalAlleles.size());
        for (int oldPLIndex = 0; oldPLIndex < glCalc.genotypeCount(); ++oldPLIndex) {
            GenotypeAlleleCounts oldAlleleCounts = glCalc.genotypeAlleleCountsAt(oldPLIndex);
            boolean containsOnlyNewAlleles = IntStream.range(0, oldAlleleCounts.distinctAlleleCount()).map(oldAlleleCounts::alleleIndexAt).allMatch(allelePermutation::isKept);
            if (!containsOnlyNewAlleles) continue;
            int[] newAlleleCounts = IntStream.range(0, newAlleles.size()).flatMap(newAlleleIndex -> IntStream.of(newAlleleIndex, oldAlleleCounts.alleleCountFor(allelePermutation.fromIndex(newAlleleIndex)))).toArray();
            int newPLIndex = glCalc.alleleCountsToIndex(newAlleleCounts);
            result[newPLIndex] = oldPLIndex;
        }
        return result;
    }

    public static int[] getIndexesOfRelevantAllelesForGVCF(List<Allele> remappedAlleles, List<Allele> targetAlleles, int position, Genotype g, boolean doSomaticMerge) {
        Utils.nonEmpty(remappedAlleles);
        Utils.nonEmpty(targetAlleles);
        if (!remappedAlleles.contains(Allele.NON_REF_ALLELE)) {
            throw new UserException("The list of input alleles must contain " + Allele.NON_REF_ALLELE + " as an allele but that is not the case at position " + position + "; please use the Haplotype Caller with gVCF output to generate appropriate records");
        }
        int indexOfNonRef = remappedAlleles.indexOf(Allele.NON_REF_ALLELE);
        int[] indexMapping = new int[targetAlleles.size()];
        indexMapping[0] = 0;
        for (int i = 1; i < targetAlleles.size(); ++i) {
            int indexOfRemappedAllele;
            int indexOfBestDel;
            int occurrences;
            indexMapping[i] = targetAlleles.get(i) == Allele.SPAN_DEL && !doSomaticMerge && g.hasPL() && (occurrences = Collections.frequency(remappedAlleles, Allele.SPAN_DEL)) > 1 ? ((indexOfBestDel = AlleleSubsettingUtils.indexOfBestDel(remappedAlleles, g.getPL(), g.getPloidy())) == -1 ? indexOfNonRef : indexOfBestDel) : ((indexOfRemappedAllele = remappedAlleles.indexOf(targetAlleles.get(i))) == -1 ? indexOfNonRef : indexOfRemappedAllele);
        }
        return indexMapping;
    }

    public static int[] getIndexesOfRelevantAlleles(List<Allele> remappedAlleles, List<Allele> targetAlleles, int position, Genotype g) {
        Utils.nonEmpty(remappedAlleles);
        Utils.nonEmpty(targetAlleles);
        int[] indexMapping = new int[targetAlleles.size()];
        indexMapping[0] = 0;
        for (int i = 1; i < targetAlleles.size(); ++i) {
            int occurrences;
            if (targetAlleles.get(i) == Allele.SPAN_DEL && g.hasPL() && (occurrences = Collections.frequency(remappedAlleles, Allele.SPAN_DEL)) > 1) {
                int indexOfBestDel = AlleleSubsettingUtils.indexOfBestDel(remappedAlleles, g.getPL(), g.getPloidy());
                if (indexOfBestDel == -1) {
                    throw new IllegalArgumentException("At position " + position + " targetAlleles contains a spanning deletion, but remappedAlleles does not.");
                }
                indexMapping[i] = indexOfBestDel;
                continue;
            }
            int indexOfRemappedAllele = remappedAlleles.indexOf(targetAlleles.get(i));
            if (indexOfRemappedAllele == -1) {
                throw new IllegalArgumentException("At position " + position + " targetAlleles contains a " + targetAlleles.get(i) + " allele, but remappedAlleles does not.");
            }
            indexMapping[i] = indexOfRemappedAllele;
        }
        return indexMapping;
    }

    private static int indexOfBestDel(List<Allele> alleles, int[] PLs, int ploidy) {
        int bestIndex = -1;
        int bestPL = Integer.MAX_VALUE;
        for (int i = 0; i < alleles.size(); ++i) {
            int homAltIndex;
            int PL;
            if (alleles.get(i) != Allele.SPAN_DEL || (PL = PLs[homAltIndex = AlleleSubsettingUtils.findHomIndex(GL_CALCS.getInstance(ploidy, alleles.size()), i, ploidy)]) >= bestPL) continue;
            bestIndex = i;
            bestPL = PL;
        }
        return bestIndex;
    }

    private static int findHomIndex(GenotypeLikelihoodCalculator calculator, int i, int ploidy) {
        if (ploidy == 2) {
            return GenotypeLikelihoods.calculatePLindex((int)i, (int)i);
        }
        if (ploidy == 1) {
            return i;
        }
        int[] alleleIndexes = new int[ploidy];
        Arrays.fill(alleleIndexes, i);
        return calculator.allelesToIndex(alleleIndexes);
    }

    public static int[] generateAD(int[] originalAD, int[] indexesOfRelevantAlleles) {
        List<Integer> adList = AlleleSubsettingUtils.remapRLengthList(Arrays.stream(originalAD).boxed().collect(Collectors.toList()), indexesOfRelevantAlleles, 0);
        return Ints.toArray(adList);
    }

    public static double[] generateAF(double[] originalAF, int[] indexesOfRelevantAlleles) {
        List<Double> afList = AlleleSubsettingUtils.remapALengthList(Arrays.stream(originalAF).boxed().collect(Collectors.toList()), indexesOfRelevantAlleles, 0.0);
        return Doubles.toArray(afList);
    }

    public static <T> List<T> remapRLengthList(List<T> originalList, int[] indexesOfRelevantAlleles, T filler) {
        Utils.nonNull(originalList);
        Utils.nonNull(indexesOfRelevantAlleles);
        return AlleleSubsettingUtils.remapList(originalList, indexesOfRelevantAlleles, 0, filler);
    }

    public static <T> List<T> remapALengthList(List<T> originalList, int[] indexesOfRelevantAlleles, T filler) {
        Utils.nonNull(originalList);
        Utils.nonNull(indexesOfRelevantAlleles);
        return AlleleSubsettingUtils.remapList(originalList, indexesOfRelevantAlleles, 1, filler);
    }

    private static <T> List<T> remapList(List<T> originalList, int[] indexesOfRelevantAlleles, int offset, T filler) {
        int numValues = indexesOfRelevantAlleles.length - offset;
        ArrayList<T> newValues = new ArrayList<T>();
        for (int i = offset; i < numValues + offset; ++i) {
            int oldIndex = indexesOfRelevantAlleles[i];
            if (oldIndex >= originalList.size() + offset) {
                newValues.add(i - offset, filler);
                continue;
            }
            newValues.add(i - offset, originalList.get(oldIndex - offset));
        }
        return newValues;
    }
}

