/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.serving.http;

import ai.djl.ModelException;
import ai.djl.modality.Input;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.serving.http.BadRequestException;
import ai.djl.serving.http.HttpRequestHandler;
import ai.djl.serving.http.ResourceNotFoundException;
import ai.djl.serving.http.ServiceUnavailableException;
import ai.djl.serving.util.ConfigManager;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.Job;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.ModelManager;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.handler.codec.http.multipart.DefaultHttpDataFactory;
import io.netty.handler.codec.http.multipart.HttpDataFactory;
import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.regex.Pattern;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InferenceRequestHandler
extends HttpRequestHandler {
    private static final Logger logger = LoggerFactory.getLogger(InferenceRequestHandler.class);
    private static final Pattern PATTERN = Pattern.compile("^/(ping|invocations|predictions)([/?].*)?");

    public boolean acceptInboundMessage(Object msg) throws Exception {
        if (super.acceptInboundMessage(msg)) {
            FullHttpRequest req = (FullHttpRequest)msg;
            return PATTERN.matcher(req.uri()).matches();
        }
        return false;
    }

    @Override
    protected void handleRequest(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelException {
        switch (segments[1]) {
            case "ping": {
                ModelManager.getInstance().workerStatus(ctx);
                break;
            }
            case "invocations": {
                this.handleInvocations(ctx, req, decoder);
                break;
            }
            case "predictions": {
                this.handlePredictions(ctx, req, decoder, segments);
                break;
            }
            default: {
                throw new AssertionError((Object)("Invalid request uri: " + req.uri()));
            }
        }
    }

    private void handlePredictions(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelNotFoundException {
        if (segments.length < 3) {
            throw new ResourceNotFoundException();
        }
        Input input = InferenceRequestHandler.parseRequest(ctx, req, decoder);
        this.predict(ctx, req, input, segments[2]);
    }

    private void handleInvocations(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder) throws ModelNotFoundException {
        byte[] buf;
        Input input = InferenceRequestHandler.parseRequest(ctx, req, decoder);
        String modelName = NettyUtils.getParameter(decoder, "model_name", null);
        if ((modelName == null || modelName.isEmpty()) && (modelName = input.getProperty("model_name", null)) == null && (buf = (byte[])input.getContent().get((Object)"model_name")) != null) {
            modelName = new String(buf, StandardCharsets.UTF_8);
        }
        if (modelName == null) {
            if (ModelManager.getInstance().getStartupModels().size() == 1) {
                modelName = ModelManager.getInstance().getStartupModels().iterator().next();
            }
            if (modelName == null) {
                throw new BadRequestException("Parameter model_name is required.");
            }
        }
        this.predict(ctx, req, input, modelName);
    }

    private void predict(ChannelHandlerContext ctx, FullHttpRequest req, Input input, String modelName) throws ModelNotFoundException {
        ModelManager modelManager = ModelManager.getInstance();
        ModelInfo model = modelManager.getModels().get(modelName);
        if (model == null) {
            String regex = ConfigManager.getInstance().getModelUrlPattern();
            if (regex == null) {
                throw new ModelNotFoundException("Model not found: " + modelName);
            }
            String modelUrl = input.getProperty("model_url", null);
            if (modelUrl == null) {
                byte[] buf = (byte[])input.getContent().get((Object)"model_url");
                if (buf == null) {
                    throw new ModelNotFoundException("Parameter model_url is required.");
                }
                modelUrl = new String(buf, StandardCharsets.UTF_8);
                if (!modelUrl.matches(regex)) {
                    throw new ModelNotFoundException("Permission denied: " + modelUrl);
                }
            }
            logger.info("Loading model {} from: {}", (Object)modelName, (Object)modelUrl);
            ((CompletableFuture)((CompletableFuture)modelManager.registerModel(modelName, modelUrl, 1, 0).thenAccept(m -> modelManager.updateModel(modelName, 1, 1))).thenAccept(p -> {
                try {
                    modelManager.addJob(new Job(ctx, modelName, input));
                }
                catch (ModelNotFoundException e) {
                    logger.warn("Unexpected error", (Throwable)e);
                    NettyUtils.sendError(ctx, e);
                }
            })).exceptionally(t -> {
                logger.warn("Unexpected error", t);
                NettyUtils.sendError(ctx, t);
                return null;
            });
            return;
        }
        if (HttpMethod.OPTIONS.equals((Object)req.method())) {
            NettyUtils.sendJsonResponse(ctx, "{}");
            return;
        }
        Job job = new Job(ctx, modelName, input);
        if (!ModelManager.getInstance().addJob(job)) {
            throw new ServiceUnavailableException("No worker is available to serve request: " + modelName);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static Input parseRequest(ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder) {
        String requestId = NettyUtils.getRequestId(ctx.channel());
        Input input = new Input(requestId);
        if (decoder != null) {
            for (Map.Entry entry : decoder.parameters().entrySet()) {
                String key = (String)entry.getKey();
                for (String value : (List)entry.getValue()) {
                    input.addData(key, value.getBytes(StandardCharsets.UTF_8));
                }
            }
        }
        CharSequence contentType = HttpUtil.getMimeType((HttpMessage)req);
        for (Map.Entry entry : req.headers().entries()) {
            input.addProperty((String)entry.getKey(), (String)entry.getValue());
        }
        if (HttpPostRequestDecoder.isMultipart((HttpRequest)req) || HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED.contentEqualsIgnoreCase(contentType)) {
            DefaultHttpDataFactory defaultHttpDataFactory = new DefaultHttpDataFactory(6553500L);
            HttpPostRequestDecoder form = new HttpPostRequestDecoder((HttpDataFactory)defaultHttpDataFactory, (HttpRequest)req);
            try {
                while (form.hasNext()) {
                    NettyUtils.addFormData(form.next(), input);
                }
            }
            catch (HttpPostRequestDecoder.EndOfDataDecoderException ignore) {
                logger.trace("End of multipart items.");
            }
            finally {
                form.cleanFiles();
                form.destroy();
            }
        } else {
            byte[] byArray = NettyUtils.getBytes(req.content());
            input.addData("body", byArray);
        }
        return input;
    }
}

