/*
 * Decompiled with CFR 0.152.
 */
package de.bioforscher.singa.structure.algorithms.superimposition.fit3d;

import de.bioforscher.singa.core.utility.Pair;
import de.bioforscher.singa.mathematics.algorithms.optimization.KuhnMunkres;
import de.bioforscher.singa.mathematics.combinatorics.StreamPermutations;
import de.bioforscher.singa.mathematics.matrices.LabeledMatrix;
import de.bioforscher.singa.mathematics.matrices.LabeledRegularMatrix;
import de.bioforscher.singa.mathematics.matrices.Matrices;
import de.bioforscher.singa.structure.algorithms.superimposition.SubstructureSuperimposer;
import de.bioforscher.singa.structure.algorithms.superimposition.SubstructureSuperimposition;
import de.bioforscher.singa.structure.algorithms.superimposition.SubstructureSuperimpositionException;
import de.bioforscher.singa.structure.algorithms.superimposition.fit3d.Fit3D;
import de.bioforscher.singa.structure.algorithms.superimposition.fit3d.Fit3DBuilder;
import de.bioforscher.singa.structure.algorithms.superimposition.fit3d.Fit3DException;
import de.bioforscher.singa.structure.algorithms.superimposition.fit3d.Fit3DMatch;
import de.bioforscher.singa.structure.algorithms.superimposition.fit3d.representations.RepresentationScheme;
import de.bioforscher.singa.structure.algorithms.superimposition.scores.PsScore;
import de.bioforscher.singa.structure.algorithms.superimposition.scores.SubstitutionMatrix;
import de.bioforscher.singa.structure.algorithms.superimposition.scores.XieScore;
import de.bioforscher.singa.structure.model.families.AminoAcidFamily;
import de.bioforscher.singa.structure.model.interfaces.Atom;
import de.bioforscher.singa.structure.model.interfaces.LeafSubstructure;
import de.bioforscher.singa.structure.model.oak.StructuralMotif;
import de.bioforscher.singa.structure.parser.pdb.structures.StructureWriter;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.StringJoiner;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Fit3DSiteAlignment
implements Fit3D {
    private static final Logger logger = LoggerFactory.getLogger(Fit3DSiteAlignment.class);
    private static final int PERMUTATION_CUTOFF = 3;
    private final StructuralMotif site1;
    private final StructuralMotif site2;
    private final LinkedHashSet<List<LeafSubstructure<?>>> site1Partitions;
    private final LinkedHashSet<List<LeafSubstructure<?>>> site2Partitions;
    private final RepresentationScheme representationScheme;
    private final Predicate<Atom> atomFilter;
    private final double rmsdCutoff;
    private final double distanceTolerance;
    private final boolean exhaustive;
    private final boolean restrictToExchanges;
    private final SubstitutionMatrix substitutionMatrix;
    private final boolean containsNonAminoAcids;
    private final boolean kuhnMunkres;
    private final double cutoffScore;
    private final List<Fit3DMatch> matches;
    private int currentAlignmentSize;
    private LabeledRegularMatrix<List<LeafSubstructure<?>>> currentSimilarityMatrix;
    private Pair<List<LeafSubstructure<?>>> currentBestMatchingPair;
    private double currentBestScore;
    private SubstructureSuperimposition currentBestSuperimposition;
    private String alignmentString;
    private boolean cutoffScoreReached;
    private XieScore xieScore;
    private PsScore psScore;
    private List<Pair<LeafSubstructure<?>>> assignment;

    public Fit3DSiteAlignment(Fit3DBuilder.Builder builder) throws SubstructureSuperimpositionException {
        this.site1 = builder.site1.getCopy();
        this.site2 = builder.site2.getCopy();
        boolean bl = this.containsNonAminoAcids = this.site1.getAllLeafSubstructures().stream().anyMatch(leafSubstructure -> !(leafSubstructure.getFamily() instanceof AminoAcidFamily)) || this.site1.getAllLeafSubstructures().stream().anyMatch(leafSubstructure -> !(leafSubstructure.getFamily() instanceof AminoAcidFamily));
        if (this.containsNonAminoAcids) {
            logger.info("sites contain non-amino acid residues, no Xie and PS-scores can be calculated");
        }
        this.exhaustive = builder.exhaustive;
        this.kuhnMunkres = builder.kuhnMunkres;
        this.restrictToExchanges = builder.restrictToExchanges;
        if (!this.restrictToExchanges) {
            logger.info("specified exchanges will be ignored for the Fit3DSite alignment and matched types will be arbitrary");
        }
        this.currentAlignmentSize = 2;
        this.currentBestScore = Double.MAX_VALUE;
        logger.debug("calculating initial 2-partitions");
        this.site1Partitions = this.createInitialPartitions(this.site1);
        this.site2Partitions = this.createInitialPartitions(this.site2);
        this.cutoffScore = builder.cutoffScore;
        this.atomFilter = builder.atomFilter;
        this.representationScheme = builder.representationScheme;
        this.rmsdCutoff = builder.rmsdCutoff;
        this.distanceTolerance = builder.distanceTolerance;
        this.substitutionMatrix = builder.substitutionMatrix;
        this.matches = new ArrayList<Fit3DMatch>();
        logger.info("computing Fit3DSite alignment for {} (size: {}) against {} (size: {}) with cutoff score {}", new Object[]{this.site1, this.site1.size(), this.site2, this.site2.size(), this.cutoffScore});
        if (!this.kuhnMunkres) {
            logger.info("using combinatorial extension to find alignment");
            this.calculateSimilarities();
            this.extendAlignment();
        } else {
            logger.info("using Kuhn-Munkres optimization with substitution matrix {} to find alignment", (Object)this.substitutionMatrix);
            this.calculateAssignment();
            this.calculateAlignment();
        }
        Collections.sort(this.matches);
    }

    @Override
    public PsScore getPsScore() {
        return this.psScore;
    }

    @Override
    public XieScore getXieScore() {
        return this.xieScore;
    }

    @Override
    public String getAlignmentString() {
        return this.alignmentString;
    }

    private void extendAlignment() {
        while (this.currentBestScore <= this.cutoffScore) {
            if (this.site1.size() == this.currentAlignmentSize || this.site2.size() == this.currentAlignmentSize) {
                logger.info("alignment fully terminated after {} iterations", (Object)this.currentAlignmentSize);
                break;
            }
            this.extendPartitions();
            this.calculateSimilarities();
            if (!this.cutoffScoreReached) continue;
            logger.info("alignment reached cutoff score of {}", (Object)this.cutoffScore);
            break;
        }
        if (this.currentBestSuperimposition != null) {
            this.matches.add(Fit3DMatch.of(this.currentBestScore, this.currentBestSuperimposition));
            if (!this.containsNonAminoAcids) {
                this.calculateXieScore();
                this.calculatePsScore();
            }
            this.outputSummary();
        } else {
            logger.info("no suitable alignment could be found");
        }
    }

    private void calculatePsScore() {
        this.psScore = PsScore.of(this.currentBestSuperimposition, this.site1.getNumberOfLeafSubstructures(), this.site2.getNumberOfLeafSubstructures());
    }

    private void calculateXieScore() {
        this.xieScore = XieScore.of(this.substitutionMatrix, this.currentBestSuperimposition);
    }

    private void outputSummary() {
        StringJoiner site1Joiner = new StringJoiner("|", "|", "|");
        StringJoiner site2Joiner = new StringJoiner("|", "|", "|");
        for (int i = 0; i < this.currentAlignmentSize; ++i) {
            site1Joiner.add(String.format("%-7s", this.currentBestSuperimposition.getReference().get(i).toString()));
            site2Joiner.add(String.format("%-7s", this.currentBestSuperimposition.getCandidate().get(i).toString()));
        }
        this.alignmentString = String.format("%-7s", "s1size") + "|" + this.site1.size() + "\n" + String.format("%-7s", "s2size") + "|" + this.site2.size() + "\n" + this.site1.getAllLeafSubstructures().stream().map(Object::toString).map(s1 -> String.format("%-7s", s1)).collect(Collectors.joining("|", String.format("%-7s", "s1") + "|", "|")) + "\n" + this.site2.getAllLeafSubstructures().stream().map(Object::toString).map(s1 -> String.format("%-7s", s1)).collect(Collectors.joining("|", String.format("%-7s", "s2") + "|", "|")) + "\n" + String.format("%-7s", "RMSD") + "|" + this.currentBestSuperimposition.getRmsd() + "\n" + String.format("%-7s", "frac") + "|" + this.getAlignedResidueFraction() + "\n" + String.format("%-7s", "XieS") + "|" + (this.containsNonAminoAcids ? "NaN" : Double.valueOf(this.getXieScore().getScore())) + "\n" + String.format("%-7s", "XieExp") + "|" + (this.containsNonAminoAcids ? "NaN" : Double.valueOf(this.getXieScore().getSignificance())) + "\n" + String.format("%-7s", "PsS") + "|" + (this.containsNonAminoAcids ? "NaN" : Double.valueOf(this.getPsScore().getScore())) + "\n" + String.format("%-7s", "PsExp") + "|" + (this.containsNonAminoAcids ? "NaN" : Double.valueOf(this.getPsScore().getSignificance())) + "\n" + String.format("%-7s", "s1algn") + site1Joiner.toString() + "\n" + String.format("%-7s", "s2algn") + site2Joiner.toString();
        logger.info("aligned {} residues (site 1 contains {} residues and site 2 contains {} residues)\n{}", new Object[]{this.currentAlignmentSize, this.site1.size(), this.site2.size(), this.alignmentString});
    }

    private double getAlignedResidueFraction() {
        return this.site1.size() > this.site2.size() ? (double)this.currentAlignmentSize / (double)this.site2.size() : (double)this.currentAlignmentSize / (double)this.site1.size();
    }

    private void extendPartitions() {
        ++this.currentAlignmentSize;
        this.site1Partitions.clear();
        this.site2Partitions.clear();
        for (LeafSubstructure<?> leafSubstructure : this.site1.getAllLeafSubstructures()) {
            ArrayList site1Partition = new ArrayList((Collection)this.currentBestMatchingPair.getFirst());
            if (!site1Partition.contains(leafSubstructure)) {
                site1Partition.add(leafSubstructure);
            }
            if (site1Partition.size() != this.currentAlignmentSize) continue;
            if (this.currentAlignmentSize <= 3 && !this.exhaustive) {
                StreamPermutations.of((Object[])site1Partition.toArray(new LeafSubstructure[site1Partition.size()])).map(s -> s.collect(Collectors.toList())).forEach(this.site1Partitions::add);
                continue;
            }
            this.site1Partitions.add(site1Partition);
        }
        for (LeafSubstructure<?> leafSubstructure : this.site2.getAllLeafSubstructures()) {
            ArrayList site2Partition = new ArrayList((Collection)this.currentBestMatchingPair.getSecond());
            if (!site2Partition.contains(leafSubstructure)) {
                site2Partition.add(leafSubstructure);
            }
            if (site2Partition.size() != this.currentAlignmentSize) continue;
            if (this.currentAlignmentSize <= 3 && !this.exhaustive) {
                StreamPermutations.of((Object[])site2Partition.toArray(new LeafSubstructure[site2Partition.size()])).map(s -> s.collect(Collectors.toList())).forEach(this.site2Partitions::add);
                continue;
            }
            this.site2Partitions.add(site2Partition);
        }
    }

    private void calculateSimilarities() throws SubstructureSuperimpositionException {
        double localBestScore = Double.MAX_VALUE;
        SubstructureSuperimposition localBestSuperimposition = null;
        double[][] temporarySimilarityMatrix = new double[this.site1Partitions.size()][this.site2Partitions.size()];
        ArrayList<List> rowLabels = new ArrayList<List>();
        ArrayList<List> columnLabels = new ArrayList<List>();
        int i = 0;
        for (List list : this.site1Partitions) {
            rowLabels.add(list);
            int j = 0;
            for (List list2 : this.site2Partitions) {
                if (!columnLabels.contains(list2)) {
                    columnLabels.add(list2);
                }
                if (this.restrictToExchanges) {
                    StructuralMotif query = this.site1.getCopy();
                    List<LeafSubstructure> queryLeavesToBeRemoved = this.site1.getAllLeafSubstructures().stream().filter(leafSubstructure -> !site1Partition.contains(leafSubstructure)).collect(Collectors.toList());
                    queryLeavesToBeRemoved.forEach(leafSubstructure -> query.removeLeafSubstructure(leafSubstructure.getIdentifier()));
                    StructuralMotif target = this.site2.getCopy();
                    List<LeafSubstructure> targetLeavesToBeRemoved = this.site2.getAllLeafSubstructures().stream().filter(leafSubstructure -> !site2Partition.contains(leafSubstructure)).collect(Collectors.toList());
                    targetLeavesToBeRemoved.forEach(target::removeLeafSubstructure);
                    Fit3D fit3d = this.representationScheme != null ? Fit3DBuilder.create().query(query).target(target).representationScheme(this.representationScheme.getType()).rmsdCutoff(this.rmsdCutoff).distanceTolerance(this.distanceTolerance).run() : Fit3DBuilder.create().query(query).target(target).atomFilter(this.atomFilter).rmsdCutoff(this.rmsdCutoff).distanceTolerance(this.distanceTolerance).run();
                    if (fit3d.getMatches().isEmpty()) {
                        temporarySimilarityMatrix[i][j] = Double.MAX_VALUE;
                    } else {
                        double rmsd;
                        Fit3DMatch bestMatch = fit3d.getMatches().get(0);
                        temporarySimilarityMatrix[i][j] = rmsd = bestMatch.getRmsd();
                        if (rmsd < localBestScore) {
                            localBestSuperimposition = bestMatch.getSubstructureSuperimposition();
                            localBestScore = rmsd;
                        }
                    }
                } else {
                    double rmsd;
                    SubstructureSuperimposition superimposition = this.representationScheme != null ? (this.exhaustive ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition(list, list2, this.representationScheme) : SubstructureSuperimposer.calculateSubstructureSuperimposition(list, list2, this.representationScheme)) : (this.exhaustive ? SubstructureSuperimposer.calculateIdealSubstructureSuperimposition(list, list2, this.atomFilter) : SubstructureSuperimposer.calculateSubstructureSuperimposition(list, list2, this.atomFilter));
                    temporarySimilarityMatrix[i][j] = rmsd = superimposition.getRmsd();
                    if (rmsd < localBestScore) {
                        localBestSuperimposition = superimposition;
                        localBestScore = rmsd;
                    }
                }
                ++j;
            }
            ++i;
        }
        this.currentSimilarityMatrix = new LabeledRegularMatrix(temporarySimilarityMatrix);
        this.currentSimilarityMatrix.setRowLabels(rowLabels);
        this.currentSimilarityMatrix.setColumnLabels(columnLabels);
        logger.debug("current similarity matrix is \n{}", (Object)this.currentSimilarityMatrix.getStringRepresentation());
        List minimalScores = Matrices.getPositionsOfMinimalElement(this.currentSimilarityMatrix);
        if (!minimalScores.isEmpty()) {
            List list = (List)this.currentSimilarityMatrix.getRowLabel(((Integer)((Pair)minimalScores.get(0)).getFirst()).intValue());
            List second = (List)this.currentSimilarityMatrix.getColumnLabel(((Integer)((Pair)minimalScores.get(0)).getSecond()).intValue());
            double scoreValue = this.currentSimilarityMatrix.getValueFromPosition((Pair)minimalScores.get(0));
            if (scoreValue > this.cutoffScore) {
                logger.info("cutoff score exceeded");
                --this.currentAlignmentSize;
                this.cutoffScoreReached = true;
                return;
            }
            this.currentBestMatchingPair = new Pair((Object)list, (Object)second);
            this.currentBestScore = scoreValue;
            this.currentBestSuperimposition = localBestSuperimposition;
            logger.info("current best matching pair of size {} is {} with RMSD {}", new Object[]{this.currentAlignmentSize, this.currentBestMatchingPair, this.currentBestScore});
        } else {
            if (this.currentAlignmentSize == 2) {
                throw new Fit3DException("could not find minimal agreement of partitions in first iteration");
            }
            logger.info("no suitable alignment found in iteration {}", (Object)this.currentAlignmentSize);
            --this.currentAlignmentSize;
            this.currentBestScore = Double.MAX_VALUE;
        }
    }

    private LinkedHashSet<List<LeafSubstructure<?>>> createInitialPartitions(StructuralMotif structuralMotif) {
        LinkedHashSet partitions = new LinkedHashSet();
        List<LeafSubstructure<?>> leafSubstructures = structuralMotif.getAllLeafSubstructures();
        for (int i = 0; i < leafSubstructures.size() - 1; ++i) {
            for (int j = i + 1; j < leafSubstructures.size(); ++j) {
                ArrayList partition1 = new ArrayList();
                partition1.add(leafSubstructures.get(i));
                partition1.add(leafSubstructures.get(j));
                partitions.add(partition1);
                if (this.exhaustive) continue;
                ArrayList partition2 = new ArrayList();
                partition2.add(leafSubstructures.get(j));
                partition2.add(leafSubstructures.get(i));
                partitions.add(partition2);
            }
        }
        return partitions;
    }

    private void calculateAssignment() {
        double[][] costValues = new double[this.site1.size()][this.site2.size()];
        for (int i = 0; i < this.site1.getNumberOfLeafSubstructures(); ++i) {
            for (int j = 0; j < this.site2.getNumberOfLeafSubstructures(); ++j) {
                LeafSubstructure<?> residue1 = this.site1.getAllLeafSubstructures().get(i);
                LeafSubstructure<?> residue2 = this.site2.getAllLeafSubstructures().get(j);
                if (this.restrictToExchanges && residue1.getFamily() != residue2.getFamily()) {
                    if (residue1.getExchangeableFamilies().contains(residue2.getFamily()) || residue2.getExchangeableFamilies().contains(residue1.getFamily())) continue;
                    costValues[i][j] = Double.MAX_VALUE;
                    continue;
                }
                costValues[i][j] = this.substitutionMatrix.getMatrix().getValueForLabel(residue1.getFamily(), residue2.getFamily());
            }
        }
        LabeledRegularMatrix costMatrix = new LabeledRegularMatrix(costValues);
        costMatrix.setRowLabels(this.site1.getAllLeafSubstructures());
        costMatrix.setColumnLabels(this.site2.getAllLeafSubstructures());
        KuhnMunkres kuhnMunkres = new KuhnMunkres((LabeledMatrix)costMatrix);
        this.assignment = kuhnMunkres.getAssignedPairs();
        if (this.restrictToExchanges) {
            this.assignment.remove(this.assignment.size() - 1);
        }
        String assignmentString = kuhnMunkres.getAssignedPairs().stream().map(arg_0 -> Fit3DSiteAlignment.lambda$calculateAssignment$9((LabeledMatrix)costMatrix, arg_0)).collect(Collectors.joining("\n"));
        logger.debug("optimal assignment of binding sites is:\n{}", (Object)assignmentString);
    }

    private void calculateAlignment() {
        List<LeafSubstructure<?>> reference = this.assignment.stream().map(Pair::getFirst).collect(Collectors.toList());
        List<LeafSubstructure<?>> candidate = this.assignment.stream().map(Pair::getSecond).collect(Collectors.toList());
        this.currentAlignmentSize = reference.size();
        this.currentBestSuperimposition = this.representationScheme != null ? SubstructureSuperimposer.calculateSubstructureSuperimposition(reference, candidate, this.representationScheme) : SubstructureSuperimposer.calculateSubstructureSuperimposition(reference, candidate, this.atomFilter);
        this.currentBestScore = this.currentBestSuperimposition.getRmsd();
        this.matches.add(Fit3DMatch.of(this.currentBestSuperimposition.getRmsd(), this.currentBestSuperimposition));
        if (!this.containsNonAminoAcids) {
            this.calculateXieScore();
            this.calculatePsScore();
        }
        this.outputSummary();
    }

    @Override
    public void writeMatches(Path outputDirectory) {
        if (this.matches.isEmpty()) {
            throw new Fit3DException("cannot write matches as they are currently empty");
        }
        SubstructureSuperimposition bestSuperimposition = this.matches.get(0).getSubstructureSuperimposition();
        List<LeafSubstructure<?>> mappedSite2 = bestSuperimposition.applyTo(this.site2.getCopy().getAllLeafSubstructures());
        try {
            StructureWriter.writeLeafSubstructures(this.site1.getAllLeafSubstructures(), outputDirectory.resolve(this.site1.getAllLeafSubstructures().stream().sorted(Comparator.comparing(LeafSubstructure::getIdentifier)).map(leafSubstructure -> leafSubstructure.getChainIdentifier() + "-" + leafSubstructure.getIdentifier().getSerial()).collect(Collectors.joining("_", bestSuperimposition.getFormattedRmsd() + "_" + this.site1.getAllLeafSubstructures().get(0).getPdbIdentifier() + "_", "")) + "_site1.pdb"));
            StructureWriter.writeLeafSubstructures(mappedSite2, outputDirectory.resolve(this.site2.getAllLeafSubstructures().stream().sorted(Comparator.comparing(LeafSubstructure::getIdentifier)).map(leafSubstructure -> leafSubstructure.getChainIdentifier() + "-" + leafSubstructure.getIdentifier().getSerial()).collect(Collectors.joining("_", bestSuperimposition.getFormattedRmsd() + "_" + this.site2.getAllLeafSubstructures().get(0).getPdbIdentifier() + "_", "")) + "_site2.pdb"));
        }
        catch (IOException e) {
            logger.error("error writing Fit3DSite results", (Throwable)e);
        }
    }

    @Override
    public List<Fit3DMatch> getMatches() {
        return this.matches;
    }

    @Override
    public double getFraction() {
        return this.getAlignedResidueFraction();
    }

    private static /* synthetic */ String lambda$calculateAssignment$9(LabeledMatrix costMatrix, Pair pair) {
        return pair.getFirst() + "+" + pair.getSecond() + ":" + costMatrix.getValueForLabel(pair.getFirst(), pair.getSecond());
    }
}

