/*
 * Decompiled with CFR 0.152.
 */
package lphy.evolution.birthdeath;

import java.util.ArrayList;
import java.util.Map;
import java.util.TreeMap;
import lphy.core.distributions.Utils;
import lphy.evolution.birthdeath.FullBirthDeathTree;
import lphy.evolution.birthdeath.SimFossilsPoisson;
import lphy.evolution.tree.PruneTree;
import lphy.evolution.tree.TimeTree;
import lphy.evolution.tree.TimeTreeNode;
import lphy.graphicalModel.GenerativeDistribution;
import lphy.graphicalModel.GeneratorInfo;
import lphy.graphicalModel.ParameterInfo;
import lphy.graphicalModel.RandomVariable;
import lphy.graphicalModel.Value;
import org.apache.commons.math3.random.RandomGenerator;

public class SimFBDAge
implements GenerativeDistribution<TimeTree> {
    private Value<Number> birthRate;
    private Value<Number> deathRate;
    private Value<Number> psiVal;
    private Value<Double> fracVal;
    private Value<Number> originAge;
    RandomGenerator random;
    private static final int MAX_ATTEMPTS = 1000;

    public SimFBDAge(@ParameterInfo(name="lambda", description="per-lineage birth rate.") Value<Number> birthRate, @ParameterInfo(name="mu", description="per-lineage death rate.") Value<Number> deathRate, @ParameterInfo(name="frac", description="fraction of extant taxa sampled.") Value<Double> fracVal, @ParameterInfo(name="psi", description="per-lineage sampling-through-time rate.") Value<Number> psiVal, @ParameterInfo(name="originAge", description="the age of the origin.") Value<Number> originAge) {
        this.birthRate = birthRate;
        this.deathRate = deathRate;
        this.fracVal = fracVal;
        this.psiVal = psiVal;
        this.originAge = originAge;
        this.random = Utils.getRandom();
    }

    @Override
    @GeneratorInfo(name="SimFBDAge", description="A tree of extant species and those sampled through time, which is conceptually embedded in a full species tree produced by a speciation-extinction (birth-death) branching process.<br>Conditioned on origin age.")
    public RandomVariable<TimeTree> sample() {
        int attempts;
        int nonNullLeafCount = 0;
        TimeTree sampleTree = null;
        for (attempts = 0; nonNullLeafCount < 1 && attempts < 1000; ++attempts) {
            FullBirthDeathTree birthDeathTree = new FullBirthDeathTree(this.birthRate, this.deathRate, null, this.originAge);
            RandomVariable<TimeTree> fullTree = birthDeathTree.sample();
            SimFossilsPoisson simFossilsPoisson = new SimFossilsPoisson(fullTree, this.psiVal);
            RandomVariable<TimeTree> fullTreeWithFossils = simFossilsPoisson.sample();
            sampleTree = new TimeTree((TimeTree)fullTreeWithFossils.value());
            ArrayList<TimeTreeNode> leafNodes = new ArrayList<TimeTreeNode>();
            for (TimeTreeNode node : sampleTree.getNodes()) {
                if (!node.isLeaf() || node.getAge() != 0.0) continue;
                leafNodes.add(node);
            }
            int toNull = (int)Math.round((double)leafNodes.size() * (1.0 - this.fracVal.value()));
            ArrayList<TimeTreeNode> nullList = new ArrayList<TimeTreeNode>();
            for (int i = 0; i < toNull; ++i) {
                nullList.add((TimeTreeNode)leafNodes.remove(this.random.nextInt(leafNodes.size())));
            }
            for (TimeTreeNode node : nullList) {
                node.setId(null);
            }
            nonNullLeafCount = leafNodes.size();
        }
        if (attempts == 1000) {
            throw new RuntimeException("Failed to simulate SimFBDAge after 1000 attempts.");
        }
        PruneTree pruneTree = new PruneTree(new Value<Object>(null, sampleTree));
        TimeTree tree = pruneTree.apply().value();
        return new RandomVariable<TimeTree>(null, tree, this);
    }

    @Override
    public double logDensity(TimeTree timeTree) {
        throw new UnsupportedOperationException("Not implemented!");
    }

    @Override
    public Map<String, Value> getParams() {
        return new TreeMap<String, Value>(){
            {
                this.put("lambda", SimFBDAge.this.birthRate);
                this.put("mu", SimFBDAge.this.deathRate);
                this.put("frac", SimFBDAge.this.fracVal);
                this.put("psi", SimFBDAge.this.psiVal);
                this.put("originAge", SimFBDAge.this.originAge);
            }
        };
    }

    @Override
    public void setParam(String paramName, Value value) {
        switch (paramName) {
            case "lambda": {
                this.birthRate = value;
                break;
            }
            case "mu": {
                this.deathRate = value;
                break;
            }
            case "frac": {
                this.fracVal = value;
                break;
            }
            case "psi": {
                this.psiVal = value;
                break;
            }
            case "originAge": {
                this.originAge = value;
                break;
            }
            default: {
                throw new RuntimeException("Unexpected parameter " + paramName);
            }
        }
    }

    public Value<Number> getBirthRate() {
        return this.birthRate;
    }

    public Value<Number> getDeathRate() {
        return this.deathRate;
    }

    public Value<Double> getRho() {
        return this.fracVal;
    }

    public Value<Number> getPsi() {
        return this.psiVal;
    }
}

