/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.variantutils;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.util.Locatable;
import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.Genotype;
import htsjdk.variant.variantcontext.GenotypeBuilder;
import htsjdk.variant.variantcontext.GenotypeLikelihoods;
import htsjdk.variant.variantcontext.GenotypesContext;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.VariantContextBuilder;
import htsjdk.variant.variantcontext.VariantContextUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections.ListUtils;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.ReferenceConfidenceVariantContextMerger;
import org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeAssignmentMethod;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.variant.GATKVariantContextUtils;

public final class PosteriorProbabilitiesUtils {
    private static final int minSamplesToUseInputs = 10;

    private PosteriorProbabilitiesUtils() {
    }

    public static VariantContext calculatePosteriorProbs(VariantContext vc1, List<VariantContext> resources, int numRefSamplesFromMissingResources, PosteriorProbabilitiesOptions opts) {
        boolean isHomRefBlock;
        GenotypeBuilder builder;
        Utils.nonNull(vc1, "VariantContext vc1 is null");
        HashMap<Allele, Integer> totalAlleleCounts = new HashMap<Allele, Integer>();
        boolean useDiscoveredACForMissing = !opts.ignoreInputSamplesForMissingResources && (vc1.getNSamples() >= 10 || numRefSamplesFromMissingResources != 0);
        List allAlleles = ListUtils.union(resources, Arrays.asList(vc1));
        Allele commonRef = GATKVariantContextUtils.determineReferenceAllele(allAlleles, new SimpleInterval((Locatable)vc1));
        List<Allele> origAllelesRemapped = ReferenceConfidenceVariantContextMerger.remapAlleles(vc1, commonRef);
        int referenceAlleleCountForMissing = resources.isEmpty() ? 2 * numRefSamplesFromMissingResources : 0;
        for (VariantContext r : resources) {
            if (r.getStart() != vc1.getStart()) continue;
            List<Allele> remappedAlleles = ReferenceConfidenceVariantContextMerger.remapAlleles(r, commonRef);
            PosteriorProbabilitiesUtils.addAlleleCounts(totalAlleleCounts, r, remappedAlleles, !opts.useMLEAC);
        }
        if (opts.useInputSamplesAlleleCounts && !resources.isEmpty() || resources.isEmpty() && useDiscoveredACForMissing) {
            PosteriorProbabilitiesUtils.addAlleleCounts(totalAlleleCounts, vc1, origAllelesRemapped, !opts.useMLEAC);
        }
        int existingRefCounts = totalAlleleCounts.getOrDefault(commonRef, 0);
        totalAlleleCounts.put(commonRef, existingRefCounts + referenceAlleleCountForMissing);
        Set allAllelesRemapped = totalAlleleCounts.keySet();
        HashSet resourceOnlyAlleles = new HashSet(allAllelesRemapped);
        resourceOnlyAlleles.removeAll(origAllelesRemapped);
        double[] alleleCounts = new double[origAllelesRemapped.size()];
        for (int i = 0; i < origAllelesRemapped.size(); ++i) {
            Allele a2 = origAllelesRemapped.get(i);
            alleleCounts[i] = a2.length() == commonRef.length() ? opts.snpPriorDirichlet + (double)totalAlleleCounts.getOrDefault(a2, 0).intValue() : (a2.isSymbolic() ? Math.max(opts.snpPriorDirichlet, opts.indelPriorDirichlet) + (double)totalAlleleCounts.getOrDefault(a2, 0).intValue() : opts.indelPriorDirichlet + (double)totalAlleleCounts.getOrDefault(a2, 0).intValue());
        }
        int nonRefInd = vc1.getAlleleIndex(Allele.NON_REF_ALLELE);
        if (nonRefInd != -1) {
            alleleCounts[nonRefInd] = Math.max(opts.snpPriorDirichlet, opts.indelPriorDirichlet) + resourceOnlyAlleles.stream().mapToDouble(a -> ((Integer)totalAlleleCounts.get(a)).intValue()).sum();
        }
        List<double[]> likelihoods = vc1.getGenotypes().stream().map(g -> PosteriorProbabilitiesUtils.parsePosteriorsIntoProbSpace(g)).collect(Collectors.toList());
        boolean useFlatPriors = !vc1.isSNP() && opts.useFlatPriorsForIndels || resources.isEmpty() && !useDiscoveredACForMissing && numRefSamplesFromMissingResources == 0;
        List<double[]> posteriors = PosteriorProbabilitiesUtils.calculatePosteriorProbs(likelihoods, alleleCounts, vc1.getMaxPloidy(2), useFlatPriors);
        GenotypesContext newContext = GenotypesContext.create();
        for (int genoIdx = 0; genoIdx < vc1.getNSamples(); ++genoIdx) {
            builder = new GenotypeBuilder(vc1.getGenotype(genoIdx));
            builder.phased(vc1.getGenotype(genoIdx).isPhased());
            if (posteriors.get(genoIdx) != null) {
                GATKVariantContextUtils.makeGenotypeCall(vc1.getMaxPloidy(2), builder, GenotypeAssignmentMethod.USE_PLS_TO_ASSIGN, posteriors.get(genoIdx), vc1.getAlleles(), null);
                builder.attribute("PP", Utils.listFromPrimitives(GenotypeLikelihoods.fromLog10Likelihoods((double[])posteriors.get(genoIdx)).getAsPLs()));
            }
            newContext.add(builder.make());
        }
        List<Integer> priors = Utils.listFromPrimitives(GenotypeLikelihoods.fromLog10Likelihoods((double[])PosteriorProbabilitiesUtils.getDirichletPrior(alleleCounts, vc1.getMaxPloidy(2), useFlatPriors)).getAsPLs());
        builder = new VariantContextBuilder(vc1).genotypes(newContext);
        boolean bl = isHomRefBlock = vc1.getAlternateAlleles().size() == 1 && vc1.getAlleles().contains(Allele.NON_REF_ALLELE);
        if (!isHomRefBlock) {
            VariantContextUtils.calculateChromosomeCounts((VariantContextBuilder)builder.attribute("PG", priors), (boolean)true);
        }
        return builder.make();
    }

