/*
 * Decompiled with CFR 0.152.
 */
package lphy.evolution.likelihood;

import java.util.List;
import java.util.SortedMap;
import java.util.TreeMap;
import jebl.evolution.sequences.SequenceType;
import lphy.core.distributions.Categorical;
import lphy.core.distributions.Utils;
import lphy.evolution.alignment.Alignment;
import lphy.evolution.alignment.SimpleAlignment;
import lphy.evolution.tree.TimeTree;
import lphy.evolution.tree.TimeTreeNode;
import lphy.graphicalModel.Citation;
import lphy.graphicalModel.GenerativeDistribution;
import lphy.graphicalModel.GeneratorInfo;
import lphy.graphicalModel.ParameterInfo;
import lphy.graphicalModel.RandomVariable;
import lphy.graphicalModel.Value;
import lphy.graphicalModel.ValueUtils;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;

@Citation(value="Felsenstein, J. (1981). Evolutionary trees from DNA sequences: a maximum likelihood approach. Journal of molecular evolution, 17(6), 368-376.", title="Evolutionary trees from DNA sequences: a maximum likelihood approach", year=1981, authors={"Felsenstein"}, DOI="https://doi.org/10.1007/BF01734359")
public class PhyloCTMC
implements GenerativeDistribution<Alignment> {
    Value<TimeTree> tree;
    Value<Number> clockRate;
    Value<Double[]> freq;
    Value<Double[][]> Q;
    Value<Double[]> siteRates;
    Value<Double[]> branchRates;
    Value<Integer> L;
    Value<SequenceType> dataType;
    RandomGenerator random;
    public static final String treeParamName = "tree";
    public static final String muParamName = "mu";
    public static final String rootFreqParamName = "freq";
    public static final String QParamName = "Q";
    public static final String siteRatesParamName = "siteRates";
    public static final String branchRatesParamName = "branchRates";
    public static final String LParamName = "L";
    public static final String dataTypeParamName = "dataType";
    final int numStates;
    private EigenDecomposition decomposition;
    private double[][] Ievc;
    private double[][] Evec;
    private Value<Double[]> rootFreqs;
    private SortedMap<String, Integer> idMap = new TreeMap<String, Integer>();
    private double[][] transProb;
    private double[][] iexp;
    private double[] Eval;
    private static double EPSILON = 2.220446049250313E-16;

    public PhyloCTMC(@ParameterInfo(name="tree", verb="on", narrativeName="phylogenetic time tree", description="the time tree.") Value<TimeTree> tree, @ParameterInfo(name="mu", narrativeName="molecular clock rate", description="the clock rate. Default value is 1.0.", optional=true) Value<Number> mu, @ParameterInfo(name="freq", description="the root probabilities. Optional parameter. If not specified then first row of e^{100*Q) is used.", optional=true) Value<Double[]> rootFreq, @ParameterInfo(name="Q", narrativeName="instantaneous rate matrix", description="the instantaneous rate matrix.") Value<Double[][]> Q, @ParameterInfo(name="siteRates", description="a rate for each site in the alignment. Site rates are assumed to be 1.0 otherwise.", optional=true) Value<Double[]> siteRates, @ParameterInfo(name="branchRates", description="a rate for each branch in the tree. Branch rates are assumed to be 1.0 otherwise.", optional=true) Value<Double[]> branchRates, @ParameterInfo(name="L", narrativeName="length", description="length of the alignment", optional=true) Value<Integer> L, @ParameterInfo(name="dataType", description="the data type used for simulations, default to nucleotide", optional=true) Value<SequenceType> dataType) {
        this.tree = tree;
        this.Q = Q;
        this.freq = rootFreq;
        this.clockRate = mu;
        this.siteRates = siteRates;
        this.branchRates = branchRates;
        this.L = L;
        this.numStates = Q.value().length;
        this.random = Utils.getRandom();
        this.iexp = new double[this.numStates][this.numStates];
        this.checkCompatibilities();
    }

    private int checkCompatibilities() {
        if (this.L != null && this.siteRates != null && this.L.value() != this.siteRates.value().length) {
            throw new RuntimeException("L and siteRates have incompatible values!");
        }
        if (this.L != null) {
            return this.L.value();
        }
        if (this.siteRates != null) {
            return this.siteRates.value().length;
        }
        throw new RuntimeException("One of L or siteRates must be specified.");
    }

    public SortedMap<String, Value> getParams() {
        TreeMap<String, Value> map = new TreeMap<String, Value>();
        map.put(treeParamName, this.tree);
        if (this.clockRate != null) {
            map.put(muParamName, this.clockRate);
        }
        if (this.freq != null) {
            map.put(rootFreqParamName, this.freq);
        }
        map.put(QParamName, this.Q);
        if (this.siteRates != null) {
            map.put(siteRatesParamName, this.siteRates);
        }
        if (this.branchRates != null) {
            map.put(branchRatesParamName, this.branchRates);
        }
        if (this.L != null) {
            map.put(LParamName, this.L);
        }
        if (this.dataType != null) {
            map.put(dataTypeParamName, this.dataType);
        }
        return map;
    }

    @Override
    public void setParam(String paramName, Value value) {
        if (paramName.equals(treeParamName)) {
            this.tree = value;
        } else if (paramName.equals(muParamName)) {
            this.clockRate = value;
        } else if (paramName.equals(rootFreqParamName)) {
            this.freq = value;
        } else if (paramName.equals(QParamName)) {
            this.Q = value;
        } else if (paramName.equals(siteRatesParamName)) {
            this.siteRates = value;
        } else if (paramName.equals(branchRatesParamName)) {
            this.branchRates = value;
        } else if (paramName.equals(LParamName)) {
            this.L = value;
        } else if (paramName.equals(dataTypeParamName)) {
            this.dataType = value;
        } else {
            throw new RuntimeException("Unrecognised parameter name: " + paramName);
        }
    }

    private void setup() {
        this.idMap.clear();
        this.fillIdMap(this.tree.value().getRoot(), this.idMap);
        this.transProb = new double[this.numStates][this.numStates];
        double[][] primitive = new double[this.numStates][this.numStates];
        for (int i = 0; i < this.numStates; ++i) {
            for (int j = 0; j < this.numStates; ++j) {
                primitive[i][j] = this.Q.value()[i][j];
            }
        }
        Array2DRowRealMatrix Qmatrix = new Array2DRowRealMatrix(primitive);
        this.decomposition = new EigenDecomposition((RealMatrix)Qmatrix);
        this.Eval = this.decomposition.getRealEigenvalues();
        this.Ievc = new double[this.numStates][this.numStates];
        this.Evec = new double[this.numStates][this.numStates];
        for (int i = 0; i < this.numStates; ++i) {
            RealVector evec = this.decomposition.getEigenvector(i);
            for (int j = 0; j < this.numStates; ++j) {
                this.Evec[j][i] = evec.getEntry(j);
            }
        }
        PhyloCTMC.luinverse(this.Evec, this.Ievc, this.numStates);
        this.rootFreqs = this.freq;
        if (this.rootFreqs == null) {
            this.rootFreqs = this.computeEquilibrium(this.transProb);
        }
    }

    @Override
    @GeneratorInfo(name="PhyloCTMC", verbClause="is assumed to have evolved under", narrativeName="phylogenetic continuous time Markov process", description="The phylogenetic continuous-time Markov chain distribution. A sequence is simulated for every leaf node, and every direct ancestor node with an id.(The sampling distribution that the phylogenetic likelihood is derived from.)")
    public RandomVariable<Alignment> sample() {
        this.setup();
        SequenceType dt = SequenceType.NUCLEOTIDE;
        if (this.dataType != null) {
            dt = this.dataType.value();
        }
        int length = this.checkCompatibilities();
        SimpleAlignment a = new SimpleAlignment(this.idMap, length, dt);
        double mu = this.clockRate == null ? 1.0 : ValueUtils.doubleValue(this.clockRate);
        for (int i = 0; i < length; ++i) {
            int rootState = Categorical.sample(this.rootFreqs.value(), this.random);
            this.traverseTree(this.tree.value().getRoot(), rootState, a, i, this.transProb, mu, this.siteRates == null ? 1.0 : this.siteRates.value()[i]);
        }
        return new RandomVariable<Alignment>("D", a, this);
    }

    public Value<Double[]> getSiteRates() {
        return this.siteRates;
    }

    public Value<Double[]> getBranchRates() {
        return this.branchRates;
    }

    public Value<Number> getClockRate() {
        return this.clockRate;
    }

    public Value<Double[][]> getQ() {
        return this.Q;
    }

    public Value<TimeTree> getTree() {
        return this.tree;
    }

    public SequenceType getDataType() {
        if (this.dataType == null) {
            return SequenceType.NUCLEOTIDE;
        }
        return this.dataType.value();
    }

    private Value<Double[]> computeEquilibrium(double[][] transProb) {
        this.getTransitionProbabilities(100.0, transProb);
        Double[] freqs = new Double[transProb.length];
        for (int i = 0; i < freqs.length; ++i) {
            freqs[i] = transProb[0][i];
            for (int j = 1; j < freqs.length; ++j) {
                if (!(Math.abs(transProb[0][i] - transProb[j][i]) > 1.0E-5)) continue;
                System.out.println("WARNING: branch length used to get equilibrium distribution was not long enough!");
            }
        }
        return new Value<Double[]>(rootFreqParamName, freqs);
    }

    private void fillIdMap(TimeTreeNode node, SortedMap<String, Integer> idMap) {
        if (node.isLeaf() || node.getId() != null) {
            Integer i = (Integer)idMap.get(node.getId());
            if (i == null) {
                int nextValue = 0;
                for (Integer j : idMap.values()) {
                    if (j < nextValue) continue;
                    nextValue = j + 1;
                }
                idMap.put(node.getId(), nextValue);
                node.setLeafIndex(nextValue);
            } else {
                node.setLeafIndex(i);
            }
        }
        for (TimeTreeNode child : node.getChildren()) {
            this.fillIdMap(child, idMap);
        }
    }

    private void traverseTree(TimeTreeNode node, int nodeState, SimpleAlignment alignment, int pos, double[][] transProb, double clockRate, double siteRate) {
        if (node.isLeaf() || node.isSingleChildNonOrigin() && node.getId() != null) {
            alignment.setState(node.getLeafIndex(), pos, nodeState);
        }
        List<TimeTreeNode> children = node.getChildren();
        for (TimeTreeNode child : children) {
            double branchLength = siteRate * clockRate * (node.getAge() - child.getAge());
            if (this.branchRates != null) {
                branchLength *= this.branchRates.value()[child.getIndex()].doubleValue();
            }
            this.getTransitionProbabilities(branchLength, transProb);
            int state = this.drawState(transProb[nodeState]);
            this.traverseTree(child, state, alignment, pos, transProb, clockRate, siteRate);
        }
    }

    private int drawState(double[] p) {
        double totalP;
        double U = this.random.nextDouble();
        if (U <= (totalP = p[0])) {
            return 0;
        }
        for (int i = 1; i < p.length; ++i) {
            if (!(U <= (totalP += p[i]))) continue;
            return i;
        }
        throw new RuntimeException("p vector doesn't add to 1.0!");
    }

    private void getTransitionProbabilities(double branchLength, double[][] transProbs) {
        int j;
        double temp;
        int i;
        for (i = 0; i < this.numStates; ++i) {
            temp = FastMath.exp((double)(branchLength * this.Eval[i]));
            for (j = 0; j < this.numStates; ++j) {
                this.iexp[i][j] = this.Ievc[i][j] * temp;
            }
        }
        for (i = 0; i < this.numStates; ++i) {
            for (j = 0; j < this.numStates; ++j) {
                temp = 0.0;
                for (int k = 0; k < this.numStates; ++k) {
                    temp += this.Evec[i][k] * this.iexp[k][j];
                }
                transProbs[i][j] = FastMath.abs((double)temp);
            }
        }
    }

    private static void luinverse(double[][] inmat, double[][] imtrx, int size) throws IllegalArgumentException {
        double sum;
        double maxb;
        int j;
        int i;
        int maxi = 0;
        int[] index = new int[size];
        double[][] omtrx = new double[size][size];
        for (i = 0; i < size; ++i) {
            for (j = 0; j < size; ++j) {
                omtrx[i][j] = inmat[i][j];
            }
        }
        double[] wk = new double[size];
        double aw = 1.0;
        for (i = 0; i < size; ++i) {
            maxb = 0.0;
            for (j = 0; j < size; ++j) {
                if (!(Math.abs(omtrx[i][j]) > maxb)) continue;
                maxb = Math.abs(omtrx[i][j]);
            }
            if (maxb == 0.0) {
                System.err.println("Singular matrix encountered");
                throw new IllegalArgumentException("Singular matrix");
            }
            wk[i] = 1.0 / maxb;
        }
        for (j = 0; j < size; ++j) {
            double tmp;
            int k;
            for (i = 0; i < j; ++i) {
                sum = omtrx[i][j];
                for (k = 0; k < i; ++k) {
                    sum -= omtrx[i][k] * omtrx[k][j];
                }
                omtrx[i][j] = sum;
            }
            maxb = 0.0;
            for (i = j; i < size; ++i) {
                sum = omtrx[i][j];
                for (k = 0; k < j; ++k) {
                    sum -= omtrx[i][k] * omtrx[k][j];
                }
                omtrx[i][j] = sum;
                tmp = wk[i] * Math.abs(sum);
                if (!(tmp >= maxb)) continue;
                maxb = tmp;
                maxi = i;
            }
            if (j != maxi) {
                for (k = 0; k < size; ++k) {
                    tmp = omtrx[maxi][k];
                    omtrx[maxi][k] = omtrx[j][k];
                    omtrx[j][k] = tmp;
                }
                aw = -aw;
                wk[maxi] = wk[j];
            }
            index[j] = maxi;
            if (omtrx[j][j] == 0.0) {
                omtrx[j][j] = EPSILON;
            }
            if (j == size - 1) continue;
            tmp = 1.0 / omtrx[j][j];
            for (i = j + 1; i < size; ++i) {
                double[] dArray = omtrx[i];
                int n = j;
                dArray[n] = dArray[n] * tmp;
            }
        }
        for (int jx = 0; jx < size; ++jx) {
            int ix;
            for (ix = 0; ix < size; ++ix) {
                wk[ix] = 0.0;
            }
            wk[jx] = 1.0;
            int l = -1;
            for (i = 0; i < size; ++i) {
                int idx = index[i];
                sum = wk[idx];
                wk[idx] = wk[i];
                if (l != -1) {
                    for (j = l; j < i; ++j) {
                        sum -= omtrx[i][j] * wk[j];
                    }
                } else if (sum != 0.0) {
                    l = i;
                }
                wk[i] = sum;
            }
            for (i = size - 1; i >= 0; --i) {
                sum = wk[i];
                for (j = i + 1; j < size; ++j) {
                    sum -= omtrx[i][j] * wk[j];
                }
                wk[i] = sum / omtrx[i][i];
            }
            for (ix = 0; ix < size; ++ix) {
                imtrx[ix][jx] = wk[ix];
            }
        }
        wk = null;
        index = null;
        omtrx = null;
    }
}

