/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.variational;

import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationHardSigmoid;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldLessThan;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BernoulliReconstructionDistribution
implements ReconstructionDistribution {
    private static final Logger log = LoggerFactory.getLogger(BernoulliReconstructionDistribution.class);
    private final IActivation activationFn;

    public BernoulliReconstructionDistribution() {
        this(Activation.SIGMOID);
    }

    public BernoulliReconstructionDistribution(Activation activationFn) {
        this(activationFn.getActivationFunction());
    }

    public BernoulliReconstructionDistribution(IActivation activationFn) {
        this.activationFn = activationFn;
        if (!(activationFn instanceof ActivationSigmoid) && !(activationFn instanceof ActivationHardSigmoid)) {
            log.warn("Using BernoulliRecontructionDistribution with activation function \"" + activationFn + "\". Using sigmoid/hard sigmoid is recommended to bound probabilities in range 0 to 1");
        }
    }

    @Override
    public boolean hasLossFunction() {
        return false;
    }

    @Override
    public int distributionInputSize(int dataSize) {
        return dataSize;
    }

    @Override
    public double negLogProbability(INDArray x, INDArray preOutDistributionParams, boolean average) {
        INDArray logProb = this.calcLogProbArray(x, preOutDistributionParams);
        if (average) {
            return -logProb.sumNumber().doubleValue() / (double)x.size(0);
        }
        return -logProb.sumNumber().doubleValue();
    }

    @Override
    public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) {
        INDArray logProb = this.calcLogProbArray(x, preOutDistributionParams);
        return logProb.sum(new int[]{1}).negi();
    }

    private INDArray calcLogProbArray(INDArray x, INDArray preOutDistributionParams) {
        INDArray output = preOutDistributionParams.dup();
        this.activationFn.getActivation(output, false);
        INDArray logOutput = Transforms.log((INDArray)output, (boolean)true);
        INDArray log1SubOut = Transforms.log((INDArray)output.rsubi((Number)1.0), (boolean)false);
        BooleanIndexing.replaceWhere((INDArray)logOutput, (Number)0.0, (Condition)Conditions.isInfinite());
        BooleanIndexing.replaceWhere((INDArray)log1SubOut, (Number)0.0, (Condition)Conditions.isInfinite());
        return logOutput.muli(x).addi(x.rsub((Number)1.0).muli(log1SubOut));
    }

    @Override
    public INDArray gradient(INDArray x, INDArray preOutDistributionParams) {
        INDArray output = preOutDistributionParams.dup();
        this.activationFn.getActivation(output, true);
        INDArray diff = x.sub(output);
        INDArray outOneMinusOut = output.rsub((Number)1.0).muli(output);
        INDArray grad = diff.divi(outOneMinusOut);
        grad = (INDArray)this.activationFn.backprop(preOutDistributionParams.dup(), grad).getFirst();
        BooleanIndexing.replaceWhere((INDArray)grad, (Number)0.0, (Condition)Conditions.isNan());
        return grad.negi();
    }

    @Override
    public INDArray generateRandom(INDArray preOutDistributionParams) {
        INDArray p = preOutDistributionParams.dup();
        this.activationFn.getActivation(p, false);
        INDArray rand = Nd4j.rand((long[])p.shape());
        INDArray out = Nd4j.createUninitialized((long[])p.shape());
        Nd4j.getExecutioner().execAndReturn((TransformOp)new OldLessThan(rand, p, out, p.length()));
        return out;
    }

    @Override
    public INDArray generateAtMean(INDArray preOutDistributionParams) {
        INDArray p = preOutDistributionParams.dup();
        this.activationFn.getActivation(p, false);
        return p;
    }

    public String toString() {
        return "BernoulliReconstructionDistribution(afn=" + this.activationFn + ")";
    }

    public IActivation getActivationFn() {
        return this.activationFn;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof BernoulliReconstructionDistribution)) {
            return false;
        }
        BernoulliReconstructionDistribution other = (BernoulliReconstructionDistribution)o;
        if (!other.canEqual(this)) {
            return false;
        }
        IActivation this$activationFn = this.getActivationFn();
        IActivation other$activationFn = other.getActivationFn();
        return !(this$activationFn == null ? other$activationFn != null : !this$activationFn.equals(other$activationFn));
    }

    protected boolean canEqual(Object other) {
        return other instanceof BernoulliReconstructionDistribution;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        IActivation $activationFn = this.getActivationFn();
        result = result * 59 + ($activationFn == null ? 43 : $activationFn.hashCode());
        return result;
    }
}

