/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.spark.pathseq;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import htsjdk.samtools.Cigar;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.samtools.SAMTag;
import htsjdk.samtools.TextCigarCodec;
import java.io.IOException;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSPathogenAlignmentHit;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSPathogenTaxonScore;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSScoreArgumentCollection;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSTaxonomyDatabase;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSTree;
import org.broadinstitute.hellbender.tools.spark.pathseq.PSUtils;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import scala.Tuple2;

public final class PSScorer {
    public static final String HITS_TAG = "YP";
    public static final double SCORE_GENOME_LENGTH_UNITS = 1000000.0;
    private static final Logger logger = LogManager.getLogger(PSScorer.class);
    private final PSScoreArgumentCollection scoreArgs;

    public PSScorer(PSScoreArgumentCollection scoreArgs) {
        this.scoreArgs = scoreArgs;
    }

    public JavaRDD<GATKRead> scoreReads(JavaSparkContext ctx, JavaRDD<GATKRead> pairedReads, JavaRDD<GATKRead> unpairedReads, SAMFileHeader header) {
        JavaRDD<Iterable<GATKRead>> groupedReads = PSScorer.groupReadsIntoPairs(pairedReads, unpairedReads, this.scoreArgs.readsPerPartitionEstimate);
        PSTaxonomyDatabase taxDB = PSScorer.readTaxonomyDatabase(this.scoreArgs.taxonomyDatabasePath);
        Broadcast taxonomyDatabaseBroadcast = ctx.broadcast((Object)taxDB);
        if (this.scoreArgs.headerWarningFile != null) {
            PSScorer.writeMissingReferenceAccessions(this.scoreArgs.headerWarningFile, header, taxDB, logger);
        }
        JavaRDD<Tuple2<Iterable<GATKRead>, PSPathogenAlignmentHit>> readHits = PSScorer.mapGroupedReadsToTax(groupedReads, this.scoreArgs.minIdentity, this.scoreArgs.identityMargin, (Broadcast<PSTaxonomyDatabase>)taxonomyDatabaseBroadcast);
        JavaRDD<GATKRead> readsFinal = PSScorer.flattenIterableKeys(readHits);
        JavaRDD alignmentHits = readHits.map(Tuple2::_2);
        boolean divideByGenomeLength = this.scoreArgs.divideByGenomeLength;
        JavaPairRDD taxScoresRdd = alignmentHits.mapPartitionsToPair((PairFlatMapFunction & Serializable)iter -> PSScorer.computeTaxScores(iter, (PSTaxonomyDatabase)taxonomyDatabaseBroadcast.value(), divideByGenomeLength));
        Map<Integer, PSPathogenTaxonScore> taxScoresMap = new HashMap<Integer, PSPathogenTaxonScore>(taxScoresRdd.reduceByKey(PSPathogenTaxonScore::add).collectAsMap());
        taxScoresMap = PSScorer.computeNormalizedScores(taxScoresMap, taxDB.tree, this.scoreArgs.notNormalizedByKingdom);
        PSScorer.writeScoresFile(taxScoresMap, taxDB.tree, this.scoreArgs.scoresPath);
        return readsFinal;
    }

    static <T, C> Iterable<C> collectValues(JavaRDD<Tuple2<T, C>> tupleRdd) {
        return tupleRdd.map((Function & Serializable)tuple -> tuple._2).collect();
    }

    static <T, C> JavaRDD<T> flattenIterableKeys(JavaRDD<Tuple2<Iterable<T>, C>> tupleRdd) {
        return tupleRdd.flatMap((FlatMapFunction & Serializable)tuple -> ((Iterable)tuple._1).iterator());
    }

    static JavaRDD<Iterable<GATKRead>> groupReadsIntoPairs(JavaRDD<GATKRead> pairedReads, JavaRDD<GATKRead> unpairedReads, int readsPerPartitionGuess) {
        JavaRDD groupedReads;
        if (pairedReads != null) {
            groupedReads = pairedReads.mapPartitions((FlatMapFunction & Serializable)iter -> PSScorer.groupPairedReadsPartition(iter, readsPerPartitionGuess));
            if (unpairedReads != null) {
                groupedReads = groupedReads.union(unpairedReads.map(Collections::singletonList));
            }
        } else if (unpairedReads != null) {
            groupedReads = unpairedReads.map(Collections::singletonList);
        } else {
            throw new UserException.BadInput("No reads were loaded. Ensure --paired-input and/or --unpaired-input are set and valid.");
        }
        return groupedReads;
    }

