/*
 * Decompiled with CFR 0.152.
 */
package lphy.evolution.coalescent;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lphy.core.distributions.Exp;
import lphy.core.distributions.Utils;
import lphy.evolution.tree.TaxaConditionedTreeGenerator;
import lphy.evolution.tree.TimeTree;
import lphy.evolution.tree.TimeTreeNode;
import lphy.graphicalModel.Citation;
import lphy.graphicalModel.GeneratorInfo;
import lphy.graphicalModel.ParameterInfo;
import lphy.graphicalModel.RandomVariable;
import lphy.graphicalModel.Value;

@Citation(value="Kingman JFC. The Coalescent. Stochastic Processes and their Applications 13, 235-248 (1982)", title="The Coalescent", year=1982, authors={"Kingman"}, DOI="https://doi.org/10.1016/0304-4149(82)90011-4")
public class Coalescent
extends TaxaConditionedTreeGenerator {
    private Value<Double> theta;

    public Coalescent(@ParameterInfo(name="theta", description="effective population size, possibly scaled to mutations or calendar units.") Value<Double> theta, @ParameterInfo(name="n", description="the number of taxa. Provide this or taxa.", optional=true) Value<Integer> n, @ParameterInfo(name="taxa", description="a string array of taxa id or a taxa object (e.g. dataframe, alignment or tree). Provide this or n.", optional=true) Value taxa) {
        super(n, taxa, null);
        this.theta = theta;
        this.random = Utils.getRandom();
        this.checkTaxaParameters(true);
    }

    @Override
    @GeneratorInfo(name="Coalescent", narrativeName="Kingman's coalescent tree prior", description="The Kingman coalescent distribution over tip-labelled time trees.")
    public RandomVariable<TimeTree> sample() {
        TimeTree tree = new TimeTree();
        List<TimeTreeNode> activeNodes = this.createLeafTaxa(tree);
        double time = 0.0;
        double theta = this.theta.value();
        while (activeNodes.size() > 1) {
            int k = activeNodes.size();
            TimeTreeNode a = this.drawRandomNode(activeNodes);
            TimeTreeNode b = this.drawRandomNode(activeNodes);
            double rate = (double)k * ((double)k - 1.0) / (theta * 2.0);
            double x = -Math.log(this.random.nextDouble()) / rate;
            TimeTreeNode parent = new TimeTreeNode(time += x, new TimeTreeNode[]{a, b});
            activeNodes.add(parent);
        }
        tree.setRoot(activeNodes.get(0));
        return new RandomVariable<TimeTree>("\u03c8", tree, this);
    }

    @Override
    public double logDensity(TimeTree timeTree) {
        double[] ages = this.getInternalNodeAges(timeTree, null);
        Arrays.sort(ages);
        double age = 0.0;
        int k = timeTree.n();
        double logDensity = 0.0;
        double theta = this.theta.value();
        for (double age1 : ages) {
            double interval = age1 - age;
            logDensity -= (double)(k * (k - 1)) * interval / (2.0 * theta);
            age = age1;
            --k;
        }
        return logDensity -= (double)(timeTree.n() - 1) * Math.log(theta);
    }

    @Override
    public Map<String, Value> getParams() {
        Map<String, Value> map = super.getParams();
        map.put("theta", this.theta);
        return map;
    }

    @Override
    public void setParam(String paramName, Value value) {
        if (paramName.equals("theta")) {
            this.theta = value;
        } else {
            super.setParam(paramName, value);
        }
    }

    private double[] getInternalNodeAges(TimeTree timeTree, double[] ages) {
        if (ages == null) {
            ages = new double[timeTree.n() - 1];
        }
        if (ages.length != timeTree.n() - 1) {
            throw new IllegalArgumentException("Ages array size must one more than the number of internal nodes in the tree.");
        }
        int i = 0;
        for (TimeTreeNode node : timeTree.getNodes()) {
            if (node.isLeaf()) continue;
            ages[i] = node.getAge();
            ++i;
        }
        return ages;
    }

    public static void main(String[] args) {
        Value<Double> thetaExpPriorRate = new Value<Double>("r", 20.0);
        Exp exp = new Exp(thetaExpPriorRate);
        RandomVariable<Double> theta = exp.sample("\u0398");
        Value<Integer> n = new Value<Integer>("n", 20);
        Coalescent coalescent = new Coalescent(theta, n, null);
        RandomVariable<TimeTree> g = coalescent.sample();
    }

    public Value<Double> getTheta() {
        return this.theta;
    }
}

