/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.vqsr;

import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.VariantContextBuilder;
import htsjdk.variant.variantcontext.writer.VariantContextWriter;
import htsjdk.variant.vcf.VCFFilterHeaderLine;
import htsjdk.variant.vcf.VCFHeader;
import htsjdk.variant.vcf.VCFHeaderLine;
import htsjdk.variant.vcf.VCFInfoHeaderLine;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineException;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.barclay.help.DocumentedFeature;
import org.broadinstitute.hellbender.engine.FeatureContext;
import org.broadinstitute.hellbender.engine.FeatureInput;
import org.broadinstitute.hellbender.engine.GATKPath;
import org.broadinstitute.hellbender.engine.ReadsContext;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.engine.TwoPassVariantWalker;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.variant.GATKVariantContextUtils;
import picard.cmdline.programgroups.VariantFilteringProgramGroup;

@DocumentedFeature
@CommandLineProgramProperties(summary="Apply tranche filtering based on a truth VCF of known common sites of variation and a score from VCF INFO field", oneLineSummary="Apply tranche filtering", programGroup=VariantFilteringProgramGroup.class)
public class FilterVariantTranches
extends TwoPassVariantWalker {
    @Argument(fullName="output", shortName="O", doc="Output VCF file")
    private GATKPath outputVcf = null;
    @Argument(fullName="snp-tranche", shortName="snp-tranche", doc="The level(s) of sensitivity to SNPs in the resource VCFs at which to filter SNPs. Higher numbers mean more desired sensitivity and thus less stringent filtering.Specified in percents, i.e. 99.9 for 99.9 percent and 1.0 for 1 percent.", optional=true)
    private List<Double> snpTranches = new ArrayList<Double>(Arrays.asList(99.95));
    @Argument(fullName="indel-tranche", shortName="indel-tranche", doc="The level(s) of sensitivity to indels in the resource VCFs at which to filter indels. Higher numbers mean more desired sensitivity and thus less stringent filtering.Specified in percents, i.e. 99.9 for 99.9 percent and 1.0 for 1 percent.", optional=true)
    private List<Double> indelTranches = new ArrayList<Double>(Arrays.asList(99.4));
    @Argument(fullName="resource", doc="A list of validated VCFs with known sites of common variation", optional=false)
    private List<FeatureInput<VariantContext>> resources = new ArrayList<FeatureInput<VariantContext>>();
    @Argument(fullName="info-key", shortName="info-key", doc="The key must be in the INFO field of the input VCF.")
    private String infoKey = "CNN_2D";
    @Argument(fullName="invalidate-previous-filters", doc="Remove all filters that already exist in the VCF.", optional=true)
    private boolean removeOldFilters = false;
    private VariantContextWriter vcfWriter;
    private List<Double> resourceSNPScores = new ArrayList<Double>();
    private List<Double> snpCutoffs = new ArrayList<Double>();
    private List<Double> resourceIndelScores = new ArrayList<Double>();
    private List<Double> indelCutoffs = new ArrayList<Double>();
    private int scoredSnps = 0;
    private int filteredSnps = 0;
    private int scoredIndels = 0;
    private int filteredIndels = 0;
    private static String SNPString = "SNP";
    private static String INDELString = "INDEL";

    @Override
    public void onTraversalStart() {
        this.snpTranches = this.validateTranches(this.snpTranches);
        this.indelTranches = this.validateTranches(this.indelTranches);
        this.vcfWriter = this.createVCFWriter(this.outputVcf);
        this.writeVCFHeader(this.vcfWriter);
    }

    @Override
    public void firstPassApply(VariantContext variant, ReadsContext readsContext, ReferenceContext referenceContext, FeatureContext featureContext) {
        if (!variant.hasAttribute(this.infoKey)) {
            return;
        }
        if (variant.isSNP()) {
            ++this.scoredSnps;
        } else if (variant.isIndel()) {
            ++this.scoredIndels;
        }
        for (FeatureInput<VariantContext> featureSource : this.resources) {
            for (VariantContext v : featureContext.getValues(featureSource)) {
                for (Allele a : variant.getAlternateAlleles()) {
                    try {
                        if (variant.getStart() != v.getStart() || !GATKVariantContextUtils.isAlleleInList(variant.getReference(), a, v.getReference(), v.getAlternateAlleles())) continue;
                        if (variant.isSNP()) {
                            this.resourceSNPScores.add(Double.parseDouble((String)variant.getAttribute(this.infoKey)));
                            return;
                        }
                        this.resourceIndelScores.add(Double.parseDouble((String)variant.getAttribute(this.infoKey)));
                        return;
                    }
                    catch (IllegalStateException e) {
                        throw new UserException.BadInput(String.format("The provided variant file(s) have inconsistent references for the same position(s) at %s:%d, %s in input vs. %s in resource", v.getContig(), v.getStart(), variant.getReference(), v.getReference()));
                    }
                }
            }
        }
    }

    @Override
    public void afterFirstPass() {
        this.logger.info(String.format("Found %d SNPs and %d indels with INFO score key:%s.", this.scoredSnps, this.scoredIndels, this.infoKey));
        this.logger.info(String.format("Found %d SNPs and %d indels in the resources.", this.resourceSNPScores.size(), this.resourceIndelScores.size()));
        if (this.scoredSnps == 0 && this.scoredIndels == 0) {
            throw new UserException.BadInput("VCF contains no variants or no variants with INFO score key \"" + this.infoKey + "\"");
        }
        if (this.resourceSNPScores.size() == 0 && this.resourceIndelScores.size() == 0) {
            throw new UserException.BadInput("Neither SNP nor indel resource contains variants overlapping input.  Filtering cannot be performed.");
        }
        if (this.scoredSnps > 0 && this.resourceSNPScores.size() == 0) {
            throw new UserException.BadInput("SNPs are present in input VCF, but cannot be filtered because no overlapping SNPs were found in the resources.");
        }
        if (this.scoredIndels > 0 && this.resourceIndelScores.size() == 0) {
            throw new UserException.BadInput("Indels are present in input VCF, but cannot be filtered because no overlapping indels were found in the resources.");
        }
        Collections.sort(this.resourceSNPScores, Collections.reverseOrder());
        Collections.sort(this.resourceIndelScores, Collections.reverseOrder());
        if (this.resourceSNPScores.size() != 0) {
            for (double t : this.snpTranches) {
                int snpIndex = (int)(t / 100.0 * (double)(this.resourceSNPScores.size() - 1));
                this.snpCutoffs.add(this.resourceSNPScores.get(snpIndex));
            }
        }
        if (this.resourceIndelScores.size() != 0) {
            for (double t : this.indelTranches) {
                int indelIndex = (int)(t / 100.0 * (double)(this.resourceIndelScores.size() - 1));
                this.indelCutoffs.add(this.resourceIndelScores.get(indelIndex));
            }
        }
    }

    @Override
    protected void secondPassApply(VariantContext variant, ReadsContext readsContext, ReferenceContext referenceContext, FeatureContext featureContext) {
        VariantContextBuilder builder = new VariantContextBuilder(variant);
        if (this.removeOldFilters) {
            builder.unfiltered();
        }
        if (variant.hasAttribute(this.infoKey)) {
            double score = Double.parseDouble((String)variant.getAttribute(this.infoKey));
            if (variant.isSNP() && this.snpCutoffs.size() != 0 && this.isTrancheFiltered(score, this.snpCutoffs)) {
                builder.filter(this.filterStringFromScore(SNPString, score, this.snpTranches, this.snpCutoffs));
                ++this.filteredSnps;
            } else if (variant.isIndel() && this.indelCutoffs.size() != 0 && this.isTrancheFiltered(score, this.indelCutoffs)) {
                builder.filter(this.filterStringFromScore(INDELString, score, this.indelTranches, this.indelCutoffs));
                ++this.filteredIndels;
            }
        }
        if (builder.getFilters() == null || builder.getFilters().size() == 0) {
            builder.passFilters();
        }
        this.vcfWriter.add(builder.make());
    }

    @Override
    public void closeTool() {
        this.logger.info(String.format("Filtered %d SNPs out of %d and filtered %d indels out of %d with INFO score: %s.", this.filteredSnps, this.scoredSnps, this.filteredIndels, this.scoredIndels, this.infoKey));
        if (this.vcfWriter != null) {
            this.vcfWriter.close();
        }
    }

    private void writeVCFHeader(VariantContextWriter vcfWriter) {
        VCFHeader inputHeader = this.getHeaderForVariants();
        LinkedHashSet<VCFHeaderLine> hInfo = new LinkedHashSet<VCFHeaderLine>();
        hInfo.addAll(inputHeader.getMetaDataInSortedOrder());
        boolean hasInfoKey = hInfo.stream().anyMatch(x -> x instanceof VCFInfoHeaderLine && ((VCFInfoHeaderLine)x).getID().equals(this.infoKey));
        if (!hasInfoKey) {
            throw new UserException(String.format("Input VCF does not contain a header line for specified info key:%s", this.infoKey));
        }
        if (this.removeOldFilters) {
            hInfo.removeIf(x -> x instanceof VCFFilterHeaderLine);
        }
        this.addTrancheHeaderFields(SNPString, this.snpTranches, hInfo);
        this.addTrancheHeaderFields(INDELString, this.indelTranches, hInfo);
        TreeSet samples = new TreeSet();
        samples.addAll(inputHeader.getGenotypeSamples());
        hInfo.addAll(this.getDefaultToolVCFHeaderLines());
        VCFHeader vcfHeader = new VCFHeader(hInfo, samples);
        vcfWriter.writeHeader(vcfHeader);
    }

    private void addTrancheHeaderFields(String type, List<Double> tranches, Set<VCFHeaderLine> hInfo) {
        if (tranches.size() >= 2) {
            for (int i = 0; i < tranches.size() - 1; ++i) {
                String filterKey = this.filterKeyFromTranches(type, this.infoKey, tranches.get(i), tranches.get(i + 1));
                String filterDescription = this.filterDescriptionFromTranches(type, this.infoKey, tranches.get(i), tranches.get(i + 1));
                hInfo.add((VCFHeaderLine)new VCFFilterHeaderLine(filterKey, filterDescription));
            }
        }
        String filterKey = this.filterKeyFromTranches(type, this.infoKey, tranches.get(tranches.size() - 1), 100.0);
        String filterDescription = this.filterDescriptionFromTranches(type, this.infoKey, tranches.get(tranches.size() - 1), 100.0);
        hInfo.add((VCFHeaderLine)new VCFFilterHeaderLine(filterKey, filterDescription));
    }

    private String filterKeyFromTranches(String type, String infoKey, double t1, double t2) {
        return String.format("%s_%s_Tranche_%.2f_%.2f", infoKey, type, t1, t2);
    }

    private String filterDescriptionFromTranches(String type, String infoKey, double t1, double t2) {
        return String.format("%s truth resource sensitivity between %.2f and %.2f for info key %s", type, t1, t2, infoKey);
    }

    private boolean isTrancheFiltered(double score, List<Double> cutoffs) {
        return score <= cutoffs.get(0);
    }

    private String filterStringFromScore(String type, double score, List<Double> tranches, List<Double> cutoffs) {
        for (int i = 0; i < cutoffs.size(); ++i) {
            if (score > cutoffs.get(i) && i == 0) {
                throw new GATKException("Trying to add a filter to a passing variant.");
            }
            if (!(score > cutoffs.get(i))) continue;
            return this.filterKeyFromTranches(type, this.infoKey, tranches.get(i - 1), tranches.get(i));
        }
        return this.filterKeyFromTranches(type, this.infoKey, tranches.get(tranches.size() - 1), 100.0);
    }

    private List<Double> validateTranches(List<Double> tranches) {
        if (tranches.size() < 1 || tranches.stream().anyMatch(d -> d < 0.0 || d >= 100.0)) {
            throw new CommandLineException("At least 1 tranche value must be given and all tranches must be greater than 0 and less than 100.");
        }
        List<Double> newTranches = tranches.stream().distinct().collect(Collectors.toList());
        newTranches.sort(Double::compareTo);
        return newTranches;
    }
}

