/*
 * Decompiled with CFR 0.152.
 */
package lphy.evolution.coalescent;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import lphy.core.distributions.Utils;
import lphy.evolution.Taxa;
import lphy.evolution.tree.TaxaConditionedTreeGenerator;
import lphy.evolution.tree.TimeTree;
import lphy.evolution.tree.TimeTreeNode;
import lphy.graphicalModel.GeneratorInfo;
import lphy.graphicalModel.ParameterInfo;
import lphy.graphicalModel.RandomVariable;
import lphy.graphicalModel.Value;
import lphy.graphicalModel.types.DoubleArray2DValue;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.CombinatoricsUtils;

public class StructuredCoalescent
extends TaxaConditionedTreeGenerator {
    public static final String MParamName = "M";
    public static final String kParamName = "k";
    public static final String demesParamName = "demes";
    public static final String sortParamName = "sort";
    private Value<Double[][]> theta;
    private Value<Integer[]> k;
    private Value<Object[]> demes;
    private Value<Boolean> sort;
    RandomGenerator random;
    private List<String> uniqueDemes;
    private Map<Integer, String> reverseDemeToIndex;
    public static final String populationLabel = "deme";

    public static int countMigrations(TimeTree timeTree) {
        int migrationCount = 0;
        for (TimeTreeNode node : timeTree.getNodes()) {
            if (node.getChildCount() != 1 || node.getMetaData(populationLabel).equals(node.getChildren().get(0).getMetaData(populationLabel))) continue;
            ++migrationCount;
        }
        return migrationCount;
    }

    public StructuredCoalescent(@ParameterInfo(name="M", description="The population process rate matrix which contains the effective population sizes and migration rates. Off-diagonal migration rates are in units of expected migrants per *generation* backwards in time.") Value<Double[][]> theta, @ParameterInfo(name="k", description="the number of taxa in each population. provide either this or a demes argument.", optional=true) Value<Integer[]> k, @ParameterInfo(name="taxa", description="the taxa.", optional=true) Value<Taxa> taxa, @ParameterInfo(name="demes", description="the deme array, which runs parallel to the taxonArray in the taxa object.", optional=true) Value<Object[]> demes, @ParameterInfo(name="sort", description="whether to sort the deme array, before mapping them to the indices of the effective population sizes and migration rates. If not, as default, the pop size indices are determined by the natural order of the deme array, if true, then the indices are the order of sorted deme array.", optional=true) Value<Boolean> sort) {
        super(null, taxa, null);
        this.theta = theta;
        this.k = k;
        this.demes = demes;
        this.sort = sort;
        if (taxa == null && k == null) {
            throw new IllegalArgumentException("One of taxa and k must be specified!");
        }
        int count = (k != null ? 1 : 0) + (demes != null ? 1 : 0);
        if (count != 1) {
            throw new IllegalArgumentException("Exactly one of demes and k must be specified!");
        }
        this.random = Utils.getRandom();
        this.initDemes();
    }

    @Override
    public int n() {
        if (this.k != null) {
            int[] sum = new int[]{0};
            Stream.of(this.k.value()).forEach(i -> {
                sum[0] = sum[0] + i;
            });
            return sum[0];
        }
        return super.n();
    }

    @Override
    @GeneratorInfo(name="StructuredCoalescent", description="The structured coalescent distribution over tip-labelled time trees.")
    public RandomVariable<TimeTree> sample() {
        Taxa taxa = this.getTaxa();
        TimeTree tree = new TimeTree(taxa);
        ArrayList<TimeTreeNode> leavesToBeAdded = new ArrayList<TimeTreeNode>();
        ArrayList<List<TimeTreeNode>> activeNodes = new ArrayList<List<TimeTreeNode>>();
        double time = 0.0;
        if (this.k != null && !this.isSort()) {
            int count = 0;
            for (int i = 0; i < this.k.value().length; ++i) {
                activeNodes.add(new ArrayList());
                for (int j = 0; j < this.k.value()[i]; ++j) {
                    TimeTreeNode node = new TimeTreeNode("" + count, tree);
                    node.setIndex(count);
                    node.setMetaData(populationLabel, i);
                    node.setAge(0.0);
                    ((List)activeNodes.get(i)).add(node);
                    ++count;
                }
            }
        } else {
            List<String> uniqueDemes = this.getUniqueDemes();
            if (uniqueDemes.size() != this.theta.value().length) {
                throw new RuntimeException("The number of unique demes " + uniqueDemes.size() + " does not match the dimension of theta " + this.theta.value().length + " !");
            }
            for (int i = 0; i < uniqueDemes.size(); ++i) {
                activeNodes.add(new ArrayList());
            }
            Object[] demesVal = this.demes != null ? this.demes.value() : (Object[])this.k.value();
            for (int i = 0; i < demesVal.length; ++i) {
                String deme = String.valueOf(demesVal[i]);
                int demeIndex = uniqueDemes.indexOf(deme);
                if (demeIndex < 0) {
                    throw new IllegalArgumentException();
                }
                TimeTreeNode node = new TimeTreeNode(taxa.getTaxon(i), tree);
                node.setIndex(i);
                node.setMetaData(populationLabel, demeIndex);
                if (node.getAge() <= time) {
                    ((List)activeNodes.get(demeIndex)).add(node);
                    continue;
                }
                leavesToBeAdded.add(node);
            }
        }
        leavesToBeAdded.sort((o1, o2) -> Double.compare(o2.getAge(), o1.getAge()));
        TimeTreeNode root = this.simulateStructuredCoalescentForest(tree, activeNodes, leavesToBeAdded, this.theta.value(), Double.POSITIVE_INFINITY).get(0);
        tree.setRoot(root);
        this.sanitiseIntegerNames(tree);
        return new RandomVariable<TimeTree>("\u03c8", tree, this);
    }

    private void initDemes() {
        this.uniqueDemes = new ArrayList<String>();
        this.reverseDemeToIndex = new HashMap<Integer, String>();
        Object[] demesVal = this.k != null ? (Object[])this.k.value() : this.demes.value();
        LinkedHashSet<Object> demesSet = new LinkedHashSet<Object>(Arrays.asList(demesVal));
        for (Object e : demesSet) {
            this.uniqueDemes.add(String.valueOf(e));
        }
        if (this.isSort()) {
            Collections.sort(this.uniqueDemes);
        }
        for (int i = 0; i < demesVal.length; ++i) {
            String string = String.valueOf(demesVal[i]);
            int demeIndex = this.uniqueDemes.indexOf(string);
            if (demeIndex < 0) {
                throw new IllegalArgumentException();
            }
            this.reverseDemeToIndex.put(demeIndex, string);
        }
    }

    private void sanitiseIntegerNames(TimeTree tree) {
        if (this.k != null) {
            for (TimeTreeNode node : tree.getNodes()) {
                Integer demeIndex = this.getDemeIndex(node);
                String properName = "deme_" + demeIndex;
                node.setMetaData(populationLabel, properName);
            }
        } else {
            List<String> uniqueDemes = this.getUniqueDemes();
            for (TimeTreeNode node : tree.getNodes()) {
                Integer demeIndex = this.getDemeIndex(node);
                Object properName = uniqueDemes.get(demeIndex);
                try {
                    Integer.parseInt((String)properName);
                    properName = "deme_" + (String)properName;
                }
                catch (NumberFormatException numberFormatException) {
                    // empty catch block
                }
                node.setMetaData(populationLabel, properName);
            }
        }
    }

    private Integer getDemeIndex(TimeTreeNode node) {
        Object demeIndex = node.getMetaData(populationLabel);
        if (!(demeIndex instanceof Integer)) {
            throw new IllegalArgumentException("Metadata name should be Integer before this process !");
        }
        return (Integer)demeIndex;
    }

    private List<TimeTreeNode> simulateStructuredCoalescentForest(TimeTree tree, List<List<TimeTreeNode>> activeNodes, List<TimeTreeNode> leavesToBeAdded, Double[][] popSizesMigrationRates, double stopTime) {
        double[][] rates = new double[activeNodes.size()][activeNodes.size()];
        double totalRate = StructuredCoalescent.populateRateMatrix(activeNodes, popSizesMigrationRates, rates);
        double time = 0.0;
        int nodeNumber = this.getTotalNodeCount(activeNodes);
        while (time < stopTime && this.getTotalNodeCount(activeNodes) + leavesToBeAdded.size() > 1) {
            int k = this.getTotalNodeCount(activeNodes);
            if (k == 1) {
                time = leavesToBeAdded.get(leavesToBeAdded.size() - 1).getAge();
            } else {
                SCEvent event = this.selectRandomEvent(rates, totalRate, time);
                if (leavesToBeAdded.size() > 0 && event.time > leavesToBeAdded.get(leavesToBeAdded.size() - 1).getAge()) {
                    time = leavesToBeAdded.get(leavesToBeAdded.size() - 1).getAge();
                } else {
                    if (event.type == EventType.coalescent) {
                        TimeTreeNode node1 = this.selectRandomNode(activeNodes.get(event.pop));
                        TimeTreeNode node2 = this.selectRandomNode(activeNodes.get(event.pop));
                        TimeTreeNode parent = new TimeTreeNode((String)null, tree);
                        parent.setIndex(nodeNumber);
                        parent.setAge(event.time);
                        parent.setMetaData(populationLabel, event.pop);
                        parent.addChild(node1);
                        parent.addChild(node2);
                        time = event.time;
                        activeNodes.get(event.pop).add(parent);
                    } else {
                        if (event.pop == event.toPop) {
                            throw new RuntimeException("migration must be between distinct populations");
                        }
                        TimeTreeNode migrant = this.selectRandomNode(activeNodes.get(event.pop));
                        TimeTreeNode migrantsParent = new TimeTreeNode((String)null, tree);
                        migrantsParent.setIndex(nodeNumber);
                        migrantsParent.setAge(event.time);
                        migrantsParent.setMetaData(populationLabel, event.toPop);
                        migrantsParent.addChild(migrant);
                        time = event.time;
                        activeNodes.get(event.toPop).add(migrantsParent);
                    }
                    ++nodeNumber;
                }
            }
            while (leavesToBeAdded.size() > 0 && leavesToBeAdded.get(leavesToBeAdded.size() - 1).getAge() == time) {
                TimeTreeNode youngest = leavesToBeAdded.remove(leavesToBeAdded.size() - 1);
                activeNodes.get((Integer)youngest.getMetaData(populationLabel)).add(youngest);
            }
            totalRate = StructuredCoalescent.populateRateMatrix(activeNodes, popSizesMigrationRates, rates);
        }
        ArrayList<TimeTreeNode> rootNodes = new ArrayList<TimeTreeNode>();
        for (List<TimeTreeNode> nodeList : activeNodes) {
            rootNodes.addAll(nodeList);
        }
        return rootNodes;
    }

    private int getTotalNodeCount(List<List<TimeTreeNode>> nodes) {
        int count = 0;
        for (List<TimeTreeNode> nodeList : nodes) {
            count += nodeList.size();
        }
        return count;
    }

    private TimeTreeNode selectRandomNode(List<TimeTreeNode> nodes) {
        int index = Utils.getRandom().nextInt(nodes.size());
        TimeTreeNode node = nodes.remove(index);
        return node;
    }

    SCEvent selectRandomEvent(double[][] rates, double totalRate, double time) {
        double U = this.random.nextDouble() * totalRate;
        for (int i = 0; i < rates.length; ++i) {
            for (int j = 0; j < rates.length; ++j) {
                if (U > rates[i][j]) {
                    U -= rates[i][j];
                    continue;
                }
                double V = this.random.nextDouble();
                double etime = time + -Math.log(V) / totalRate;
                return new SCEvent(i, j, etime);
            }
        }
        throw new RuntimeException();
    }

    static double populateRateMatrix(List<List<TimeTreeNode>> nodes, Double[][] popSizesMigrationRates, double[][] rates) {
        double totalRate = 0.0;
        for (int i = 0; i < rates.length; ++i) {
            double popSizei = popSizesMigrationRates[i][i];
            int sampleSizei = nodes.get(i).size();
            rates[i][i] = sampleSizei < 2 ? 0.0 : (double)CombinatoricsUtils.binomialCoefficient((int)sampleSizei, (int)2) / popSizei;
            for (int j = 0; j < rates[i].length; ++j) {
                double popSizej = popSizesMigrationRates[j][j];
                if (i != j) {
                    rates[i][j] = (double)nodes.get(i).size() * (popSizesMigrationRates[i][j] * popSizej) / popSizei;
                }
                totalRate += rates[i][j];
            }
        }
        return totalRate;
    }

    @Override
    public double logDensity(TimeTree timeTree) {
        return Double.NaN;
    }

    @Override
    public Map<String, Value> getParams() {
        Map<String, Value> params = super.getParams();
        params.put(MParamName, this.theta);
        if (this.k != null) {
            params.put(kParamName, this.k);
        }
        if (this.demes != null) {
            params.put(demesParamName, this.demes);
        }
        if (this.sort != null) {
            params.put(sortParamName, this.sort);
        }
        return params;
    }

    @Override
    public void setParam(String paramName, Value value) {
        if (paramName.equals(MParamName)) {
            this.theta = value;
        } else if (paramName.equals(kParamName)) {
            this.k = value;
        } else if (paramName.equals(demesParamName)) {
            this.demes = value;
        } else if (paramName.equals(sortParamName)) {
            this.sort = value;
        } else {
            super.setParam(paramName, value);
        }
    }

    public Value<Double[][]> getM() {
        return this.theta;
    }

    public String getPopulationLabel() {
        return populationLabel;
    }

    public boolean isSort() {
        return this.sort != null && this.sort.value() != false;
    }

    public List<String> getUniqueDemes() {
        if (this.uniqueDemes == null) {
            throw new IllegalArgumentException();
        }
        return this.uniqueDemes;
    }

    @Override
    public String toString() {
        return this.getName();
    }

    public static void main(String[] args) {
        for (int n = 2; n < 10; ++n) {
            System.out.println(CombinatoricsUtils.binomialCoefficient((int)n, (int)2));
        }
        long reps = 1000L;
        double[] popSize1 = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
        double[] popSize2 = new double[]{1.0, 2.0, 4.0, 8.0, 16.0};
        System.out.println("pop0.leaf, pop1.leaf, pop0.mig, pop1.mig, pop0.coal, pop1.coal");
        for (double m = 0.125; m < 32.0; m *= 2.0) {
            for (int i = 0; i < popSize1.length; ++i) {
                long count = 0L;
                long migrations = 0L;
                int j = 0;
                while ((long)j < reps) {
                    DoubleArray2DValue theta = new DoubleArray2DValue("theta", new Double[][]{{popSize1[i], m}, {m, popSize2[i]}});
                    Value<Integer[]> k = new Value<Integer[]>(kParamName, new Integer[]{2, 2});
                    StructuredCoalescent coalescent = new StructuredCoalescent(theta, k, null, null, null);
                    RandomVariable<TimeTree> tree = coalescent.sample();
                    Object meta = ((TimeTree)tree.value()).getRoot().getMetaData(populationLabel);
                    String meta2 = String.valueOf(meta).substring(String.valueOf(meta).lastIndexOf("_") + 1);
                    Integer intLabel = Integer.parseInt(meta2);
                    count += intLabel == 0 ? 1L : 0L;
                    ++j;
                }
                System.out.println(popSize1[i] + "\t" + popSize2[i] + "\t" + m + "\t" + (double)count / (double)reps);
            }
        }
    }

    class SCEvent {
        int pop;
        int toPop;
        double time;
        EventType type;

        public SCEvent(int pop1, int pop2, double time) {
            this.pop = pop1;
            this.toPop = pop2;
            this.time = time;
            this.type = this.pop == this.toPop ? EventType.coalescent : EventType.migration;
        }
    }

    static enum EventType {
        coalescent,
        migration;

    }
}

