/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs;

import com.google.common.collect.Sets;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.BaseEdge;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.BaseVertex;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.SeqGraph;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.graphs.SeqVertex;
import org.broadinstitute.hellbender.utils.Utils;
import org.jgrapht.DirectedGraph;
import org.jgrapht.EdgeFactory;
import org.jgrapht.alg.CycleDetector;
import org.jgrapht.graph.DefaultDirectedGraph;

public abstract class BaseGraph<V extends BaseVertex, E extends BaseEdge>
extends DefaultDirectedGraph<V, E> {
    private static final long serialVersionUID = 1L;
    protected final int kmerSize;

    protected BaseGraph(int kmerSize, EdgeFactory<V, E> edgeFactory) {
        super(edgeFactory);
        Utils.validateArg(kmerSize > 0, () -> "kmerSize must be > 0 but got " + kmerSize);
        this.kmerSize = kmerSize;
    }

    public final int getKmerSize() {
        return this.kmerSize;
    }

    public final boolean isReferenceNode(V v) {
        Utils.nonNull(v, "Attempting to test a null vertex.");
        if (this.edgesOf(v).stream().anyMatch(e -> e.isRef())) {
            return true;
        }
        return this.vertexSet().size() == 1;
    }

    public final boolean isSource(V v) {
        Utils.nonNull(v, "Attempting to test a null vertex.");
        return this.inDegreeOf(v) == 0;
    }

    public final boolean isSink(V v) {
        Utils.nonNull(v, "Attempting to test a null vertex.");
        return this.outDegreeOf(v) == 0;
    }

    public final LinkedHashSet<V> getSources() {
        return this.vertexSet().stream().filter(v -> this.isSource(v)).collect(Collectors.toCollection(LinkedHashSet::new));
    }

    public final LinkedHashSet<V> getSinks() {
        return this.vertexSet().stream().filter(v -> this.isSink(v)).collect(Collectors.toCollection(LinkedHashSet::new));
    }

    public SeqGraph toSequenceGraph() {
        SeqGraph seqGraph = new SeqGraph(this.kmerSize);
        HashMap<BaseVertex, SeqVertex> vertexMap = new HashMap<BaseVertex, SeqVertex>();
        for (BaseVertex dv : this.vertexSet()) {
            SeqVertex sv = new SeqVertex(dv.getAdditionalSequence(this.isSource(dv)));
            sv.setAdditionalInfo(dv.getAdditionalInfo());
            vertexMap.put(dv, sv);
            seqGraph.addVertex(sv);
        }
        for (BaseEdge e : this.edgeSet()) {
            SeqVertex seqInV = (SeqVertex)vertexMap.get(this.getEdgeSource(e));
            SeqVertex seqOutV = (SeqVertex)vertexMap.get(this.getEdgeTarget(e));
            seqGraph.addEdge(seqInV, seqOutV, new BaseEdge(e.isRef(), e.getMultiplicity()));
        }
        return seqGraph;
    }

    public final byte[] getAdditionalSequence(V v) {
        Utils.nonNull(v, "Attempting to pull sequence from a null vertex.");
        return ((BaseVertex)v).getAdditionalSequence(this.isSource(v));
    }

    public static final byte[] getAdditionalSequence(BaseVertex v, boolean isSource) {
        Utils.nonNull(v, "Attempting to pull sequence from a null vertex.");
        return v.getAdditionalSequence(isSource);
    }

    public final boolean isRefSource(V v) {
        Utils.nonNull(v, "Attempting to pull sequence from a null vertex.");
        if (this.incomingEdgesOf(v).stream().anyMatch(e -> e.isRef())) {
            return false;
        }
        if (this.outgoingEdgesOf(v).stream().anyMatch(e -> e.isRef())) {
            return true;
        }
        return this.vertexSet().size() == 1;
    }

    public final boolean isRefSink(V v) {
        Utils.nonNull(v, "Attempting to pull sequence from a null vertex.");
        if (this.outgoingEdgesOf(v).stream().anyMatch(e -> e.isRef())) {
            return false;
        }
        if (this.incomingEdgesOf(v).stream().anyMatch(e -> e.isRef())) {
            return true;
        }
        return this.vertexSet().size() == 1;
    }

    public V getReferenceSourceVertex() {
        return (V)((BaseVertex)this.vertexSet().stream().filter(v -> this.isRefSource(v)).findFirst().orElse(null));
    }

    public V getReferenceSinkVertex() {
        return (V)((BaseVertex)this.vertexSet().stream().filter(v -> this.isRefSink(v)).findFirst().orElse(null));
    }

    public final V getNextReferenceVertex(V v) {
        return this.getNextReferenceVertex(v, false, Optional.empty());
    }

    public final V getNextReferenceVertex(V v, boolean allowNonRefPaths, Optional<E> blacklistedEdge) {
        if (v == null) {
            return null;
        }
        Set outgoingEdges = this.outgoingEdgesOf(v);
        if (outgoingEdges.isEmpty()) {
            return null;
        }
        for (BaseEdge edgeToTest : outgoingEdges) {
            if (!edgeToTest.isRef()) continue;
            return (V)((BaseVertex)this.getEdgeTarget(edgeToTest));
        }
        if (!allowNonRefPaths) {
            return null;
        }
        Set blacklistedEdgeSet = blacklistedEdge.isPresent() ? Collections.singleton(blacklistedEdge.get()) : Collections.emptySet();
        List edges = outgoingEdges.stream().filter(e -> !blacklistedEdgeSet.contains(e)).limit(2L).collect(Collectors.toList());
        return (V)(edges.size() == 1 ? (BaseVertex)this.getEdgeTarget(edges.get(0)) : null);
    }

    public final V getPrevReferenceVertex(V v) {
        if (v == null) {
            return null;
        }
        return (V)((BaseVertex)this.incomingEdgesOf(v).stream().map(e -> (BaseVertex)this.getEdgeSource(e)).filter(vrtx -> this.isReferenceNode(vrtx)).findFirst().orElse(null));
    }

    public byte[] getReferenceBytes(V fromVertex, V toVertex, boolean includeStart, boolean includeStop) {
        Utils.nonNull(fromVertex, "Starting vertex in requested path cannot be null.");
        Utils.nonNull(toVertex, "From vertex in requested path cannot be null.");
        byte[] bytes = null;
        V v = fromVertex;
        if (includeStart) {
            bytes = ArrayUtils.addAll(bytes, (byte[])this.getAdditionalSequence(v));
        }
        v = this.getNextReferenceVertex(v);
        while (v != null && !((BaseVertex)v).equals(toVertex)) {
            bytes = ArrayUtils.addAll((byte[])bytes, (byte[])this.getAdditionalSequence(v));
            v = this.getNextReferenceVertex(v);
        }
        if (includeStop && v != null && ((BaseVertex)v).equals(toVertex)) {
            bytes = ArrayUtils.addAll((byte[])bytes, (byte[])this.getAdditionalSequence(v));
        }
        return bytes;
    }

    @SafeVarargs
    public final void addVertices(V ... vertices) {
        Utils.nonNull(vertices);
        this.addVertices((Collection<V>)Arrays.asList(vertices));
    }

    public final void addVertices(Collection<V> vertices) {
        Utils.nonNull(vertices);
        vertices.forEach(v -> this.addVertex(v));
    }

    @SafeVarargs
    public final void addEdges(V start, V ... remaining) {
        Utils.nonNull(start, "start vertex");
        if (remaining == null || remaining.length == 0) {
            return;
        }
        V prev = start;
        for (V next : remaining) {
            Utils.nonNull(next, "null vertex");
            this.addEdge(prev, next);
            prev = next;
        }
    }

    @SafeVarargs
    public final void addEdges(Supplier<E> template, V start, V ... remaining) {
        Utils.nonNull(template, "template edge");
        Utils.nonNull(start, "start vertex");
        V prev = start;
        for (V next : remaining) {
            Utils.nonNull(next, "null vertex");
            this.addEdge(prev, next, template.get());
            prev = next;
        }
    }

    public final Set<V> outgoingVerticesOf(V v) {
        Utils.nonNull(v);
        return this.outgoingEdgesOf(v).stream().map(e -> (BaseVertex)this.getEdgeTarget(e)).collect(Collectors.toCollection(LinkedHashSet::new));
    }

    public final Set<V> incomingVerticesOf(V v) {
        Utils.nonNull(v);
        return this.incomingEdgesOf(v).stream().map(e -> (BaseVertex)this.getEdgeSource(e)).collect(Collectors.toCollection(LinkedHashSet::new));
    }

    public final Set<V> neighboringVerticesOf(V v) {
        Utils.nonNull(v);
        return Sets.union(this.incomingVerticesOf(v), this.outgoingVerticesOf(v));
    }

    public final void printGraph(File destination, int pruneFactor) {
        try (PrintStream stream = new PrintStream(new FileOutputStream(destination));){
            this.printGraph(stream, true, pruneFactor);
        }
        catch (FileNotFoundException e) {
            throw new UserException.CouldNotReadInputFile(destination.getAbsolutePath(), (Exception)e);
        }
    }

    public final void printGraph(PrintStream graphWriter, boolean writeHeader, int pruneFactor) {
        if (writeHeader) {
            graphWriter.println("digraph assemblyGraphs {");
        }
        for (BaseEdge edge : this.edgeSet()) {
            String edgeString = String.format("\t%s -> %s ", ((BaseVertex)this.getEdgeSource(edge)).toString(), ((BaseVertex)this.getEdgeTarget(edge)).toString());
            String edgeLabelString = edge.getMultiplicity() > 0 && edge.getMultiplicity() < pruneFactor ? String.format("[style=dotted,color=grey,label=\"%s\"];", edge.getDotLabel()) : String.format("[label=\"%s\"];", edge.getDotLabel());
            graphWriter.print(edgeString);
            graphWriter.print(edgeLabelString);
            if (!edge.isRef()) continue;
            graphWriter.println(edgeString + " [color=red];");
        }
        for (BaseVertex v : this.vertexSet()) {
            graphWriter.println(String.format("\t%s [label=\"%s\",shape=box]", v.toString(), new String(this.getAdditionalSequence(v)) + v.getAdditionalInfo()));
        }
        this.getExtraGraphFileLines().forEach(graphWriter::println);
        if (writeHeader) {
            graphWriter.println("}");
        }
    }

    public List<String> getExtraGraphFileLines() {
        return Collections.emptyList();
    }

    public final void cleanNonRefPaths() {
        BaseEdge e;
        if (this.getReferenceSourceVertex() == null || this.getReferenceSinkVertex() == null) {
            return;
        }
        HashSet edgesToCheck = new HashSet();
        edgesToCheck.addAll(this.incomingEdgesOf(this.getReferenceSourceVertex()));
        while (!edgesToCheck.isEmpty()) {
            e = (BaseEdge)edgesToCheck.iterator().next();
            if (!e.isRef()) {
                edgesToCheck.addAll(this.incomingEdgesOf(this.getEdgeSource(e)));
                this.removeEdge(e);
            }
            edgesToCheck.remove(e);
        }
        edgesToCheck.addAll(this.outgoingEdgesOf(this.getReferenceSinkVertex()));
        while (!edgesToCheck.isEmpty()) {
            e = (BaseEdge)edgesToCheck.iterator().next();
            if (!e.isRef()) {
                edgesToCheck.addAll(this.outgoingEdgesOf(this.getEdgeTarget(e)));
                this.removeEdge(e);
            }
            edgesToCheck.remove(e);
        }
        this.removeSingletonOrphanVertices();
    }

    public void removeSingletonOrphanVertices() {
        List toRemove = this.vertexSet().stream().filter(v -> this.isSingletonOrphan(v)).collect(Collectors.toList());
        this.removeAllVertices(toRemove);
    }

    private boolean isSingletonOrphan(V v) {
        Utils.nonNull(v);
        return this.inDegreeOf(v) == 0 && this.outDegreeOf(v) == 0 && !this.isRefSource(v);
    }

    public final void removeVerticesNotConnectedToRefRegardlessOfEdgeDirection() {
        HashSet toRemove = new HashSet(this.vertexSet());
        V refV = this.getReferenceSourceVertex();
        if (refV != null) {
            for (BaseVertex v : new BaseGraphIterator(this, (BaseVertex)refV, true, true, null)) {
                toRemove.remove(v);
            }
        }
        this.removeAllVertices(toRemove);
    }

    public final void removePathsNotConnectedToRef() {
        if (this.getReferenceSourceVertex() == null || this.getReferenceSinkVertex() == null) {
            throw new IllegalStateException("Graph must have ref source and sink vertices");
        }
        HashSet<BaseVertex> onPathFromRefSource = new HashSet<BaseVertex>(this.vertexSet().size());
        for (Object v : new BaseGraphIterator(this, (BaseVertex)this.getReferenceSourceVertex(), false, true, null)) {
            onPathFromRefSource.add((BaseVertex)v);
        }
        HashSet<BaseVertex> onPathFromRefSink = new HashSet<BaseVertex>(this.vertexSet().size());
        for (BaseVertex v : new BaseGraphIterator(this, (BaseVertex)this.getReferenceSinkVertex(), true, false, null)) {
            onPathFromRefSink.add(v);
        }
        HashSet verticesToRemove = new HashSet(this.vertexSet());
        onPathFromRefSource.retainAll(onPathFromRefSink);
        verticesToRemove.removeAll(onPathFromRefSource);
        this.removeAllVertices(verticesToRemove);
        if (this.getSinks().size() > 1) {
            throw new IllegalStateException("Should have eliminated all but the reference sink, but found " + this.getSinks());
        }
        if (this.getSources().size() > 1) {
            throw new IllegalStateException("Should have eliminated all but the reference source, but found " + this.getSources());
        }
    }

    public static <T extends BaseVertex, E extends BaseEdge> boolean graphEquals(BaseGraph<T, E> g1, BaseGraph<T, E> g2) {
        Utils.nonNull(g1, "g1");
        Utils.nonNull(g2, "g2");
        Set vertices1 = g1.vertexSet();
        Set vertices2 = g2.vertexSet();
        Set edges1 = g1.edgeSet();
        Set edges2 = g2.edgeSet();
        if (vertices1.size() != vertices2.size() || edges1.size() != edges2.size()) {
            return false;
        }
        boolean ok = vertices1.stream().map(v1 -> v1.getSequenceString()).allMatch(v1seqString -> vertices2.stream().anyMatch(v2 -> v1seqString.equals(v2.getSequenceString())));
        if (!ok) {
            return false;
        }
        boolean okG1 = edges1.stream().allMatch(e1 -> edges2.stream().anyMatch(e2 -> g1.seqEquals(e1, e2, g2)));
        if (!okG1) {
            return false;
        }
        return edges2.stream().allMatch(e2 -> edges1.stream().anyMatch(e1 -> g2.seqEquals(e2, e1, g1)));
    }

    private boolean seqEquals(E edge1, E edge2, BaseGraph<V, E> graph2) {
        return ((BaseVertex)this.getEdgeSource(edge1)).seqEquals((BaseVertex)graph2.getEdgeSource(edge2)) && ((BaseVertex)this.getEdgeTarget(edge1)).seqEquals((BaseVertex)graph2.getEdgeTarget(edge2));
    }

    public final E incomingEdgeOf(V v) {
        Utils.nonNull(v);
        return this.getSingletonEdge(this.incomingEdgesOf(v));
    }

    public final E outgoingEdgeOf(V v) {
        Utils.nonNull(v);
        return this.getSingletonEdge(this.outgoingEdgesOf(v));
    }

    private E getSingletonEdge(Collection<E> edges) {
        Utils.validateArg(edges.size() <= 1, () -> "Cannot get a single incoming edge for a vertex with multiple incoming edges " + edges);
        return (E)(edges.isEmpty() ? null : (BaseEdge)edges.iterator().next());
    }

    public final void addOrUpdateEdge(V source, V target, E e) {
        Utils.nonNull(source, "source");
        Utils.nonNull(target, "target");
        Utils.nonNull(e, "edge");
        BaseEdge prev = (BaseEdge)this.getEdge(source, target);
        if (prev != null) {
            prev.add((BaseEdge)e);
        } else {
            this.addEdge(source, target, e);
        }
    }

    public String toString() {
        return "BaseGraph{kmerSize=" + this.kmerSize + '}';
    }

    private Set<V> verticesWithinDistance(V source, int distance) {
        if (distance == 0) {
            return Collections.singleton(source);
        }
        HashSet<Object> found = new HashSet<Object>();
        found.add(source);
        for (BaseVertex v : this.neighboringVerticesOf(source)) {
            found.addAll(this.verticesWithinDistance(v, distance - 1));
        }
        return found;
    }

    public final BaseGraph<V, E> subsetToNeighbors(V target, int distance) {
        Utils.nonNull(target, "Target cannot be null");
        Utils.validateArg(this.containsVertex(target), () -> "Graph doesn't contain vertex " + target);
        Utils.validateArg(distance >= 0, () -> "Distance must be >= 0 but got " + distance);
        Set<V> toKeep = this.verticesWithinDistance(target, distance);
        HashSet toRemove = new HashSet(this.vertexSet());
        toRemove.removeAll(toKeep);
        BaseGraph<V, E> result = this.clone();
        result.removeAllVertices(toRemove);
        return result;
    }

    public final BaseGraph<V, E> subsetToRefSource(int refSourceNeighborhood) {
        Utils.validateArg(refSourceNeighborhood > 0, () -> "refSourceNeighborhood needs to be positive but was " + refSourceNeighborhood);
        return this.subsetToNeighbors(this.getReferenceSourceVertex(), refSourceNeighborhood);
    }

    public final boolean containsAllVertices(Collection<? extends V> vertices) {
        Utils.nonNull(vertices, "the input vertices collection cannot be null");
        Utils.containsNoNull(vertices, "null vertex");
        return vertices.stream().allMatch(v -> this.containsVertex(v));
    }

    public final boolean hasCycles() {
        return new CycleDetector((DirectedGraph)this).detectCycles();
    }

    public BaseGraph<V, E> clone() {
        return (BaseGraph)((Object)super.clone());
    }

    private static final class BaseGraphIterator<T extends BaseVertex, E extends BaseEdge>
    implements Iterator<T>,
    Iterable<T> {
        final Collection<T> visited = new HashSet<T>();
        final Deque<T> toVisit = new LinkedList<T>();
        final BaseGraph<T, E> graph;
        final boolean followIncomingEdges;
        final boolean followOutgoingEdges;

        private BaseGraphIterator(BaseGraph<T, E> graph, T start, boolean followIncomingEdges, boolean followOutgoingEdges) {
            Utils.nonNull(graph, "graph cannot be null");
            Utils.nonNull(start, "start cannot be null");
            Utils.validateArg(graph.containsVertex(start), () -> "start " + start + " must be in graph but it isn't");
            this.graph = graph;
            this.followIncomingEdges = followIncomingEdges;
            this.followOutgoingEdges = followOutgoingEdges;
            this.toVisit.add(start);
        }

        @Override
        public Iterator<T> iterator() {
            return this;
        }

        @Override
        public boolean hasNext() {
            return !this.toVisit.isEmpty();
        }

        @Override
        public T next() {
            BaseVertex v = (BaseVertex)this.toVisit.pop();
            if (!this.visited.contains(v)) {
                this.visited.add(v);
                if (this.followIncomingEdges) {
                    this.toVisit.addAll(this.graph.incomingVerticesOf(v));
                }
                if (this.followOutgoingEdges) {
                    this.toVisit.addAll(this.graph.outgoingVerticesOf(v));
                }
            }
            return (T)v;
        }

        @Override
        public void remove() {
            throw new UnsupportedOperationException("Doesn't implement remove");
        }

        /* synthetic */ BaseGraphIterator(BaseGraph x0, BaseVertex x1, boolean x2, boolean x3, 1 x4) {
            this(x0, x1, x2, x3);
        }
    }
}

