/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.serving.wlm;

import ai.djl.ModelException;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.serving.http.BadRequestException;
import ai.djl.serving.http.DescribeModelResponse;
import ai.djl.serving.http.StatusResponse;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.Job;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.WorkLoadManager;
import ai.djl.serving.wlm.WorkerThread;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.HttpResponseStatus;
import java.io.IOException;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class ModelManager {
    private static final Logger logger = LoggerFactory.getLogger(ModelManager.class);
    private static ModelManager modelManager;
    private ConfigManager configManager;
    private WorkLoadManager wlm;
    private ConcurrentHashMap<String, ModelInfo> models;
    private Set<String> startupModels;

    private ModelManager(ConfigManager configManager) {
        this.configManager = configManager;
        this.wlm = new WorkLoadManager(configManager);
        this.models = new ConcurrentHashMap();
        this.startupModels = new HashSet<String>();
    }

    public static void init(ConfigManager configManager) {
        modelManager = new ModelManager(configManager);
    }

    public static ModelManager getInstance() {
        return modelManager;
    }

    public CompletableFuture<ModelInfo> registerModel(String modelName, String modelUrl, int batchSize, int maxBatchDelay) {
        return CompletableFuture.supplyAsync(() -> {
            try {
                String actualModelName;
                Criteria criteria = Criteria.builder().setTypes(Input.class, Output.class).optModelUrls(modelUrl).build();
                ZooModel model = ModelZoo.loadModel((Criteria)criteria);
                if (modelName == null || modelName.isEmpty()) {
                    actualModelName = model.getName();
                    actualModelName = actualModelName.replaceAll("(\\W|^_)", "_");
                } else {
                    actualModelName = modelName;
                }
                ModelInfo modelInfo = new ModelInfo(actualModelName, modelUrl, (ZooModel<Input, Output>)model, this.configManager.getJobQueueSize());
                modelInfo.setBatchSize(batchSize);
                modelInfo.setMaxBatchDelay(maxBatchDelay);
                ModelInfo existingModel = this.models.putIfAbsent(actualModelName, modelInfo);
                if (existingModel != null) {
                    model.close();
                    throw new BadRequestException("Model " + actualModelName + " is already registered.");
                }
                logger.info("Model {} loaded.", (Object)modelInfo.getModelName());
                return modelInfo;
            }
            catch (ModelException | IOException e) {
                throw new CompletionException(e);
            }
        });
    }

    public boolean unregisterModel(String modelName) {
        ModelInfo model = this.models.remove(modelName);
        if (model == null) {
            logger.warn("Model not found: " + modelName);
            return false;
        }
        model.setMinWorkers(0);
        model.setMaxWorkers(0);
        this.wlm.modelChanged(model);
        this.startupModels.remove(modelName);
        model.close();
        logger.info("Model {} unregistered.", (Object)modelName);
        return true;
    }

    public void updateModel(String modelName, int minWorkers, int maxWorkers) {
        ModelInfo model = this.models.get(modelName);
        if (model == null) {
            throw new AssertionError((Object)("Model not found: " + modelName));
        }
        model.setMinWorkers(minWorkers);
        model.setMaxWorkers(maxWorkers);
        logger.debug("updateModel: {}, count: {}", (Object)modelName, (Object)minWorkers);
        this.wlm.modelChanged(model);
    }

    public Map<String, ModelInfo> getModels() {
        return this.models;
    }

    public Set<String> getStartupModels() {
        return this.startupModels;
    }

    public boolean addJob(Job job) throws ModelNotFoundException {
        String modelName = job.getModelName();
        ModelInfo model = this.models.get(modelName);
        if (model == null) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        if (this.wlm.hasWorker(modelName)) {
            return model.addJob(job);
        }
        return false;
    }

    public DescribeModelResponse describeModel(String modelName) throws ModelNotFoundException {
        ModelInfo model = this.models.get(modelName);
        if (model == null) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        DescribeModelResponse resp = new DescribeModelResponse();
        resp.setModelName(modelName);
        resp.setModelUrl(model.getModelUrl());
        resp.setBatchSize(model.getBatchSize());
        resp.setMaxBatchDelay(model.getMaxBatchDelay());
        resp.setMaxWorkers(model.getMaxWorkers());
        resp.setMinWorkers(model.getMinWorkers());
        resp.setLoadedAtStartup(this.startupModels.contains(modelName));
        int activeWorker = this.wlm.getNumRunningWorkers(modelName);
        int targetWorker = model.getMinWorkers();
        resp.setStatus(activeWorker >= targetWorker ? "Healthy" : "Unhealthy");
        List<WorkerThread> workers = this.wlm.getWorkers(modelName);
        for (WorkerThread worker : workers) {
            int workerId = worker.getWorkerId();
            long startTime = worker.getStartTime();
            boolean isRunning = worker.isRunning();
            int gpuId = worker.getGpuId();
            resp.addWorker(workerId, startTime, isRunning, gpuId);
        }
        return resp;
    }

    public void workerStatus(ChannelHandlerContext ctx) {
        Runnable r = () -> {
            String response = "Healthy";
            int numWorking = 0;
            int numScaled = 0;
            for (Map.Entry<String, ModelInfo> m : this.models.entrySet()) {
                numScaled += m.getValue().getMinWorkers();
                numWorking += this.wlm.getNumRunningWorkers(m.getValue().getModelName());
            }
            if (numWorking > 0 && numWorking < numScaled) {
                response = "Partial Healthy";
            } else if (numWorking == 0 && numScaled > 0) {
                response = "Unhealthy";
            }
            NettyUtils.sendJsonResponse(ctx, new StatusResponse(response), HttpResponseStatus.OK);
        };
        this.wlm.scheduleAsync(r);
    }
}