    private static Iterator<Iterable<GATKRead>> groupPairedReadsPartition(Iterator<GATKRead> iter, int readsPerPartitionGuess) {
        ArrayList newPartitionList = new ArrayList(readsPerPartitionGuess / 2);
        while (iter.hasNext()) {
            GATKRead read1 = iter.next();
            GATKRead read2 = null;
            if (iter.hasNext()) {
                read2 = iter.next();
            }
            if (read2 == null || !read1.getName().equals(read2.getName())) {
                throw new UserException.BadInput("Found an unpaired read but expected all reads to be paired: " + read1.getName());
            }
            ArrayList<GATKRead> pair = new ArrayList<GATKRead>(2);
            pair.add(read1);
            pair.add(read2);
            newPartitionList.add(pair);
        }
        newPartitionList.trimToSize();
        return newPartitionList.iterator();
    }

    public static void writeMissingReferenceAccessions(String path, SAMFileHeader header, PSTaxonomyDatabase taxDB, Logger logger) {
        if (header != null && header.getSequenceDictionary() != null && header.getSequenceDictionary().getSequences() != null) {
            Set unknownSequences = header.getSequenceDictionary().getSequences().stream().map(SAMSequenceRecord::getSequenceName).filter(name -> !taxDB.accessionToTaxId.containsKey(name)).collect(Collectors.toSet());
            try (PrintStream file = new PrintStream(BucketUtils.createFile(path));){
                unknownSequences.stream().forEach(file::print);
                if (file.checkError()) {
                    logger.warn("Error writing to header warnings file");
                }
            }
        }
    }

    static JavaRDD<Tuple2<Iterable<GATKRead>, PSPathogenAlignmentHit>> mapGroupedReadsToTax(JavaRDD<Iterable<GATKRead>> pairs, double minIdentity, double identityMargin, Broadcast<PSTaxonomyDatabase> taxonomyDatabaseBroadcast) {
        return pairs.map((Function & Serializable)readIter -> {
            List<Integer> hitTaxIds;
            int numReads = (int)Utils.stream(readIter).count();
            Stream<Object> taxIds = Utils.stream(readIter).flatMap(read -> PSScorer.getValidHits(read, (PSTaxonomyDatabase)taxonomyDatabaseBroadcast.value(), minIdentity, identityMargin).stream());
            if (numReads > 1) {
                Map<Integer, Long> taxIdCounts = taxIds.collect(Collectors.groupingBy(e -> e, Collectors.counting()));
                hitTaxIds = taxIdCounts.entrySet().stream().map(entry -> (Long)entry.getValue() == (long)numReads ? (Integer)entry.getKey() : null).filter(Objects::nonNull).collect(Collectors.toList());
            } else {
                hitTaxIds = taxIds.collect(Collectors.toList());
            }
            PSPathogenAlignmentHit info = new PSPathogenAlignmentHit(hitTaxIds, numReads);
            if (hitTaxIds.size() > 0) {
                String hitString = String.join((CharSequence)",", hitTaxIds.stream().map(String::valueOf).collect(Collectors.toList()));
                Utils.stream(readIter).forEach(read -> read.setAttribute(HITS_TAG, hitString));
            }
            return new Tuple2(readIter, (Object)info);
        });
    }

    private static Set<Integer> getValidHits(GATKRead read, PSTaxonomyDatabase taxonomyDatabase, double minIdentity, double identityMargin) {
        if (read.isUnmapped()) {
            return Collections.emptySet();
        }
        if (!read.hasAttribute(SAMTag.NM.name())) {
            throw new UserException.BadInput("SAM flag indicates a read is mapped, but the NM tag is absent");
        }
        ArrayList<PSPathogenHitAlignment> hits = new ArrayList<PSPathogenHitAlignment>();
        double minIdentityBases = minIdentity * (double)read.getLength();
        int numMismatches = read.getAttributeAsInteger(SAMTag.NM.name());
        int numMatches = PSUtils.getMatchesLessDeletions(read.getCigar(), numMismatches);
        if ((double)numMatches >= minIdentityBases) {
            String recordName = SAMSequenceRecord.truncateSequenceName((String)read.getAssignedContig());
            hits.add(new PSPathogenHitAlignment(numMatches, recordName, read.getCigar()));
        }
        hits.addAll(PSScorer.getValidAlternateHits(read, "XA", 0, 2, 3, minIdentityBases));
        hits.addAll(PSScorer.getValidAlternateHits(read, "SA", 0, 3, 5, minIdentityBases));
        if (hits.isEmpty()) {
            return Collections.emptySet();
        }
        int maxMatches = hits.stream().mapToInt(PSPathogenHitAlignment::getNumMatches).max().getAsInt();
        double minIdentityBasesMargin = (1.0 - identityMargin) * (double)maxMatches;
        List bestHits = hits.stream().filter(hit -> (double)hit.getNumMatches() >= minIdentityBasesMargin).collect(Collectors.toList());
        return bestHits.stream().map(hit -> taxonomyDatabase.accessionToTaxId.getOrDefault(hit.getAccession(), null)).filter(Objects::nonNull).collect(Collectors.toSet());
    }

