/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.variational;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper;
import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.shade.jackson.annotation.JsonProperty;

public class CompositeReconstructionDistribution
implements ReconstructionDistribution {
    private final int[] distributionSizes;
    private final ReconstructionDistribution[] reconstructionDistributions;
    private final int totalSize;

    public CompositeReconstructionDistribution(@JsonProperty(value="distributionSizes") int[] distributionSizes, @JsonProperty(value="reconstructionDistributions") ReconstructionDistribution[] reconstructionDistributions, @JsonProperty(value="totalSize") int totalSize) {
        this.distributionSizes = distributionSizes;
        this.reconstructionDistributions = reconstructionDistributions;
        this.totalSize = totalSize;
    }

    private CompositeReconstructionDistribution(Builder builder) {
        this.distributionSizes = new int[builder.distributionSizes.size()];
        this.reconstructionDistributions = new ReconstructionDistribution[this.distributionSizes.length];
        int sizeCount = 0;
        for (int i = 0; i < this.distributionSizes.length; ++i) {
            this.distributionSizes[i] = (Integer)builder.distributionSizes.get(i);
            this.reconstructionDistributions[i] = (ReconstructionDistribution)builder.reconstructionDistributions.get(i);
            sizeCount += this.distributionSizes[i];
        }
        this.totalSize = sizeCount;
    }

    public INDArray computeLossFunctionScoreArray(INDArray data, INDArray reconstruction) {
        if (!this.hasLossFunction()) {
            throw new IllegalStateException("Cannot compute score array unless hasLossFunction() == true");
        }
        int inputSoFar = 0;
        int paramsSoFar = 0;
        INDArray reconstructionScores = null;
        for (int i = 0; i < this.distributionSizes.length; ++i) {
            int thisInputSize = this.distributionSizes[i];
            int thisParamsSize = this.reconstructionDistributions[i].distributionInputSize(thisInputSize);
            INDArray dataSubset = data.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)inputSoFar, (int)(inputSoFar + thisInputSize))});
            INDArray reconstructionSubset = reconstruction.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)paramsSoFar, (int)(paramsSoFar + thisParamsSize))});
            if (i == 0) {
                reconstructionScores = this.getScoreArray(this.reconstructionDistributions[i], dataSubset, reconstructionSubset);
            } else {
                reconstructionScores.addi(this.getScoreArray(this.reconstructionDistributions[i], dataSubset, reconstructionSubset));
            }
            inputSoFar += thisInputSize;
            paramsSoFar += thisParamsSize;
        }
        return reconstructionScores;
    }

    private INDArray getScoreArray(ReconstructionDistribution reconstructionDistribution, INDArray dataSubset, INDArray reconstructionSubset) {
        if (reconstructionDistribution instanceof LossFunctionWrapper) {
            ILossFunction lossFunction = ((LossFunctionWrapper)reconstructionDistribution).getLossFunction();
            return lossFunction.computeScoreArray(dataSubset, reconstructionSubset, (IActivation)new ActivationIdentity(), null);
        }
        if (reconstructionDistribution instanceof CompositeReconstructionDistribution) {
            return ((CompositeReconstructionDistribution)reconstructionDistribution).computeLossFunctionScoreArray(dataSubset, reconstructionSubset);
        }
        throw new UnsupportedOperationException("Cannot calculate composite reconstruction distribution");
    }

    @Override
    public boolean hasLossFunction() {
        for (ReconstructionDistribution rd : this.reconstructionDistributions) {
            if (rd.hasLossFunction()) continue;
            return false;
        }
        return true;
    }

    @Override
    public int distributionInputSize(int dataSize) {
        if (dataSize != this.totalSize) {
            throw new IllegalStateException("Invalid input size: Got input size " + dataSize + " for data, but expected input size for all distributions is " + this.totalSize + ". Distribution sizes: " + Arrays.toString(this.distributionSizes));
        }
        int sum = 0;
        for (int i = 0; i < this.distributionSizes.length; ++i) {
            sum += this.reconstructionDistributions[i].distributionInputSize(this.distributionSizes[i]);
        }
        return sum;
    }

    @Override
    public double negLogProbability(INDArray x, INDArray preOutDistributionParams, boolean average) {
        int inputSoFar = 0;
        int paramsSoFar = 0;
        double logProbSum = 0.0;
        for (int i = 0; i < this.distributionSizes.length; ++i) {
            int thisInputSize = this.distributionSizes[i];
            int thisParamsSize = this.reconstructionDistributions[i].distributionInputSize(thisInputSize);
            INDArray inputSubset = x.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)inputSoFar, (int)(inputSoFar + thisInputSize))});
            INDArray paramsSubset = preOutDistributionParams.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)paramsSoFar, (int)(paramsSoFar + thisParamsSize))});
            logProbSum += this.reconstructionDistributions[i].negLogProbability(inputSubset, paramsSubset, average);
            inputSoFar += thisInputSize;
            paramsSoFar += thisParamsSize;
        }
        return logProbSum;
    }

    @Override
    public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) {
        int inputSoFar = 0;
        int paramsSoFar = 0;
        INDArray exampleLogProbSum = null;
        for (int i = 0; i < this.distributionSizes.length; ++i) {
            int thisInputSize = this.distributionSizes[i];
            int thisParamsSize = this.reconstructionDistributions[i].distributionInputSize(thisInputSize);
            INDArray inputSubset = x.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)inputSoFar, (int)(inputSoFar + thisInputSize))});
            INDArray paramsSubset = preOutDistributionParams.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)paramsSoFar, (int)(paramsSoFar + thisParamsSize))});
            if (i == 0) {
                exampleLogProbSum = this.reconstructionDistributions[i].exampleNegLogProbability(inputSubset, paramsSubset);
            } else {
                exampleLogProbSum.addi(this.reconstructionDistributions[i].exampleNegLogProbability(inputSubset, paramsSubset));
            }
            inputSoFar += thisInputSize;
            paramsSoFar += thisParamsSize;
        }
        return exampleLogProbSum;
    }

    @Override
    public INDArray gradient(INDArray x, INDArray preOutDistributionParams) {
        int inputSoFar = 0;
        int paramsSoFar = 0;
        INDArray gradient = preOutDistributionParams.ulike();
        for (int i = 0; i < this.distributionSizes.length; ++i) {
            int thisInputSize = this.distributionSizes[i];
            int thisParamsSize = this.reconstructionDistributions[i].distributionInputSize(thisInputSize);
            INDArray inputSubset = x.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)inputSoFar, (int)(inputSoFar + thisInputSize))});
            INDArray paramsSubset = preOutDistributionParams.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)paramsSoFar, (int)(paramsSoFar + thisParamsSize))});
            INDArray grad = this.reconstructionDistributions[i].gradient(inputSubset, paramsSubset);
            gradient.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)paramsSoFar, (int)(paramsSoFar + thisParamsSize))}, grad);
            inputSoFar += thisInputSize;
            paramsSoFar += thisParamsSize;
        }
        return gradient;
    }

    @Override
    public INDArray generateRandom(INDArray preOutDistributionParams) {
        return this.randomSample(preOutDistributionParams, false);
    }

    @Override
    public INDArray generateAtMean(INDArray preOutDistributionParams) {
        return this.randomSample(preOutDistributionParams, true);
    }

    private INDArray randomSample(INDArray preOutDistributionParams, boolean isMean) {
        int inputSoFar = 0;
        int paramsSoFar = 0;
        INDArray out = Nd4j.createUninitialized((DataType)preOutDistributionParams.dataType(), (long[])new long[]{preOutDistributionParams.size(0), this.totalSize});
        for (int i = 0; i < this.distributionSizes.length; ++i) {
            int thisDataSize = this.distributionSizes[i];
            int thisParamsSize = this.reconstructionDistributions[i].distributionInputSize(thisDataSize);
            INDArray paramsSubset = preOutDistributionParams.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)paramsSoFar, (int)(paramsSoFar + thisParamsSize))});
            INDArray thisRandomSample = isMean ? this.reconstructionDistributions[i].generateAtMean(paramsSubset) : this.reconstructionDistributions[i].generateRandom(paramsSubset);
            out.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)inputSoFar, (int)(inputSoFar + thisDataSize))}, thisRandomSample);
            inputSoFar += thisDataSize;
            paramsSoFar += thisParamsSize;
        }
        return out;
    }

    public int[] getDistributionSizes() {
        return this.distributionSizes;
    }

    public ReconstructionDistribution[] getReconstructionDistributions() {
        return this.reconstructionDistributions;
    }

    public int getTotalSize() {
        return this.totalSize;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof CompositeReconstructionDistribution)) {
            return false;
        }
        CompositeReconstructionDistribution other = (CompositeReconstructionDistribution)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getTotalSize() != other.getTotalSize()) {
            return false;
        }
        if (!Arrays.equals(this.getDistributionSizes(), other.getDistributionSizes())) {
            return false;
        }
        return Arrays.deepEquals(this.getReconstructionDistributions(), other.getReconstructionDistributions());
    }

    protected boolean canEqual(Object other) {
        return other instanceof CompositeReconstructionDistribution;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getTotalSize();
        result = result * 59 + Arrays.hashCode(this.getDistributionSizes());
        result = result * 59 + Arrays.deepHashCode(this.getReconstructionDistributions());
        return result;
    }

    public String toString() {
        return "CompositeReconstructionDistribution(distributionSizes=" + Arrays.toString(this.getDistributionSizes()) + ", reconstructionDistributions=" + Arrays.deepToString(this.getReconstructionDistributions()) + ", totalSize=" + this.getTotalSize() + ")";
    }

    public static class Builder {
        private List<Integer> distributionSizes = new ArrayList<Integer>();
        private List<ReconstructionDistribution> reconstructionDistributions = new ArrayList<ReconstructionDistribution>();

        public Builder addDistribution(int distributionSize, ReconstructionDistribution distribution) {
            this.distributionSizes.add(distributionSize);
            this.reconstructionDistributions.add(distribution);
            return this;
        }

        public CompositeReconstructionDistribution build() {
            return new CompositeReconstructionDistribution(this);
        }
    }
}

