/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.triton;

import ai.vespa.llm.clients.TritonConfig;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.modelintegration.utils.ModelPathOrData;
import ai.vespa.triton.TritonOnnxClient;
import ai.vespa.triton.TritonOnnxEvaluator;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.vespa.defaults.Defaults;
import inference.ModelConfigOuterClass;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.nio.file.attribute.PosixFilePermission;
import java.nio.file.attribute.PosixFilePermissions;
import java.util.Set;

public class TritonOnnxRuntime
extends AbstractComponent
implements OnnxRuntime {
    private final TritonConfig config;
    private final TritonOnnxClient client;

    public TritonOnnxRuntime() {
        this(new TritonConfig.Builder().build());
    }

    @Inject
    public TritonOnnxRuntime(TritonConfig config) {
        this.config = config;
        this.client = new TritonOnnxClient(config);
    }

    @Override
    public OnnxEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) {
        boolean isExplicitControlMode;
        if (!this.client.isHealthy()) {
            throw new IllegalStateException("Triton server is not healthy! (target=%s)".formatted(this.config.target()));
        }
        String modelName = TritonOnnxRuntime.createModelName(modelPath, options);
        boolean bl = isExplicitControlMode = this.config.modelControlMode() == TritonConfig.ModelControlMode.EXPLICIT;
        if (isExplicitControlMode) {
            this.copyModelToRepository(modelName, modelPath, options);
        }
        return new TritonOnnxEvaluator(this.client, modelName, isExplicitControlMode);
    }

    public void deconstruct() {
        this.client.close();
    }

    private void copyModelToRepository(String modelName, String externalModelPath, OnnxEvaluatorOptions options) {
        String modelRepositoryPath = Defaults.getDefaults().underVespaHome(this.config.modelRepositoryPath());
        Path modelBasePath = Paths.get(modelRepositoryPath, modelName);
        Path modelVersionPath = modelBasePath.resolve("1");
        Path modelFilePath = modelVersionPath.resolve("model.onnx");
        Path modelConfigPath = modelBasePath.resolve("config.pbtxt");
        try {
            Files.createDirectories(modelVersionPath, PosixFilePermissions.asFileAttribute(PosixFilePermissions.fromString("rwxrwxr-x")));
            Files.copy(Paths.get(externalModelPath, new String[0]), modelFilePath, StandardCopyOption.REPLACE_EXISTING);
            String modelConfig = options.rawConfig().orElseGet(() -> TritonOnnxRuntime.generateConfigFromEvaluatorOptions(modelName, options).toString());
            Files.writeString(modelConfigPath, (CharSequence)modelConfig, new OpenOption[0]);
            TritonOnnxRuntime.addReadPermissions(modelFilePath);
            TritonOnnxRuntime.addReadPermissions(modelConfigPath);
        }
        catch (IOException e) {
            throw new UncheckedIOException("Failed to copy model file to repository", e);
        }
    }

    private static void addReadPermissions(Path path) throws IOException {
        Set<PosixFilePermission> modelPerms = Files.getPosixFilePermissions(path, new LinkOption[0]);
        modelPerms.add(PosixFilePermission.GROUP_READ);
        modelPerms.add(PosixFilePermission.OTHERS_READ);
        Files.setPosixFilePermissions(path, modelPerms);
    }

    static String createModelName(String modelPath, OnnxEvaluatorOptions options) {
        String fileName = Paths.get(modelPath, new String[0]).getFileName().toString();
        String baseName = fileName.substring(0, fileName.lastIndexOf(46));
        long modelHash = ModelPathOrData.of(modelPath).calculateHash();
        long optionsHash = options.calculateHash();
        String combinedHash = Long.toHexString(31L * modelHash + optionsHash);
        return baseName + "_" + combinedHash;
    }

    private static ModelConfigOuterClass.ModelConfig generateConfigFromEvaluatorOptions(String modelName, OnnxEvaluatorOptions options) {
        ModelConfigOuterClass.ModelInstanceGroup.Kind kind = options.gpuDeviceRequired() ? ModelConfigOuterClass.ModelInstanceGroup.Kind.KIND_GPU : (options.gpuDeviceNumber() >= 0 ? ModelConfigOuterClass.ModelInstanceGroup.Kind.KIND_AUTO : ModelConfigOuterClass.ModelInstanceGroup.Kind.KIND_CPU);
        return ModelConfigOuterClass.ModelConfig.newBuilder().setName(modelName).addInstanceGroup(ModelConfigOuterClass.ModelInstanceGroup.newBuilder().setCount(1).setKind(kind).build()).setPlatform("onnxruntime_onnx").setMaxBatchSize(1).putParameters("enable_mem_area", ModelConfigOuterClass.ModelParameter.newBuilder().setStringValue("0").build()).putParameters("enable_mem_pattern", ModelConfigOuterClass.ModelParameter.newBuilder().setStringValue("0").build()).putParameters("intra_op_thread_count", ModelConfigOuterClass.ModelParameter.newBuilder().setStringValue(Integer.toString(options.intraOpThreads())).build()).putParameters("inter_op_thread_count", ModelConfigOuterClass.ModelParameter.newBuilder().setStringValue(Integer.toString(options.interOpThreads())).build()).build();
    }
}