    private static List<PSPathogenHitAlignment> getValidAlternateHits(GATKRead read, String tag, int contigIndex, int cigarIndex, int numMismatchesIndex, double minIdentityBases) {
        ArrayList<PSPathogenHitAlignment> alternateHits = new ArrayList<PSPathogenHitAlignment>();
        if (read.hasAttribute(tag)) {
            String[] tagTokens;
            int expectedTokens = Math.max(contigIndex, Math.max(cigarIndex, numMismatchesIndex)) + 1;
            String tagValue = read.getAttributeAsString(tag);
            for (String tok : tagTokens = tagValue.split(";")) {
                String[] subtokens = tok.split(",");
                if (subtokens.length < expectedTokens) {
                    throw new UserException.BadInput("Error parsing " + tag + " tag: expected at least " + expectedTokens + " values per alignment but found " + subtokens.length);
                }
                int numMismatches = Integer.valueOf(subtokens[numMismatchesIndex]);
                Cigar cigar = TextCigarCodec.decode((String)subtokens[cigarIndex]);
                int numMatches = PSUtils.getMatchesLessDeletions(cigar, numMismatches);
                if (!((double)numMatches >= minIdentityBases)) continue;
                String recordName = SAMSequenceRecord.truncateSequenceName((String)subtokens[contigIndex]);
                alternateHits.add(new PSPathogenHitAlignment(numMatches, recordName, cigar));
            }
        }
        return alternateHits;
    }

    public static Iterator<Tuple2<Integer, PSPathogenTaxonScore>> computeTaxScores(Iterator<PSPathogenAlignmentHit> taxonHits, PSTaxonomyDatabase taxonomyDatabase, boolean divideByGenomeLength) {
        PSTree tree = taxonomyDatabase.tree;
        HashMap<Integer, PSPathogenTaxonScore> taxIdsToScores = new HashMap<Integer, PSPathogenTaxonScore>();
        HashSet invalidIds = new HashSet();
        while (taxonHits.hasNext()) {
            int taxId;
            PSPathogenAlignmentHit hit = taxonHits.next();
            HashSet<Integer> hitTaxIds = new HashSet<Integer>(hit.taxIDs);
            HashSet<Integer> hitInvalidTaxIds = new HashSet<Integer>(SVUtils.hashMapCapacity(hitTaxIds.size()));
            Iterator iterator = hitTaxIds.iterator();
            while (iterator.hasNext()) {
                int taxId2 = (Integer)iterator.next();
                if (tree.hasNode(taxId2) && tree.getLengthOf(taxId2) != 0L) continue;
                hitInvalidTaxIds.add(taxId2);
            }
            hitTaxIds.removeAll(hitInvalidTaxIds);
            invalidIds.addAll(hitInvalidTaxIds);
            int numHits = hitTaxIds.size();
            if (numHits == 0) continue;
            int lowestCommonAncestor = tree.getLCA(hitTaxIds);
            List<Integer> lcaPath = tree.getPathOf(lowestCommonAncestor);
            for (int taxId3 : lcaPath) {
                PSScorer.getOrAddScoreInfo(taxId3, taxIdsToScores, tree).addUnambiguousReads(hit.numMates);
            }
            HashSet<Integer> hitPathNodes = new HashSet<Integer>();
            Iterator iterator2 = hitTaxIds.iterator();
            while (iterator2.hasNext()) {
                taxId = (Integer)iterator2.next();
                double score = (double)hit.numMates / (double)numHits;
                if (divideByGenomeLength) {
                    score *= 1000000.0 / (double)tree.getLengthOf(taxId);
                }
                List<Integer> path = tree.getPathOf(taxId);
                hitPathNodes.addAll(path);
                for (int pathTaxId : path) {
                    PSPathogenTaxonScore info = PSScorer.getOrAddScoreInfo(pathTaxId, taxIdsToScores, tree);
                    if (pathTaxId == taxId) {
                        info.addSelfScore(score);
                    } else {
                        info.addDescendentScore(score);
                    }
                    taxIdsToScores.put(pathTaxId, info);
                }
            }
            iterator2 = hitPathNodes.iterator();
            while (iterator2.hasNext()) {
                taxId = (Integer)iterator2.next();
                PSScorer.getOrAddScoreInfo(taxId, taxIdsToScores, tree).addTotalReads(hit.numMates);
            }
        }
        PSUtils.logItemizedWarning(logger, invalidIds, "The following taxonomic ID hits were ignored because they either could not be found in the tree or had a reference length of 0 (this may happen when the catalog file, taxdump file, and/or pathogen reference are inconsistent)");
        return taxIdsToScores.entrySet().stream().map(entry -> new Tuple2(entry.getKey(), entry.getValue())).iterator();
    }

