/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.math.bayesianinference;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.graph.Graph;
import ai.libs.jaicore.math.bayesianinference.ABayesianInferenceAlgorithm;
import ai.libs.jaicore.math.bayesianinference.BayesianInferenceProblem;
import ai.libs.jaicore.math.bayesianinference.DiscreteProbabilityDistribution;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.api4.java.algorithm.events.IAlgorithmEvent;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;

public class VariableElimination
extends ABayesianInferenceAlgorithm {
    private List<Factor> factors = new ArrayList<Factor>();

    public VariableElimination(BayesianInferenceProblem input) {
        super(input);
    }

    public List<String> preprocessVariables() {
        boolean variableRemoved;
        Graph reducedGraph = new Graph(this.net.getNet());
        do {
            variableRemoved = false;
            Collection sinks = reducedGraph.getSinks();
            for (String sink : sinks) {
                if (this.queryVariables.contains(sink) || this.evidence.containsKey(sink)) continue;
                reducedGraph.removeItem((Object)sink);
                variableRemoved = true;
            }
        } while (variableRemoved);
        ArrayList<String> vars = new ArrayList<String>();
        while (!reducedGraph.isEmpty()) {
            Collection sinks = reducedGraph.getSinks();
            for (String var : sinks) {
                vars.add(var);
                reducedGraph.removeItem((Object)var);
            }
        }
        return vars;
    }

    public IAlgorithmEvent nextWithException() throws InterruptedException, AlgorithmExecutionCanceledException, AlgorithmTimeoutedException, AlgorithmException {
        List<String> relevantAndOrderedVariables = this.preprocessVariables();
        for (String var : relevantAndOrderedVariables) {
            this.factors.add(this.makeFactor(var, this.evidence));
            if (!this.hiddenVariables.contains(var)) continue;
            this.factors = this.sumOut(var, this.factors);
        }
        this.setDistribution(this.multiply(this.factors).getNormalizedCopy());
        return null;
    }

    private Factor makeFactor(String var, Map<String, Boolean> evidence) throws InterruptedException {
        Collection inputVariables = SetUtil.difference((Collection)this.net.getNet().getPredecessors((Object)var), evidence.keySet());
        Set trueEvidenceVariables = this.net.getNet().getPredecessors((Object)var).stream().filter(k -> evidence.containsKey(k) && (Boolean)evidence.get(k) != false).collect(Collectors.toSet());
        boolean branchOverQueryVar = !evidence.keySet().contains(var);
        DiscreteProbabilityDistribution factorDistribution = new DiscreteProbabilityDistribution();
        Collection factorEntries = SetUtil.powerset((Collection)inputVariables);
        for (Collection event : factorEntries) {
            HashSet<String> eventWithEvidence = new HashSet<String>(event);
            eventWithEvidence.addAll(trueEvidenceVariables);
            if (branchOverQueryVar) {
                double probWithPosVal = this.net.getProbabilityOfPositiveEvent(var, eventWithEvidence);
                double probWithNegVal = 1.0 - probWithPosVal;
                factorDistribution.addProbability(event, probWithNegVal);
                HashSet<String> eventWithPositiveVar = new HashSet<String>(event);
                eventWithPositiveVar.add(var);
                factorDistribution.addProbability(eventWithPositiveVar, probWithPosVal);
                continue;
            }
            double prob = -1.0;
            boolean wantPositiveProb = evidence.get(var);
            prob = wantPositiveProb ? this.net.getProbabilityOfPositiveEvent(var, eventWithEvidence) : 1.0 - this.net.getProbabilityOfPositiveEvent(var, eventWithEvidence);
            factorDistribution.addProbability(event, prob);
        }
        return new Factor(factorDistribution);
    }

    private List<Factor> sumOut(String var, List<Factor> factors) throws InterruptedException {
        ArrayList<Factor> newFactors = new ArrayList<Factor>();
        ArrayList<Factor> eliminatedFactors = new ArrayList<Factor>();
        for (Factor f : factors) {
            if (!f.subDistribution.getVariables().contains(var)) {
                newFactors.add(f);
                continue;
            }
            eliminatedFactors.add(f);
        }
        DiscreteProbabilityDistribution productDistribution = eliminatedFactors.size() > 1 ? this.multiply(eliminatedFactors) : ((Factor)eliminatedFactors.get(0)).subDistribution;
        DiscreteProbabilityDistribution distOfNewFactor = new DiscreteProbabilityDistribution();
        List<String> remainingVariablesInFactor = productDistribution.getVariables();
        remainingVariablesInFactor.remove(var);
        Collection entriesInReducedFactor = SetUtil.powerset(remainingVariablesInFactor);
        for (Collection entry : entriesInReducedFactor) {
            HashSet<String> event = new HashSet<String>(entry);
            double probForEventWithVariableIsNegative = productDistribution.getProbabilities().get(event);
            event.add(var);
            double probForEventWithVariableIsPositive = productDistribution.getProbabilities().get(event);
            event.remove(var);
            distOfNewFactor.addProbability(event, probForEventWithVariableIsNegative + probForEventWithVariableIsPositive);
        }
        newFactors.add(new Factor(distOfNewFactor));
        return newFactors;
    }

    public DiscreteProbabilityDistribution multiply(Collection<Factor> factors) throws InterruptedException {
        DiscreteProbabilityDistribution current = null;
        for (Factor f : factors) {
            if (current != null) {
                current = this.multiply(current, f.subDistribution);
                continue;
            }
            current = f.subDistribution;
        }
        return current;
    }

    public DiscreteProbabilityDistribution multiply(DiscreteProbabilityDistribution f1, DiscreteProbabilityDistribution f2) throws InterruptedException {
        HashSet<String> variables = new HashSet<String>();
        variables.addAll(f1.getVariables());
        variables.addAll(f2.getVariables());
        ArrayList intersectionVariables = new ArrayList(SetUtil.intersection(f1.getVariables(), f2.getVariables()));
        Collection commonVariableCombinations = SetUtil.powerset(intersectionVariables);
        Collection otherVariables = SetUtil.difference(variables, intersectionVariables);
        Collection disjointVariableCombinations = SetUtil.powerset((Collection)otherVariables);
        DiscreteProbabilityDistribution newDist = new DiscreteProbabilityDistribution();
        for (Collection intersectionVarCombo : commonVariableCombinations) {
            for (Collection differenceVarCombo : disjointVariableCombinations) {
                HashSet eventInFirst = new HashSet(SetUtil.union((Collection[])new Collection[]{intersectionVarCombo, SetUtil.intersection((Collection)differenceVarCombo, f1.getVariables())}));
                HashSet eventInSecond = new HashSet(SetUtil.union((Collection[])new Collection[]{intersectionVarCombo, SetUtil.intersection((Collection)differenceVarCombo, f2.getVariables())}));
                double p1 = f1.getProbabilities().get(eventInFirst);
                double p2 = f2.getProbabilities().get(eventInSecond);
                double p = p1 * p2;
                HashSet<String> jointEvent = new HashSet<String>();
                jointEvent.addAll(intersectionVarCombo);
                jointEvent.addAll(differenceVarCombo);
                newDist.addProbability(jointEvent, p);
            }
        }
        return newDist;
    }

    private class Factor {
        private DiscreteProbabilityDistribution subDistribution;

        public Factor(DiscreteProbabilityDistribution subDistribution) {
            this.subDistribution = subDistribution;
        }
    }
}

