/*
 * Decompiled with CFR 0.152.
 */
package com.linecorp.armeria.server.thrift;

import com.linecorp.armeria.common.AggregatedHttpRequest;
import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.Request;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.RpcRequest;
import com.linecorp.armeria.common.RpcResponse;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.thrift.ThriftCall;
import com.linecorp.armeria.common.thrift.ThriftProtocolFactories;
import com.linecorp.armeria.common.thrift.ThriftReply;
import com.linecorp.armeria.common.thrift.ThriftSerializationFormats;
import com.linecorp.armeria.common.unsafe.PooledHttpData;
import com.linecorp.armeria.common.unsafe.PooledHttpRequest;
import com.linecorp.armeria.common.util.CompletionActions;
import com.linecorp.armeria.common.util.Exceptions;
import com.linecorp.armeria.common.util.SafeCloseable;
import com.linecorp.armeria.internal.common.thrift.TByteBufTransport;
import com.linecorp.armeria.internal.common.thrift.ThriftFieldAccess;
import com.linecorp.armeria.internal.common.thrift.ThriftFunction;
import com.linecorp.armeria.internal.shaded.guava.base.Preconditions;
import com.linecorp.armeria.internal.shaded.guava.collect.ImmutableSet;
import com.linecorp.armeria.internal.shaded.guava.collect.Iterables;
import com.linecorp.armeria.server.DecoratingService;
import com.linecorp.armeria.server.HttpResponseException;
import com.linecorp.armeria.server.HttpService;
import com.linecorp.armeria.server.HttpStatusException;
import com.linecorp.armeria.server.RpcService;
import com.linecorp.armeria.server.Service;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.thrift.THttpServiceBuilder;
import com.linecorp.armeria.server.thrift.ThriftCallService;
import com.linecorp.armeria.server.thrift.ThriftServiceEntry;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufHolder;
import io.netty.util.concurrent.EventExecutor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.apache.thrift.TApplicationException;
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.meta_data.FieldMetaData;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TTransport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class THttpService
extends DecoratingService<RpcRequest, RpcResponse, HttpRequest, HttpResponse>
implements HttpService {
    private static final Logger logger = LoggerFactory.getLogger(THttpService.class);
    private static final String PROTOCOL_NOT_SUPPORTED = "Specified content-type not supported";
    private static final String ACCEPT_THRIFT_PROTOCOL_MUST_MATCH_CONTENT_TYPE = "Thrift protocol specified in Accept header must match the one specified in the content-type header";
    private static final SerializationFormat[] EMPTY_FORMATS = new SerializationFormat[0];
    private final SerializationFormat[] supportedSerializationFormatArray;
    private final Set<SerializationFormat> supportedSerializationFormats;
    private final ThriftCallService thriftService;

    public static THttpServiceBuilder builder() {
        return new THttpServiceBuilder();
    }

    public static THttpService of(Object implementation) {
        return THttpService.of(implementation, ThriftSerializationFormats.BINARY);
    }

    public static THttpService of(Object implementation, SerializationFormat defaultSerializationFormat) {
        return new THttpService(ThriftCallService.of(implementation), THttpService.newSupportedSerializationFormats(defaultSerializationFormat, ThriftSerializationFormats.values()));
    }

    public static THttpService ofFormats(Object implementation, SerializationFormat defaultSerializationFormat, SerializationFormat ... otherSupportedSerializationFormats) {
        Objects.requireNonNull(otherSupportedSerializationFormats, "otherSupportedSerializationFormats");
        return THttpService.ofFormats(implementation, defaultSerializationFormat, Arrays.asList(otherSupportedSerializationFormats));
    }

    public static THttpService ofFormats(Object implementation, SerializationFormat defaultSerializationFormat, Iterable<SerializationFormat> otherSupportedSerializationFormats) {
        return new THttpService(ThriftCallService.of(implementation), THttpService.newSupportedSerializationFormats(defaultSerializationFormat, otherSupportedSerializationFormats));
    }

    public static Function<? super RpcService, THttpService> newDecorator() {
        return THttpService.newDecorator(ThriftSerializationFormats.BINARY);
    }

    public static Function<? super RpcService, THttpService> newDecorator(SerializationFormat defaultSerializationFormat) {
        SerializationFormat[] supportedSerializationFormatArray = THttpService.newSupportedSerializationFormats(defaultSerializationFormat, ThriftSerializationFormats.values());
        return delegate -> new THttpService((RpcService)delegate, supportedSerializationFormatArray);
    }

    public static Function<? super RpcService, THttpService> newDecorator(SerializationFormat defaultSerializationFormat, SerializationFormat ... otherSupportedSerializationFormats) {
        Objects.requireNonNull(otherSupportedSerializationFormats, "otherSupportedSerializationFormats");
        return THttpService.newDecorator(defaultSerializationFormat, Arrays.asList(otherSupportedSerializationFormats));
    }

    public static Function<? super RpcService, THttpService> newDecorator(SerializationFormat defaultSerializationFormat, Iterable<SerializationFormat> otherSupportedSerializationFormats) {
        SerializationFormat[] supportedSerializationFormatArray = THttpService.newSupportedSerializationFormats(defaultSerializationFormat, otherSupportedSerializationFormats);
        return delegate -> new THttpService((RpcService)delegate, supportedSerializationFormatArray);
    }

    private static SerializationFormat[] newSupportedSerializationFormats(SerializationFormat defaultSerializationFormat, Iterable<SerializationFormat> otherSupportedSerializationFormats) {
        Objects.requireNonNull(defaultSerializationFormat, "defaultSerializationFormat");
        Objects.requireNonNull(otherSupportedSerializationFormats, "otherSupportedSerializationFormats");
        LinkedHashSet<SerializationFormat> set = new LinkedHashSet<SerializationFormat>();
        set.add(defaultSerializationFormat);
        Iterables.addAll(set, otherSupportedSerializationFormats);
        return set.toArray(EMPTY_FORMATS);
    }

    THttpService(RpcService delegate, SerializationFormat[] supportedSerializationFormatArray) {
        super((Service)delegate);
        this.thriftService = THttpService.findThriftService(delegate);
        this.supportedSerializationFormatArray = supportedSerializationFormatArray;
        this.supportedSerializationFormats = ImmutableSet.copyOf((Object[])supportedSerializationFormatArray);
    }

    private static ThriftCallService findThriftService(Service<?, ?> delegate) {
        ThriftCallService thriftService = (ThriftCallService)delegate.as(ThriftCallService.class);
        Preconditions.checkState((thriftService != null ? 1 : 0) != 0, (String)"service being decorated is not a ThriftCallService: %s", delegate);
        return thriftService;
    }

    public Map<String, ThriftServiceEntry> entries() {
        return this.thriftService.entries();
    }

    public Set<SerializationFormat> supportedSerializationFormats() {
        return this.supportedSerializationFormats;
    }

    public SerializationFormat defaultSerializationFormat() {
        return this.supportedSerializationFormatArray[0];
    }

    public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exception {
        if (req.method() != HttpMethod.POST) {
            return HttpResponse.of((HttpStatus)HttpStatus.METHOD_NOT_ALLOWED);
        }
        SerializationFormat serializationFormat = this.determineSerializationFormat(req);
        if (serializationFormat == null) {
            return HttpResponse.of((HttpStatus)HttpStatus.UNSUPPORTED_MEDIA_TYPE, (MediaType)MediaType.PLAIN_TEXT_UTF_8, (String)PROTOCOL_NOT_SUPPORTED);
        }
        if (!THttpService.validateAcceptHeaders(req, serializationFormat)) {
            return HttpResponse.of((HttpStatus)HttpStatus.NOT_ACCEPTABLE, (MediaType)MediaType.PLAIN_TEXT_UTF_8, (String)ACCEPT_THRIFT_PROTOCOL_MUST_MATCH_CONTENT_TYPE);
        }
        CompletableFuture responseFuture = new CompletableFuture();
        HttpResponse res = HttpResponse.from(responseFuture);
        ctx.logBuilder().serializationFormat(serializationFormat);
        ctx.logBuilder().deferRequestContent();
        ((CompletableFuture)PooledHttpRequest.of((HttpRequest)req).aggregateWithPooledObjects((EventExecutor)ctx.eventLoop(), ctx.alloc()).handle((aReq, cause) -> {
            if (cause != null) {
                HttpResponse errorRes = ctx.config().verboseResponses() ? HttpResponse.of((HttpStatus)HttpStatus.INTERNAL_SERVER_ERROR, (MediaType)MediaType.PLAIN_TEXT_UTF_8, (String)Exceptions.traceText((Throwable)cause)) : HttpResponse.of((HttpStatus)HttpStatus.INTERNAL_SERVER_ERROR);
                responseFuture.complete(errorRes);
                return null;
            }
            this.decodeAndInvoke(ctx, (AggregatedHttpRequest)aReq, serializationFormat, responseFuture);
            return null;
        })).exceptionally(CompletionActions::log);
        return res;
    }

    @Nullable
    private SerializationFormat determineSerializationFormat(HttpRequest req) {
        RequestHeaders headers = req.headers();
        MediaType contentType = headers.contentType();
        if (contentType != null) {
            SerializationFormat serializationFormat = this.findSerializationFormat(contentType);
            if (serializationFormat == null) {
                if (!("text".equals(contentType.type()) && "plain".equals(contentType.subtype()) || "application".equals(contentType.type()) && "octet-stream".equals(contentType.subtype()))) {
                    return null;
                }
            } else {
                return serializationFormat;
            }
        }
        return this.defaultSerializationFormat();
    }

    private static boolean validateAcceptHeaders(HttpRequest req, SerializationFormat serializationFormat) {
        List acceptHeaders = req.headers().getAll((CharSequence)HttpHeaderNames.ACCEPT);
        return acceptHeaders.isEmpty() || serializationFormat.mediaTypes().matchHeaders((Iterable)acceptHeaders) != null;
    }

    @Nullable
    private SerializationFormat findSerializationFormat(MediaType contentType) {
        for (SerializationFormat format : this.supportedSerializationFormatArray) {
            if (!format.isAccepted(contentType)) continue;
            return format;
        }
        return null;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void decodeAndInvoke(ServiceRequestContext ctx, AggregatedHttpRequest req, SerializationFormat serializationFormat, CompletableFuture<HttpResponse> httpRes) {
        RpcRequest decodedReq;
        ThriftFunction f;
        int seqId;
        ByteBuf buf;
        HttpData content = req.content();
        if (content instanceof ByteBufHolder) {
            buf = ((ByteBufHolder)content).content();
        } else {
            buf = ctx.alloc().buffer(content.length());
            buf.writeBytes(content.array());
        }
        TByteBufTransport inTransport = new TByteBufTransport(buf);
        TProtocol inProto = ThriftProtocolFactories.get(serializationFormat).getProtocol((TTransport)inTransport);
        try {
            String methodName;
            String serviceName;
            TMessage header;
            try {
                header = inProto.readMessageBegin();
            }
            catch (Exception e) {
                logger.debug("{} Failed to decode a {} header:", new Object[]{ctx, serializationFormat, e});
                HttpResponse errorRes = ctx.config().verboseResponses() ? HttpResponse.of((HttpStatus)HttpStatus.BAD_REQUEST, (MediaType)MediaType.PLAIN_TEXT_UTF_8, (String)"Failed to decode a %s header: %s", (Object[])new Object[]{serializationFormat, Exceptions.traceText((Throwable)e)}) : HttpResponse.of((HttpStatus)HttpStatus.BAD_REQUEST, (MediaType)MediaType.PLAIN_TEXT_UTF_8, (String)"Failed to decode a %s header", (Object[])new Object[]{serializationFormat});
                httpRes.complete(errorRes);
                buf.release();
                ctx.logBuilder().requestContent(null, null);
                return;
            }
            seqId = header.seqid;
            byte typeValue = header.type;
            int colonIdx = header.name.indexOf(58);
            if (colonIdx < 0) {
                serviceName = "";
                methodName = header.name;
            } else {
                serviceName = header.name.substring(0, colonIdx);
                methodName = header.name.substring(colonIdx + 1);
            }
            if (typeValue != 1 && typeValue != 4) {
                TApplicationException cause = new TApplicationException(2, "unexpected TMessageType: " + THttpService.typeString(typeValue));
                THttpService.handlePreDecodeException(ctx, httpRes, (Throwable)cause, serializationFormat, seqId, methodName);
                return;
            }
            ThriftServiceEntry entry = this.entries().get(serviceName);
            ThriftFunction thriftFunction = f = entry != null ? entry.metadata.function(methodName) : null;
            if (f == null) {
                TApplicationException cause = new TApplicationException(1, "unknown method: " + header.name);
                THttpService.handlePreDecodeException(ctx, httpRes, (Throwable)cause, serializationFormat, seqId, methodName);
                return;
            }
            ctx.logBuilder().name(f.serviceType().getName(), methodName);
            try {
                TBase<?, ?> args = f.newArgs();
                args.read(inProto);
                inProto.readMessageEnd();
                decodedReq = THttpService.toRpcRequest(f.serviceType(), header.name, args);
                ctx.logBuilder().requestContent((Object)decodedReq, (Object)new ThriftCall(header, args));
            }
            catch (Exception e) {
                logger.debug("{} Failed to decode Thrift arguments:", (Object)ctx, (Object)e);
                TApplicationException cause = new TApplicationException(7, "failed to decode arguments: " + e);
                THttpService.handlePreDecodeException(ctx, httpRes, (Throwable)cause, serializationFormat, seqId, methodName);
                buf.release();
                ctx.logBuilder().requestContent(null, null);
                return;
            }
        }
        finally {
            buf.release();
            ctx.logBuilder().requestContent(null, null);
        }
        this.invoke(ctx, serializationFormat, seqId, f, decodedReq, httpRes);
    }

    private static String typeString(byte typeValue) {
        switch (typeValue) {
            case 1: {
                return "CALL";
            }
            case 2: {
                return "REPLY";
            }
            case 3: {
                return "EXCEPTION";
            }
            case 4: {
                return "ONEWAY";
            }
        }
        return "UNKNOWN(" + (typeValue & 0xFF) + ')';
    }

    private void invoke(ServiceRequestContext ctx, SerializationFormat serializationFormat, int seqId, ThriftFunction func, RpcRequest call, CompletableFuture<HttpResponse> res) {
        RpcResponse reply;
        try (SafeCloseable ignored = ctx.push();){
            reply = (RpcResponse)((Service)this.unwrap()).serve(ctx, (Request)call);
        }
        catch (Throwable cause2) {
            THttpService.handleException(ctx, RpcResponse.ofFailure((Throwable)cause2), res, serializationFormat, seqId, func, cause2);
            return;
        }
        reply.handle((result, cause) -> {
            if (func.isOneWay()) {
                THttpService.handleOneWaySuccess(ctx, reply, res, serializationFormat);
                return null;
            }
            if (cause != null) {
                THttpService.handleException(ctx, reply, res, serializationFormat, seqId, func, cause);
                return null;
            }
            try {
                THttpService.handleSuccess(ctx, reply, res, serializationFormat, seqId, func, result);
            }
            catch (Throwable t) {
                THttpService.handleException(ctx, RpcResponse.ofFailure((Throwable)t), res, serializationFormat, seqId, func, t);
            }
            return null;
        }).exceptionally(CompletionActions::log);
    }

    private static RpcRequest toRpcRequest(Class<?> serviceType, String method, TBase<?, ?> thriftArgs) {
        Objects.requireNonNull(thriftArgs, "thriftArgs");
        Set fields = FieldMetaData.getStructMetaDataMap(thriftArgs.getClass()).keySet();
        int numFields = fields.size();
        switch (numFields) {
            case 0: {
                return RpcRequest.of(serviceType, (String)method);
            }
            case 1: {
                return RpcRequest.of(serviceType, (String)method, (Object)ThriftFieldAccess.get(thriftArgs, (TFieldIdEnum)fields.iterator().next()));
            }
        }
        ArrayList<Object> list = new ArrayList<Object>(numFields);
        for (TFieldIdEnum field : fields) {
            list.add(ThriftFieldAccess.get(thriftArgs, field));
        }
        return RpcRequest.of(serviceType, (String)method, list);
    }

    private static void handleSuccess(ServiceRequestContext ctx, RpcResponse rpcRes, CompletableFuture<HttpResponse> httpRes, SerializationFormat serializationFormat, int seqId, ThriftFunction func, Object returnValue) {
        TBase<?, ?> wrappedResult = func.newResult();
        func.setSuccess(wrappedResult, returnValue);
        THttpService.respond(serializationFormat, THttpService.encodeSuccess(ctx, rpcRes, serializationFormat, func.name(), seqId, wrappedResult), httpRes);
    }

    private static void handleOneWaySuccess(ServiceRequestContext ctx, RpcResponse rpcRes, CompletableFuture<HttpResponse> httpRes, SerializationFormat serializationFormat) {
        ctx.logBuilder().responseContent((Object)rpcRes, null);
        THttpService.respond(serializationFormat, HttpData.empty(), httpRes);
    }

    private static void handleException(ServiceRequestContext ctx, RpcResponse rpcRes, CompletableFuture<HttpResponse> httpRes, SerializationFormat serializationFormat, int seqId, ThriftFunction func, Throwable cause) {
        if (cause instanceof HttpStatusException) {
            httpRes.complete(HttpResponse.of((HttpStatus)((HttpStatusException)cause).httpStatus()));
            return;
        }
        if (cause instanceof HttpResponseException) {
            httpRes.complete(((HttpResponseException)cause).httpResponse());
            return;
        }
        TBase<?, ?> result = func.newResult();
        HttpData content = func.setException(result, cause) ? THttpService.encodeSuccess(ctx, rpcRes, serializationFormat, func.name(), seqId, result) : THttpService.encodeException(ctx, rpcRes, serializationFormat, seqId, func.name(), cause);
        THttpService.respond(serializationFormat, content, httpRes);
    }

    private static void handlePreDecodeException(ServiceRequestContext ctx, CompletableFuture<HttpResponse> httpRes, Throwable cause, SerializationFormat serializationFormat, int seqId, String methodName) {
        HttpData content = THttpService.encodeException(ctx, RpcResponse.ofFailure((Throwable)cause), serializationFormat, seqId, methodName, cause);
        THttpService.respond(serializationFormat, content, httpRes);
    }

    private static void respond(SerializationFormat serializationFormat, HttpData content, CompletableFuture<HttpResponse> res) {
        res.complete(HttpResponse.of((HttpStatus)HttpStatus.OK, (MediaType)serializationFormat.mediaType(), (HttpData)content));
    }

    private static HttpData encodeSuccess(ServiceRequestContext ctx, RpcResponse reply, SerializationFormat serializationFormat, String methodName, int seqId, TBase<?, ?> result) {
        ByteBuf buf = ctx.alloc().buffer(128);
        boolean success = false;
        try {
            TByteBufTransport transport = new TByteBufTransport(buf);
            TProtocol outProto = ThriftProtocolFactories.get(serializationFormat).getProtocol((TTransport)transport);
            TMessage header = new TMessage(methodName, 2, seqId);
            outProto.writeMessageBegin(header);
            result.write(outProto);
            outProto.writeMessageEnd();
            ctx.logBuilder().responseContent((Object)reply, (Object)new ThriftReply(header, result));
            PooledHttpData encoded = PooledHttpData.wrap((ByteBuf)buf);
            success = true;
            PooledHttpData pooledHttpData = encoded;
            return pooledHttpData;
        }
        catch (TException e) {
            throw new Error(e);
        }
        finally {
            if (!success) {
                buf.release();
            }
        }
    }

    private static HttpData encodeException(ServiceRequestContext ctx, RpcResponse reply, SerializationFormat serializationFormat, int seqId, String methodName, Throwable cause) {
        TApplicationException appException;
        if (cause instanceof TApplicationException) {
            appException = (TApplicationException)cause;
        } else {
            appException = ctx.config().verboseResponses() ? new TApplicationException(6, "\n---- BEGIN server-side trace ----\n" + Exceptions.traceText((Throwable)cause) + "---- END server-side trace ----") : new TApplicationException(6);
            appException.initCause(cause);
        }
        ByteBuf buf = ctx.alloc().buffer(128);
        boolean success = false;
        try {
            TByteBufTransport transport = new TByteBufTransport(buf);
            TProtocol outProto = ThriftProtocolFactories.get(serializationFormat).getProtocol((TTransport)transport);
            TMessage header = new TMessage(methodName, 3, seqId);
            outProto.writeMessageBegin(header);
            appException.write(outProto);
            outProto.writeMessageEnd();
            ctx.logBuilder().responseContent((Object)reply, (Object)new ThriftReply(header, appException));
            PooledHttpData encoded = PooledHttpData.wrap((ByteBuf)buf);
            success = true;
            PooledHttpData pooledHttpData = encoded;
            return pooledHttpData;
        }
        catch (TException e) {
            throw new Error(e);
        }
        finally {
            if (!success) {
                buf.release();
            }
        }
    }
}

