/*
 * Decompiled with CFR 0.152.
 */
package lphy.core.distributions;

import java.util.Map;
import java.util.TreeMap;
import lphy.core.distributions.Utils;
import lphy.graphicalModel.GenerativeDistribution;
import lphy.graphicalModel.GeneratorInfo;
import lphy.graphicalModel.ParameterInfo;
import lphy.graphicalModel.RandomVariable;
import lphy.graphicalModel.Value;
import lphy.graphicalModel.ValueUtils;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.random.RandomGenerator;

public class NormalGamma
implements GenerativeDistribution<Double[]> {
    private Value<Number> shape;
    private Value<Number> scale;
    private Value<Number> mean;
    private Value<Number> precision;
    private RandomGenerator random;
    NormalDistribution normalDistribution;
    static final String precisionParamName = "precision";

    public NormalGamma(@ParameterInfo(name="shape", description="the shape of the distribution.") Value<Number> shape, @ParameterInfo(name="scale", description="the scale of the distribution.") Value<Number> scale, @ParameterInfo(name="mean", description="the mean of the distribution.") Value<Number> mean, @ParameterInfo(name="precision", narrativeName="precision", description="the standard deviation of the distribution.") Value<Number> precision) {
        this.shape = shape;
        this.scale = scale;
        this.mean = mean;
        if (mean == null) {
            throw new IllegalArgumentException("The mean value can't be null!");
        }
        this.precision = precision;
        if (precision == null) {
            throw new IllegalArgumentException("The precision value can't be null!");
        }
        this.random = Utils.getRandom();
    }

    @Override
    @GeneratorInfo(name="NormalGamma", verbClause="has", narrativeName="normal-gamma prior", description="The normal-gamma probability distribution.")
    public RandomVariable<Double[]> sample() {
        double m = ValueUtils.doubleValue(this.mean);
        double sh = ValueUtils.doubleValue(this.shape);
        double sc = ValueUtils.doubleValue(this.scale);
        double lambda = ValueUtils.doubleValue(this.precision);
        GammaDistribution gammaDistribution = new GammaDistribution(this.random, sh, sc);
        double T = gammaDistribution.sample();
        this.normalDistribution = new NormalDistribution(this.random, m, lambda * T);
        double x = this.normalDistribution.sample();
        return new RandomVariable<Double[]>(null, new Double[]{x, T}, this);
    }

    @Override
    public double density(Double[] x) {
        double m = ValueUtils.doubleValue(this.mean);
        double sh = ValueUtils.doubleValue(this.shape);
        double sc = ValueUtils.doubleValue(this.scale);
        double lambda = ValueUtils.doubleValue(this.precision);
        GammaDistribution gammaDistribution = new GammaDistribution(this.random, sh, sc);
        this.normalDistribution = new NormalDistribution(this.random, m, lambda * x[0]);
        return gammaDistribution.density(x[0].doubleValue()) * this.normalDistribution.density(x[1].doubleValue());
    }

    @Override
    public Map<String, Value> getParams() {
        return new TreeMap<String, Value>(){
            {
                this.put("shape", NormalGamma.this.shape);
                this.put("scale", NormalGamma.this.scale);
                this.put("mean", NormalGamma.this.mean);
                this.put(NormalGamma.precisionParamName, NormalGamma.this.precision);
            }
        };
    }

    @Override
    public void setParam(String paramName, Value value) {
        switch (paramName) {
            case "mean": {
                this.mean = value;
                break;
            }
            case "precision": {
                this.precision = value;
                break;
            }
            case "shape": {
                this.shape = value;
                break;
            }
            case "scale": {
                this.scale = value;
                break;
            }
            default: {
                throw new RuntimeException("Unrecognised parameter name: " + paramName);
            }
        }
    }

    public String toString() {
        return this.getName();
    }

    public Value<Number> getMean() {
        return this.mean;
    }

    public Value<Number> getPrecision() {
        return this.precision;
    }
}

