/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.ops;

import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.ops.SDOps;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRUCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRULayerOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;

public class SDRNN
extends SDOps {
    public SDRNN(SameDiff sameDiff) {
        super(sameDiff);
    }

    public GRUCellOutputs gru(@NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (hLast == null) {
            throw new NullPointerException("hLast is marked @NonNull but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        GRUCell c = new GRUCell(this.sd, x, hLast, weights);
        return new GRUCellOutputs(c.outputVariables());
    }

    public GRUCellOutputs gru(String baseName, @NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (hLast == null) {
            throw new NullPointerException("hLast is marked @NonNull but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        GRUCell c = new GRUCell(this.sd, x, hLast, weights);
        return new GRUCellOutputs(c.outputVariables(baseName));
    }

    public LSTMCellOutputs lstmCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, LSTMWeights weights, LSTMConfiguration config) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (cLast == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (yLast == null) {
            throw new NullPointerException("yLast is marked @NonNull but is null");
        }
        LSTMBlockCell c = new LSTMBlockCell(this.sd, x, cLast, yLast, weights, config);
        return new LSTMCellOutputs(c.outputVariables());
    }

    public LSTMCellOutputs lstmCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (cLast == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (yLast == null) {
            throw new NullPointerException("yLast is marked @NonNull but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        if (config == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
        LSTMBlockCell c = new LSTMBlockCell(this.sd, x, cLast, yLast, weights, config);
        return new LSTMCellOutputs(c.outputVariables(baseName));
    }

    public LSTMLayerOutputs lstmLayer(@NonNull SDVariable maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config) {
        if (maxTSLength == null) {
            throw new NullPointerException("maxTSLength is marked @NonNull but is null");
        }
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (cLast == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (yLast == null) {
            throw new NullPointerException("yLast is marked @NonNull but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        if (config == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
        LSTMLayer c = new LSTMLayer(this.sd, maxTSLength, x, cLast, yLast, weights, config);
        return new LSTMLayerOutputs(c.outputVariables(), config.getDataFormat());
    }

    public LSTMLayerOutputs lstmLayer(int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (cLast == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (yLast == null) {
            throw new NullPointerException("yLast is marked @NonNull but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        if (config == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
        return this.lstmLayer(this.sd.scalar("lstm_max_ts_length", maxTSLength), x, cLast, yLast, weights, config);
    }

    public LSTMLayerOutputs lstmLayer(String baseName, int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (cLast == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (yLast == null) {
            throw new NullPointerException("yLast is marked @NonNull but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        if (config == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
        if (baseName != null) {
            return this.lstmLayer(baseName, this.sd.scalar(this.sd.generateDistinctCustomVariableName(baseName + "_max_ts_length"), maxTSLength), x, cLast, yLast, weights, config);
        }
        return this.lstmLayer(maxTSLength, x, cLast, yLast, weights, config);
    }

    public LSTMLayerOutputs lstmLayer(String baseName, @NonNull SDVariable maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config) {
        if (maxTSLength == null) {
            throw new NullPointerException("maxTSLength is marked @NonNull but is null");
        }
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (cLast == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (yLast == null) {
            throw new NullPointerException("yLast is marked @NonNull but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        if (config == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
        LSTMLayer c = new LSTMLayer(this.sd, maxTSLength, x, cLast, yLast, weights, config);
        return new LSTMLayerOutputs(c.outputVariables(baseName), config.getDataFormat());
    }

    public SRUCellOutputs sruCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (cLast == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return new SRUCellOutputs(new SRUCell(this.sd, x, cLast, weights).outputVariables());
    }

    public SRUCellOutputs sruCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (cLast == null) {
            throw new NullPointerException("cLast is marked @NonNull but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return new SRUCellOutputs(new SRUCell(this.sd, x, cLast, weights).outputVariables(baseName));
    }

    public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (initialC == null) {
            throw new NullPointerException("initialC is marked @NonNull but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return this.sru(x, initialC, null, weights);
    }

    public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (initialC == null) {
            throw new NullPointerException("initialC is marked @NonNull but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return this.sru(baseName, x, initialC, null, weights);
    }

    public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (initialC == null) {
            throw new NullPointerException("initialC is marked @NonNull but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return new SRULayerOutputs(new SRU(this.sd, x, initialC, mask, weights).outputVariables());
    }

    public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
        if (x == null) {
            throw new NullPointerException("x is marked @NonNull but is null");
        }
        if (initialC == null) {
            throw new NullPointerException("initialC is marked @NonNull but is null");
        }
        if (weights == null) {
            throw new NullPointerException("weights is marked @NonNull but is null");
        }
        return new SRULayerOutputs(new SRU(this.sd, x, initialC, mask, weights).outputVariables(baseName));
    }
}