    public static int[] parsePosteriorsIntoPhredSpace(Genotype genotype) {
        Object PPfromVCF = genotype.getExtendedAttribute("PP");
        if (PPfromVCF == null) {
            return genotype.getPL();
        }
        if (PPfromVCF instanceof String) {
            String PPstring = (String)PPfromVCF;
            return PPstring.charAt(0) == '.' ? genotype.getPL() : Arrays.stream(PPstring.split(",")).mapToInt(i -> Integer.parseInt(i)).toArray();
        }
        return Arrays.stream(PosteriorProbabilitiesUtils.extractInts(PPfromVCF)).toArray();
    }

    public static double[] parsePosteriorsIntoProbSpace(Genotype genotype) {
        Object PPfromVCF = genotype.getExtendedAttribute("PP");
        if (PPfromVCF == null) {
            return PosteriorProbabilitiesUtils.getLikelihoodsVector(genotype);
        }
        if (PPfromVCF instanceof String) {
            String PPstring = (String)PPfromVCF;
            return PPstring.charAt(0) == '.' ? PosteriorProbabilitiesUtils.getLikelihoodsVector(genotype) : Arrays.stream(PPstring.split(",")).mapToDouble(s -> Double.parseDouble(s) / -10.0).toArray();
        }
        return Arrays.stream(PosteriorProbabilitiesUtils.extractInts(PPfromVCF)).mapToDouble(i -> (double)i / -10.0).toArray();
    }

    private static double[] getLikelihoodsVector(Genotype genotype) {
        return PosteriorProbabilitiesUtils.hasRealLikelihoods(genotype) ? genotype.getLikelihoods().getAsVector() : null;
    }

