/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.vespa.model.container.component;

import com.yahoo.config.ModelReference;
import com.yahoo.config.model.builder.xml.XmlHelper;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.text.XML;
import com.yahoo.vespa.model.container.component.TypedComponent;
import com.yahoo.vespa.model.container.xml.ModelIdResolver;
import java.util.Optional;
import org.w3c.dom.Element;

public class HuggingFaceEmbedder
extends TypedComponent
implements HuggingFaceEmbedderConfig.Producer {
    private final ModelReference model;
    private final ModelReference vocab;
    private final Integer maxTokens;
    private final String transformerInputIds;
    private final String transformerAttentionMask;
    private final String transformerTokenTypeIds;
    private final String transformerOutput;
    private final Boolean normalize;
    private final String onnxExecutionMode;
    private final Integer onnxInteropThreads;
    private final Integer onnxIntraopThreads;
    private final Integer onnxGpuDevice;
    private final String poolingStrategy;

    public HuggingFaceEmbedder(Element xml, DeployState state) {
        super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", "model-integration", xml);
        Element transformerModelElem = XmlHelper.getOptionalChild(xml, "transformer-model").orElseThrow();
        this.model = ModelIdResolver.resolveToModelReference(transformerModelElem, state);
        this.vocab = XmlHelper.getOptionalChild(xml, "tokenizer-model").map(elem -> ModelIdResolver.resolveToModelReference(elem, state)).orElseGet(() -> HuggingFaceEmbedder.resolveDefaultVocab(transformerModelElem, state));
        this.maxTokens = XML.getChildValue((Element)xml, (String)"max-tokens").map(Integer::parseInt).orElse(null);
        this.transformerInputIds = XML.getChildValue((Element)xml, (String)"transformer-input-ids").orElse(null);
        this.transformerAttentionMask = XML.getChildValue((Element)xml, (String)"transformer-attention-mask").orElse(null);
        this.transformerTokenTypeIds = XML.getChildValue((Element)xml, (String)"transformer-token-type-ids").orElse(null);
        this.transformerOutput = XML.getChildValue((Element)xml, (String)"transformer-output").orElse(null);
        this.normalize = XML.getChildValue((Element)xml, (String)"normalize").map(Boolean::parseBoolean).orElse(null);
        this.onnxExecutionMode = XML.getChildValue((Element)xml, (String)"onnx-execution-mode").orElse(null);
        this.onnxInteropThreads = XML.getChildValue((Element)xml, (String)"onnx-interop-threads").map(Integer::parseInt).orElse(null);
        this.onnxIntraopThreads = XML.getChildValue((Element)xml, (String)"onnx-intraop-threads").map(Integer::parseInt).orElse(null);
        this.onnxGpuDevice = XML.getChildValue((Element)xml, (String)"onnx-gpu-device").map(Integer::parseInt).orElse(null);
        this.poolingStrategy = XML.getChildValue((Element)xml, (String)"pooling-strategy").orElse(null);
    }

    private static ModelReference resolveDefaultVocab(Element model, DeployState state) {
        if (state.isHosted() && model.hasAttribute("model-id")) {
            String implicitVocabId = model.getAttribute("model-id") + "-vocab";
            return ModelIdResolver.resolveToModelReference("tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state);
        }
        throw new IllegalArgumentException("'tokenizer-model' must be specified");
    }

    public void getConfig(HuggingFaceEmbedderConfig.Builder b) {
        b.transformerModel(this.model).tokenizerPath(this.vocab);
        if (this.maxTokens != null) {
            b.transformerMaxTokens(this.maxTokens.intValue());
        }
        if (this.transformerInputIds != null) {
            b.transformerInputIds(this.transformerInputIds);
        }
        if (this.transformerAttentionMask != null) {
            b.transformerAttentionMask(this.transformerAttentionMask);
        }
        if (this.transformerTokenTypeIds != null) {
            b.transformerTokenTypeIds(this.transformerTokenTypeIds);
        }
        if (this.transformerOutput != null) {
            b.transformerOutput(this.transformerOutput);
        }
        if (this.normalize != null) {
            b.normalize(this.normalize.booleanValue());
        }
        if (this.onnxExecutionMode != null) {
            b.transformerExecutionMode(HuggingFaceEmbedderConfig.TransformerExecutionMode.Enum.valueOf((String)this.onnxExecutionMode));
        }
        if (this.onnxInteropThreads != null) {
            b.transformerInterOpThreads(this.onnxInteropThreads.intValue());
        }
        if (this.onnxIntraopThreads != null) {
            b.transformerIntraOpThreads(this.onnxIntraopThreads.intValue());
        }
        if (this.onnxGpuDevice != null) {
            b.transformerGpuDevice(this.onnxGpuDevice.intValue());
        }
        if (this.poolingStrategy != null) {
            b.poolingStrategy(HuggingFaceEmbedderConfig.PoolingStrategy.Enum.valueOf((String)this.poolingStrategy));
        }
    }
}

