/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.llm.clients;

import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModel;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.generation.OnnxEncoderDecoderConfig;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.sentencepiece.SentencePieceEmbedder;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;

@Beta
public class OnnxEncoderDecoder
extends AbstractComponent
implements LanguageModel {
    private static final int TOKEN_EOS = 1;
    private static final String BATCH_DIMENSION = "d0";
    private static final String SEQUENCE_DIMENSION = "d1";
    private final int tokenizerMaxTokens;
    private final String encoderInputIdsName;
    private final String encoderAttentionMaskName;
    private final String encoderOutputName;
    private final String decoderInputIdsName;
    private final String decoderAttentionMaskName;
    private final String decoderEncoderHiddenStateName;
    private final String decoderOutputName;
    private final SentencePieceEmbedder tokenizer;
    private final OnnxEvaluator encoder;
    private final OnnxEvaluator decoder;

    @Inject
    public OnnxEncoderDecoder(OnnxRuntime onnx, OnnxEncoderDecoderConfig config) {
        this.tokenizer = new SentencePieceEmbedder.Builder(config.tokenizerModel().toString()).build();
        this.tokenizerMaxTokens = config.tokenizerMaxTokens();
        this.encoderInputIdsName = config.encoderModelInputIdsName();
        this.encoderAttentionMaskName = config.encoderModelAttentionMaskName();
        this.encoderOutputName = config.encoderModelOutputName();
        OnnxEvaluatorOptions encoderOptions = new OnnxEvaluatorOptions.Builder().setExecutionMode(config.encoderOnnxExecutionMode().toString()).setThreads(config.encoderOnnxInterOpThreads(), config.encoderOnnxIntraOpThreads()).build();
        this.encoder = onnx.evaluatorOf(config.encoderModel().toString(), encoderOptions);
        this.decoderInputIdsName = config.decoderModelInputIdsName();
        this.decoderAttentionMaskName = config.decoderModelAttentionMaskName();
        this.decoderEncoderHiddenStateName = config.decoderModelEncoderHiddenStateName();
        this.decoderOutputName = config.decoderModelOutputName();
        OnnxEvaluatorOptions decoderOptions = new OnnxEvaluatorOptions.Builder().setExecutionMode(config.decoderOnnxExecutionMode().toString()).setThreads(config.decoderOnnxInterOpThreads(), config.decoderOnnxIntraOpThreads()).build();
        this.decoder = onnx.evaluatorOf(config.decoderModel().toString(), decoderOptions);
        this.validateModels();
    }

    public String generate(String prompt, DecoderOptions options) {
        return switch (options.getSearchMethod()) {
            case DecoderOptions.SearchMethod.GREEDY -> this.generateGreedy(prompt, options);
            default -> this.generateNotImplemented(options);
        };
    }

    public void deconstruct() {
        this.encoder.close();
        this.decoder.close();
    }

    private String generateNotImplemented(DecoderOptions options) {
        throw new UnsupportedOperationException("Search method '" + String.valueOf((Object)options.getSearchMethod()) + "' is currently not implemented");
    }

    private String generateGreedy(String prompt, DecoderOptions options) {
        ArrayList<Integer> generatedTokens = new ArrayList<Integer>();
        generatedTokens.add(0);
        List<Integer> inputTokens = this.tokenize(prompt);
        Tensor encoderInput = OnnxEncoderDecoder.createTensorRepresentation(inputTokens, SEQUENCE_DIMENSION);
        Tensor encoderMask = OnnxEncoderDecoder.createAttentionMask(encoderInput).expand(BATCH_DIMENSION);
        Tensor encoderOutput = this.evaluateEncoder(encoderInput.expand(BATCH_DIMENSION), encoderMask);
        while (generatedTokens.size() < options.getMaxLength()) {
            Tensor decoderInput = OnnxEncoderDecoder.createTensorRepresentation(generatedTokens, SEQUENCE_DIMENSION).expand(BATCH_DIMENSION);
            IndexedTensor logits = this.evaluateDecoder(decoderInput, encoderMask, encoderOutput);
            int nextToken = OnnxEncoderDecoder.findMostProbableToken(logits, generatedTokens.size() - 1, BATCH_DIMENSION, SEQUENCE_DIMENSION);
            generatedTokens.add(nextToken);
        }
        return this.detokenize(generatedTokens);
    }

    private Tensor evaluateEncoder(Tensor input, Tensor mask) {
        Map<String, Tensor> encoderInputs = Map.of(this.encoderInputIdsName, input, this.encoderAttentionMaskName, mask);
        return this.encoder.evaluate(encoderInputs, this.encoderOutputName);
    }

    private IndexedTensor evaluateDecoder(Tensor input, Tensor encoderMask, Tensor encoderOutput) {
        Map<String, Tensor> inputs = Map.of(this.decoderInputIdsName, input, this.decoderAttentionMaskName, encoderMask, this.decoderEncoderHiddenStateName, encoderOutput);
        Tensor output = this.decoder.evaluate(inputs, this.decoderOutputName);
        if (!(output instanceof IndexedTensor)) {
            throw new IllegalArgumentException("Output of decoder model is not an 'IndexedTensor'");
        }
        IndexedTensor indexedTensor = (IndexedTensor)output;
        return indexedTensor;
    }

    private static int findMostProbableToken(IndexedTensor logits, int seqIndex, String batchDim, String seqDim) {
        if (logits.type().rank() != 3) {
            throw new IllegalArgumentException("Expected a tensor with rank 3: batch, sequence, and vocabulary size. Got: " + String.valueOf(logits.type()));
        }
        IndexedTensor.SubspaceIterator iterator = logits.cellIterator(new PartialAddress.Builder(2).add(batchDim, 0L).add(seqDim, (long)seqIndex).build(), DimensionSizes.of((TensorType)logits.type()));
        Double maxVal = iterator.next().getValue();
        int maxIndex = 0;
        int i = 1;
        while (iterator.hasNext()) {
            Double val = iterator.next().getValue();
            if (val >= maxVal && i != 1) {
                maxVal = val;
                maxIndex = i;
            }
            ++i;
        }
        return maxIndex;
    }

    private List<Integer> tokenize(String text) {
        List<Integer> tokens = this.tokenizer.embed(text, new Embedder.Context("tokenizer"));
        tokens = tokens.size() >= this.tokenizerMaxTokens ? tokens.subList(0, this.tokenizerMaxTokens - 1) : tokens;
        tokens.add(1);
        return tokens;
    }

    private String detokenize(List<Integer> tokens) {
        return this.tokenizer.decode(tokens, new Embedder.Context("tokenizer"), true);
    }

    private static Tensor createTensorRepresentation(List<Integer> tokens, String dimension) {
        int size = tokens.size();
        TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, (long)size).build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of((TensorType)type);
        for (int i = 0; i < size; ++i) {
            builder.cell((float)tokens.get(i).intValue(), new long[]{i});
        }
        return builder.build();
    }

    private static Tensor createAttentionMask(Tensor d) {
        return d.map(x -> x > 0.0 ? 1.0 : 0.0);
    }

    private void validateModels() {
        Map<String, TensorType> inputs = this.encoder.getInputInfo();
        this.validateName(inputs, this.encoderInputIdsName, "input");
        this.validateName(inputs, this.encoderAttentionMaskName, "input");
        Map<String, TensorType> outputs = this.encoder.getOutputInfo();
        this.validateName(outputs, this.encoderOutputName, "output");
        inputs = this.decoder.getInputInfo();
        this.validateName(inputs, this.decoderInputIdsName, "input");
        this.validateName(inputs, this.decoderAttentionMaskName, "input");
        this.validateName(inputs, this.decoderEncoderHiddenStateName, "input");
        outputs = this.decoder.getOutputInfo();
        this.validateName(outputs, this.decoderOutputName, "output");
    }

    private void validateName(Map<String, TensorType> types, String name, String type) {
        if (!types.containsKey(name)) {
            throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. Model contains: " + String.join((CharSequence)",", types.keySet()));
        }
    }

    public List<Completion> complete(Prompt prompt, InferenceParameters options) {
        String completionText = this.generate(prompt.asString(), new DecoderOptions());
        Completion completion = new Completion(completionText, Completion.FinishReason.stop);
        return List.of(completion);
    }

    public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) {
        throw new UnsupportedOperationException("Asynchronous completion is not supported");
    }

    public static class DecoderOptions {
        private SearchMethod searchMethod = SearchMethod.GREEDY;
        private int maxLength = 20;

        public SearchMethod getSearchMethod() {
            return this.searchMethod;
        }

        public DecoderOptions setSearchMethod(SearchMethod searchMethod) {
            this.searchMethod = searchMethod;
            return this;
        }

        public int getMaxLength() {
            return this.maxLength;
        }

        public DecoderOptions setMaxLength(int maxLength) {
            this.maxLength = maxLength;
            return this;
        }

        public static enum SearchMethod {
            GREEDY,
            CONTRASTIVE,
            BEAM,
            SAMPLE;

        }
    }
}

