/*
 * Decompiled with CFR 0.152.
 */
package lphy.graphicalModel;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lphy.core.narrative.Narrative;
import lphy.graphicalModel.Citation;
import lphy.graphicalModel.Generator;
import lphy.graphicalModel.GraphicalModelNode;
import lphy.graphicalModel.GraphicalModelNodeVisitor;
import lphy.graphicalModel.NarrativeUtils;
import lphy.graphicalModel.RandomVariable;
import lphy.graphicalModel.Value;
import lphy.parser.functions.ExpressionNode;
import lphy.parser.functions.ExpressionNodeWrapper;

public interface GraphicalModel {
    public Map<String, Value<?>> getDataDictionary();

    public Map<String, Value<?>> getModelDictionary();

    public Set<Value> getDataValues();

    public Set<Value> getModelValues();

    default public Value getValue(String id, Context context) {
        switch (context) {
            case data: {
                return this.getDataDictionary().get(id);
            }
        }
        Map<String, Value<?>> data = this.getDataDictionary();
        Map<String, Value<?>> model = this.getModelDictionary();
        if (model.containsKey(id)) {
            return model.get(id);
        }
        return data.get(id);
    }

    default public boolean hasValue(String id, Context context) {
        return this.getValue(id, context) != null;
    }

    default public boolean isClamped(String id) {
        return id != null && this.getDataDictionary().containsKey(id) && this.getModelDictionary().containsKey(id) && this.getModelDictionary().get(id) instanceof RandomVariable;
    }

    default public boolean isClampedVariable(Value value) {
        return value instanceof RandomVariable && this.isClamped(value.getId());
    }

    default public List<Value<?>> getModelSinks() {
        ArrayList nonArguments = new ArrayList();
        this.getDataDictionary().values().forEach(val -> {
            if (!val.isAnonymous() && val.getOutputs().size() == 0) {
                nonArguments.add((Value<?>)val);
            }
        });
        this.getModelDictionary().values().forEach(val -> {
            if (!val.isAnonymous() && val.getOutputs().size() == 0) {
                nonArguments.add((Value<?>)val);
            }
        });
        nonArguments.sort(Comparator.comparing(Value::getId));
        return nonArguments;
    }

    default public List<RandomVariable<?>> getAllVariablesFromSinks() {
        ArrayList variables = new ArrayList();
        for (Value<?> value : Utils.getAllValuesFromSinks(this)) {
            if (!(value instanceof RandomVariable)) continue;
            variables.add((RandomVariable)value);
        }
        return variables;
    }

    default public boolean isNamedDataValue(Value value) {
        return !value.isAnonymous() && !(value instanceof RandomVariable) && this.hasValue(value.getId(), Context.data);
    }

    default public boolean inDataBlock(Value value) {
        return this.getDataValues().contains(value);
    }

    default public double computeLogPosterior() {
        List<RandomVariable<?>> variables = this.getAllVariablesFromSinks();
        double logPosterior = 0.0;
        for (RandomVariable<?> variable : variables) {
            if (!this.isClampedVariable(variable)) {
                logPosterior += variable.getGenerativeDistribution().logDensity(variable.value);
                continue;
            }
            logPosterior += variable.getGenerativeDistribution().logDensity(this.getDataDictionary().get(variable.getId()));
        }
        return logPosterior;
    }

    default public void put(String id, Value value, Context context) {
        switch (context) {
            case data: {
                this.getDataDictionary().put(id, value);
                this.getDataValues().add(value);
                break;
            }
            default: {
                this.getModelDictionary().put(id, value);
                this.getModelValues().add(value);
            }
        }
    }

    public static enum Context {
        data,
        model;

    }

    public static class Utils {
        private static int wrapLength = 80;

        public static boolean isRandomVariableLine(String line) {
            return line.indexOf(126) > 0;
        }

        public static void wrapExpressionNodes(GraphicalModel model) {
            int wrappedExpressionNodeCount = 0;
            boolean found = false;
            do {
                for (Value<?> value : model.getModelSinks()) {
                    found = Utils.wrapExpressionNodes(value);
                    if (!found) continue;
                    ++wrappedExpressionNodeCount;
                }
            } while (found);
        }

        private static boolean wrapExpressionNodes(Value value) {
            for (GraphicalModelNode node : value.getInputs()) {
                ExpressionNode eNode;
                if (!(node instanceof ExpressionNode) || ExpressionNodeWrapper.expressionSubtreeSize(eNode = (ExpressionNode)node) <= 1) continue;
                ExpressionNodeWrapper wrapper = new ExpressionNodeWrapper((ExpressionNode)node);
                value.setFunction(wrapper);
                return true;
            }
            for (GraphicalModelNode node : value.getInputs()) {
                if (!(node instanceof Generator)) continue;
                Generator p = (Generator)node;
                for (GraphicalModelNode v : p.getInputs()) {
                    if (!(v instanceof Value)) continue;
                    return Utils.wrapExpressionNodes((Value)v);
                }
            }
            return false;
        }

