/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.recurrent;

import ai.djl.ndarray.NDList;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.recurrent.RNN;

public abstract class RecurrentCell
extends ParameterBlock {
    protected long stateSize;
    protected float dropRate;
    protected int numStackedLayers;
    protected String mode;
    protected boolean useSequenceLength;
    protected boolean useBidirectional;
    protected boolean stateOutputs;

    public RecurrentCell(BaseBuilder<?> builder) {
        this.stateSize = builder.stateSize;
        this.dropRate = builder.dropRate;
        this.numStackedLayers = builder.numStackedLayers;
        this.useSequenceLength = builder.useSequenceLength;
        this.useBidirectional = builder.useBidirectional;
        this.stateOutputs = builder.stateOutputs;
    }

    protected void validateInputSize(NDList inputs) {
        int numberofInputsRequired = 1;
        if (this.useSequenceLength) {
            numberofInputsRequired = 2;
        }
        if (inputs.size() != numberofInputsRequired) {
            throw new IllegalArgumentException("Invalid number of inputs for RNN. Size of input NDList must be " + numberofInputsRequired + " when useSequenceLength is " + this.useSequenceLength);
        }
    }

    public static abstract class BaseBuilder<T extends BaseBuilder> {
        protected float dropRate;
        protected long stateSize = -1L;
        protected int numStackedLayers = -1;
        protected double lstmStateClipMin;
        protected double lstmStateClipMax;
        protected boolean clipLstmState;
        protected boolean useSequenceLength;
        protected boolean useBidirectional;
        protected boolean stateOutputs;
        protected RNN.Activation activation;

        public T optDropRate(float dropRate) {
            this.dropRate = dropRate;
            return this.self();
        }

        public T optLstmStateClipMin(float lstmStateClipMin, float lstmStateClipMax) {
            this.lstmStateClipMin = lstmStateClipMin;
            this.lstmStateClipMax = lstmStateClipMax;
            this.clipLstmState = true;
            return this.self();
        }

        public T setStateSize(int stateSize) {
            this.stateSize = stateSize;
            return this.self();
        }

        public T setNumStackedLayers(int numStackedLayers) {
            this.numStackedLayers = numStackedLayers;
            return this.self();
        }

        public T setActivation(RNN.Activation activation) {
            this.activation = activation;
            return this.self();
        }

        public T setSequenceLength(boolean useSequenceLength) {
            this.useSequenceLength = useSequenceLength;
            return this.self();
        }

        public T optBidrectional(boolean useBidirectional) {
            this.useBidirectional = useBidirectional;
            return this.self();
        }

        public T optStateOutput(boolean stateOutputs) {
            this.stateOutputs = stateOutputs;
            return this.self();
        }

        protected abstract T self();
    }
}

