/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.zoo.nlp.qa;

import ai.djl.Model;
import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.bert.BertToken;
import ai.djl.modality.nlp.bert.BertTokenizer;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.modality.nlp.translator.QATranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.Utils;
import com.google.gson.annotations.SerializedName;
import java.io.IOException;
import java.net.URL;
import java.util.List;
import java.util.stream.Collectors;

public class MxBertQATranslator
extends QATranslator {
    private List<String> tokens;
    private Vocabulary vocabulary;
    private BertTokenizer tokenizer;
    private int seqLength;

    MxBertQATranslator(Builder builder) {
        super((QATranslator.BaseBuilder)builder);
        this.seqLength = builder.seqLength;
    }

    public void prepare(TranslatorContext ctx) throws IOException {
        Model model = ctx.getModel();
        this.vocabulary = DefaultVocabulary.builder().addFromCustomizedFile(model.getArtifact("vocab.json"), VocabParser::parseToken).optUnknownToken("[UNK]").build();
        this.tokenizer = new BertTokenizer();
    }

    public Batchifier getBatchifier() {
        return null;
    }

    public NDList processInput(TranslatorContext ctx, QAInput input) {
        BertToken token = this.tokenizer.encode(input.getQuestion().toLowerCase(), input.getParagraph().toLowerCase(), this.seqLength);
        this.tokens = token.getTokens();
        List indices = token.getTokens().stream().map(arg_0 -> ((Vocabulary)this.vocabulary).getIndex(arg_0)).collect(Collectors.toList());
        float[] indexesFloat = Utils.toFloatArray(indices);
        float[] types = Utils.toFloatArray((List)token.getTokenTypes());
        int validLength = token.getValidLength();
        NDManager manager = ctx.getNDManager();
        NDArray data0 = manager.create(indexesFloat);
        data0.setName("data0");
        NDArray data1 = manager.create(types);
        data1.setName("data1");
        NDArray data2 = manager.create(new float[]{validLength});
        data2.setName("data2");
        return new NDList(new NDArray[]{data0, data1, data2});
    }

    public String processOutput(TranslatorContext ctx, NDList list) {
        NDArray array = list.singletonOrThrow();
        NDList output = array.split(2L, 2);
        NDArray startLogits = ((NDArray)output.get(0)).reshape(new Shape(new long[]{1L, -1L}));
        NDArray endLogits = ((NDArray)output.get(1)).reshape(new Shape(new long[]{1L, -1L}));
        int startIdx = (int)startLogits.argMax(1).getLong(new long[0]);
        int endIdx = (int)endLogits.argMax(1).getLong(new long[0]);
        return this.tokens.subList(startIdx, endIdx + 1).toString();
    }

    public static Builder builder() {
        return new Builder();
    }

    private static final class VocabParser {
        @SerializedName(value="idx_to_token")
        List<String> idx2token;

        private VocabParser() {
        }

        /*
         * Exception decompiling
         */
        public static List<String> parseToken(URL url) {
            /*
             * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
             * 
             * org.benf.cfr.reader.util.ConfusedCFRException: Started 3 blocks at once
             *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
             *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
             *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
             *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
             *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
             *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
             *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
             *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseInnerClassesPass1(ClassFile.java:923)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1035)
             *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
             *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
             *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
             *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
             *     at org.benf.cfr.reader.Main.main(Main.java:54)
             */
            throw new IllegalStateException("Decompilation failed");
        }
    }

    public static class Builder
    extends QATranslator.BaseBuilder<Builder> {
        private int seqLength;

        public Builder setSeqLength(int seqLength) {
            this.seqLength = seqLength;
            return this.self();
        }

        protected Builder self() {
            return this;
        }

        protected MxBertQATranslator build() {
            if (this.seqLength == 0) {
                throw new IllegalArgumentException("You must specify a seqLength with value > 0");
            }
            return new MxBertQATranslator(this);
        }
    }
}