        public static String getNarrative(final GraphicalModel model, Narrative narrative, boolean data, boolean includeModelBlock) {
            String valueNarrative;
            Integer count;
            String name;
            final HashMap nameCounts = new HashMap();
            final ArrayList dataVisited = new ArrayList();
            final ArrayList modelVisited = new ArrayList();
            StringBuilder builder = new StringBuilder();
            for (Value<?> value : model.getModelSinks()) {
                Value.traverseGraphicalModel(value, new GraphicalModelNodeVisitor(){

                    @Override
                    public void visitValue(Value value) {
                        if (model.inDataBlock(value)) {
                            if (!dataVisited.contains(value)) {
                                dataVisited.add(value);
                                String name = NarrativeUtils.getName(value);
                                if (!value.isAnonymous() && !model.isClamped(value.getId())) {
                                    nameCounts.merge(name, 1, Integer::sum);
                                }
                            }
                        } else if (!modelVisited.contains(value)) {
                            modelVisited.add(value);
                            String name = NarrativeUtils.getName(value);
                            nameCounts.merge(name, 1, Integer::sum);
                        }
                    }

                    @Override
                    public void visitGenerator(Generator generator) {
                    }
                }, false);
            }
            if (dataVisited.size() > 0 && data) {
                builder.append(narrative.section("Data"));
                for (Value<Object> dataValue : dataVisited) {
                    name = NarrativeUtils.getName(dataValue);
                    count = (Integer)nameCounts.get(name);
                    if (count == null) continue;
                    valueNarrative = dataValue.getNarrative(count == 1, narrative);
                    builder.append(valueNarrative);
                    if (valueNarrative.length() <= 0) continue;
                    builder.append("\n");
                }
                builder.append("\n\n");
            }
            if (modelVisited.size() > 0 && includeModelBlock) {
                builder.append(narrative.section("Model"));
                for (Value<Object> modelValue : modelVisited) {
                    name = NarrativeUtils.getName(modelValue);
                    count = (Integer)nameCounts.get(name);
                    if (count == null) continue;
                    valueNarrative = modelValue.getNarrative(count == 1, narrative);
                    builder.append(valueNarrative);
                    if (valueNarrative.length() <= 0) continue;
                    builder.append("\n");
                }
                builder.append("\n");
            }
            return builder.toString();
        }

        public static String getReferences(GraphicalModel model, Narrative narrative) {
            final ArrayList refs = new ArrayList();
            for (Value<?> value : model.getModelSinks()) {
                Value.traverseGraphicalModel(value, new GraphicalModelNodeVisitor(){

                    @Override
                    public void visitValue(Value value) {
                    }

                    @Override
                    public void visitGenerator(Generator generator) {
                        Citation citation = generator.getCitation();
                        if (citation != null && !refs.contains(citation)) {
                            refs.add(citation);
                        }
                    }
                }, false);
            }
            return narrative.referenceSection();
        }

        public static String getInferenceStatement(final GraphicalModel model, Narrative narrative) {
            final ArrayList modelVisited = new ArrayList();
            final ArrayList dataValues = new ArrayList();
            StringBuilder builder = new StringBuilder();
            for (Value<?> value2 : model.getModelSinks()) {
                Value.traverseGraphicalModel(value2, new GraphicalModelNodeVisitor(){

                    @Override
                    public void visitValue(Value value) {
                        if (!model.isNamedDataValue(value) && !modelVisited.contains(value)) {
                            modelVisited.add(value);
                            if (model.isClamped(value.getId()) || value.getOutputs().size() == 0) {
                                dataValues.add(value);
                            }
                        }
                    }

                    @Override
                    public void visitGenerator(Generator generator) {
                    }
                }, false);
            }
            if (modelVisited.size() > 0) {
                String name;
                builder.append(narrative.startMathMode(false, true));
                builder.append("P(");
                int count = 0;
                for (Value modelValue : modelVisited) {
                    if (dataValues.contains(modelValue) || !(modelValue instanceof RandomVariable)) continue;
                    if (count > 0) {
                        builder.append(", ");
                    }
                    name = narrative.getId(modelValue, false);
                    builder.append(name);
                    ++count;
                }
                if (dataValues.size() > 0) {
                    builder.append(" | ");
                }
                count = 0;
                for (Value dataValue : dataValues) {
                    name = narrative.getId(dataValue, false);
                    if (count > 0 && name != null) {
                        builder.append(", ");
                    }
                    if (name != null) {
                        builder.append(name);
                    }
                    ++count;
                }
                builder.append(") ");
                builder.append(narrative.symbol("\u221d"));
                builder.append(" ");
                builder.append(narrative.mathAlign());
                boolean bl = false;
                List randomVariables = modelVisited.stream().filter(value -> value instanceof RandomVariable).map(value -> (RandomVariable)value).collect(Collectors.toList());
                for (int i = 0; i < randomVariables.size(); ++i) {
                    RandomVariable modelVariable = (RandomVariable)randomVariables.get(i);
                    String statement = modelVariable.getGenerator().getInferenceStatement(modelVariable, narrative);
                    builder.append(statement);
                    if ((bl2 += statement.length()) <= wrapLength || i >= modelVisited.size() - 1) continue;
                    builder.append(narrative.mathNewLine());
                    builder.append(narrative.mathAlign());
                    boolean bl2 = false;
                    builder.append(" ");
                }
                builder.append(narrative.endMathMode());
                builder.append("\n");
            }
            return builder.toString();
        }

        public static List<Value<?>> getAllValuesFromSinks(GraphicalModel model) {
            ArrayList values = new ArrayList();
            for (Value<?> v : model.getModelSinks()) {
                Utils.getAllValues(v, values);
            }
            return values;
        }

        private static void getAllValues(GraphicalModelNode<?> node, List<Value<?>> values) {
            if (node instanceof Value && !values.contains(node)) {
                values.add((Value)node);
            }
            for (GraphicalModelNode childNode : node.getInputs()) {
                Utils.getAllValues(childNode, values);
            }
        }
    }
}

