/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.mutect.filtering;

import com.netflix.servo.util.VisibleForTesting;
import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.Genotype;
import htsjdk.variant.variantcontext.VariantContext;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.ErrorType;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2FilteringEngine;
import org.broadinstitute.hellbender.tools.walkers.mutect.filtering.Mutect2VariantFilter;
import org.broadinstitute.hellbender.tools.walkers.readorientation.ArtifactPrior;
import org.broadinstitute.hellbender.tools.walkers.readorientation.ArtifactPriorCollection;
import org.broadinstitute.hellbender.tools.walkers.readorientation.ArtifactState;
import org.broadinstitute.hellbender.tools.walkers.readorientation.F1R2FilterUtils;
import org.broadinstitute.hellbender.tools.walkers.readorientation.LearnReadOrientationModelEngine;
import org.broadinstitute.hellbender.utils.MathUtils;
import org.broadinstitute.hellbender.utils.Nucleotide;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.variant.VariantContextGetters;

public class ReadOrientationFilter
extends Mutect2VariantFilter {
    private Map<String, ArtifactPriorCollection> artifactPriorCollections = new HashMap<String, ArtifactPriorCollection>();

    public ReadOrientationFilter(List<File> readOrientationPriorTables) {
        readOrientationPriorTables.stream().forEach(file -> {
            ArtifactPriorCollection artifactPriorCollection = ArtifactPriorCollection.readArtifactPriors(file);
            this.artifactPriorCollections.put(artifactPriorCollection.getSample(), artifactPriorCollection);
        });
    }

    public static int[] getF1R2(Genotype g) {
        return VariantContextGetters.getAttributeAsIntArray(g, "F1R2", () -> null, 0);
    }

    public static int[] getF2R1(Genotype g) {
        return VariantContextGetters.getAttributeAsIntArray(g, "F2R1", () -> null, 0);
    }

    @Override
    public ErrorType errorType() {
        return ErrorType.ARTIFACT;
    }

    @Override
    public double calculateErrorProbability(VariantContext vc, Mutect2FilteringEngine filteringEngine, ReferenceContext referenceContext) {
        if (!vc.isSNP() && !vc.isMNP()) {
            return 0.0;
        }
        ArrayList<ImmutablePair<Integer, Double>> depthsAndPosteriors = new ArrayList<ImmutablePair<Integer, Double>>();
        vc.getGenotypes().stream().filter(filteringEngine::isTumor).forEach(g -> {
            double artifactPosterior = this.artifactProbability(referenceContext, vc, (Genotype)g);
            int[] ADs = g.getAD();
            int altCount = (int)MathUtils.sum(ADs) - ADs[0];
            depthsAndPosteriors.add(ImmutablePair.of((Object)altCount, (Object)artifactPosterior));
        });
        double artifactPosterior = ReadOrientationFilter.weightedMedianPosteriorProbability(depthsAndPosteriors);
        return artifactPosterior;
    }

    @Override
    public String filterName() {
        return "orientation";
    }

    @Override
    public Optional<String> phredScaledPosteriorAnnotationName() {
        return Optional.of("ROQ");
    }

    @Override
    protected List<String> requiredInfoAnnotations() {
        return Collections.emptyList();
    }

    @VisibleForTesting
    double artifactProbability(ReferenceContext referenceContext, VariantContext vc, Genotype g) {
        if (g.isHomRef() || !vc.isSNP() && !vc.isMNP()) {
            return 0.0;
        }
        if (!this.artifactPriorCollections.containsKey(g.getSampleName())) {
            return 0.0;
        }
        double[] tumorLods = VariantContextGetters.getAttributeAsDoubleArray(vc, "TLOD", () -> null, -1.0);
        int indexOfMaxTumorLod = MathUtils.maxElementIndex(tumorLods);
        Allele altAllele = vc.getAlternateAllele(indexOfMaxTumorLod);
        byte[] altBases = altAllele.getBases();
        return IntStream.range(0, altBases.length).mapToDouble(n -> {
            Nucleotide altBase = Nucleotide.valueOf(new String(new byte[]{altBases[n]}));
            return this.artifactProbability(referenceContext, vc.getStart() + n, g, indexOfMaxTumorLod, altBase);
        }).max().orElse(0.0);
    }

    private double artifactProbability(ReferenceContext referenceContext, int refPosition, Genotype g, int indexOfMaxTumorLod, Nucleotide altBase) {
        String refContext = referenceContext.getKmerAround(refPosition, 1);
        if (refContext == null || refContext.contains("N")) {
            return 0.0;
        }
        Utils.validate(refContext.length() == 3, String.format("kmer must have length %d but got %d", 3, refContext.length()));
        Nucleotide refAllele = F1R2FilterUtils.getMiddleBase(refContext);
        if (!g.hasExtendedAttribute("F1R2") || !g.hasExtendedAttribute("F2R1")) {
            return 0.0;
        }
        int[] f1r2 = ReadOrientationFilter.getF1R2(g);
        int[] f2r1 = ReadOrientationFilter.getF2R1(g);
        int refCount = f1r2[0] + f2r1[0];
        int altF1R2 = f1r2[indexOfMaxTumorLod + 1];
        int altF2R1 = f2r1[indexOfMaxTumorLod + 1];
        int altCount = altF1R2 + altF2R1;
        Optional<ArtifactPrior> artifactPrior = this.artifactPriorCollections.get(g.getSampleName()).get(refContext);
        if (!artifactPrior.isPresent()) {
            return 0.0;
        }
        int depth = refCount + altCount;
        double[] posterior = LearnReadOrientationModelEngine.computeResponsibilities(refAllele, altBase, altCount, altF1R2, depth, artifactPrior.get().getPi(), true);
        double posteriorOfF1R2 = posterior[ArtifactState.getF1R2StateForAlt(altBase).ordinal()];
        double posteriorOfF2R1 = posterior[ArtifactState.getF2R1StateForAlt(altBase).ordinal()];
        return Math.max(posteriorOfF1R2, posteriorOfF2R1);
    }
}