    static final Map<Integer, PSPathogenTaxonScore> computeNormalizedScores(Map<Integer, PSPathogenTaxonScore> taxIdsToScores, PSTree tree, boolean notNormalizedByKingdom) {
        HashMap<Integer, Double> normalizationSums = new HashMap<Integer, Double>();
        PSScorer.assignKingdoms(taxIdsToScores, normalizationSums, tree, notNormalizedByKingdom);
        for (Map.Entry<Integer, PSPathogenTaxonScore> entry : taxIdsToScores.entrySet()) {
            int taxId = entry.getKey();
            double selfScore = entry.getValue().getSelfScore();
            int kingdomTaxonId = entry.getValue().getKingdomTaxonId();
            double kingdomSum = (Double)normalizationSums.get(kingdomTaxonId);
            double normalizedScore = kingdomSum == 0.0 ? 0.0 : 100.0 * selfScore / kingdomSum;
            List<Integer> path = tree.getPathOf(taxId);
            for (int pathTaxId : path) {
                taxIdsToScores.get(pathTaxId).addScoreNormalized(normalizedScore);
            }
        }
        return taxIdsToScores;
    }

    private static void assignKingdoms(Map<Integer, PSPathogenTaxonScore> taxIdsToScores, Map<Integer, Double> normalizationSums, PSTree tree, boolean notNormalizedByKingdom) {
        block0: for (Map.Entry<Integer, PSPathogenTaxonScore> entry : taxIdsToScores.entrySet()) {
            int taxonId = entry.getKey();
            PSPathogenTaxonScore score = entry.getValue();
            if (!notNormalizedByKingdom) {
                List<Integer> path = tree.getPathOf(taxonId);
                for (int nodeId : path) {
                    if (!tree.getRankOf(nodeId).equals("kingdom") && !tree.getRankOf(nodeId).equals("superkingdom") && nodeId != 1) continue;
                    double sum = normalizationSums.getOrDefault(nodeId, 0.0);
                    normalizationSums.put(nodeId, sum + score.getSelfScore());
                    score.setKingdomTaxonId(nodeId);
                    continue block0;
                }
                continue;
            }
            double sum = normalizationSums.getOrDefault(1, 0.0);
            normalizationSums.put(1, sum + score.getSelfScore());
            score.setKingdomTaxonId(1);
        }
    }

    private static PSPathogenTaxonScore getOrAddScoreInfo(int taxId, Map<Integer, PSPathogenTaxonScore> taxScores, PSTree tree) {
        PSPathogenTaxonScore score;
        if (taxScores.containsKey(taxId)) {
            score = taxScores.get(taxId);
        } else {
            score = new PSPathogenTaxonScore();
            score.setReferenceLength(tree.getLengthOf(taxId));
            taxScores.put(taxId, score);
        }
        return score;
    }

    public static PSTaxonomyDatabase readTaxonomyDatabase(String filePath) {
        Kryo kryo = new Kryo();
        kryo.setReferences(false);
        Input input = new Input(BucketUtils.openFile(filePath));
        PSTaxonomyDatabase taxonomyDatabase = (PSTaxonomyDatabase)kryo.readObject(input, PSTaxonomyDatabase.class);
        input.close();
        return taxonomyDatabase;
    }

    public static void writeScoresFile(Map<Integer, PSPathogenTaxonScore> scores, PSTree tree, String filePath) {
        String header = "tax_id\ttaxonomy\ttype\tname\t" + PSPathogenTaxonScore.outputHeader;
        try (PrintStream printStream = new PrintStream(BucketUtils.createFile(filePath));){
            printStream.println(header);
            for (int key : scores.keySet()) {
                String name = tree.getNameOf(key);
                String rank = tree.getRankOf(key);
                List path = tree.getPathOf(key).stream().map(tree::getNameOf).collect(Collectors.toList());
                Collections.reverse(path);
                String taxonomy = String.join((CharSequence)"|", path);
                String line = key + "\t" + taxonomy + "\t" + rank + "\t" + name + "\t" + scores.get(key).toString(tree);
                printStream.println(line.replace(" ", "_"));
            }
            if (printStream.checkError()) {
                throw new UserException.CouldNotCreateOutputFile(filePath, (Exception)new IOException());
            }
        }
    }

    private static final class PSPathogenHitAlignment {
        private final int numMatches;
        private final String accession;
        private final Cigar cigar;

        public PSPathogenHitAlignment(int numMatches, String accession, Cigar cigar) {
            this.numMatches = numMatches;
            this.accession = accession;
            this.cigar = cigar;
        }

        public int getNumMatches() {
            return this.numMatches;
        }

        public String getAccession() {
            return this.accession;
        }

        public Cigar getCigar() {
            return this.cigar;
        }
    }
}

