/*
 * Decompiled with CFR 0.152.
 */
package lphy.evolution.coalescent;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import lphy.core.distributions.Utils;
import lphy.evolution.Taxa;
import lphy.evolution.Taxon;
import lphy.evolution.tree.TimeTree;
import lphy.evolution.tree.TimeTreeNode;
import lphy.graphicalModel.GenerativeDistribution;
import lphy.graphicalModel.GeneratorInfo;
import lphy.graphicalModel.ParameterInfo;
import lphy.graphicalModel.RandomVariable;
import lphy.graphicalModel.Value;
import org.apache.commons.math3.random.RandomGenerator;

public class MultispeciesCoalescent
implements GenerativeDistribution {
    public static final String SParamName = "S";
    private Value<Double[]> theta;
    private Value<Integer[]> n;
    private Value<Integer> numLoci;
    private Value<Taxa> taxa;
    private Value<TimeTree> S;
    RandomGenerator random;
    public static final String separator = "_";
    Taxa geneTreeTaxa = null;

    public MultispeciesCoalescent(@ParameterInfo(name="theta", description="effective population sizes, one for each species (both extant and ancestral).") Value<Double[]> theta, @ParameterInfo(name="n", description="the number of sampled taxa in the gene tree for each extant species.", optional=true) Value<Integer[]> n, @ParameterInfo(name="taxa", description="the taxa for the gene tree, with species to define the mapping.", optional=true) Value<Taxa> taxa, @ParameterInfo(name="S", description="the species tree. ") Value<TimeTree> S) {
        List<TimeTreeNode> extant;
        this.theta = theta;
        this.n = n;
        this.taxa = taxa;
        this.S = S;
        this.random = Utils.getRandom();
        if (n != null && (extant = S.value().getExtantNodes()).size() != n.value().length) {
            throw new IllegalArgumentException("Length of n must be equal to the number of extant taxa in species tree provided");
        }
    }

    @GeneratorInfo(name="MultispeciesCoalescent", description="The Kingman coalescent distribution within each branch of species tree gives rise to a distribution over gene trees conditional on the species tree. The (optional) taxa object provides for non-trivial mappings from individuals to species, and not all species have to have representatives. The (optional) numLoci parameter can be used to produce more than one gene tree from this distribution.")
    public RandomVariable sample() {
        this.geneTreeTaxa = this.createGeneTreeTaxa();
        return new RandomVariable<TimeTree>(null, this.simulateGeneTree(), this);
    }

    private TimeTree simulateGeneTree() {
        TimeTree geneTree = new TimeTree(this.geneTreeTaxa);
        TreeMap<String, List<TimeTreeNode>> activeNodes = new TreeMap<String, List<TimeTreeNode>>();
        this.createActiveNodes(activeNodes, geneTree);
        List<TimeTreeNode> root = this.doSpeciesTreeBranch(this.S.value().getRoot(), activeNodes, this.theta.value());
        if (root.size() != 1) {
            throw new RuntimeException("Returned multiple gene roots from " + this.S.value().getRoot());
        }
        geneTree.setRoot(root.get(0));
        return geneTree;
    }

    public Taxa getGeneTreeTaxa() {
        return this.geneTreeTaxa;
    }

    public Taxa createGeneTreeTaxa() {
        ArrayList<Taxon> taxonList = new ArrayList<Taxon>();
        Taxon[] taxonArray = new Taxon[]{};
        if (this.n != null) {
            List<TimeTreeNode> extant = this.S.value().getExtantNodes();
            if (this.numericIds(extant)) {
                extant.sort(Comparator.comparingInt(o -> Integer.parseInt(o.getId())));
            } else {
                extant.sort(Comparator.comparing(TimeTreeNode::getId));
            }
            int i = 0;
            for (TimeTreeNode node : extant) {
                for (int k = 0; k < this.n.value()[i]; ++k) {
                    taxonList.add(new Taxon(node.getId() + separator + k, node.getId(), node.getAge()));
                }
            }
            taxonArray = taxonList.toArray(taxonArray);
        } else if (this.taxa != null) {
            taxonArray = this.taxa.value().getTaxonArray();
        } else {
            Taxa speciesTreeTaxa = this.S.value().getTaxa();
            for (Taxon speciesTaxon : speciesTreeTaxa.getTaxonArray()) {
                if (!speciesTaxon.isExtant()) continue;
                taxonList.add(new Taxon(speciesTaxon.getName() + "_0", speciesTaxon.getName(), speciesTaxon.getAge()));
            }
            taxonArray = taxonList.toArray(taxonArray);
        }
        return new Taxa.Simple(taxonArray);
    }

    private boolean numericIds(List<TimeTreeNode> nodes) {
        for (TimeTreeNode node : nodes) {
            try {
                int n = Integer.parseInt(node.getId());
            }
            catch (NumberFormatException nfe) {
                return false;
            }
        }
        return true;
    }

    private void createActiveNodes(Map<String, List<TimeTreeNode>> activeNodes, TimeTree geneTree) {
        for (Taxon taxon : this.geneTreeTaxa.getTaxonArray()) {
            List taxaInSp = activeNodes.computeIfAbsent(taxon.getSpecies(), k -> new ArrayList());
            taxaInSp.add(new TimeTreeNode(taxon, geneTree));
        }
    }

    private List<TimeTreeNode> doSpeciesTreeBranch(TimeTreeNode spNode, Map<String, List<TimeTreeNode>> allLeafActiveNodes, Double[] allThetas) {
        double x;
        List<Object> activeNodes;
        if (!spNode.isLeaf()) {
            activeNodes = new ArrayList();
            for (TimeTreeNode child : spNode.getChildren()) {
                activeNodes.addAll(this.doSpeciesTreeBranch(child, allLeafActiveNodes, allThetas));
            }
        } else {
            activeNodes = allLeafActiveNodes.get(spNode.getId());
            if (activeNodes == null) {
                activeNodes = new ArrayList();
            }
        }
        double theta = allThetas[spNode.getIndex()];
        for (double time = spNode.getAge(); activeNodes.size() > 1 && (spNode.getParent() == null || time < spNode.getParent().getAge()); time += x) {
            int k = activeNodes.size();
            double rate = (double)k * ((double)k - 1.0) / (theta * 2.0);
            x = -Math.log(this.random.nextDouble()) / rate;
            if (spNode.getParent() != null && !(time < spNode.getParent().getAge())) continue;
            TimeTreeNode a = (TimeTreeNode)activeNodes.remove(this.random.nextInt(activeNodes.size()));
            TimeTreeNode b = (TimeTreeNode)activeNodes.remove(this.random.nextInt(activeNodes.size()));
            TimeTreeNode parent = new TimeTreeNode(time, new TimeTreeNode[]{a, b});
            activeNodes.add(parent);
        }
        return activeNodes;
    }

    public double logDensity(Object timeTreeObject) {
        return 0.0;
    }

    public SortedMap<String, Value> getParams() {
        TreeMap<String, Value> map = new TreeMap<String, Value>();
        map.put("theta", this.theta);
        if (this.n != null) {
            map.put("n", this.n);
        }
        if (this.taxa != null) {
            map.put("taxa", this.taxa);
        }
        map.put(SParamName, this.S);
        return map;
    }

    @Override
    public void setParam(String paramName, Value value) {
        switch (paramName) {
            case "theta": {
                this.theta = value;
                break;
            }
            case "n": {
                this.n = value;
                break;
            }
            case "taxa": {
                this.taxa = value;
                break;
            }
            case "S": {
                this.S = value;
                break;
            }
            default: {
                throw new RuntimeException("Unrecognised parameter name: " + paramName);
            }
        }
    }

    public String toString() {
        return this.getName();
    }

    public Value<TimeTree> getSpeciesTree() {
        return this.S;
    }

    public Value<Double[]> getPopulationSizes() {
        return this.theta;
    }

    public Value<Integer[]> getN() {
        return this.n;
    }
}

