/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.spark.sv.evidence;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.util.SequenceUtil;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.AlignedAssemblyOrExcuse;
import org.broadinstitute.hellbender.tools.spark.sv.evidence.FindBreakpointEvidenceSpark;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVFastqUtils;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVKmer;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVKmerShort;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVKmerizer;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVUtils;
import org.broadinstitute.hellbender.tools.spark.utils.HopscotchMultiMap;
import org.broadinstitute.hellbender.utils.BaseUtils;
import org.broadinstitute.hellbender.utils.bwa.BwaMemAligner;
import org.broadinstitute.hellbender.utils.bwa.BwaMemIndexCache;
import org.broadinstitute.hellbender.utils.fermi.FermiLiteAssembler;
import org.broadinstitute.hellbender.utils.fermi.FermiLiteAssembly;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import scala.Tuple2;

public final class FermiLiteAssemblyHandler
implements FindBreakpointEvidenceSpark.LocalAssemblyHandler {
    private static final long serialVersionUID = 1L;
    private final String alignerIndexFile;
    private final int maxFastqSize;
    private final String fastqDir;
    private final boolean writeGFAs;
    private final boolean popVariantBubbles;
    private final boolean removeShadowedContigs;
    private final boolean expandAssemblyGraph;
    private final int zDropoff;

    public FermiLiteAssemblyHandler(String alignerIndexFile, int maxFastqSize, String fastqDir, boolean writeGFAs, boolean popVariantBubbles, boolean removeShadowedContigs, boolean expandAssemblyGraph, int zDropoff) {
        this.alignerIndexFile = alignerIndexFile;
        this.maxFastqSize = maxFastqSize;
        this.fastqDir = fastqDir;
        this.writeGFAs = writeGFAs;
        this.popVariantBubbles = popVariantBubbles;
        this.removeShadowedContigs = removeShadowedContigs;
        this.expandAssemblyGraph = expandAssemblyGraph;
        this.zDropoff = zDropoff;
    }

    @Override
    public AlignedAssemblyOrExcuse apply(Tuple2<Integer, List<SVFastqUtils.FastqRead>> intervalAndReads) {
        int intervalID = (Integer)intervalAndReads._1();
        String assemblyName = AlignedAssemblyOrExcuse.formatAssemblyID(intervalID);
        List readsList = (List)intervalAndReads._2();
        int fastqSize = readsList.stream().mapToInt(FastqRead2 -> FastqRead2.getBases().length).sum();
        if (fastqSize > this.maxFastqSize) {
            return new AlignedAssemblyOrExcuse(intervalID, "no assembly -- too big (" + fastqSize + " bytes).");
        }
        if (this.fastqDir != null) {
            String fastqName = String.format("%s/%s.fastq", this.fastqDir, assemblyName);
            ArrayList<SVFastqUtils.FastqRead> sortedReads = new ArrayList<SVFastqUtils.FastqRead>(readsList);
            sortedReads.sort(Comparator.comparing(SVFastqUtils.FastqRead::getHeader));
            SVFastqUtils.writeFastqFile(fastqName, sortedReads.iterator());
        }
        FermiLiteAssembler assembler = new FermiLiteAssembler();
        if (this.popVariantBubbles) {
            int MAG_F_AGGRESSIVE = 32;
            int MAG_F_POPOPEN = 64;
            int MAG_F_NO_SIMPL = 128;
            assembler.setCleaningFlag(96);
        }
        long timeStart = System.currentTimeMillis();
        FermiLiteAssembly initialAssembly = assembler.createAssembly((Iterable)readsList);
        int secondsInAssembly = (int)((System.currentTimeMillis() - timeStart + 500L) / 1000L);
        if (initialAssembly.getNContigs() == 0) {
            return new AlignedAssemblyOrExcuse(intervalID, "no assembly -- no contigs produced by assembler.");
        }
        FermiLiteAssembly assembly = FermiLiteAssemblyHandler.reviseAssembly(initialAssembly, this.removeShadowedContigs, this.expandAssemblyGraph);
        if (this.fastqDir != null && this.writeGFAs) {
            String gfaName = String.format("%s/%s.gfa", this.fastqDir, assemblyName);
            try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(BucketUtils.createFile(gfaName)));){
                assembly.writeGFA((Writer)writer);
            }
            catch (IOException ioe) {
                throw new GATKException("Can't write " + gfaName, ioe);
            }
        }
        try (BwaMemAligner aligner = new BwaMemAligner(BwaMemIndexCache.getInstance(this.alignerIndexFile));){
            aligner.setIntraCtgOptions();
            aligner.setZDropOption(this.zDropoff);
            List sequences = assembly.getContigs().stream().map(FermiLiteAssembly.Contig::getSequence).collect(SVUtils.arrayListCollector(assembly.getNContigs()));
            List alignments = aligner.alignSeqs(sequences);
            AlignedAssemblyOrExcuse alignedAssemblyOrExcuse = new AlignedAssemblyOrExcuse(intervalID, assembly, secondsInAssembly, alignments);
            return alignedAssemblyOrExcuse;
        }
    }

    @VisibleForTesting
    static FermiLiteAssembly reviseAssembly(FermiLiteAssembly initialAssembly, boolean removeShadowedContigs, boolean expandAssemblyGraph) {
        FermiLiteAssembly unshadowedAssembly = removeShadowedContigs ? FermiLiteAssemblyHandler.removeShadowedContigs(initialAssembly) : initialAssembly;
        return expandAssemblyGraph ? FermiLiteAssemblyHandler.expandAssemblyGraph(FermiLiteAssemblyHandler.removeUnbranchedConnections(unshadowedAssembly)) : unshadowedAssembly;
    }

    @VisibleForTesting
    static FermiLiteAssembly removeShadowedContigs(FermiLiteAssembly assembly) {
        int kmerSize = 31;
        double maxMismatchRate = 0.01;
        int capacity = assembly.getContigs().stream().mapToInt(tig -> tig.getSequence().length - 31 + 1).sum();
        HopscotchMultiMap kmerMap = new HopscotchMultiMap(capacity);
        assembly.getContigs().forEach(tig -> {
            int contigOffset = 0;
            SVKmerizer contigKmerItr = new SVKmerizer(tig.getSequence(), 31, (SVKmer)new SVKmerShort());
            while (contigKmerItr.hasNext()) {
                SVKmerShort kmer = (SVKmerShort)contigKmerItr.next();
                SVKmerShort canonicalKmer = kmer.canonical(31);
                ContigLocation location = new ContigLocation((FermiLiteAssembly.Contig)tig, contigOffset++, kmer.equals(canonicalKmer));
                kmerMap.add(new KmerLocation(canonicalKmer, location));
            }
        });
        HashSet contigsToRemove = new HashSet();
        assembly.getContigs().forEach(tig -> {
            HashSet<ContigLocation> testedLocations = new HashSet<ContigLocation>();
            byte[] tigBases = tig.getSequence();
            int maxMismatches = (int)((double)tigBases.length * 0.01);
            int tigOffset = 0;
            SVKmerizer contigKmerItr = new SVKmerizer(tig.getSequence(), 31, (SVKmer)new SVKmerShort(31));
            while (contigKmerItr.hasNext()) {
                SVKmerShort contigKmer = (SVKmerShort)contigKmerItr.next();
                SVKmerShort canonicalContigKmer = contigKmer.canonical(31);
                boolean contigKmerIsCanonical = contigKmer.equals(canonicalContigKmer);
                Iterator locItr = kmerMap.findEach(canonicalContigKmer);
                while (locItr.hasNext()) {
                    int tig2Start;
                    ContigLocation tig2Location = ((KmerLocation)locItr.next()).getLocation();
                    FermiLiteAssembly.Contig tig2 = tig2Location.getContig();
                    if (tig == tig2 || contigsToRemove.contains(tig2)) continue;
                    byte[] tig2Bases = tig2.getSequence();
                    boolean isRC = contigKmerIsCanonical != tig2Location.isCanonical();
                    int tig2Offset = isRC ? tig2Bases.length - tig2Location.getOffset() - 31 : tig2Location.getOffset();
                    if (tigOffset > tig2Offset || tigBases.length - tigOffset > tig2Bases.length - tig2Offset || !testedLocations.add(new ContigLocation(tig2, tig2Start = tig2Offset - tigOffset, isRC))) continue;
                    int nMismatches = 0;
                    if (!isRC) {
                        for (int idx = 0; idx != tigBases.length && (tigBases[idx] == tig2Bases[tig2Start + idx] || ++nMismatches <= maxMismatches); ++idx) {
                        }
                    } else {
                        int tig2RCOffset = tig2Bases.length - tig2Start - 1;
                        for (int idx = 0; idx != tigBases.length && (tigBases[idx] == BaseUtils.simpleComplement(tig2Bases[tig2RCOffset - idx]) || ++nMismatches <= maxMismatches); ++idx) {
                        }
                    }
                    if (nMismatches > maxMismatches) continue;
                    contigsToRemove.add(tig);
                    break;
                }
                ++tigOffset;
            }
        });
        ArrayList contigList = new ArrayList(assembly.getContigs().size() - contigsToRemove.size());
        assembly.getContigs().stream().filter(tig -> !contigsToRemove.contains(tig)).forEach(contigList::add);
        HashSet staleConnectionContigs = new HashSet(SVUtils.hashMapCapacity(contigsToRemove.size()));
        contigsToRemove.forEach(tig -> tig.getConnections().forEach(conn -> staleConnectionContigs.add(conn.getTarget())));
        staleConnectionContigs.forEach(tig -> {
            ArrayList connections = new ArrayList(tig.getConnections().size() - 1);
            tig.getConnections().stream().filter(conn -> !contigsToRemove.contains(conn.getTarget())).forEach(connections::add);
            tig.setConnections(connections);
        });
        return new FermiLiteAssembly(contigList);
    }

    @VisibleForTesting
    static FermiLiteAssembly removeUnbranchedConnections(FermiLiteAssembly assembly) {
        int nContigs = assembly.getNContigs();
        ArrayList contigList = new ArrayList(nContigs);
        HashSet examined = new HashSet(SVUtils.hashMapCapacity(nContigs));
        assembly.getContigs().forEach(tig -> {
            FermiLiteAssembly.Connection conn2;
            FermiLiteAssembly.Connection conn;
            if (!examined.add(tig)) {
                return;
            }
            while ((conn = tig.getSolePredecessor()) != null && !examined.contains(conn.getTarget()) && (conn2 = conn.getTarget().getSingletonConnection(!conn.isTargetRC())) != null) {
                examined.add(conn.getTarget());
                tig = FermiLiteAssemblyHandler.joinContigsWithConnections(tig, conn, conn2);
            }
            while ((conn = tig.getSoleSuccessor()) != null && !examined.contains(conn.getTarget()) && (conn2 = conn.getTarget().getSingletonConnection(!conn.isTargetRC())) != null) {
                examined.add(conn.getTarget());
                tig = FermiLiteAssemblyHandler.joinContigsWithConnections(tig, conn, conn2);
            }
            contigList.add(tig);
        });
        return new FermiLiteAssembly(contigList);
    }

    private static FermiLiteAssembly.Contig joinContigsWithConnections(FermiLiteAssembly.Contig firstContig, FermiLiteAssembly.Connection connection, FermiLiteAssembly.Connection rcConnection) {
        FermiLiteAssembly.Connection newConnection;
        FermiLiteAssembly.Contig joinedContig = FermiLiteAssemblyHandler.joinContigs(firstContig, Collections.singletonList(connection));
        FermiLiteAssembly.Contig lastContig = connection.getTarget();
        int capacity = firstContig.getConnections().size() + lastContig.getConnections().size() - 2;
        ArrayList<FermiLiteAssembly.Connection> connections = new ArrayList<FermiLiteAssembly.Connection>(capacity);
        for (FermiLiteAssembly.Connection conn : firstContig.getConnections()) {
            if (conn == connection) continue;
            newConnection = new FermiLiteAssembly.Connection(conn.getTarget(), conn.getOverlapLen(), true, conn.isTargetRC());
            FermiLiteAssemblyHandler.replaceConnection(conn.getTarget(), conn.rcConnection(firstContig), newConnection.rcConnection(joinedContig));
            connections.add(newConnection);
        }
        for (FermiLiteAssembly.Connection conn : lastContig.getConnections()) {
            if (conn == rcConnection) continue;
            newConnection = new FermiLiteAssembly.Connection(conn.getTarget(), conn.getOverlapLen(), false, conn.isTargetRC());
            FermiLiteAssemblyHandler.replaceConnection(conn.getTarget(), conn.rcConnection(lastContig), newConnection.rcConnection(joinedContig));
            connections.add(newConnection);
        }
        joinedContig.setConnections(connections);
        return joinedContig;
    }

    private static void replaceConnection(FermiLiteAssembly.Contig contig, FermiLiteAssembly.Connection oldConnection, FermiLiteAssembly.Connection newConnection) {
        List oldConnections = contig.getConnections();
        ArrayList<FermiLiteAssembly.Connection> newConnections = new ArrayList<FermiLiteAssembly.Connection>(oldConnections.size());
        for (FermiLiteAssembly.Connection conn : oldConnections) {
            FermiLiteAssembly.Connection toAdd = conn.getTarget() == oldConnection.getTarget() && conn.isRC() == oldConnection.isRC() && conn.isTargetRC() == oldConnection.isTargetRC() ? newConnection : conn;
            newConnections.add(toAdd);
        }
        contig.setConnections(newConnections);
    }

    private static FermiLiteAssembly.Contig joinContigs(FermiLiteAssembly.Contig firstContig, List<FermiLiteAssembly.Connection> path) {
        if (path.isEmpty()) {
            return firstContig;
        }
        int nSupportingReads = path.stream().mapToInt(conn -> conn.getTarget().getNSupportingReads()).reduce(firstContig.getNSupportingReads(), Integer::sum);
        int newContigLen = path.stream().mapToInt(conn -> conn.getTarget().getSequence().length - conn.getOverlapLen()).reduce(firstContig.getSequence().length, Integer::sum);
        byte[] sequence = new byte[newContigLen];
        int destinationOffset = firstContig.getSequence().length;
        System.arraycopy(firstContig.getSequence(), 0, sequence, 0, destinationOffset);
        if (path.get(0).isRC()) {
            SequenceUtil.reverseComplement((byte[])sequence, (int)0, (int)destinationOffset);
        }
        for (FermiLiteAssembly.Connection conn2 : path) {
            byte[] contigSequence = conn2.getTarget().getSequence();
            int len = contigSequence.length - conn2.getOverlapLen();
            if (!conn2.isTargetRC()) {
                System.arraycopy(contigSequence, conn2.getOverlapLen(), sequence, destinationOffset, len);
            } else {
                System.arraycopy(contigSequence, 0, sequence, destinationOffset, len);
                SequenceUtil.reverseComplement((byte[])sequence, (int)destinationOffset, (int)len);
            }
            destinationOffset += len;
        }
        return new FermiLiteAssembly.Contig(sequence, null, nSupportingReads);
    }

    @VisibleForTesting
    static FermiLiteAssembly expandAssemblyGraph(FermiLiteAssembly assembly) {
        int nContigs = assembly.getNContigs();
        ArrayList contigList = new ArrayList(nContigs);
        HashSet visited = new HashSet();
        HashSet examined = new HashSet(SVUtils.hashMapCapacity(nContigs));
        assembly.getContigs().forEach(tig -> {
            if (examined.contains(tig)) {
                return;
            }
            if (tig.getConnections().isEmpty()) {
                contigList.add(tig);
                examined.add(tig);
            } else {
                int nPredecessors = FermiLiteAssemblyHandler.countPredecessors(tig);
                int nSuccessors = tig.getConnections().size() - nPredecessors;
                if (nPredecessors == 0) {
                    FermiLiteAssemblyHandler.tracePaths(tig, false, contigList, examined, visited);
                } else if (nSuccessors == 0) {
                    FermiLiteAssemblyHandler.tracePaths(tig, true, contigList, examined, visited);
                }
            }
        });
        assembly.getContigs().forEach(tig -> {
            if (!examined.contains(tig)) {
                FermiLiteAssemblyHandler.tracePaths(tig, false, contigList, examined, visited);
            }
        });
        return new FermiLiteAssembly(contigList);
    }

    private static int countPredecessors(FermiLiteAssembly.Contig contig) {
        return contig.getConnections().stream().mapToInt(conn -> conn.isRC() ? 1 : 0).sum();
    }

    private static void tracePaths(FermiLiteAssembly.Contig contig, boolean isRC, List<FermiLiteAssembly.Contig> contigList, Set<FermiLiteAssembly.Contig> examined, Set<ContigStrand> visited) {
        examined.add(contig);
        LinkedList<FermiLiteAssembly.Connection> path = new LinkedList<FermiLiteAssembly.Connection>();
        ContigStrand contigStrand = new ContigStrand(contig, isRC);
        boolean isCycle = !visited.add(contigStrand);
        for (FermiLiteAssembly.Connection connection : contig.getConnections()) {
            if (connection.isRC() != isRC || connection.getOverlapLen() < 0) continue;
            FermiLiteAssemblyHandler.extendPath(contig, connection, path, contigList, examined, visited);
        }
        if (!isCycle) {
            visited.remove(contigStrand);
        }
    }

    private static void extendPath(FermiLiteAssembly.Contig firstContig, FermiLiteAssembly.Connection connection, LinkedList<FermiLiteAssembly.Connection> path, List<FermiLiteAssembly.Contig> contigList, Set<FermiLiteAssembly.Contig> examined, Set<ContigStrand> visited) {
        path.addLast(connection);
        ContigStrand contigStrand = new ContigStrand(connection.getTarget(), connection.isTargetRC());
        boolean isCycle = !visited.add(contigStrand);
        boolean atEndOfPath = true;
        if (!isCycle) {
            boolean needsPhasing;
            FermiLiteAssembly.Contig target = connection.getTarget();
            int nPredecessors = FermiLiteAssemblyHandler.countPredecessors(target);
            int nSuccessors = target.getConnections().size() - nPredecessors;
            boolean bl = needsPhasing = nPredecessors > 1 && nSuccessors > 1;
            if (needsPhasing) {
                if (examined.add(target)) {
                    FermiLiteAssemblyHandler.tracePaths(target, connection.isTargetRC(), contigList, examined, visited);
                }
            } else {
                examined.add(target);
                for (FermiLiteAssembly.Connection conn : target.getConnections()) {
                    if (conn.isRC() != connection.isTargetRC() || conn.getOverlapLen() < 0) continue;
                    FermiLiteAssemblyHandler.extendPath(firstContig, conn, path, contigList, examined, visited);
                    atEndOfPath = false;
                }
            }
        }
        if (atEndOfPath) {
            contigList.add(FermiLiteAssemblyHandler.joinContigs(firstContig, path));
        }
        if (!isCycle) {
            visited.remove(contigStrand);
        }
        path.removeLast();
    }

    private static final class ContigStrand {
        private final FermiLiteAssembly.Contig contig;
        private final boolean isRC;

        public ContigStrand(FermiLiteAssembly.Contig contig, boolean isRC) {
            this.contig = contig;
            this.isRC = isRC;
        }

        public FermiLiteAssembly.Contig getContig() {
            return this.contig;
        }

        public boolean isRC() {
            return this.isRC;
        }

        public ContigStrand rc() {
            return new ContigStrand(this.contig, !this.isRC);
        }

        public boolean equals(Object obj) {
            return obj instanceof ContigStrand && this.equals((ContigStrand)obj);
        }

        public boolean equals(ContigStrand that) {
            if (this == that) {
                return true;
            }
            return this.contig == that.contig && this.isRC == that.isRC;
        }

        public int hashCode() {
            return this.isRC ? -this.contig.hashCode() : this.contig.hashCode();
        }
    }

    private static final class KmerLocation
    implements Map.Entry<SVKmerShort, ContigLocation> {
        private final SVKmerShort kmer;
        private final ContigLocation location;

        public KmerLocation(SVKmerShort kmer, ContigLocation location) {
            this.kmer = kmer;
            this.location = location;
        }

        public SVKmerShort getKmer() {
            return this.kmer;
        }

        public ContigLocation getLocation() {
            return this.location;
        }

        @Override
        public SVKmerShort getKey() {
            return this.kmer;
        }

        @Override
        public ContigLocation getValue() {
            return this.location;
        }

        @Override
        public ContigLocation setValue(ContigLocation value) {
            throw new UnsupportedOperationException("KmerLocation is immutable");
        }
    }

    private static final class ContigLocation {
        private final FermiLiteAssembly.Contig contig;
        private final int offset;
        private final boolean canonical;

        public ContigLocation(FermiLiteAssembly.Contig contig, int offset, boolean canonical) {
            this.contig = contig;
            this.offset = offset;
            this.canonical = canonical;
        }

        public FermiLiteAssembly.Contig getContig() {
            return this.contig;
        }

        public int getOffset() {
            return this.offset;
        }

        public boolean isCanonical() {
            return this.canonical;
        }

        public boolean equals(Object obj) {
            return obj instanceof ContigLocation && this.equals((ContigLocation)obj);
        }

        public boolean equals(ContigLocation that) {
            if (this == that) {
                return true;
            }
            return this.contig == that.contig && this.offset == that.offset && this.canonical == that.canonical;
        }

        public int hashCode() {
            return 47 * (47 * (this.contig.hashCode() + 47 * this.offset) + (this.canonical ? 31 : 5));
        }
    }
}