    private static boolean hasRealLikelihoods(Genotype genotype) {
        if (!genotype.hasLikelihoods()) {
            return false;
        }
        if (genotype.hasDP() && genotype.getDP() == 0) {
            return MathUtils.arrayMax(genotype.getPL()) > 0;
        }
        return true;
    }

    protected static List<double[]> calculatePosteriorProbs(List<double[]> genotypeLikelihoods, double[] knownAlleleCountsByAllele, int ploidy, boolean useFlatPriors) {
        if (ploidy != 2) {
            throw new IllegalStateException("Genotype posteriors not yet implemented for ploidy != 2");
        }
        double[] genotypePriorByAllele = PosteriorProbabilitiesUtils.getDirichletPrior(knownAlleleCountsByAllele, ploidy, useFlatPriors);
        ArrayList<double[]> posteriors = new ArrayList<double[]>(genotypeLikelihoods.size());
        for (double[] likelihoods : genotypeLikelihoods) {
            double[] posteriorProbabilities = null;
            if (likelihoods != null) {
                if (likelihoods.length != genotypePriorByAllele.length) {
                    throw new IllegalStateException(String.format("Likelihoods not of correct size: expected %d, observed %d", knownAlleleCountsByAllele.length * (knownAlleleCountsByAllele.length + 1) / 2, likelihoods.length));
                }
                posteriorProbabilities = new double[genotypePriorByAllele.length];
                for (int genoIdx = 0; genoIdx < likelihoods.length; ++genoIdx) {
                    posteriorProbabilities[genoIdx] = likelihoods[genoIdx] + genotypePriorByAllele[genoIdx];
                }
                posteriorProbabilities = MathUtils.normalizeLog10(posteriorProbabilities);
            }
            posteriors.add(posteriorProbabilities);
        }
        return posteriors;
    }

    @VisibleForTesting
    static double[] calculatePosteriorProbs(double[] genotypeLikelihoods, double[] knownAlleleCountsByAllele, int ploidy, boolean useFlatPriors) {
        return PosteriorProbabilitiesUtils.calculatePosteriorProbs(Arrays.asList(new double[][]{genotypeLikelihoods}), knownAlleleCountsByAllele, ploidy, useFlatPriors).get(0);
    }

    @VisibleForTesting
    static double[] getDirichletPrior(double[] knownCountsByAllele, int ploidy, boolean useFlatPrior) {
        if (ploidy != 2) {
            throw new IllegalStateException("Genotype priors not yet implemented for ploidy != 2");
        }
        double sumOfKnownCounts = MathUtils.sum(knownCountsByAllele);
        double[] priors = new double[knownCountsByAllele.length * (knownCountsByAllele.length + 1) / 2];
        int priorIndex = 0;
        for (int allele2 = 0; allele2 < knownCountsByAllele.length; ++allele2) {
            for (int allele1 = 0; allele1 <= allele2; ++allele1) {
                if (useFlatPrior) {
                    priors[priorIndex++] = 1.0;
                    continue;
                }
                int[] counts = new int[knownCountsByAllele.length];
                int n = allele1;
                counts[n] = counts[n] + 1;
                int n2 = allele2;
                counts[n2] = counts[n2] + 1;
                priors[priorIndex++] = MathUtils.dirichletMultinomial(knownCountsByAllele, counts);
            }
        }
        return priors;
    }

