/*
 * Decompiled with CFR 0.152.
 */
package de.bioforscher.singa.structure.algorithms.superimposition.consensus;

import de.bioforscher.singa.core.utility.Pair;
import de.bioforscher.singa.mathematics.graphs.trees.BinaryTree;
import de.bioforscher.singa.mathematics.graphs.trees.BinaryTreeNode;
import de.bioforscher.singa.mathematics.matrices.LabeledSymmetricMatrix;
import de.bioforscher.singa.structure.algorithms.superimposition.SubstructureSuperimposer;
import de.bioforscher.singa.structure.algorithms.superimposition.SubstructureSuperimposition;
import de.bioforscher.singa.structure.algorithms.superimposition.consensus.ConsensusBuilder;
import de.bioforscher.singa.structure.algorithms.superimposition.consensus.ConsensusContainer;
import de.bioforscher.singa.structure.algorithms.superimposition.consensus.ConsensusException;
import de.bioforscher.singa.structure.algorithms.superimposition.fit3d.representations.RepresentationScheme;
import de.bioforscher.singa.structure.algorithms.superimposition.fit3d.representations.RepresentationSchemeFactory;
import de.bioforscher.singa.structure.algorithms.superimposition.fit3d.representations.RepresentationSchemeType;
import de.bioforscher.singa.structure.model.families.AminoAcidFamily;
import de.bioforscher.singa.structure.model.identifiers.LeafIdentifier;
import de.bioforscher.singa.structure.model.interfaces.Atom;
import de.bioforscher.singa.structure.model.interfaces.LeafSubstructure;
import de.bioforscher.singa.structure.model.interfaces.LeafSubstructureContainer;
import de.bioforscher.singa.structure.model.oak.LeafSubstructureFactory;
import de.bioforscher.singa.structure.model.oak.OakAtom;
import de.bioforscher.singa.structure.model.oak.OakLeafSubstructure;
import de.bioforscher.singa.structure.model.oak.StructuralMotif;
import de.bioforscher.singa.structure.parser.pdb.structures.StructureWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ConsensusAlignment {
    private static final Logger logger = LoggerFactory.getLogger(ConsensusAlignment.class);
    private final List<ConsensusContainer> input;
    private final boolean idealSuperimposition;
    private final List<BinaryTree<ConsensusContainer>> consensusTrees;
    private final List<Double> alignmentTrace;
    private final List<Integer> alignmentCounts;
    private final Predicate<Atom> atomFilter;
    private final boolean alignWithinClusters;
    private final double clusterCutoff;
    private RepresentationScheme representationScheme;
    private double consensusScore;
    private int iterationCounter;
    private TreeMap<SubstructureSuperimposition, Pair<ConsensusContainer>> alignments;
    private LabeledSymmetricMatrix<ConsensusContainer> distanceMatrix;
    private List<BinaryTreeNode<ConsensusContainer>> leaves;
    private ConsensusContainer currentConsensus;
    private List<BinaryTree<ConsensusContainer>> clusters;

    ConsensusAlignment(ConsensusBuilder.Builder builder) {
        this.input = builder.structuralMotifs.stream().map(ConsensusAlignment::toContainer).collect(Collectors.toList());
        logger.info("consensus alignment initialized with {} structures", (Object)this.input.size());
        this.clusterCutoff = builder.clusterCutoff;
        this.alignWithinClusters = builder.alignWithinClusters;
        this.atomFilter = builder.atomFilter;
        RepresentationSchemeType representationSchemeType = builder.representationSchemeType;
        if (representationSchemeType != null) {
            logger.info("using representation scheme {}", (Object)representationSchemeType);
            this.representationScheme = RepresentationSchemeFactory.createRepresentationScheme(representationSchemeType);
        }
        this.idealSuperimposition = builder.idealSuperimposition;
        if (this.input.stream().map(ConsensusContainer::getStructuralMotif).map(StructuralMotif::getAllLeafSubstructures).map(List::size).collect(Collectors.toSet()).size() != 1) {
            throw new ConsensusException("all substructures must contain the same number of leaf structures to calculate a consensus alignment");
        }
        this.iterationCounter = 0;
        this.alignmentTrace = new ArrayList<Double>();
        this.alignmentCounts = new ArrayList<Integer>();
        this.consensusTrees = new ArrayList<BinaryTree<ConsensusContainer>>();
        this.calculateInitialAlignments();
        logger.info("{} initial alignment pairs were computed, in total we have to compute {} alignments", (Object)this.alignments.size(), (Object)(this.alignments.size() * (this.input.size() - 1)));
        this.createTreeLeaves();
        this.calculateConsensusAlignment();
        this.splitTopLevelTree();
        if (this.alignWithinClusters) {
            this.alignWithinClusters();
        }
    }

    private static ConsensusContainer toContainer(StructuralMotif structuralMotif) {
        return new ConsensusContainer(structuralMotif.getCopy(), false);
    }

    public List<Double> getAlignmentTrace() {
        return this.alignmentTrace;
    }

    public List<BinaryTree<ConsensusContainer>> getClusters() {
        return this.clusters;
    }

    public void writeClusters(Path outputPath) throws IOException {
        logger.info("writing {} clusters to {}", (Object)this.clusters.size(), (Object)outputPath);
        Files.createDirectories(outputPath, new FileAttribute[0]);
        for (int i = 0; i < this.clusters.size(); ++i) {
            String clusterBaseLocation = "cluster_" + (i + 1) + "/";
            BinaryTree<ConsensusContainer> currentCluster = this.clusters.get(i);
            if (currentCluster.getLeafNodes().size() > 1) {
                StructureWriter.writeLeafSubstructures(((ConsensusContainer)currentCluster.getRoot().getData()).getStructuralMotif().getAllLeafSubstructures(), outputPath.resolve(clusterBaseLocation + "consensus_" + (i + 1) + ".pdb"));
            }
            for (BinaryTreeNode leafNode : currentCluster.getLeafNodes()) {
                if (((ConsensusContainer)leafNode.getData()).getSuperimposition() != null) {
                    StructureWriter.writeLeafSubstructures(((ConsensusContainer)leafNode.getData()).getSuperimposition().getMappedFullCandidate(), outputPath.resolve(clusterBaseLocation + ((ConsensusContainer)leafNode.getData()).toString() + ".pdb"));
                    continue;
                }
                StructureWriter.writeLeafSubstructures(((ConsensusContainer)leafNode.getData()).getStructuralMotif().getAllLeafSubstructures(), outputPath.resolve(clusterBaseLocation + ((ConsensusContainer)leafNode.getData()).toString() + ".pdb"));
            }
        }
    }

    private void alignWithinClusters() {
        this.clusters.stream().filter(cluster -> cluster.size() > 1).forEach(cluster -> {
            ConsensusContainer reference = (ConsensusContainer)cluster.getRoot().getData();
            cluster.getLeafNodes().stream().map(BinaryTreeNode::getData).forEach(consensusContainer -> {
                SubstructureSuperimposition superimposition = this.representationScheme == null ? (this.idealSuperimposition ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition((LeafSubstructureContainer)reference.getStructuralMotif(), (LeafSubstructureContainer)consensusContainer.getStructuralMotif(), this.atomFilter) : SubstructureSuperimposer.calculateSubstructureSuperimposition(reference.getStructuralMotif().getAllLeafSubstructures(), consensusContainer.getStructuralMotif().getAllLeafSubstructures(), this.atomFilter)) : (this.idealSuperimposition ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition((LeafSubstructureContainer)reference.getStructuralMotif(), (LeafSubstructureContainer)consensusContainer.getStructuralMotif(), this.representationScheme) : SubstructureSuperimposer.calculateSubstructureSuperimposition(reference.getStructuralMotif().getAllLeafSubstructures(), consensusContainer.getStructuralMotif().getAllLeafSubstructures(), this.representationScheme));
                consensusContainer.setSuperimposition(superimposition);
            });
        });
    }

    private void splitTopLevelTree() {
        this.clusters = new ArrayList<BinaryTree<ConsensusContainer>>();
        this.clusters.add(this.getTopConsensusTree());
        ListIterator<BinaryTree<ConsensusContainer>> clustersIterator = this.clusters.listIterator();
        while (clustersIterator.hasNext()) {
            BinaryTreeNode currentNode = clustersIterator.next().getRoot();
            BinaryTreeNode leftNode = currentNode.getLeft();
            BinaryTreeNode rightNode = currentNode.getRight();
            double leftDistance = leftNode != null ? ((ConsensusContainer)leftNode.getData()).getConsensusDistance() : 0.0;
            double rightDistance = rightNode != null ? ((ConsensusContainer)rightNode.getData()).getConsensusDistance() : 0.0;
            if (!(leftDistance > this.clusterCutoff) && !(rightDistance > this.clusterCutoff)) continue;
            clustersIterator.remove();
            clustersIterator.add((BinaryTree<ConsensusContainer>)new BinaryTree(currentNode.getLeft()));
            clustersIterator.previous();
            clustersIterator.add((BinaryTree<ConsensusContainer>)new BinaryTree(currentNode.getRight()));
            clustersIterator.previous();
        }
    }

    public BinaryTree<ConsensusContainer> getTopConsensusTree() {
        return this.consensusTrees.get(this.consensusTrees.size() - 1);
    }

    public List<BinaryTree<ConsensusContainer>> getConsensusTrees() {
        return this.consensusTrees;
    }

    private void calculateConsensusAlignment() {
        while (!this.alignments.isEmpty()) {
            this.findAndMergeClosestPair();
        }
    }

    public double getConsensusScore() {
        return this.consensusScore;
    }

    public double getNormalizedConsensusScore() {
        return this.consensusScore / (double)(this.iterationCounter * this.input.get(0).getStructuralMotif().size());
    }

    private void findAndMergeClosestPair() {
        ++this.iterationCounter;
        Pair<ConsensusContainer> closestPair = this.alignments.firstEntry().getValue();
        SubstructureSuperimposition closestPairSuperimposition = this.alignments.firstKey();
        double closestPairRmsd = closestPairSuperimposition.getRmsd();
        this.alignmentTrace.add(closestPairRmsd);
        this.alignmentCounts.add(this.input.size());
        logger.debug("closest pair for iteration {} is {} with RMSD {}", new Object[]{this.iterationCounter, closestPair, closestPairRmsd});
        this.consensusScore += closestPairRmsd;
        this.createConsensus(this.alignments.firstEntry());
        this.updateAlignments(this.alignments.firstEntry());
    }

    private void updateAlignments(Map.Entry<SubstructureSuperimposition, Pair<ConsensusContainer>> substructurePair) {
        Iterator<Map.Entry<SubstructureSuperimposition, Pair<ConsensusContainer>>> alignmentsIterator = this.alignments.entrySet().iterator();
        while (alignmentsIterator.hasNext()) {
            boolean candidateObservationMatches;
            Map.Entry<SubstructureSuperimposition, Pair<ConsensusContainer>> currentAlignment = alignmentsIterator.next();
            boolean referenceObservationMatches = ((ConsensusContainer)currentAlignment.getValue().getFirst()).equals(substructurePair.getValue().getFirst()) || ((ConsensusContainer)currentAlignment.getValue().getFirst()).equals(substructurePair.getValue().getSecond());
            boolean bl = candidateObservationMatches = ((ConsensusContainer)currentAlignment.getValue().getSecond()).equals(substructurePair.getValue().getSecond()) || ((ConsensusContainer)currentAlignment.getValue().getSecond()).equals(substructurePair.getValue().getFirst());
            if (!referenceObservationMatches && !candidateObservationMatches) continue;
            alignmentsIterator.remove();
        }
        this.input.removeIf(inputStructure -> inputStructure.equals(((Pair)substructurePair.getValue()).getFirst()));
        this.input.removeIf(inputStructure -> inputStructure.equals(((Pair)substructurePair.getValue()).getSecond()));
        for (ConsensusContainer inputStructure2 : this.input) {
            SubstructureSuperimposition superimposition = this.representationScheme == null ? (this.idealSuperimposition ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition((LeafSubstructureContainer)this.currentConsensus.getStructuralMotif(), (LeafSubstructureContainer)inputStructure2.getStructuralMotif(), this.atomFilter) : SubstructureSuperimposer.calculateSubstructureSuperimposition(this.currentConsensus.getStructuralMotif().getAllLeafSubstructures(), inputStructure2.getStructuralMotif().getAllLeafSubstructures(), this.atomFilter)) : (this.idealSuperimposition ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition((LeafSubstructureContainer)this.currentConsensus.getStructuralMotif(), (LeafSubstructureContainer)inputStructure2.getStructuralMotif(), this.representationScheme) : SubstructureSuperimposer.calculateSubstructureSuperimposition(this.currentConsensus.getStructuralMotif().getAllLeafSubstructures(), inputStructure2.getStructuralMotif().getAllLeafSubstructures(), this.representationScheme));
            Pair alignmentPair = new Pair((Object)this.currentConsensus, (Object)inputStructure2);
            this.alignments.put(superimposition, (Pair<ConsensusContainer>)alignmentPair);
        }
        this.input.add(this.currentConsensus);
    }

    private void createConsensus(Map.Entry<SubstructureSuperimposition, Pair<ConsensusContainer>> substructurePair) {
        BinaryTreeNode consensusNode;
        BinaryTreeNode<ConsensusContainer> rightNode;
        List candidateAtoms;
        List referenceAtoms;
        List<LeafSubstructure<?>> reference = ((ConsensusContainer)substructurePair.getValue().getFirst()).getStructuralMotif().getAllLeafSubstructures();
        List<LeafSubstructure<?>> candidate = substructurePair.getKey().getMappedFullCandidate();
        LinkedHashMap perAtomAlignment = new LinkedHashMap();
        IntStream.range(0, reference.size()).forEach(i -> {
            Set cfr_ignored_0 = perAtomAlignment.put(new Pair(reference.get(i), candidate.get(i)), new HashSet());
        });
        perAtomAlignment.entrySet().forEach(this::defineIntersectingAtoms);
        if (this.representationScheme == null) {
            referenceAtoms = perAtomAlignment.entrySet().stream().map(pairSetEntry -> ((LeafSubstructure)((Pair)pairSetEntry.getKey()).getFirst()).getAllAtoms().stream().filter(this.atomFilter).filter(atom -> ((Set)pairSetEntry.getValue()).contains(atom.getAtomName())).sorted(Comparator.comparing(Atom::getAtomName)).collect(Collectors.toList())).collect(Collectors.toList());
            candidateAtoms = perAtomAlignment.entrySet().stream().map(pairSetEntry -> ((LeafSubstructure)((Pair)pairSetEntry.getKey()).getSecond()).getAllAtoms().stream().filter(this.atomFilter).filter(atom -> ((Set)pairSetEntry.getValue()).contains(atom.getAtomName())).sorted(Comparator.comparing(Atom::getAtomName)).collect(Collectors.toList())).collect(Collectors.toList());
        } else {
            referenceAtoms = perAtomAlignment.entrySet().stream().map(pairSetEntry -> {
                ArrayList<Atom> atomList = new ArrayList<Atom>();
                atomList.add(this.representationScheme.determineRepresentingAtom((LeafSubstructure)((Pair)pairSetEntry.getKey()).getFirst()));
                return atomList;
            }).collect(Collectors.toList());
            candidateAtoms = perAtomAlignment.entrySet().stream().map(pairSetEntry -> {
                ArrayList<Atom> atomList = new ArrayList<Atom>();
                atomList.add(this.representationScheme.determineRepresentingAtom((LeafSubstructure)((Pair)pairSetEntry.getKey()).getSecond()));
                return atomList;
            }).collect(Collectors.toList());
        }
        ArrayList consensusLeaveSubstructures = new ArrayList();
        int atomCounter = 1;
        int leafCounter = 1;
        for (int i2 = 0; i2 < referenceAtoms.size(); ++i2) {
            List currentReferenceAtoms = (List)referenceAtoms.get(i2);
            List currentCandidateAtoms = (List)candidateAtoms.get(i2);
            ArrayList<OakAtom> averagedAtoms = new ArrayList<OakAtom>();
            for (int j = 0; j < currentReferenceAtoms.size(); ++j) {
                Atom referenceAtom = (Atom)currentReferenceAtoms.get(j);
                Atom candidateAtom = (Atom)currentCandidateAtoms.get(j);
                averagedAtoms.add(new OakAtom(atomCounter, referenceAtom.getElement(), referenceAtom.getAtomName(), referenceAtom.getPosition().add(candidateAtom.getPosition()).divide(2.0)));
                ++atomCounter;
            }
            AminoAcidFamily family = null;
            if (reference.get(i2).getFamily().equals(candidate.get(i2).getFamily())) {
                family = candidate.get(i2).getFamily();
            }
            if (family == null) {
                family = AminoAcidFamily.UNKNOWN;
            }
            OakLeafSubstructure<?> leafSubstructure = LeafSubstructureFactory.createLeafSubstructure(new LeafIdentifier(leafCounter), family);
            averagedAtoms.forEach(leafSubstructure::addAtom);
            consensusLeaveSubstructures.add(leafSubstructure);
            ++leafCounter;
        }
        this.currentConsensus = new ConsensusContainer(StructuralMotif.fromLeafSubstructures(consensusLeaveSubstructures), true);
        if (this.iterationCounter == 1) {
            BinaryTreeNode<ConsensusContainer> leftNode = this.findLeave((ConsensusContainer)substructurePair.getValue().getFirst());
            rightNode = this.findLeave((ConsensusContainer)substructurePair.getValue().getSecond());
            consensusNode = new BinaryTreeNode((Object)this.currentConsensus, leftNode, rightNode);
        } else {
            BinaryTreeNode<ConsensusContainer> leftNode = this.findNode((ConsensusContainer)substructurePair.getValue().getFirst());
            if (leftNode == null) {
                leftNode = this.findLeave((ConsensusContainer)substructurePair.getValue().getFirst());
            }
            if ((rightNode = this.findNode((ConsensusContainer)substructurePair.getValue().getSecond())) == null) {
                rightNode = this.findLeave((ConsensusContainer)substructurePair.getValue().getSecond());
            }
            consensusNode = new BinaryTreeNode((Object)this.currentConsensus, leftNode, rightNode);
        }
        BinaryTree consensusTree = new BinaryTree(consensusNode);
        this.currentConsensus.setConsensusTree((BinaryTree<ConsensusContainer>)consensusTree);
        this.consensusTrees.add((BinaryTree<ConsensusContainer>)consensusTree);
        ((ConsensusContainer)consensusTree.getRoot().getLeft().getData()).addToConsensusDistance(substructurePair.getKey().getRmsd() / 2.0);
        ((ConsensusContainer)consensusTree.getRoot().getRight().getData()).addToConsensusDistance(substructurePair.getKey().getRmsd() / 2.0);
    }

    private BinaryTreeNode<ConsensusContainer> findLeave(ConsensusContainer consensusContainer) {
        return this.leaves.stream().filter(leave -> ((ConsensusContainer)leave.getData()).equals(consensusContainer)).findFirst().orElseThrow(() -> new ConsensusException("failed during tree construction"));
    }

    private BinaryTreeNode<ConsensusContainer> findNode(ConsensusContainer consensusContainer) {
        BinaryTree<ConsensusContainer> tree;
        BinaryTreeNode nodeForObservation = null;
        Iterator<BinaryTree<ConsensusContainer>> iterator = this.consensusTrees.iterator();
        while (iterator.hasNext() && (nodeForObservation = (tree = iterator.next()).findNode((Object)consensusContainer)) == null) {
        }
        return nodeForObservation;
    }

    private void defineIntersectingAtoms(Map.Entry<Pair<LeafSubstructure>, Set<String>> pairListEntry) {
        if (this.representationScheme == null) {
            pairListEntry.getValue().addAll(((LeafSubstructure)pairListEntry.getKey().getFirst()).getAllAtoms().stream().filter(this.atomFilter).map(Atom::getAtomName).collect(Collectors.toSet()));
            pairListEntry.getValue().retainAll(((LeafSubstructure)pairListEntry.getKey().getSecond()).getAllAtoms().stream().filter(this.atomFilter).map(Atom::getAtomName).collect(Collectors.toSet()));
        } else {
            pairListEntry.getValue().add(this.representationScheme.determineRepresentingAtom((LeafSubstructure)pairListEntry.getKey().getFirst()).getAtomName());
            pairListEntry.getValue().add(this.representationScheme.determineRepresentingAtom((LeafSubstructure)pairListEntry.getKey().getSecond()).getAtomName());
        }
    }

    private void createTreeLeaves() {
        this.leaves = this.input.stream().map(BinaryTreeNode::new).collect(Collectors.toList());
    }

    private void calculateInitialAlignments() {
        this.alignments = new TreeMap(Comparator.comparing(SubstructureSuperimposition::getRmsd));
        double[][] temporaryDistanceMatrix = new double[this.input.size()][this.input.size()];
        ArrayList<ConsensusContainer> distanceMatrixLabels = new ArrayList<ConsensusContainer>();
        distanceMatrixLabels.add(this.input.get(0));
        int alignmentCounter = 0;
        for (int i = 0; i < this.input.size() - 1; ++i) {
            for (int j = i + 1; j < this.input.size(); ++j) {
                StructuralMotif reference = this.input.get(i).getStructuralMotif();
                StructuralMotif candidate = this.input.get(j).getStructuralMotif();
                SubstructureSuperimposition superimposition = this.representationScheme == null ? (this.idealSuperimposition ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition((LeafSubstructureContainer)reference, (LeafSubstructureContainer)candidate, this.atomFilter) : SubstructureSuperimposer.calculateSubstructureSuperimposition(reference.getAllLeafSubstructures(), candidate.getAllLeafSubstructures(), this.atomFilter)) : (this.idealSuperimposition ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition((LeafSubstructureContainer)reference, (LeafSubstructureContainer)candidate, this.representationScheme) : SubstructureSuperimposer.calculateSubstructureSuperimposition(reference.getAllLeafSubstructures(), candidate.getAllLeafSubstructures(), this.representationScheme));
                Pair alignmentPair = new Pair((Object)new ConsensusContainer(reference, false), (Object)new ConsensusContainer(candidate, false));
                this.alignments.put(superimposition, (Pair<ConsensusContainer>)alignmentPair);
                temporaryDistanceMatrix[i][j] = superimposition.getRmsd();
                temporaryDistanceMatrix[j][i] = superimposition.getRmsd();
                if (++alignmentCounter % 1000 != 0) continue;
                logger.info("computed {} of {} initial alignments ", (Object)alignmentCounter, (Object)(this.input.size() * ((this.input.size() - 1) / 2)));
            }
            distanceMatrixLabels.add(this.input.get(i + 1));
        }
        this.distanceMatrix = new LabeledSymmetricMatrix(temporaryDistanceMatrix);
        this.distanceMatrix.setColumnLabels(distanceMatrixLabels);
    }
}

