/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.serving.http;

import ai.djl.ModelException;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.serving.http.BadRequestException;
import ai.djl.serving.http.DescribeModelResponse;
import ai.djl.serving.http.HttpRequestHandler;
import ai.djl.serving.http.ListModelsResponse;
import ai.djl.serving.http.MethodNotAllowedException;
import ai.djl.serving.http.StatusResponse;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.ModelManager;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.QueryStringDecoder;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.regex.Pattern;

public class ManagementRequestHandler
extends HttpRequestHandler {
    private static final Pattern PATTERN = Pattern.compile("^/models([/?].*)?");

    public boolean acceptInboundMessage(Object msg) throws Exception {
        if (super.acceptInboundMessage(msg)) {
            FullHttpRequest req = (FullHttpRequest)msg;
            return PATTERN.matcher(req.uri()).matches();
        }
        return false;
    }

    @Override
    protected void handleRequest(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelException {
        HttpMethod method = req.method();
        if (segments.length < 3) {
            if (HttpMethod.GET.equals((Object)method)) {
                this.handleListModels(ctx, decoder);
                return;
            }
            if (HttpMethod.POST.equals((Object)method)) {
                this.handleRegisterModel(ctx, decoder);
                return;
            }
            throw new MethodNotAllowedException();
        }
        if (HttpMethod.GET.equals((Object)method)) {
            this.handleDescribeModel(ctx, segments[2]);
        } else if (HttpMethod.PUT.equals((Object)method)) {
            this.handleScaleModel(ctx, decoder, segments[2]);
        } else if (HttpMethod.DELETE.equals((Object)method)) {
            this.handleUnregisterModel(ctx, segments[2]);
        } else {
            throw new MethodNotAllowedException();
        }
    }

    private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder decoder) {
        int limit = NettyUtils.getIntParameter(decoder, "limit", 100);
        int pageToken = NettyUtils.getIntParameter(decoder, "next_page_token", 0);
        if (limit > 100 || limit < 0) {
            limit = 100;
        }
        if (pageToken < 0) {
            pageToken = 0;
        }
        ModelManager modelManager = ModelManager.getInstance();
        Map<String, ModelInfo> models = modelManager.getModels();
        ArrayList<String> keys = new ArrayList<String>(models.keySet());
        Collections.sort(keys);
        ListModelsResponse list = new ListModelsResponse();
        int last = pageToken + limit;
        if (last > keys.size()) {
            last = keys.size();
        } else {
            list.setNextPageToken(String.valueOf(last));
        }
        for (int i = pageToken; i < last; ++i) {
            String modelName = (String)keys.get(i);
            ModelInfo model = models.get(modelName);
            list.addModel(modelName, model.getModelUrl());
        }
        NettyUtils.sendJsonResponse(ctx, list);
    }

    private void handleDescribeModel(ChannelHandlerContext ctx, String modelName) throws ModelNotFoundException {
        ModelManager modelManager = ModelManager.getInstance();
        DescribeModelResponse resp = modelManager.describeModel(modelName);
        NettyUtils.sendJsonResponse(ctx, resp);
    }

    private void handleRegisterModel(ChannelHandlerContext ctx, QueryStringDecoder decoder) {
        String modelUrl = NettyUtils.getParameter(decoder, "url", null);
        if (modelUrl == null) {
            throw new BadRequestException("Parameter url is required.");
        }
        String modelName = NettyUtils.getParameter(decoder, "model_name", null);
        int batchSize = NettyUtils.getIntParameter(decoder, "batch_size", 1);
        int maxBatchDelay = NettyUtils.getIntParameter(decoder, "max_batch_delay", 100);
        int initialWorkers = NettyUtils.getIntParameter(decoder, "initial_workers", 1);
        boolean synchronous = Boolean.parseBoolean(NettyUtils.getParameter(decoder, "synchronous", "true"));
        ModelManager modelManager = ModelManager.getInstance();
        CompletableFuture<ModelInfo> future = modelManager.registerModel(modelName, modelUrl, batchSize, maxBatchDelay);
        CompletionStage f = future.thenAccept(m -> modelManager.updateModel(modelName, initialWorkers, initialWorkers));
        if (synchronous) {
            String msg = "Model \"" + modelName + "\" registered.";
            f = ((CompletableFuture)f).thenAccept(m -> NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg)));
        } else {
            String msg = "Model \"" + modelName + "\" registration scheduled.";
            NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg));
        }
        ((CompletableFuture)f).exceptionally(t -> {
            NettyUtils.sendError(ctx, t.getCause());
            return null;
        });
    }

    private void handleUnregisterModel(ChannelHandlerContext ctx, String modelName) throws ModelNotFoundException {
        ModelManager modelManager = ModelManager.getInstance();
        if (!modelManager.unregisterModel(modelName)) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        String msg = "Model \"" + modelName + "\" unregistered";
        NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg));
    }

    private void handleScaleModel(ChannelHandlerContext ctx, QueryStringDecoder decoder, String modelName) throws ModelNotFoundException {
        int minWorkers = NettyUtils.getIntParameter(decoder, "min_worker", 1);
        int maxWorkers = NettyUtils.getIntParameter(decoder, "max_worker", minWorkers);
        if (maxWorkers < minWorkers) {
            throw new BadRequestException("max_worker cannot be less than min_worker.");
        }
        ModelManager modelManager = ModelManager.getInstance();
        ModelInfo modelInfo = modelManager.getModels().get(modelName);
        if (modelInfo == null) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        modelManager.updateModel(modelName, minWorkers, maxWorkers);
        String msg = "Model \"" + modelName + "\" worker scaled.";
        NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg));
    }
}

