/*
 * Decompiled with CFR 0.152.
 */
package lphy.evolution.continuous;

import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import lphy.core.StringDoubleArrayMap;
import lphy.core.distributions.Utils;
import lphy.evolution.alignment.ContinuousCharacterData;
import lphy.evolution.tree.TimeTree;
import lphy.evolution.tree.TimeTreeNode;
import lphy.graphicalModel.GenerativeDistribution;
import lphy.graphicalModel.ParameterInfo;
import lphy.graphicalModel.RandomVariable;
import lphy.graphicalModel.Value;
import lphy.graphicalModel.types.DoubleArrayValue;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
import org.apache.commons.math3.random.RandomGenerator;

public class PhyloMultivariateBrownian
implements GenerativeDistribution<ContinuousCharacterData> {
    Value<TimeTree> tree;
    Value<Double[][]> diffusionMatrix;
    Value<Double[]> y0;
    RandomGenerator random;
    public static final String treeParamName = "tree";
    public static final String diffusionMatrixParamName = "diffusionMatrix";
    public static final String y0ParamName = "y0";

    public PhyloMultivariateBrownian(@ParameterInfo(name="tree", description="the time tree.") Value<TimeTree> tree, @ParameterInfo(name="diffusionMatrix", description="the multivariate diffusion rates.") Value<Double[][]> diffusionRate, @ParameterInfo(name="y0", description="the value of multivariate traits at the root.") Value<Double[]> y0) {
        this.tree = tree;
        this.diffusionMatrix = diffusionRate;
        this.y0 = y0;
        this.random = Utils.getRandom();
    }

    @Override
    public RandomVariable<ContinuousCharacterData> sample() {
        TreeMap<String, Integer> idMap = new TreeMap<String, Integer>();
        this.fillIdMap(this.tree.value().getRoot(), idMap);
        StringDoubleArrayMap tipValues = new StringDoubleArrayMap();
        this.fillValuesTraversingTree(this.tree.value().getRoot(), this.y0, tipValues, this.diffusionMatrix.value(), idMap);
        Double[][] contData = new Double[this.tree.value().n()][this.y0.value().length];
        for (Map.Entry entry : tipValues.entrySet()) {
            contData[this.tree.value().getTaxa().indexOfTaxon((String)((String)entry.getKey()))] = (Double[])entry.getValue();
        }
        return new RandomVariable<ContinuousCharacterData>("x", new ContinuousCharacterData(this.tree.value().getTaxa(), contData), this);
    }

    private void fillIdMap(TimeTreeNode node, SortedMap<String, Integer> idMap) {
        if (node.isLeaf()) {
            Integer i = (Integer)idMap.get(node.getId());
            if (i == null) {
                int nextValue = 0;
                for (Integer j : idMap.values()) {
                    if (j < nextValue) continue;
                    nextValue = j + 1;
                }
                idMap.put(node.getId(), nextValue);
            }
        } else {
            for (TimeTreeNode child : node.getChildren()) {
                this.fillIdMap(child, idMap);
            }
        }
    }

    private void fillValuesTraversingTree(TimeTreeNode node, Value<Double[]> nodeState, Map<String, Double[]> tipValues, Double[][] diffusionMatrix, Map<String, Integer> idMap) {
        if (node.isLeaf()) {
            tipValues.put(node.getId(), nodeState.value());
        } else {
            for (TimeTreeNode child : node.getChildren()) {
                double branchLength = node.getAge() - child.getAge();
                Double[] newIntNodeState = this.getSampleFromNewMVN(nodeState.value(), diffusionMatrix, branchLength);
                DoubleArrayValue newIntNodeStateValue = new DoubleArrayValue(null, newIntNodeState);
                this.fillValuesTraversingTree(child, newIntNodeStateValue, tipValues, diffusionMatrix, idMap);
            }
        }
    }

    protected Double[] handleBoundaries(double[] rawValues) {
        return ArrayUtils.toObject((double[])rawValues);
    }

    Double[] getSampleFromNewMVN(Double[] oldValue, Double[][] diffusionMatrix, double branchLength) {
        double[] means = new double[oldValue.length];
        double[][] covariances = new double[diffusionMatrix.length][diffusionMatrix[0].length];
        for (int i = 0; i < covariances.length; ++i) {
            means[i] = oldValue[i];
            for (int j = 0; j < covariances.length; ++j) {
                covariances[i][j] = diffusionMatrix[i][j] * branchLength;
            }
        }
        MultivariateNormalDistribution mvn = new MultivariateNormalDistribution(means, covariances);
        return this.handleBoundaries(mvn.sample());
    }

    @Override
    public Map<String, Value> getParams() {
        return new TreeMap<String, Value>(){
            {
                this.put(PhyloMultivariateBrownian.treeParamName, PhyloMultivariateBrownian.this.tree);
                this.put(PhyloMultivariateBrownian.diffusionMatrixParamName, PhyloMultivariateBrownian.this.diffusionMatrix);
                this.put(PhyloMultivariateBrownian.y0ParamName, PhyloMultivariateBrownian.this.y0);
            }
        };
    }

    @Override
    public void setParam(String paramName, Value value) {
        switch (paramName) {
            case "tree": {
                this.tree = value;
                break;
            }
            case "diffusionMatrix": {
                this.diffusionMatrix = value;
                break;
            }
            case "y0": {
                this.y0 = value;
                break;
            }
            default: {
                throw new RuntimeException("Unrecognised parameter name: " + paramName);
            }
        }
    }
}

