/*
 * Decompiled with CFR 0.152.
 */
package com.linecorp.armeria.server.thrift;

import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.linecorp.armeria.common.DefaultRpcResponse;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.Request;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.RpcRequest;
import com.linecorp.armeria.common.RpcResponse;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.http.AggregatedHttpMessage;
import com.linecorp.armeria.common.http.HttpData;
import com.linecorp.armeria.common.http.HttpHeaderNames;
import com.linecorp.armeria.common.http.HttpHeaders;
import com.linecorp.armeria.common.http.HttpRequest;
import com.linecorp.armeria.common.http.HttpResponseWriter;
import com.linecorp.armeria.common.http.HttpStatus;
import com.linecorp.armeria.common.thrift.ThriftCall;
import com.linecorp.armeria.common.thrift.ThriftProtocolFactories;
import com.linecorp.armeria.common.thrift.ThriftReply;
import com.linecorp.armeria.common.thrift.ThriftSerializationFormats;
import com.linecorp.armeria.common.util.CompletionActions;
import com.linecorp.armeria.common.util.Functions;
import com.linecorp.armeria.common.util.SafeCloseable;
import com.linecorp.armeria.internal.http.ByteBufHttpData;
import com.linecorp.armeria.internal.thrift.ThriftFieldAccess;
import com.linecorp.armeria.internal.thrift.ThriftFunction;
import com.linecorp.armeria.server.Service;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.http.AbstractHttpService;
import com.linecorp.armeria.server.thrift.TByteBufTransport;
import com.linecorp.armeria.server.thrift.ThriftCallService;
import com.linecorp.armeria.server.thrift.ThriftServiceEntry;
import io.netty.buffer.ByteBuf;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import org.apache.thrift.TApplicationException;
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.meta_data.FieldMetaData;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.transport.TMemoryInputTransport;
import org.apache.thrift.transport.TTransport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class THttpService
extends AbstractHttpService {
    private static final Logger logger = LoggerFactory.getLogger(THttpService.class);
    private static final String PROTOCOL_NOT_SUPPORTED = "Specified content-type not supported";
    private static final String ACCEPT_THRIFT_PROTOCOL_MUST_MATCH_CONTENT_TYPE = "Thrift protocol specified in Accept header must match the one specified in the content-type header";
    private static final Map<SerializationFormat, ThreadLocalTProtocol> FORMAT_TO_THREAD_LOCAL_INPUT_PROTOCOL = THttpService.createFormatToThreadLocalTProtocolMap();
    private final Service<RpcRequest, RpcResponse> delegate;
    private final SerializationFormat[] allowedSerializationFormatArray;
    private final Set<SerializationFormat> allowedSerializationFormats;
    private final ThriftCallService thriftService;

    public static THttpService of(Object implementation) {
        return THttpService.of(implementation, ThriftSerializationFormats.BINARY);
    }

    public static THttpService of(Map<String, ?> implementations) {
        return THttpService.of(implementations, ThriftSerializationFormats.BINARY);
    }

    public static THttpService of(Object implementation, SerializationFormat defaultSerializationFormat) {
        return new THttpService(ThriftCallService.of(implementation), THttpService.newAllowedSerializationFormats(defaultSerializationFormat, ThriftSerializationFormats.values()));
    }

    public static THttpService of(Map<String, ?> implementations, SerializationFormat defaultSerializationFormat) {
        return new THttpService(ThriftCallService.of(implementations), THttpService.newAllowedSerializationFormats(defaultSerializationFormat, ThriftSerializationFormats.values()));
    }

    public static THttpService ofFormats(Object implementation, SerializationFormat defaultSerializationFormat, SerializationFormat ... otherAllowedSerializationFormats) {
        Objects.requireNonNull(otherAllowedSerializationFormats, "otherAllowedSerializationFormats");
        return THttpService.ofFormats(implementation, defaultSerializationFormat, Arrays.asList(otherAllowedSerializationFormats));
    }

    public static THttpService ofFormats(Map<String, ?> implementations, SerializationFormat defaultSerializationFormat, SerializationFormat ... otherAllowedSerializationFormats) {
        Objects.requireNonNull(otherAllowedSerializationFormats, "otherAllowedSerializationFormats");
        return THttpService.ofFormats(implementations, defaultSerializationFormat, Arrays.asList(otherAllowedSerializationFormats));
    }

    public static THttpService ofFormats(Object implementation, SerializationFormat defaultSerializationFormat, Iterable<SerializationFormat> otherAllowedSerializationFormats) {
        return new THttpService(ThriftCallService.of(implementation), THttpService.newAllowedSerializationFormats(defaultSerializationFormat, otherAllowedSerializationFormats));
    }

    public static THttpService ofFormats(Map<String, ?> implementations, SerializationFormat defaultSerializationFormat, Iterable<SerializationFormat> otherAllowedSerializationFormats) {
        return new THttpService(ThriftCallService.of(implementations), THttpService.newAllowedSerializationFormats(defaultSerializationFormat, otherAllowedSerializationFormats));
    }

    public static Function<Service<RpcRequest, RpcResponse>, THttpService> newDecorator() {
        return THttpService.newDecorator(ThriftSerializationFormats.BINARY);
    }

    public static Function<Service<RpcRequest, RpcResponse>, THttpService> newDecorator(SerializationFormat defaultSerializationFormat) {
        SerializationFormat[] allowedSerializationFormatArray = THttpService.newAllowedSerializationFormats(defaultSerializationFormat, ThriftSerializationFormats.values());
        return delegate -> new THttpService((Service<RpcRequest, RpcResponse>)delegate, allowedSerializationFormatArray);
    }

    public static Function<Service<RpcRequest, RpcResponse>, THttpService> newDecorator(SerializationFormat defaultSerializationFormat, SerializationFormat ... otherAllowedSerializationFormats) {
        Objects.requireNonNull(otherAllowedSerializationFormats, "otherAllowedSerializationFormats");
        return THttpService.newDecorator(defaultSerializationFormat, Arrays.asList(otherAllowedSerializationFormats));
    }

    public static Function<Service<RpcRequest, RpcResponse>, THttpService> newDecorator(SerializationFormat defaultSerializationFormat, Iterable<SerializationFormat> otherAllowedSerializationFormats) {
        SerializationFormat[] allowedSerializationFormatArray = THttpService.newAllowedSerializationFormats(defaultSerializationFormat, otherAllowedSerializationFormats);
        return delegate -> new THttpService((Service<RpcRequest, RpcResponse>)delegate, allowedSerializationFormatArray);
    }

    private static SerializationFormat[] newAllowedSerializationFormats(SerializationFormat defaultSerializationFormat, Iterable<SerializationFormat> otherAllowedSerializationFormats) {
        Objects.requireNonNull(defaultSerializationFormat, "defaultSerializationFormat");
        Objects.requireNonNull(otherAllowedSerializationFormats, "otherAllowedSerializationFormats");
        LinkedHashSet<SerializationFormat> set = new LinkedHashSet<SerializationFormat>();
        set.add(defaultSerializationFormat);
        Iterables.addAll(set, otherAllowedSerializationFormats);
        return set.toArray(new SerializationFormat[set.size()]);
    }

    private THttpService(Service<RpcRequest, RpcResponse> delegate, SerializationFormat[] allowedSerializationFormatArray) {
        Objects.requireNonNull(delegate, "delegate");
        this.delegate = delegate;
        this.thriftService = THttpService.findThriftService(delegate);
        this.allowedSerializationFormatArray = allowedSerializationFormatArray;
        this.allowedSerializationFormats = ImmutableSet.copyOf((Object[])allowedSerializationFormatArray);
    }

    private static ThriftCallService findThriftService(Service<?, ?> delegate) {
        return (ThriftCallService)delegate.as(ThriftCallService.class).orElseThrow(() -> new IllegalStateException("service being decorated is not a ThriftCallService: " + delegate));
    }

    public Map<String, ThriftServiceEntry> entries() {
        return this.thriftService.entries();
    }

    public Set<SerializationFormat> allowedSerializationFormats() {
        return this.allowedSerializationFormats;
    }

    public SerializationFormat defaultSerializationFormat() {
        return this.allowedSerializationFormatArray[0];
    }

    protected void doPost(ServiceRequestContext ctx, HttpRequest req, HttpResponseWriter res) {
        SerializationFormat serializationFormat = this.validateRequestAndDetermineSerializationFormat(req, res);
        if (serializationFormat == null) {
            return;
        }
        ctx.logBuilder().serializationFormat(serializationFormat);
        ctx.logBuilder().deferRequestContent();
        ((CompletableFuture)req.aggregate().handle(Functions.voidFunction((aReq, cause) -> {
            if (cause != null) {
                res.respond(HttpStatus.INTERNAL_SERVER_ERROR, MediaType.PLAIN_TEXT_UTF_8, Throwables.getStackTraceAsString((Throwable)cause));
                return;
            }
            this.decodeAndInvoke(ctx, (AggregatedHttpMessage)aReq, serializationFormat, res);
        }))).exceptionally(CompletionActions::log);
    }

    private SerializationFormat validateRequestAndDetermineSerializationFormat(HttpRequest req, HttpResponseWriter res) {
        List acceptHeaders;
        SerializationFormat serializationFormat;
        HttpHeaders headers = req.headers();
        String contentType = (String)headers.get((Object)HttpHeaderNames.CONTENT_TYPE);
        if (contentType != null) {
            serializationFormat = this.findSerializationFormat(contentType);
            if (serializationFormat == null) {
                res.respond(HttpStatus.UNSUPPORTED_MEDIA_TYPE, MediaType.PLAIN_TEXT_UTF_8, PROTOCOL_NOT_SUPPORTED);
                return null;
            }
        } else {
            serializationFormat = this.defaultSerializationFormat();
        }
        if (!(acceptHeaders = headers.getAll((Object)HttpHeaderNames.ACCEPT)).isEmpty() && !serializationFormat.mediaTypes().matchHeaders((Iterable)acceptHeaders).isPresent()) {
            res.respond(HttpStatus.NOT_ACCEPTABLE, MediaType.PLAIN_TEXT_UTF_8, ACCEPT_THRIFT_PROTOCOL_MUST_MATCH_CONTENT_TYPE);
            return null;
        }
        return serializationFormat;
    }

    private SerializationFormat findSerializationFormat(String contentType) {
        MediaType mediaType;
        try {
            mediaType = MediaType.parse((String)contentType);
        }
        catch (IllegalArgumentException e) {
            logger.debug("Failed to parse the 'content-type' header: {}", (Object)contentType, (Object)e);
            return null;
        }
        for (SerializationFormat format : this.allowedSerializationFormatArray) {
            if (!format.isAccepted(new MediaType[]{mediaType})) continue;
            return format;
        }
        return null;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void decodeAndInvoke(ServiceRequestContext ctx, AggregatedHttpMessage req, SerializationFormat serializationFormat, HttpResponseWriter res) {
        RpcRequest decodedReq;
        ThriftFunction f;
        int seqId;
        TProtocol inProto = (TProtocol)FORMAT_TO_THREAD_LOCAL_INPUT_PROTOCOL.get(serializationFormat).get();
        inProto.reset();
        TMemoryInputTransport inTransport = (TMemoryInputTransport)inProto.getTransport();
        HttpData content = req.content();
        inTransport.reset(content.array(), content.offset(), content.length());
        try {
            String methodName;
            String serviceName;
            TMessage header;
            try {
                header = inProto.readMessageBegin();
            }
            catch (Exception e) {
                logger.debug("{} Failed to decode Thrift header:", (Object)ctx, (Object)e);
                res.respond(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, "Failed to decode Thrift header: " + Throwables.getStackTraceAsString((Throwable)e));
                inTransport.clear();
                ctx.logBuilder().requestContent(null, null);
                return;
            }
            seqId = header.seqid;
            byte typeValue = header.type;
            int colonIdx = header.name.indexOf(58);
            if (colonIdx < 0) {
                serviceName = "";
                methodName = header.name;
            } else {
                serviceName = header.name.substring(0, colonIdx);
                methodName = header.name.substring(colonIdx + 1);
            }
            if (typeValue != 1 && typeValue != 4) {
                TApplicationException cause = new TApplicationException(2, "unexpected TMessageType: " + THttpService.typeString(typeValue));
                THttpService.handlePreDecodeException(ctx, res, (Throwable)cause, serializationFormat, seqId, methodName);
                return;
            }
            ThriftServiceEntry entry = this.entries().get(serviceName);
            ThriftFunction thriftFunction = f = entry != null ? entry.metadata.function(methodName) : null;
            if (f == null) {
                TApplicationException cause = new TApplicationException(1, "unknown method: " + header.name);
                THttpService.handlePreDecodeException(ctx, res, (Throwable)cause, serializationFormat, seqId, methodName);
                return;
            }
            try {
                TBase<?, ?> args = f.newArgs();
                args.read(inProto);
                inProto.readMessageEnd();
                decodedReq = THttpService.toRpcRequest(f.serviceType(), header.name, args);
                ctx.logBuilder().requestContent((Object)decodedReq, (Object)new ThriftCall(header, args));
            }
            catch (Exception e) {
                logger.debug("{} Failed to decode Thrift arguments:", (Object)ctx, (Object)e);
                TApplicationException cause = new TApplicationException(7, "failed to decode arguments: " + e);
                THttpService.handlePreDecodeException(ctx, res, (Throwable)cause, serializationFormat, seqId, methodName);
                inTransport.clear();
                ctx.logBuilder().requestContent(null, null);
                return;
            }
        }
        finally {
            inTransport.clear();
            ctx.logBuilder().requestContent(null, null);
        }
        this.invoke(ctx, serializationFormat, seqId, f, decodedReq, res);
    }

    private static String typeString(byte typeValue) {
        switch (typeValue) {
            case 1: {
                return "CALL";
            }
            case 2: {
                return "REPLY";
            }
            case 3: {
                return "EXCEPTION";
            }
            case 4: {
                return "ONEWAY";
            }
        }
        return "UNKNOWN(" + (typeValue & 0xFF) + ')';
    }

    private void invoke(ServiceRequestContext ctx, SerializationFormat serializationFormat, int seqId, ThriftFunction func, RpcRequest call, HttpResponseWriter res) {
        RpcResponse reply;
        try (SafeCloseable ignored = RequestContext.push((RequestContext)ctx);){
            reply = (RpcResponse)this.delegate.serve(ctx, (Request)call);
        }
        catch (Throwable cause2) {
            THttpService.handleException(ctx, (RpcResponse)new DefaultRpcResponse(cause2), res, serializationFormat, seqId, func, cause2);
            return;
        }
        reply.handle(Functions.voidFunction((result, cause) -> {
            if (cause != null) {
                THttpService.handleException(ctx, reply, res, serializationFormat, seqId, func, cause);
                return;
            }
            if (func.isOneWay()) {
                THttpService.handleOneWaySuccess(ctx, reply, res, serializationFormat);
                return;
            }
            try {
                THttpService.handleSuccess(ctx, reply, res, serializationFormat, seqId, func, result);
            }
            catch (Throwable t) {
                THttpService.handleException(ctx, (RpcResponse)new DefaultRpcResponse(t), res, serializationFormat, seqId, func, t);
            }
        })).exceptionally(CompletionActions::log);
    }

    private static RpcRequest toRpcRequest(Class<?> serviceType, String method, TBase<?, ?> thriftArgs) {
        Objects.requireNonNull(thriftArgs, "thriftArgs");
        Set fields = FieldMetaData.getStructMetaDataMap(thriftArgs.getClass()).keySet();
        int numFields = fields.size();
        switch (numFields) {
            case 0: {
                return RpcRequest.of(serviceType, (String)method);
            }
            case 1: {
                return RpcRequest.of(serviceType, (String)method, (Object)ThriftFieldAccess.get(thriftArgs, (TFieldIdEnum)fields.iterator().next()));
            }
        }
        ArrayList<Object> list = new ArrayList<Object>(numFields);
        for (TFieldIdEnum field : fields) {
            list.add(ThriftFieldAccess.get(thriftArgs, field));
        }
        return RpcRequest.of(serviceType, (String)method, list);
    }

    private static void handleSuccess(ServiceRequestContext ctx, RpcResponse rpcRes, HttpResponseWriter httpRes, SerializationFormat serializationFormat, int seqId, ThriftFunction func, Object returnValue) {
        TBase<?, ?> wrappedResult = func.newResult();
        func.setSuccess(wrappedResult, returnValue);
        THttpService.respond(serializationFormat, THttpService.encodeSuccess(ctx, rpcRes, serializationFormat, func.name(), seqId, wrappedResult), httpRes);
    }

    private static void handleOneWaySuccess(ServiceRequestContext ctx, RpcResponse rpcRes, HttpResponseWriter httpRes, SerializationFormat serializationFormat) {
        ctx.logBuilder().responseContent((Object)rpcRes, null);
        THttpService.respond(serializationFormat, HttpData.EMPTY_DATA, httpRes);
    }

    private static void handleException(ServiceRequestContext ctx, RpcResponse rpcRes, HttpResponseWriter httpRes, SerializationFormat serializationFormat, int seqId, ThriftFunction func, Throwable cause) {
        TBase<?, ?> result = func.newResult();
        HttpData content = func.setException(result, cause) ? THttpService.encodeSuccess(ctx, rpcRes, serializationFormat, func.name(), seqId, result) : THttpService.encodeException(ctx, rpcRes, serializationFormat, seqId, func.name(), cause);
        THttpService.respond(serializationFormat, content, httpRes);
    }

    private static void handlePreDecodeException(ServiceRequestContext ctx, HttpResponseWriter httpRes, Throwable cause, SerializationFormat serializationFormat, int seqId, String methodName) {
        HttpData content = THttpService.encodeException(ctx, (RpcResponse)new DefaultRpcResponse(cause), serializationFormat, seqId, methodName, cause);
        THttpService.respond(serializationFormat, content, httpRes);
    }

    private static void respond(SerializationFormat serializationFormat, HttpData content, HttpResponseWriter res) {
        res.respond(HttpStatus.OK, serializationFormat.mediaType(), content);
    }

    private static HttpData encodeSuccess(ServiceRequestContext ctx, RpcResponse reply, SerializationFormat serializationFormat, String methodName, int seqId, TBase<?, ?> result) {
        ByteBuf buf = ctx.alloc().buffer(128);
        boolean success = false;
        try {
            TByteBufTransport transport = new TByteBufTransport(buf);
            TProtocol outProto = ThriftProtocolFactories.get(serializationFormat).getProtocol((TTransport)transport);
            TMessage header = new TMessage(methodName, 2, seqId);
            outProto.writeMessageBegin(header);
            result.write(outProto);
            outProto.writeMessageEnd();
            ctx.logBuilder().responseContent((Object)reply, (Object)new ThriftReply(header, result));
            ByteBufHttpData encoded = new ByteBufHttpData(buf, false);
            success = true;
            ByteBufHttpData byteBufHttpData = encoded;
            return byteBufHttpData;
        }
        catch (TException e) {
            throw new Error(e);
        }
        finally {
            if (!success) {
                buf.release();
            }
        }
    }

    private static HttpData encodeException(ServiceRequestContext ctx, RpcResponse reply, SerializationFormat serializationFormat, int seqId, String methodName, Throwable cause) {
        TApplicationException appException = cause instanceof TApplicationException ? (TApplicationException)cause : new TApplicationException(6, "internal server error:" + System.lineSeparator() + "---- BEGIN server-side trace ----" + System.lineSeparator() + Throwables.getStackTraceAsString((Throwable)cause) + "---- END server-side trace ----");
        ByteBuf buf = ctx.alloc().buffer(128);
        boolean success = false;
        try {
            TByteBufTransport transport = new TByteBufTransport(buf);
            TProtocol outProto = ThriftProtocolFactories.get(serializationFormat).getProtocol((TTransport)transport);
            TMessage header = new TMessage(methodName, 3, seqId);
            outProto.writeMessageBegin(header);
            appException.write(outProto);
            outProto.writeMessageEnd();
            ctx.logBuilder().responseContent((Object)reply, (Object)new ThriftReply(header, appException));
            ByteBufHttpData encoded = new ByteBufHttpData(buf, false);
            success = true;
            ByteBufHttpData byteBufHttpData = encoded;
            return byteBufHttpData;
        }
        catch (TException e) {
            throw new Error(e);
        }
        finally {
            if (!success) {
                buf.release();
            }
        }
    }

    private static Map<SerializationFormat, ThreadLocalTProtocol> createFormatToThreadLocalTProtocolMap() {
        return (Map)ThriftSerializationFormats.values().stream().collect(ImmutableMap.toImmutableMap(Function.identity(), f -> new ThreadLocalTProtocol(ThriftProtocolFactories.get(f))));
    }

    private static final class ThreadLocalTProtocol
    extends ThreadLocal<TProtocol> {
        private final TProtocolFactory protoFactory;

        private ThreadLocalTProtocol(TProtocolFactory protoFactory) {
            this.protoFactory = protoFactory;
        }

        @Override
        protected TProtocol initialValue() {
            return this.protoFactory.getProtocol((TTransport)new TMemoryInputTransport());
        }
    }
}

