/*
 * Decompiled with CFR 0.152.
 */
package com.linecorp.armeria.server.thrift;

import com.linecorp.armeria.common.AggregatedHttpRequest;
import com.linecorp.armeria.common.AggregationOptions;
import com.linecorp.armeria.common.ExchangeType;
import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.Request;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.RpcRequest;
import com.linecorp.armeria.common.RpcResponse;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.logging.RequestLogProperty;
import com.linecorp.armeria.common.thrift.ThriftCall;
import com.linecorp.armeria.common.thrift.ThriftReply;
import com.linecorp.armeria.common.thrift.ThriftSerializationFormats;
import com.linecorp.armeria.common.util.CompletionActions;
import com.linecorp.armeria.common.util.Exceptions;
import com.linecorp.armeria.common.util.SafeCloseable;
import com.linecorp.armeria.internal.common.thrift.TByteBufTransport;
import com.linecorp.armeria.internal.common.thrift.ThriftFieldAccess;
import com.linecorp.armeria.internal.common.thrift.ThriftFunction;
import com.linecorp.armeria.internal.common.thrift.ThriftProtocolUtil;
import com.linecorp.armeria.internal.shaded.guava.base.Preconditions;
import com.linecorp.armeria.internal.shaded.guava.collect.ImmutableMap;
import com.linecorp.armeria.internal.shaded.guava.collect.ImmutableSet;
import com.linecorp.armeria.internal.shaded.guava.primitives.Ints;
import com.linecorp.armeria.server.DecoratingService;
import com.linecorp.armeria.server.HttpResponseException;
import com.linecorp.armeria.server.HttpService;
import com.linecorp.armeria.server.HttpStatusException;
import com.linecorp.armeria.server.RoutingContext;
import com.linecorp.armeria.server.RpcService;
import com.linecorp.armeria.server.Service;
import com.linecorp.armeria.server.ServiceConfig;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.thrift.THttpServiceBuilder;
import com.linecorp.armeria.server.thrift.ThriftCallService;
import com.linecorp.armeria.server.thrift.ThriftServiceEntry;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.util.concurrent.EventExecutor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.thrift.TApplicationException;
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.meta_data.FieldMetaData;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolException;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.transport.TTransport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class THttpService
extends DecoratingService<RpcRequest, RpcResponse, HttpRequest, HttpResponse>
implements HttpService {
    private static final Logger logger = LoggerFactory.getLogger(THttpService.class);
    private static final String PROTOCOL_NOT_SUPPORTED = "Specified content-type not supported";
    private static final String ACCEPT_THRIFT_PROTOCOL_MUST_MATCH_CONTENT_TYPE = "Thrift protocol specified in Accept header must match the one specified in the content-type header";
    private final ThriftCallService thriftService;
    private final SerializationFormat defaultSerializationFormat;
    private final Set<SerializationFormat> supportedSerializationFormats;
    private final BiFunction<? super ServiceRequestContext, ? super Throwable, ? extends RpcResponse> exceptionHandler;
    private int maxRequestStringLength;
    private int maxRequestContainerLength;
    private final Map<SerializationFormat, TProtocolFactory> responseProtocolFactories;
    private Map<SerializationFormat, TProtocolFactory> requestProtocolFactories;

    public static THttpServiceBuilder builder() {
        return new THttpServiceBuilder();
    }

    public static THttpService of(Object implementation) {
        return THttpService.of(implementation, ThriftSerializationFormats.BINARY);
    }

    public static THttpService of(Object implementation, SerializationFormat defaultSerializationFormat) {
        return THttpService.builder().addService(implementation).defaultSerializationFormat(defaultSerializationFormat).build();
    }

    public static THttpService ofFormats(Object implementation, SerializationFormat defaultSerializationFormat, SerializationFormat ... otherSupportedSerializationFormats) {
        Objects.requireNonNull(otherSupportedSerializationFormats, "otherSupportedSerializationFormats");
        return THttpService.ofFormats(implementation, defaultSerializationFormat, Arrays.asList(otherSupportedSerializationFormats));
    }

    public static THttpService ofFormats(Object implementation, SerializationFormat defaultSerializationFormat, Iterable<SerializationFormat> otherSupportedSerializationFormats) {
        return THttpService.builder().addService(implementation).defaultSerializationFormat(defaultSerializationFormat).otherSerializationFormats(otherSupportedSerializationFormats).build();
    }

    public static Function<? super RpcService, THttpService> newDecorator() {
        return THttpService.newDecorator(ThriftSerializationFormats.BINARY);
    }

    public static Function<? super RpcService, THttpService> newDecorator(SerializationFormat defaultSerializationFormat) {
        return THttpService.builder().defaultSerializationFormat(defaultSerializationFormat).newDecorator();
    }

    public static Function<? super RpcService, THttpService> newDecorator(SerializationFormat defaultSerializationFormat, SerializationFormat ... otherSupportedSerializationFormats) {
        Objects.requireNonNull(otherSupportedSerializationFormats, "otherSupportedSerializationFormats");
        return THttpService.newDecorator(defaultSerializationFormat, (Iterable<SerializationFormat>)ImmutableSet.copyOf((Object[])otherSupportedSerializationFormats));
    }

    public static Function<? super RpcService, THttpService> newDecorator(SerializationFormat defaultSerializationFormat, Iterable<SerializationFormat> otherSupportedSerializationFormats) {
        return THttpService.builder().defaultSerializationFormat(defaultSerializationFormat).otherSerializationFormats(otherSupportedSerializationFormats).newDecorator();
    }

    THttpService(RpcService delegate, SerializationFormat defaultSerializationFormat, Set<SerializationFormat> supportedSerializationFormats, int maxRequestStringLength, int maxRequestContainerLength, BiFunction<? super ServiceRequestContext, ? super Throwable, ? extends RpcResponse> exceptionHandler) {
        super((Service)delegate);
        this.thriftService = THttpService.findThriftService(delegate);
        this.defaultSerializationFormat = defaultSerializationFormat;
        this.supportedSerializationFormats = ImmutableSet.copyOf(supportedSerializationFormats);
        this.maxRequestStringLength = maxRequestStringLength;
        this.maxRequestContainerLength = maxRequestContainerLength;
        this.exceptionHandler = exceptionHandler;
        this.responseProtocolFactories = (Map)supportedSerializationFormats.stream().collect(ImmutableMap.toImmutableMap(Function.identity(), format -> ThriftSerializationFormats.protocolFactory(format, 0, 0)));
        this.requestProtocolFactories = this.responseProtocolFactories;
    }

    private static ThriftCallService findThriftService(Service<?, ?> delegate) {
        ThriftCallService thriftService = (ThriftCallService)delegate.as(ThriftCallService.class);
        Preconditions.checkState((thriftService != null ? 1 : 0) != 0, (String)"service being decorated is not a ThriftCallService: %s", delegate);
        return thriftService;
    }

    public Map<String, ThriftServiceEntry> entries() {
        return this.thriftService.entries();
    }

    public Set<SerializationFormat> supportedSerializationFormats() {
        return this.supportedSerializationFormats;
    }

    public SerializationFormat defaultSerializationFormat() {
        return this.defaultSerializationFormat;
    }

    public void serviceAdded(ServiceConfig cfg) throws Exception {
        if (this.maxRequestStringLength == -1) {
            this.maxRequestStringLength = Ints.saturatedCast((long)cfg.maxRequestLength());
        }
        if (this.maxRequestContainerLength == -1) {
            this.maxRequestContainerLength = Ints.saturatedCast((long)cfg.maxRequestLength());
        }
        this.requestProtocolFactories = (Map)this.supportedSerializationFormats.stream().collect(ImmutableMap.toImmutableMap(Function.identity(), format -> ThriftSerializationFormats.protocolFactory(format, this.maxRequestStringLength, this.maxRequestContainerLength)));
        super.serviceAdded(cfg);
    }

    public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exception {
        if (req.method() != HttpMethod.POST) {
            return HttpResponse.of((HttpStatus)HttpStatus.METHOD_NOT_ALLOWED);
        }
        SerializationFormat serializationFormat = this.determineSerializationFormat(req);
        if (serializationFormat == null) {
            return HttpResponse.of((HttpStatus)HttpStatus.UNSUPPORTED_MEDIA_TYPE, (MediaType)MediaType.PLAIN_TEXT_UTF_8, (String)PROTOCOL_NOT_SUPPORTED);
        }
        if (!THttpService.validateAcceptHeaders(req, serializationFormat)) {
            return HttpResponse.of((HttpStatus)HttpStatus.NOT_ACCEPTABLE, (MediaType)MediaType.PLAIN_TEXT_UTF_8, (String)ACCEPT_THRIFT_PROTOCOL_MUST_MATCH_CONTENT_TYPE);
        }
        CompletableFuture responseFuture = new CompletableFuture();
        HttpResponse res = HttpResponse.from(responseFuture);
        ctx.logBuilder().serializationFormat(serializationFormat);
        ctx.logBuilder().defer(RequestLogProperty.REQUEST_CONTENT);
        ((CompletableFuture)req.aggregate(AggregationOptions.usePooledObjects((ByteBufAllocator)ctx.alloc(), (EventExecutor)ctx.eventLoop())).handle((aReq, cause) -> {
            if (cause != null) {
                HttpResponse errorRes = ctx.config().verboseResponses() ? HttpResponse.of((HttpStatus)HttpStatus.INTERNAL_SERVER_ERROR, (MediaType)MediaType.PLAIN_TEXT_UTF_8, (String)Exceptions.traceText((Throwable)cause)) : HttpResponse.of((HttpStatus)HttpStatus.INTERNAL_SERVER_ERROR);
                responseFuture.complete(errorRes);
                return null;
            }
            this.decodeAndInvoke(ctx, (AggregatedHttpRequest)aReq, serializationFormat, responseFuture);
            return null;
        })).exceptionally(CompletionActions::log);
        return res;
    }

    public ExchangeType exchangeType(RoutingContext routingContext) {
        return ExchangeType.UNARY;
    }

    @Nullable
    private SerializationFormat determineSerializationFormat(HttpRequest req) {
        RequestHeaders headers = req.headers();
        MediaType contentType = headers.contentType();
        if (contentType != null) {
            SerializationFormat serializationFormat = this.findSerializationFormat(contentType);
            if (serializationFormat == null) {
                if (!("text".equals(contentType.type()) && "plain".equals(contentType.subtype()) || "application".equals(contentType.type()) && "octet-stream".equals(contentType.subtype()))) {
                    return null;
                }
            } else {
                return serializationFormat;
            }
        }
        return this.defaultSerializationFormat();
    }

    private static boolean validateAcceptHeaders(HttpRequest req, SerializationFormat serializationFormat) {
        List acceptTypes = req.headers().accept();
        return acceptTypes.isEmpty() || serializationFormat.mediaTypes().match((Iterable)acceptTypes) != null;
    }

    @Nullable
    private SerializationFormat findSerializationFormat(MediaType contentType) {
        for (SerializationFormat format : this.supportedSerializationFormats) {
            if (!format.isAccepted(contentType)) continue;
            return format;
        }
        return null;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private void decodeAndInvoke(ServiceRequestContext ctx, AggregatedHttpRequest req, SerializationFormat serializationFormat, CompletableFuture<HttpResponse> httpRes) {
        RpcRequest decodedReq;
        ThriftFunction f;
        int seqId;
        try {
            HttpData content = req.content();
            try {
                String methodName;
                String serviceName;
                TMessage header;
                ByteBuf buf = content.byteBuf();
                TByteBufTransport inTransport = new TByteBufTransport(buf);
                TProtocol inProto = this.requestProtocolFactories.get(serializationFormat).getProtocol((TTransport)inTransport);
                try {
                    ThriftProtocolUtil.maybeCheckMessageLength(serializationFormat, buf, this.maxRequestStringLength);
                    header = inProto.readMessageBegin();
                }
                catch (Exception e) {
                    String message;
                    HttpStatus httpStatus;
                    logger.debug("{} Failed to decode a {} header:", new Object[]{ctx, serializationFormat, e});
                    if (e instanceof TProtocolException && ((TProtocolException)e).getType() == 3) {
                        httpStatus = HttpStatus.REQUEST_ENTITY_TOO_LARGE;
                        message = e.getMessage();
                    } else {
                        httpStatus = HttpStatus.BAD_REQUEST;
                        message = "Failed to decode a " + serializationFormat + " header";
                    }
                    if (ctx.config().verboseResponses()) {
                        message = message + '\n' + Exceptions.traceText((Throwable)e);
                    }
                    httpRes.complete(HttpResponse.of((HttpStatus)httpStatus, (MediaType)MediaType.PLAIN_TEXT_UTF_8, (String)message));
                    if (content != null) {
                        content.close();
                    }
                    ctx.logBuilder().requestContent(null, null);
                    return;
                }
                seqId = header.seqid;
                byte typeValue = header.type;
                int colonIdx = header.name.indexOf(58);
                if (colonIdx < 0) {
                    serviceName = "";
                    methodName = header.name;
                } else {
                    serviceName = header.name.substring(0, colonIdx);
                    methodName = header.name.substring(colonIdx + 1);
                }
                if (typeValue != 1 && typeValue != 4) {
                    TApplicationException cause = new TApplicationException(2, "unexpected TMessageType: " + THttpService.typeString(typeValue));
                    this.handlePreDecodeException(ctx, httpRes, (Throwable)cause, serializationFormat, seqId, methodName);
                    return;
                }
                ThriftServiceEntry entry = this.entries().get(serviceName);
                ThriftFunction thriftFunction = f = entry != null ? entry.metadata.function(methodName) : null;
                if (f == null) {
                    TApplicationException cause = new TApplicationException(1, "unknown method: " + header.name);
                    this.handlePreDecodeException(ctx, httpRes, (Throwable)cause, serializationFormat, seqId, methodName);
                    return;
                }
                try {
                    TBase<?, ?> args = f.newArgs();
                    args.read(inProto);
                    inProto.readMessageEnd();
                    decodedReq = THttpService.toRpcRequest(f.serviceType(), header.name, args);
                    ctx.logBuilder().requestContent((Object)decodedReq, (Object)new ThriftCall(header, args));
                }
                catch (Exception e) {
                    logger.debug("{} Failed to decode Thrift arguments:", (Object)ctx, (Object)e);
                    TApplicationException cause = new TApplicationException(7, "failed to decode arguments: " + e);
                    this.handlePreDecodeException(ctx, httpRes, (Throwable)cause, serializationFormat, seqId, methodName);
                    if (content != null) {
                        content.close();
                    }
                    ctx.logBuilder().requestContent(null, null);
                    return;
                }
            }
            finally {
                if (content != null) {
                    try {
                        content.close();
                    }
                    catch (Throwable throwable) {
                        Throwable throwable2;
                        throwable2.addSuppressed(throwable);
                    }
                }
            }
        }
        finally {
            ctx.logBuilder().requestContent(null, null);
        }
        this.invoke(ctx, serializationFormat, seqId, f, decodedReq, httpRes);
    }

    private static String typeString(byte typeValue) {
        switch (typeValue) {
            case 1: {
                return "CALL";
            }
            case 2: {
                return "REPLY";
            }
            case 3: {
                return "EXCEPTION";
            }
            case 4: {
                return "ONEWAY";
            }
        }
        return "UNKNOWN(" + (typeValue & 0xFF) + ')';
    }

    private void invoke(ServiceRequestContext ctx, SerializationFormat serializationFormat, int seqId, ThriftFunction func, RpcRequest call, CompletableFuture<HttpResponse> res) {
        RpcResponse reply;
        try (SafeCloseable ignored = ctx.push();){
            reply = (RpcResponse)((Service)this.unwrap()).serve(ctx, (Request)call);
        }
        catch (Throwable cause2) {
            this.handleException(ctx, res, serializationFormat, seqId, func, cause2);
            return;
        }
        reply.handle((result, cause) -> {
            if (func.isOneWay()) {
                THttpService.handleOneWaySuccess(ctx, reply, res, serializationFormat);
                return null;
            }
            if (cause != null) {
                this.handleException(ctx, res, serializationFormat, seqId, func, (Throwable)cause);
                return null;
            }
            try {
                this.handleSuccess(ctx, reply, res, serializationFormat, seqId, func, result);
            }
            catch (Throwable t) {
                this.handleException(ctx, res, serializationFormat, seqId, func, t);
                return null;
            }
            return null;
        }).exceptionally(CompletionActions::log);
    }

    private static RpcRequest toRpcRequest(Class<?> serviceType, String method, TBase<?, ?> thriftArgs) {
        Objects.requireNonNull(thriftArgs, "thriftArgs");
        Set fields = FieldMetaData.getStructMetaDataMap(thriftArgs.getClass()).keySet();
        int numFields = fields.size();
        switch (numFields) {
            case 0: {
                return RpcRequest.of(serviceType, (String)method);
            }
            case 1: {
                return RpcRequest.of(serviceType, (String)method, (Object)ThriftFieldAccess.get(thriftArgs, (TFieldIdEnum)fields.iterator().next()));
            }
        }
        ArrayList<Object> list = new ArrayList<Object>(numFields);
        for (TFieldIdEnum field : fields) {
            list.add(ThriftFieldAccess.get(thriftArgs, field));
        }
        return RpcRequest.of(serviceType, (String)method, list);
    }

    private void handleSuccess(ServiceRequestContext ctx, RpcResponse rpcRes, CompletableFuture<HttpResponse> httpRes, SerializationFormat serializationFormat, int seqId, ThriftFunction func, Object returnValue) {
        TBase<?, ?> wrappedResult = func.newResult();
        func.setSuccess(wrappedResult, returnValue);
        THttpService.respond(serializationFormat, this.encodeSuccess(ctx, rpcRes, serializationFormat, func.name(), seqId, wrappedResult), httpRes);
    }

    private static void handleOneWaySuccess(ServiceRequestContext ctx, RpcResponse rpcRes, CompletableFuture<HttpResponse> httpRes, SerializationFormat serializationFormat) {
        ctx.logBuilder().responseContent((Object)rpcRes, null);
        THttpService.respond(serializationFormat, HttpData.empty(), httpRes);
    }

    private void handleException(ServiceRequestContext ctx, CompletableFuture<HttpResponse> res, SerializationFormat serializationFormat, int seqId, ThriftFunction func, Throwable cause) {
        RpcResponse response = this.handleException(ctx, Exceptions.peel((Throwable)cause));
        response.handle((result, convertedCause) -> {
            if (convertedCause != null) {
                this.handleException(ctx, response, res, serializationFormat, seqId, func, (Throwable)convertedCause);
            } else {
                this.handleSuccess(ctx, response, res, serializationFormat, seqId, func, result);
            }
            return null;
        });
    }

    private RpcResponse handleException(ServiceRequestContext ctx, Throwable cause) {
        RpcResponse res = this.exceptionHandler.apply((ServiceRequestContext)ctx, cause);
        if (res == null) {
            logger.warn("exceptionHandler.apply() returned null.");
            return RpcResponse.ofFailure((Throwable)cause);
        }
        return res;
    }

    private void handleException(ServiceRequestContext ctx, RpcResponse rpcRes, CompletableFuture<HttpResponse> httpRes, SerializationFormat serializationFormat, int seqId, ThriftFunction func, Throwable cause) {
        if (cause instanceof HttpStatusException || cause instanceof HttpResponseException) {
            httpRes.complete(HttpResponse.ofFailure((Throwable)cause));
            return;
        }
        TBase<?, ?> result = func.newResult();
        HttpData content = func.setException(result, cause) ? this.encodeSuccess(ctx, rpcRes, serializationFormat, func.name(), seqId, result) : this.encodeException(ctx, rpcRes, serializationFormat, seqId, func.name(), cause);
        THttpService.respond(serializationFormat, content, httpRes);
    }

    private void handlePreDecodeException(ServiceRequestContext ctx, CompletableFuture<HttpResponse> httpRes, Throwable cause, SerializationFormat serializationFormat, int seqId, String methodName) {
        HttpData content = this.encodeException(ctx, RpcResponse.ofFailure((Throwable)cause), serializationFormat, seqId, methodName, cause);
        THttpService.respond(serializationFormat, content, httpRes);
    }

    private static void respond(SerializationFormat serializationFormat, HttpData content, CompletableFuture<HttpResponse> res) {
        res.complete(HttpResponse.of((HttpStatus)HttpStatus.OK, (MediaType)serializationFormat.mediaType(), (HttpData)content));
    }

    private HttpData encodeSuccess(ServiceRequestContext ctx, RpcResponse reply, SerializationFormat serializationFormat, String methodName, int seqId, TBase<?, ?> result) {
        ByteBuf buf = ctx.alloc().buffer(128);
        boolean success = false;
        try {
            TByteBufTransport transport = new TByteBufTransport(buf);
            TProtocol outProto = this.responseProtocolFactories.get(serializationFormat).getProtocol((TTransport)transport);
            TMessage header = new TMessage(methodName, 2, seqId);
            outProto.writeMessageBegin(header);
            result.write(outProto);
            outProto.writeMessageEnd();
            ctx.logBuilder().responseContent((Object)reply, (Object)new ThriftReply(header, result));
            HttpData encoded = HttpData.wrap((ByteBuf)buf);
            success = true;
            HttpData httpData = encoded;
            return httpData;
        }
        catch (TException e) {
            throw new Error(e);
        }
        finally {
            if (!success) {
                buf.release();
            }
        }
    }

    private HttpData encodeException(ServiceRequestContext ctx, RpcResponse reply, SerializationFormat serializationFormat, int seqId, String methodName, Throwable cause) {
        TApplicationException appException;
        if (cause instanceof TApplicationException) {
            appException = (TApplicationException)cause;
        } else {
            appException = ctx.config().verboseResponses() ? new TApplicationException(6, "\n---- BEGIN server-side trace ----\n" + Exceptions.traceText((Throwable)cause) + "---- END server-side trace ----") : new TApplicationException(6);
            appException.initCause(cause);
        }
        ByteBuf buf = ctx.alloc().buffer(128);
        boolean success = false;
        try {
            TByteBufTransport transport = new TByteBufTransport(buf);
            TProtocol outProto = this.responseProtocolFactories.get(serializationFormat).getProtocol((TTransport)transport);
            TMessage header = new TMessage(methodName, 3, seqId);
            outProto.writeMessageBegin(header);
            appException.write(outProto);
            outProto.writeMessageEnd();
            ctx.logBuilder().responseContent((Object)reply, (Object)new ThriftReply(header, appException));
            HttpData encoded = HttpData.wrap((ByteBuf)buf);
            success = true;
            HttpData httpData = encoded;
            return httpData;
        }
        catch (TException e) {
            throw new Error(e);
        }
        finally {
            if (!success) {
                buf.release();
            }
        }
    }
}

