/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp;

import ai.djl.MalformedModelException;
import ai.djl.modality.nlp.Decoder;
import ai.djl.modality.nlp.Encoder;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;

public class EncoderDecoder
extends AbstractBlock {
    private static final byte VERSION = 1;
    protected Encoder encoder;
    protected Decoder decoder;

    public EncoderDecoder(Encoder encoder, Decoder decoder) {
        super((byte)1);
        this.encoder = this.addChildBlock("Encoder", encoder);
        this.decoder = this.addChildBlock("Decoder", decoder);
        this.inputNames = Arrays.asList("encoderInput", "decoderInput");
    }

    @Override
    public PairList<String, Shape> describeInput() {
        if (!this.isInitialized()) {
            throw new IllegalStateException("Parameter of this block are not initialised");
        }
        return new PairList<String, Shape>(this.inputNames, Arrays.asList(this.inputShapes));
    }

    @Override
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        if (training) {
            throw new IllegalArgumentException("You must use forward with labels when training");
        }
        throw new UnsupportedOperationException("EncoderDecoder prediction has not been implemented yet");
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training) {
        return this.forward(parameterStore, inputs, training, null);
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList data, NDList labels, PairList<String, Object> params) {
        NDList encoderOutputs = this.encoder.forward(parameterStore, data, true, params);
        labels.addAll(this.encoder.getStates(encoderOutputs));
        return this.decoder.forward(parameterStore, labels, true, params);
    }

    @Override
    public Shape[] initialize(NDManager manager, DataType dataType, Shape ... inputShapes) {
        this.beforeInitialize(inputShapes);
        this.encoder.initialize(manager, dataType, inputShapes[0]);
        return this.decoder.initialize(manager, dataType, inputShapes[1]);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        return this.decoder.getOutputShapes(manager, new Shape[]{inputShapes[1]});
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        this.encoder.saveParameters(os);
        this.decoder.saveParameters(os);
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        this.encoder.loadParameters(manager, is);
        this.decoder.loadParameters(manager, is);
    }
}

