/*
 * Decompiled with CFR 0.152.
 */
package lphy.evolution.coalescent;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lphy.core.distributions.Utils;
import lphy.evolution.Taxa;
import lphy.evolution.tree.TaxaConditionedTreeGenerator;
import lphy.evolution.tree.TimeTree;
import lphy.evolution.tree.TimeTreeNode;
import lphy.graphicalModel.Citation;
import lphy.graphicalModel.GeneratorInfo;
import lphy.graphicalModel.ParameterInfo;
import lphy.graphicalModel.RandomVariable;
import lphy.graphicalModel.Value;

@Citation(value="Drummond, A. J., Rambaut, A., Shapiro, B, & Pybus, O. G. (2005).\nBayesian coalescent inference of past population dynamics from molecular sequences.\nMolecular biology and evolution, 22(5), 1185-1192.", title="Bayesian coalescent inference of past population dynamics from molecular sequences", year=2005, authors={"Drummond", "Rambaut", "Shapiro", "Pybus"}, DOI="10.1093/molbev/msi103")
public class SkylineCoalescent
extends TaxaConditionedTreeGenerator {
    public static final String groupSizesParamName = "groupSizes";
    private Value<Double[]> theta;
    private Value<Integer[]> groupSizes;

    public SkylineCoalescent(@ParameterInfo(name="theta", narrativeName="population sizes", description="effective population size, one value for each group of coalescent intervals, ordered from present to past. Possibly scaled to mutations or calendar units. If no groupSizes are specified, then the number of coalescent intervals will be equal to the number of population size parameters.") Value<Double[]> theta, @ParameterInfo(name="groupSizes", narrativeName="group sizes", description="A tuple of group sizes. The sum of this tuple determines the number of coalescent events in the tree and thus the number of taxa. By default all group sizes are 1 which is equivalent to the classic skyline coalescent.", optional=true) Value<Integer[]> groupSizes, @ParameterInfo(name="n", description="number of taxa.", optional=true) Value<Integer> n, @ParameterInfo(name="taxa", description="Taxa object, (e.g. Taxa or Object[])", optional=true) Value<Taxa> taxa, @ParameterInfo(name="ages", description="an array of leaf node ages.", optional=true) Value<Double[]> ages) {
        super(n, taxa, ages);
        this.theta = theta;
        this.groupSizes = groupSizes;
        this.random = Utils.getRandom();
        int c = (ages == null ? 0 : 1) + (taxa == null ? 0 : 1) + (n == null ? 0 : 1);
        if (c > 1) {
            throw new IllegalArgumentException("One one of n, ages and taxa may be specified in " + this.getName());
        }
        this.checkThetaDimensions();
        super.checkTaxaParameters(false);
        this.checkDimensions();
    }

    private void checkThetaDimensions() {
        if (this.groupSizes != null && this.theta.value().length != this.groupSizes.value().length) {
            throw new IllegalArgumentException("groupSizes and theta arrays must be the same dimension.");
        }
    }

    private void checkDimensions() {
        boolean success = true;
        if (this.n != null && ((Integer)this.n.value()).intValue() != this.n()) {
            success = false;
        }
        if (this.ages != null && ((Double[])this.ages.value()).length != this.n()) {
            success = false;
        }
        if (!success) {
            throw new IllegalArgumentException("The number of theta values must be exactly one less than the number of taxa!");
        }
    }

    @Override
    protected int n() {
        if (this.groupSizes != null) {
            int sum = 0;
            Integer[] integerArray = this.groupSizes.value();
            int n = integerArray.length;
            for (int i = 0; i < n; ++i) {
                int groupSize = integerArray[i];
                sum += groupSize;
            }
            return sum + 1;
        }
        return this.theta.value().length + 1;
    }

    @Override
    @GeneratorInfo(name="SkylineCoalescent", verbClause="has", narrativeName="skyline coalescent prior", description="The skyline coalescent distribution over tip-labelled time trees. If no group sizes are specified, then there is one population parameter per coalescent event (as per classic skyline coalescent of Pybus, Rambaut and Harvey 2000)")
    public RandomVariable<TimeTree> sample() {
        TimeTree tree = new TimeTree(this.getTaxa());
        List<TimeTreeNode> leafNodes = this.createLeafTaxa(tree);
        ArrayList<TimeTreeNode> activeNodes = new ArrayList<TimeTreeNode>();
        ArrayList<TimeTreeNode> leavesToBeAdded = new ArrayList<TimeTreeNode>();
        double time = 0.0;
        for (TimeTreeNode leaf : leafNodes) {
            if (leaf.getAge() <= time) {
                activeNodes.add(leaf);
                continue;
            }
            leavesToBeAdded.add(leaf);
        }
        leavesToBeAdded.sort((o1, o2) -> Double.compare(o2.getAge(), o1.getAge()));
        Double[] theta = this.theta.value();
        int thetaIndex = 0;
        int groupIndex = 0;
        int countWithinGroup = 0;
        while (activeNodes.size() + leavesToBeAdded.size() > 1) {
            int k = activeNodes.size();
            if (k == 1) {
                time = ((TimeTreeNode)leavesToBeAdded.get(leavesToBeAdded.size() - 1)).getAge();
            } else {
                double rate = (double)k * ((double)k - 1.0) / (theta[thetaIndex] * 2.0);
                double x = -Math.log(this.random.nextDouble()) / rate;
                time += x;
                if (leavesToBeAdded.size() > 0 && time > ((TimeTreeNode)leavesToBeAdded.get(leavesToBeAdded.size() - 1)).getAge()) {
                    time = ((TimeTreeNode)leavesToBeAdded.get(leavesToBeAdded.size() - 1)).getAge();
                } else {
                    TimeTreeNode a = (TimeTreeNode)activeNodes.remove(this.random.nextInt(activeNodes.size()));
                    TimeTreeNode b = (TimeTreeNode)activeNodes.remove(this.random.nextInt(activeNodes.size()));
                    TimeTreeNode parent = new TimeTreeNode(time, new TimeTreeNode[]{a, b});
                    activeNodes.add(parent);
                    if (this.groupSizes != null) {
                        int groupSize = this.groupSizes.value()[groupIndex];
                        if (countWithinGroup == groupSize - 1) {
                            ++groupIndex;
                            countWithinGroup = 0;
                            ++thetaIndex;
                        } else {
                            ++countWithinGroup;
                        }
                    } else {
                        ++thetaIndex;
                    }
                }
            }
            while (leavesToBeAdded.size() > 0 && ((TimeTreeNode)leavesToBeAdded.get(leavesToBeAdded.size() - 1)).getAge() == time) {
                TimeTreeNode youngest = (TimeTreeNode)leavesToBeAdded.remove(leavesToBeAdded.size() - 1);
                activeNodes.add(youngest);
            }
        }
        tree.setRoot((TimeTreeNode)activeNodes.get(0));
        if (thetaIndex != theta.length) {
            throw new AssertionError((Object)("Programmer error in indexing " + thetaIndex + " the theta array " + theta.length + " during simulation!"));
        }
        if (this.groupSizes != null && (countWithinGroup != 0 || groupIndex != this.groupSizes.value().length)) {
            throw new AssertionError((Object)("Programmer error in indexing the groupSizes array during simulation." + countWithinGroup + " " + groupIndex + Arrays.toString((Object[])this.groupSizes.value())));
        }
        return new RandomVariable<TimeTree>("\u03c8", tree, this);
    }

    @Override
    public double logDensity(TimeTree timeTree) {
        return 0.0;
    }

    @Override
    public Map<String, Value> getParams() {
        Map<String, Value> map = super.getParams();
        map.put("theta", this.theta);
        if (this.groupSizes != null) {
            map.put(groupSizesParamName, this.groupSizes);
        }
        if (this.n != null) {
            map.put("n", this.n);
        }
        if (this.ages != null) {
            map.put("ages", this.ages);
        }
        return map;
    }

    @Override
    public void setParam(String paramName, Value value) {
        switch (paramName) {
            case "theta": {
                this.theta = value;
                break;
            }
            case "groupSizes": {
                this.groupSizes = value;
                break;
            }
            case "ages": {
                this.ages = value;
                break;
            }
            default: {
                super.setParam(paramName, value);
            }
        }
    }

    public Value<Double[]> getTheta() {
        return this.theta;
    }

    public Value<Integer[]> getGroupSizes() {
        return this.groupSizes;
    }
}

