/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.lossfunctions.impl;

import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonInclude(value=JsonInclude.Include.NON_NULL)
public class LossMixtureDensity
implements ILossFunction {
    private int mMixtures;
    private int mLabelWidth;
    private static final double SQRT_TWO_PI = Math.sqrt(Math.PI * 2);

    public LossMixtureDensity() {
    }

    private LossMixtureDensity(@JsonProperty(value="mixtures") int mixtures, @JsonProperty(value="labelWidth") int labelWidth) {
        this.mMixtures = mixtures;
        this.mLabelWidth = labelWidth;
    }

    public MixtureDensityComponents extractComponents(INDArray output) {
        long outputSize = output.size(1);
        if (outputSize != (long)((this.mLabelWidth + 2) * this.mMixtures)) {
            throw new IllegalArgumentException("Network output size " + outputSize + " must be (labels+2)*mixtures where labels = " + this.mLabelWidth + " and mixtures = " + this.mMixtures);
        }
        MixtureDensityComponents mdc = new MixtureDensityComponents();
        mdc.alpha = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, this.mMixtures));
        mdc.sigma = output.get(NDArrayIndex.all(), NDArrayIndex.interval(this.mMixtures, 2 * this.mMixtures));
        mdc.mu = output.get(NDArrayIndex.all(), NDArrayIndex.interval(2 * this.mMixtures, (this.mLabelWidth + 2) * this.mMixtures)).reshape(output.size(0), this.mMixtures, this.mLabelWidth);
        mdc.alpha = Nd4j.exec(new SoftMax(mdc.alpha, mdc.alpha, -1))[0];
        mdc.sigma = Transforms.exp(mdc.sigma);
        return mdc;
    }

    @Override
    public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
        INDArray scoreArr = this.computeScoreArray(labels, preOutput, activationFn, mask);
        double score = scoreArr.sumNumber().doubleValue();
        if (average) {
            score /= (double)scoreArr.size(0);
        }
        return score;
    }

    @Override
    public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        labels = labels.castTo(preOutput.dataType());
        INDArray output = activationFn.getActivation(preOutput.dup(), false);
        MixtureDensityComponents mdc = this.extractComponents(output);
        INDArray scoreArr = this.negativeLogLikelihood(labels, mdc.alpha, mdc.mu, mdc.sigma);
        if (mask != null) {
            LossUtil.applyMask(scoreArr, mask);
        }
        return scoreArr;
    }

    @Override
    public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        labels = labels.castTo(preOutput.dataType());
        long nSamples = labels.size(0);
        INDArray output = activationFn.getActivation(preOutput.dup(), false);
        MixtureDensityComponents mdc = this.extractComponents(output);
        INDArray gradient = Nd4j.zeros(nSamples, (long)preOutput.columns());
        INDArray labelsMinusMu = this.labelsMinusMu(labels, mdc.mu);
        INDArray labelsMinusMuSquared = labelsMinusMu.mul(labelsMinusMu).sum(2);
        INDArray variance = mdc.sigma.mul(mdc.sigma);
        INDArray minustwovariance = variance.mul(2).negi();
        INDArray normalPart = mdc.alpha.div(Transforms.pow(mdc.sigma.mul(SQRT_TWO_PI), this.mLabelWidth));
        INDArray exponent = labelsMinusMuSquared.div(minustwovariance);
        INDArray exponentMax = exponent.max(1);
        exponent.subiColumnVector(exponentMax);
        INDArray pi = Transforms.exp(exponent).muli(normalPart);
        INDArray piDivisor = pi.sum(true, 1);
        pi.diviColumnVector(piDivisor);
        INDArray dLdZAlpha = mdc.alpha.sub(pi);
        INDArray dLdZSigma = labelsMinusMuSquared.div(variance).subi(this.mLabelWidth).muli(-1).muli(pi);
        INDArray dLdZMu = Nd4j.create(nSamples, this.mMixtures, this.mLabelWidth);
        for (int k = 0; k < this.mLabelWidth; ++k) {
            dLdZMu.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(k)}, labelsMinusMu.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(k)).muli(pi).divi(variance).negi());
        }
        dLdZMu = dLdZMu.reshape(nSamples, (long)(this.mMixtures * this.mLabelWidth));
        gradient.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, this.mMixtures)}, dLdZAlpha);
        gradient.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(this.mMixtures, this.mMixtures * 2)}, dLdZSigma);
        gradient.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(this.mMixtures * 2, (this.mLabelWidth + 2) * this.mMixtures)}, dLdZMu);
        INDArray gradients = (INDArray)activationFn.backprop(preOutput, gradient).getFirst();
        if (mask != null) {
            LossUtil.applyMask(gradients, mask);
        }
        return gradients;
    }

    @Override
    public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
        double score = this.computeScore(labels, preOutput, activationFn, mask, average);
        INDArray gradient = this.computeGradient(labels, preOutput, activationFn, mask);
        Pair returnCode = new Pair((Object)score, (Object)gradient);
        return returnCode;
    }

    @Override
    public String name() {
        return "lossmixturedensity";
    }

    private INDArray negativeLogLikelihood(INDArray labels, INDArray alpha, INDArray mu, INDArray sigma) {
        INDArray labelsMinusMu = this.labelsMinusMu(labels, mu);
        INDArray diffsquared = labelsMinusMu.mul(labelsMinusMu).sum(2);
        INDArray phitimesalphasum = this.phi(diffsquared, sigma).muli(alpha).sum(true, 1);
        INDArray result = Transforms.log(phitimesalphasum).negi();
        return result;
    }

    private INDArray labelsMinusMu(INDArray labels, INDArray mu) {
        long nSamples = labels.size(0);
        long labelsPerSample = labels.size(1);
        INDArray labelMinusMu = Nd4j.zeros(nSamples, this.mMixtures, labelsPerSample);
        for (int k = 0; k < this.mMixtures; ++k) {
            labelMinusMu.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point(k), NDArrayIndex.all()}, labels);
        }
        labelMinusMu.subi(mu);
        return labelMinusMu;
    }

    private INDArray phi(INDArray diffSquared, INDArray sigma) {
        INDArray minustwovariance = sigma.mul(sigma).muli(2).negi();
        INDArray likelihoods = Transforms.exp(diffSquared.divi(minustwovariance)).divi(Transforms.pow(sigma.mul(SQRT_TWO_PI), this.mLabelWidth));
        return likelihoods;
    }

    @JsonProperty(value="mixtures")
    public int getNMixtures() {
        return this.mMixtures;
    }

    @JsonProperty(value="labelWidth")
    public int getLabelWidth() {
        return this.mLabelWidth;
    }

    public String toString() {
        return "LossMixtureDensity(mixtures=" + this.mMixtures + ", labels=" + this.mLabelWidth + ")";
    }

    public static Builder builder() {
        return new Builder();
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LossMixtureDensity)) {
            return false;
        }
        LossMixtureDensity other = (LossMixtureDensity)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.mMixtures != other.mMixtures) {
            return false;
        }
        return this.mLabelWidth == other.mLabelWidth;
    }

    protected boolean canEqual(Object other) {
        return other instanceof LossMixtureDensity;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.mMixtures;
        result = result * 59 + this.mLabelWidth;
        return result;
    }

    public static class Builder {
        private int mGaussians = 0;
        private int mLabelWidth = 0;

        private Builder() {
        }

        public Builder gaussians(int aGaussians) {
            this.mGaussians = aGaussians;
            return this;
        }

        public Builder labelWidth(int aLabelWidth) {
            this.mLabelWidth = aLabelWidth;
            return this;
        }

        public LossMixtureDensity build() {
            if (this.mGaussians <= 0) {
                throw new IllegalArgumentException("Mixture density cost function must specify the number of mixtures to fit");
            }
            if (this.mLabelWidth <= 0) {
                throw new IllegalArgumentException("Mixture density cost function must specify the size of the labels vectors");
            }
            return new LossMixtureDensity(this.mGaussians, this.mLabelWidth);
        }
    }

    public static class MixtureDensityComponents {
        private INDArray alpha;
        private INDArray mu;
        private INDArray sigma;

        public INDArray getAlpha() {
            return this.alpha;
        }

        public INDArray getMu() {
            return this.mu;
        }

        public INDArray getSigma() {
            return this.sigma;
        }

        public void setAlpha(INDArray alpha) {
            this.alpha = alpha;
        }

        public void setMu(INDArray mu) {
            this.mu = mu;
        }

        public void setSigma(INDArray sigma) {
            this.sigma = sigma;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof MixtureDensityComponents)) {
                return false;
            }
            MixtureDensityComponents other = (MixtureDensityComponents)o;
            if (!other.canEqual(this)) {
                return false;
            }
            INDArray this$alpha = this.getAlpha();
            INDArray other$alpha = other.getAlpha();
            if (this$alpha == null ? other$alpha != null : !this$alpha.equals(other$alpha)) {
                return false;
            }
            INDArray this$mu = this.getMu();
            INDArray other$mu = other.getMu();
            if (this$mu == null ? other$mu != null : !this$mu.equals(other$mu)) {
                return false;
            }
            INDArray this$sigma = this.getSigma();
            INDArray other$sigma = other.getSigma();
            return !(this$sigma == null ? other$sigma != null : !this$sigma.equals(other$sigma));
        }

        protected boolean canEqual(Object other) {
            return other instanceof MixtureDensityComponents;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            INDArray $alpha = this.getAlpha();
            result = result * 59 + ($alpha == null ? 43 : $alpha.hashCode());
            INDArray $mu = this.getMu();
            result = result * 59 + ($mu == null ? 43 : $mu.hashCode());
            INDArray $sigma = this.getSigma();
            result = result * 59 + ($sigma == null ? 43 : $sigma.hashCode());
            return result;
        }

        public String toString() {
            return "LossMixtureDensity.MixtureDensityComponents(alpha=" + this.getAlpha() + ", mu=" + this.getMu() + ", sigma=" + this.getSigma() + ")";
        }
    }
}

