/*
 * Decompiled with CFR 0.152.
 */
package lphy.evolution.birthdeath;

import java.util.List;
import java.util.Map;
import lphy.core.distributions.Utils;
import lphy.evolution.Taxa;
import lphy.evolution.tree.TaxaConditionedTreeGenerator;
import lphy.evolution.tree.TimeTree;
import lphy.evolution.tree.TimeTreeNode;
import lphy.graphicalModel.Citation;
import lphy.graphicalModel.GeneratorInfo;
import lphy.graphicalModel.ParameterInfo;
import lphy.graphicalModel.RandomVariable;
import lphy.graphicalModel.Value;
import lphy.graphicalModel.ValueUtils;
import org.apache.commons.math3.util.FastMath;

@Citation(value="Tanja Stadler, Ziheng Yang (2013) Dating Phylogenies with Sequentially Sampled Tips, Systematic Biology, 62(5):674\u2013688", title="Dating Phylogenies with Sequentially Sampled Tips", DOI="10.1093/sysbio/syt030", authors={"Stadler", "Yang"}, year=2013)
public class BirthDeathSerialSamplingTree
extends TaxaConditionedTreeGenerator {
    private Value<Number> birthRate;
    private Value<Number> deathRate;
    private Value<Number> psiVal;
    private Value<Number> rhoVal;
    private Value<Number> rootAge;
    private double c1;
    private double c2;
    private double gt;

    public BirthDeathSerialSamplingTree(@ParameterInfo(name="lambda", description="per-lineage birth rate.") Value<Number> birthRate, @ParameterInfo(name="mu", description="per-lineage death rate.") Value<Number> deathRate, @ParameterInfo(name="rho", description="proportion of extant taxa sampled.") Value<Number> rhoVal, @ParameterInfo(name="psi", description="per-lineage sampling-through-time rate.") Value<Number> psiVal, @ParameterInfo(name="n", description="the number of taxa. optional.", optional=true) Value<Integer> n, @ParameterInfo(name="taxa", description="Taxa object", optional=true) Value<Taxa> taxa, @ParameterInfo(name="ages", description="an array of leaf node ages.", optional=true) Value<Double[]> ages, @ParameterInfo(name="rootAge", description="the age of the root.") Value<Number> rootAge) {
        super(n, taxa, ages);
        this.birthRate = birthRate;
        this.deathRate = deathRate;
        this.rhoVal = rhoVal;
        this.psiVal = psiVal;
        this.rootAge = rootAge;
        this.ages = ages;
        this.random = Utils.getRandom();
        this.checkTaxaParameters(false);
    }

    @Override
    @GeneratorInfo(name="BirthDeathSerialSampling", description="A tree of extant species and those sampled through time, which is conceptually embedded in a full species tree produced by a speciation-extinction (birth-death) branching process.<br>Conditioned on root age and on number of taxa and their ages (Stadler and Yang, 2013).")
    public RandomVariable<TimeTree> sample() {
        double lambda = ValueUtils.doubleValue(this.birthRate);
        double mu = ValueUtils.doubleValue(this.deathRate);
        double rho = ValueUtils.doubleValue(this.rhoVal);
        double psi = ValueUtils.doubleValue(this.psiVal);
        double tmrca = ValueUtils.doubleValue(this.rootAge);
        this.c1 = Math.sqrt(Math.pow(lambda - mu - psi, 2.0) + 4.0 * lambda * psi);
        this.c2 = -(lambda - mu - 2.0 * lambda * rho - psi) / this.c1;
        this.gt = 1.0 / (FastMath.exp((double)(-this.c1 * tmrca)) * (1.0 - this.c2) + (1.0 + this.c2));
        TimeTree tree = this.randomTreeTopology();
        tree.getRoot().setAge(tmrca);
        this.drawDivTimes(tree);
        this.constructTree(tree);
        return new RandomVariable<TimeTree>("\u03c8", tree, this);
    }

    private TimeTree randomTreeTopology() {
        TimeTree tree = new TimeTree(this.getTaxa());
        List<TimeTreeNode> activeNodes = this.createLeafTaxa(tree);
        while (activeNodes.size() > 1) {
            TimeTreeNode a = this.drawRandomNode(activeNodes);
            TimeTreeNode b = this.drawRandomNode(activeNodes);
            TimeTreeNode parent = new TimeTreeNode(Math.max(a.getAge(), b.getAge()), new TimeTreeNode[]{a, b});
            activeNodes.add(parent);
        }
        tree.setRoot(activeNodes.get(0));
        return tree;
    }

    private int traverseTree(TimeTreeNode node, int i, int[] index) {
        if (!node.isLeaf()) {
            i = this.traverseTree(node.getChild(0), i, index);
            index[i] = node.getIndex();
            ++i;
            i = this.traverseTree(node.getChild(1), i, index);
        }
        return i;
    }

    private void drawDivTimes(TimeTree tree) {
        int[] index = new int[tree.n() - 1];
        this.traverseTree(tree.getRoot(), 0, index);
        for (int j : index) {
            if (j == tree.getRoot().getIndex()) continue;
            int k = tree.getNodeByIndex(j).getChild(1).getIndex();
            while (k >= tree.n()) {
                k = tree.getNodeByIndex(k).getChild(0).getIndex();
            }
            double z0 = tree.getNodeByIndex(k).getAge();
            k = tree.getNodeByIndex(j).getChild(0).getIndex();
            while (k >= tree.n()) {
                k = tree.getNodeByIndex(k).getChild(1).getIndex();
            }
            double z1 = tree.getNodeByIndex(k).getAge();
            double zstar = Math.max(z0, z1);
            double gzstar = 1.0 / (FastMath.exp((double)(-this.c1 * zstar)) * (1.0 - this.c2) + (1.0 + this.c2));
            double a2 = this.gt - gzstar;
            double constantChildren = 1.0 / (a2 * ((1.0 - this.c2) * Math.exp(-this.c1 * zstar) + (1.0 + this.c2)));
            double y = this.random.nextDouble();
            double x = Math.log(1.0 / (a2 * (y + constantChildren) * (1.0 - this.c2)) - (1.0 + this.c2) / (1.0 - this.c2)) / -this.c1;
            tree.getNodeByIndex(j).setAge(x);
        }
    }

    private void constructTree(TimeTree tree) {
        List<TimeTreeNode> nodes = tree.getNodes();
        double[] heights = new double[nodes.size()];
        int[] reverseOrder = new int[nodes.size()];
        this.collectHeights(tree.getRoot(), heights, reverseOrder, 0);
        TimeTreeNode root = this.constructTree(nodes, heights, reverseOrder, 0, heights.length, new boolean[heights.length]);
        tree.setRoot(root);
    }

    private TimeTreeNode constructTree(List<TimeTreeNode> nodes, double[] heights, int[] reverseOrder, int from, int to, boolean[] hasParent) {
        int nodeIndex = -1;
        double max = Double.NEGATIVE_INFINITY;
        for (int j = from; j < to; ++j) {
            if (!(max < heights[j]) || nodes.get(reverseOrder[j]).isLeaf()) continue;
            max = heights[j];
            nodeIndex = j;
        }
        if (nodeIndex < 0) {
            return null;
        }
        TimeTreeNode node = nodes.get(reverseOrder[nodeIndex]);
        int left = -1;
        max = Double.NEGATIVE_INFINITY;
        for (int j = from; j < nodeIndex; ++j) {
            if (!(max < heights[j]) || hasParent[j]) continue;
            max = heights[j];
            left = j;
        }
        int right = -1;
        max = Double.NEGATIVE_INFINITY;
        for (int j = nodeIndex + 1; j < to; ++j) {
            if (!(max < heights[j]) || hasParent[j]) continue;
            max = heights[j];
            right = j;
        }
        node.setLeft(nodes.get(reverseOrder[left]));
        node.setRight(nodes.get(reverseOrder[right]));
        if (node.getLeft().isLeaf()) {
            heights[left] = Double.NEGATIVE_INFINITY;
        }
        if (node.getRight().isLeaf()) {
            heights[right] = Double.NEGATIVE_INFINITY;
        }
        hasParent[left] = true;
        hasParent[right] = true;
        heights[nodeIndex] = Double.NEGATIVE_INFINITY;
        this.constructTree(nodes, heights, reverseOrder, from, nodeIndex, hasParent);
        this.constructTree(nodes, heights, reverseOrder, nodeIndex, to, hasParent);
        return node;
    }

    private int collectHeights(TimeTreeNode node, double[] heights, int[] reverseOrder, int current) {
        if (node.isLeaf()) {
            heights[current] = node.getAge();
            reverseOrder[current] = node.getIndex();
            ++current;
        } else {
            current = this.collectHeights(node.getLeft(), heights, reverseOrder, current);
            heights[current] = node.getAge();
            reverseOrder[current] = node.getIndex();
            ++current;
            current = this.collectHeights(node.getRight(), heights, reverseOrder, current);
        }
        return current;
    }

    @Override
    public double logDensity(TimeTree timeTree) {
        throw new UnsupportedOperationException("Not implemented!");
    }

    @Override
    public Map<String, Value> getParams() {
        Map<String, Value> map = super.getParams();
        map.put("lambda", this.birthRate);
        map.put("mu", this.deathRate);
        map.put("rho", this.rhoVal);
        map.put("psi", this.psiVal);
        map.put("rootAge", this.rootAge);
        return map;
    }

    @Override
    public void setParam(String paramName, Value value) {
        switch (paramName) {
            case "lambda": {
                this.birthRate = value;
                break;
            }
            case "mu": {
                this.deathRate = value;
                break;
            }
            case "rho": {
                this.rhoVal = value;
                break;
            }
            case "psi": {
                this.psiVal = value;
                break;
            }
            case "rootAge": {
                this.rootAge = value;
                break;
            }
            default: {
                super.setParam(paramName, value);
            }
        }
    }

    public Value<Number> getBirthRate() {
        return this.birthRate;
    }

    public Value<Number> getDeathRate() {
        return this.deathRate;
    }

    public Value<Number> getRho() {
        return this.rhoVal;
    }

    public Value<Number> getPsi() {
        return this.psiVal;
    }

    public Value<Number> getRootAge() {
        return this.rootAge;
    }
}

