/*
 * Decompiled with CFR 0.152.
 */
package lphy.core.distributions;

import java.util.Map;
import java.util.TreeMap;
import lphy.graphicalModel.GenerativeDistribution;
import lphy.graphicalModel.GeneratorInfo;
import lphy.graphicalModel.ParameterInfo;
import lphy.graphicalModel.RandomVariable;
import lphy.graphicalModel.Value;
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;

public class MVN
implements GenerativeDistribution<Double[]> {
    private static final String covariancesParamName = "covariances";
    private Value<Double[]> mean;
    private Value<Double[][]> covariances;
    MultivariateNormalDistribution multivariateNormalDistribution;

    public MVN(@ParameterInfo(name="mean", description="the mean of the distribution.") Value<Double[]> mean, @ParameterInfo(name="covariances", description="the variance-covariance matrix of the distribution.") Value<Double[][]> covariances) {
        this.mean = mean;
        if (mean == null) {
            throw new IllegalArgumentException("The means can't be null!");
        }
        this.covariances = covariances;
        if (covariances == null) {
            throw new IllegalArgumentException("The covariances can't be null!");
        }
        double[] means = new double[mean.value().length];
        double[][] cv = new double[covariances.value().length][covariances.value().length];
        for (int i = 0; i < means.length; ++i) {
            means[i] = mean.value()[i];
            for (int j = 0; j < means.length; ++j) {
                cv[i][j] = this.covariances.value()[i][j];
            }
        }
        this.multivariateNormalDistribution = new MultivariateNormalDistribution(means, cv);
    }

    @Override
    @GeneratorInfo(name="MVN", description="The normal probability distribution.")
    public RandomVariable<Double[]> sample() {
        double[] sample = this.multivariateNormalDistribution.sample();
        Double[] result = new Double[sample.length];
        for (int i = 0; i < sample.length; ++i) {
            result[i] = sample[i];
        }
        return new RandomVariable<Double[]>("X", result, this);
    }

    @Override
    public double density(Double[] x) {
        double[] xx = new double[this.mean.value().length];
        for (int i = 0; i < x.length; ++i) {
            xx[i] = x[i];
        }
        return this.multivariateNormalDistribution.density(xx);
    }

    @Override
    public Map<String, Value> getParams() {
        return new TreeMap<String, Value>(){
            {
                this.put("mean", MVN.this.mean);
                this.put(MVN.covariancesParamName, MVN.this.covariances);
            }
        };
    }

    @Override
    public void setParam(String paramName, Value value) {
        switch (paramName) {
            case "mean": {
                this.mean = value;
                break;
            }
            case "covariances": {
                this.covariances = value;
                break;
            }
            default: {
                throw new RuntimeException("Unrecognised parameter name: " + paramName);
            }
        }
    }

    public String toString() {
        return this.getName();
    }
}

