/*
 * Decompiled with CFR 0.152.
 */
package lphy.evolution.substitutionmodel;

import lphy.evolution.substitutionmodel.RateMatrix;
import lphy.graphicalModel.DeterministicFunction;
import lphy.graphicalModel.GeneratorInfo;
import lphy.graphicalModel.ParameterInfo;
import lphy.graphicalModel.Value;
import lphy.graphicalModel.types.DoubleArray2DValue;

public class GeneralTimeReversible
extends RateMatrix {
    public static final String ratesParamName = "rates";
    public static final String freqParamName = "freq";
    int numStates;
    int ratesDim;

    public GeneralTimeReversible(@ParameterInfo(name="rates", description="the relative rates of the GTR process.") Value<Double[]> rates, @ParameterInfo(name="freq", description="the base frequencies.") Value<Double[]> freq, @ParameterInfo(name="meanRate", description="the base frequencies.", optional=true) Value<Number> meanRate) {
        super(meanRate);
        this.setParam(ratesParamName, (Value)rates);
        this.setParam(freqParamName, (Value)freq);
        this.update(rates, freq);
    }

    @Override
    @GeneratorInfo(name="generalTimeReversible", description="The general time reversible instantaneous rate matrix. Takes relative rates and base frequencies and produces an general time reversible rate matrix.")
    public Value<Double[][]> apply() {
        Value<Double[]> rates = this.getRates();
        Value<Double[]> freq = this.getFreq();
        this.update(rates, freq);
        return new DoubleArray2DValue(this.generalTimeReversible(rates.value(), freq.value()), (DeterministicFunction)this);
    }

    private void update(Value<Double[]> rates, Value<Double[]> freq) {
        this.numStates = freq.value().length;
        this.ratesDim = this.numStates * (this.numStates - 1) / 2;
        if (rates.value().length != this.ratesDim) {
            throw new RuntimeException("Expected dimension of " + this.ratesDim + " for the rates of a " + this.numStates + " state model.");
        }
    }

    private Double[][] generalTimeReversible(Double[] rates, Double[] freqs) {
        int i;
        Double[][] Q = new Double[this.numStates][this.numStates];
        double[] totalRates = new double[this.numStates];
        int upper = 0;
        for (i = 0; i < this.numStates; ++i) {
            for (int j = i + 1; j < this.numStates; ++j) {
                Q[i][j] = rates[upper] * freqs[j];
                Q[j][i] = rates[upper] * freqs[i];
                ++upper;
            }
        }
        for (i = 0; i < this.numStates; ++i) {
            double totalRate = 0.0;
            for (int j = 0; j < this.numStates; ++j) {
                if (j == i) continue;
                totalRate += Q[i][j].doubleValue();
            }
            Q[i][i] = -totalRate;
        }
        this.normalize(freqs, Q, this.totalRateDefault1());
        return Q;
    }

    public Value<Double[]> getRates() {
        return this.getParams().get(ratesParamName);
    }

    public Value<Double[]> getFreq() {
        return this.getParams().get(freqParamName);
    }
}