    private static void addAlleleCounts(Map<Allele, Integer> counts, VariantContext context, List<Allele> remappedAlleles, boolean useAC) {
        int[] ac;
        if (context.hasAttribute("MLEAC") && !useAC) {
            ac = PosteriorProbabilitiesUtils.getAlleleCounts("MLEAC", context);
        } else if (context.hasAttribute("AC")) {
            ac = PosteriorProbabilitiesUtils.getAlleleCounts("AC", context);
        } else {
            ac = new int[context.getAlternateAlleles().size()];
            int idx = 0;
            for (Allele allele : context.getAlternateAlleles()) {
                ac[idx++] = context.getCalledChrCount(allele);
            }
        }
        for (int i = 0; i < context.getAlleles().size(); ++i) {
            Allele allele = remappedAlleles.get(i);
            Allele origAllele = (Allele)context.getAlleles().get(i);
            int count = allele.isReference() ? (context.hasAttribute("AN") ? Math.max(context.getAttributeAsInt("AN", -1) - (int)MathUtils.sum(ac), 0) : Math.max(context.getCalledChrCount() - (int)MathUtils.sum(ac), 0)) : ac[context.getAlternateAlleles().indexOf(origAllele)];
            if (!counts.containsKey(allele)) {
                counts.put(allele, 0);
            }
            counts.put(allele, count + counts.get(allele));
        }
    }

    private static int[] getAlleleCounts(String VCFkey, VariantContext context) {
        Object alleleCountsFromVCF = context.getAttribute(VCFkey);
        if (alleleCountsFromVCF instanceof List ? ((List)alleleCountsFromVCF).size() != context.getAlternateAlleles().size() : (alleleCountsFromVCF instanceof String || alleleCountsFromVCF instanceof Integer) && context.getAlternateAlleles().size() != 1) {
            throw new UserException(String.format("Variant does not contain the same number of MLE allele counts as alternate alleles for record at %s:%d", context.getContig(), context.getStart()));
        }
        return PosteriorProbabilitiesUtils.extractInts(alleleCountsFromVCF);
    }

    public static int[] extractInts(Object integerListContainingVCField) {
        List<Integer> mleList = null;
        if (integerListContainingVCField instanceof List) {
            if (((List)integerListContainingVCField).get(0) instanceof String) {
                mleList = new ArrayList<Integer>(((List)integerListContainingVCField).size());
                for (Object s : (List)integerListContainingVCField) {
                    mleList.add(Integer.parseInt((String)s));
                }
            } else {
                mleList = (List<Integer>)integerListContainingVCField;
            }
        } else if (integerListContainingVCField instanceof Integer) {
            mleList = Arrays.asList((Integer)integerListContainingVCField);
        } else if (integerListContainingVCField instanceof String) {
            mleList = Arrays.asList(Integer.parseInt((String)integerListContainingVCField));
        }
        Utils.nonNull(mleList, () -> String.format("VCF does not have properly formatted %s or %s.", "MLEAC", "AC"));
        int[] mle = new int[mleList.size()];
        if (!(mleList.get(0) instanceof Integer)) {
            throw new IllegalStateException("BUG: The AC values should be an Integer, but was " + mleList.get(0).getClass().getCanonicalName());
        }
        for (int idx = 0; idx < mle.length; ++idx) {
            mle[idx] = mleList.get(idx);
        }
        return mle;
    }

    public static final class PosteriorProbabilitiesOptions {
        double snpPriorDirichlet;
        double indelPriorDirichlet;
        boolean useInputSamplesAlleleCounts;
        boolean useMLEAC;
        boolean ignoreInputSamplesForMissingResources;
        boolean useFlatPriorsForIndels;

        public PosteriorProbabilitiesOptions(double snpPriorDirichlet, double indelPriorDirichlet, boolean useInputSamplesAlleleCounts, boolean useMLEAC, boolean ignoreInputSamplesForMissingResources, boolean useFlatPriorsForIndels) {
            this.snpPriorDirichlet = snpPriorDirichlet;
            this.indelPriorDirichlet = indelPriorDirichlet;
            this.useInputSamplesAlleleCounts = useInputSamplesAlleleCounts;
            this.useMLEAC = useMLEAC;
            this.ignoreInputSamplesForMissingResources = ignoreInputSamplesForMissingResources;
            this.useFlatPriorsForIndels = useFlatPriorsForIndels;
        }
    }
}

