/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.layers.recurrent;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.shade.guava.primitives.Booleans;

public class LSTMLayerBp
extends DynamicCustomOp {
    private LSTMLayerConfig configuration;
    private LSTMLayerWeights weights;
    private SDVariable cLast;
    private SDVariable yLast;
    private SDVariable maxTSLength;

    public LSTMLayerBp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, @NonNull LSTMLayerWeights weights, @NonNull LSTMLayerConfig configuration, SDVariable dLdh, SDVariable dLdhL, SDVariable dLdcL) {
        super("lstmLayer_bp", sameDiff, LSTMLayerBp.wrapFilterNull(x, weights.getWeights(), weights.getRWeights(), weights.getBias(), maxTSLength, yLast, cLast, weights.getPeepholeWeights(), dLdh, dLdhL, dLdcL));
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        if (x == null) {
            throw new NullPointerException("x is marked non-null but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked non-null but is null");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        this.configuration = configuration;
        this.weights = weights;
        this.cLast = cLast;
        this.yLast = yLast;
        this.maxTSLength = maxTSLength;
        this.addIArgument(this.iArgs());
        this.addTArgument(this.tArgs());
        this.addBArgument(this.bArgs(weights, maxTSLength, yLast, cLast));
        Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(), "You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence  methods  in LSTMLayerConfig builder to specify them");
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
        DataType dt = inputDataTypes.get(1);
        Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", (Object)dt);
        ArrayList<DataType> list = new ArrayList<DataType>();
        list.add(dt);
        list.add(dt);
        list.add(dt);
        if (this.weights.hasBias()) {
            list.add(dt);
        }
        if (this.maxTSLength != null) {
            list.add(dt);
        }
        if (this.yLast != null) {
            list.add(dt);
        }
        if (this.cLast != null) {
            list.add(dt);
        }
        if (this.weights.hasPH()) {
            list.add(dt);
        }
        return list;
    }

    @Override
    public String opName() {
        return "lstmLayer_bp";
    }

    @Override
    public Map<String, Object> propertiesForFunction() {
        return this.configuration.toProperties(true, true);
    }

    @Override
    public long[] iArgs() {
        return new long[]{this.configuration.getLstmdataformat().ordinal(), this.configuration.getDirectionMode().ordinal(), this.configuration.getGateAct().ordinal(), this.configuration.getOutAct().ordinal(), this.configuration.getCellAct().ordinal()};
    }

    @Override
    public double[] tArgs() {
        return new double[]{this.configuration.getCellClip()};
    }

    protected <T> boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) {
        return new boolean[]{weights.hasBias(), maxTSLength != null, yLast != null, cLast != null, weights.hasPH(), this.configuration.isRetFullSequence(), this.configuration.isRetLastH(), this.configuration.isRetLastC()};
    }

    @Override
    public boolean isConfigProperties() {
        return true;
    }

    @Override
    public String configFieldName() {
        return "configuration";
    }

    @Override
    public int getNumOutputs() {
        return Booleans.countTrue(true, true, true, this.weights.hasBias(), this.maxTSLength != null, this.yLast != null, this.cLast != null, this.weights.hasPH());
    }

    public LSTMLayerBp() {
    }

    public LSTMLayerConfig getConfiguration() {
        return this.configuration;
    }

    public LSTMLayerWeights getWeights() {
        return this.weights;
    }
}

