/*
 * Decompiled with CFR 0.152.
 */
package lphy.evolution.branchrates;

import java.util.Map;
import lphy.evolution.tree.TimeTree;
import lphy.evolution.tree.TimeTreeNode;
import lphy.graphicalModel.DeterministicFunction;
import lphy.graphicalModel.GeneratorInfo;
import lphy.graphicalModel.ParameterInfo;
import lphy.graphicalModel.Value;

public class LocalBranchRates
extends DeterministicFunction<Double[]> {
    public static final String treeParamName = "tree";
    public static final String indicatorsParamName = "indicators";
    public static final String ratesParamName = "rates";

    public LocalBranchRates(@ParameterInfo(name="tree", description="the tree.") Value<TimeTree> tree, @ParameterInfo(name="indicators", description="a boolean indicator for each node except the root. True if there is a new rate on the branch above this node, false if the rate is inherited from the parent node.") Value<Boolean[]> indicators, @ParameterInfo(name="rates", description="A rate for each node in the tree (except root). Only those with a corresponding indicator are used.") Value<Double[]> rates) {
        this.setParam(treeParamName, (Value)tree);
        this.setParam(indicatorsParamName, (Value)indicators);
        this.setParam(ratesParamName, (Value)rates);
    }

    @Override
    @GeneratorInfo(name="localBranchRates", description="A function that returns branch rates for the given tree, indicator mask and raw rates. Each branch takes on the rate of its node index if the indicator is true, or inherits the rate of its parent branch otherwise.")
    public Value<Double[]> apply() {
        Map<String, Value> params = this.getParams();
        Double[] rawRates = (Double[])params.get(ratesParamName).value();
        Boolean[] indicators = (Boolean[])params.get(indicatorsParamName).value();
        TimeTree tree = (TimeTree)params.get(treeParamName).value();
        Double[] branchRates = new Double[rawRates.length];
        this.traverseTree(tree.getRoot(), branchRates, rawRates, indicators);
        return new Value<Double[]>(branchRates, this);
    }

    private void traverseTree(TimeTreeNode node, Double[] branchRates, Double[] rawRates, Boolean[] indicators) {
        int nodeNumber = node.getIndex();
        branchRates[nodeNumber] = node.isRoot() || indicators[nodeNumber] != false ? rawRates[nodeNumber] : branchRates[node.getParent().getIndex()];
        for (TimeTreeNode child : node.getChildren()) {
            this.traverseTree(child, branchRates, rawRates, indicators);
        }
    }

    public Value<TimeTree> getTree() {
        return this.getParams().get(treeParamName);
    }

    public Value<Double[]> getRates() {
        return this.getParams().get(ratesParamName);
    }

    public Value<Boolean[]> getIndicators() {
        return this.getParams().get(indicatorsParamName);
    }
}

