/*
 * Decompiled with CFR 0.152.
 */
package com.linecorp.armeria.server.grpc;

import com.linecorp.armeria.common.ExchangeType;
import com.linecorp.armeria.common.HttpHeaders;
import com.linecorp.armeria.common.HttpHeadersBuilder;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpResponseWriter;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.ResponseHeadersBuilder;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.grpc.GrpcJsonMarshaller;
import com.linecorp.armeria.common.grpc.GrpcSerializationFormats;
import com.linecorp.armeria.common.grpc.GrpcStatusFunction;
import com.linecorp.armeria.common.grpc.protocol.GrpcHeaderNames;
import com.linecorp.armeria.common.logging.RequestLogProperty;
import com.linecorp.armeria.common.util.SafeCloseable;
import com.linecorp.armeria.common.util.TimeoutMode;
import com.linecorp.armeria.internal.common.grpc.GrpcStatus;
import com.linecorp.armeria.internal.common.grpc.MetadataUtil;
import com.linecorp.armeria.internal.common.grpc.TimeoutHeaderUtil;
import com.linecorp.armeria.internal.shaded.guava.base.MoreObjects;
import com.linecorp.armeria.internal.shaded.guava.collect.ImmutableList;
import com.linecorp.armeria.internal.shaded.guava.collect.ImmutableMap;
import com.linecorp.armeria.server.AbstractHttpService;
import com.linecorp.armeria.server.RequestTimeoutException;
import com.linecorp.armeria.server.Route;
import com.linecorp.armeria.server.ServiceConfig;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.grpc.ArmeriaServerCall;
import com.linecorp.armeria.server.grpc.GrpcHealthCheckService;
import com.linecorp.armeria.server.grpc.GrpcRequestUtil;
import com.linecorp.armeria.server.grpc.GrpcService;
import com.linecorp.armeria.server.grpc.GrpcServiceBuilder;
import com.linecorp.armeria.server.grpc.HandlerRegistry;
import com.linecorp.armeria.server.grpc.ProtoReflectionServiceInterceptor;
import io.grpc.Codec;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServiceDescriptor;
import io.grpc.Status;
import io.netty.util.AttributeKey;
import java.time.Duration;
import java.util.AbstractMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

final class FramedGrpcService
extends AbstractHttpService
implements GrpcService {
    private static final Logger logger = LoggerFactory.getLogger(FramedGrpcService.class);
    static final AttributeKey<ServerMethodDefinition<?, ?>> RESOLVED_GRPC_METHOD = AttributeKey.valueOf(FramedGrpcService.class, (String)"RESOLVED_GRPC_METHOD");
    private final HandlerRegistry registry;
    private final Set<Route> routes;
    private final Map<String, ExchangeType> exchangeTypes;
    private final DecompressorRegistry decompressorRegistry;
    private final CompressorRegistry compressorRegistry;
    private final Set<SerializationFormat> supportedSerializationFormats;
    private final Map<String, GrpcJsonMarshaller> jsonMarshallers;
    @Nullable
    private final ProtoReflectionServiceInterceptor protoReflectionServiceInterceptor;
    @Nullable
    private final GrpcStatusFunction statusFunction;
    private final int maxResponseMessageLength;
    private final boolean useBlockingTaskExecutor;
    private final boolean unsafeWrapRequestBuffers;
    private final boolean useClientTimeoutHeader;
    private final String advertisedEncodingsHeader;
    private final Map<SerializationFormat, ResponseHeaders> defaultHeaders;
    @Nullable
    private final GrpcHealthCheckService grpcHealthCheckService;
    private int maxRequestMessageLength;
    private boolean lookupMethodFromAttribute;

    private static Map<String, GrpcJsonMarshaller> getJsonMarshallers(HandlerRegistry registry, Set<SerializationFormat> supportedSerializationFormats, Function<? super ServiceDescriptor, ? extends GrpcJsonMarshaller> jsonMarshallerFactory) {
        if (supportedSerializationFormats.stream().noneMatch(GrpcSerializationFormats::isJson)) {
            return ImmutableMap.of();
        }
        try {
            return (Map)registry.services().stream().map(ServerServiceDefinition::getServiceDescriptor).distinct().collect(ImmutableMap.toImmutableMap(ServiceDescriptor::getName, jsonMarshallerFactory));
        }
        catch (Exception e) {
            logger.warn("Failed to instantiate a JSON marshaller. Consider disabling gRPC-JSON serialization with {}.supportedSerializationFormats() or using {}.ofGson() instead.", new Object[]{GrpcServiceBuilder.class.getName(), GrpcJsonMarshaller.class.getName(), e});
            return ImmutableMap.of();
        }
    }

    FramedGrpcService(HandlerRegistry registry, Set<Route> routes, DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry, Set<SerializationFormat> supportedSerializationFormats, Function<? super ServiceDescriptor, ? extends GrpcJsonMarshaller> jsonMarshallerFactory, @Nullable ProtoReflectionServiceInterceptor protoReflectionServiceInterceptor, @Nullable GrpcStatusFunction statusFunction, int maxRequestMessageLength, int maxResponseMessageLength, boolean useBlockingTaskExecutor, boolean unsafeWrapRequestBuffers, boolean useClientTimeoutHeader, boolean lookupMethodFromAttribute, @Nullable GrpcHealthCheckService grpcHealthCheckService) {
        this.registry = Objects.requireNonNull(registry, "registry");
        this.routes = Objects.requireNonNull(routes, "routes");
        this.exchangeTypes = (Map)registry.methods().entrySet().stream().collect(ImmutableMap.toImmutableMap(e -> '/' + (String)e.getKey(), e -> FramedGrpcService.toExchangeType((ServerMethodDefinition)e.getValue())));
        this.decompressorRegistry = Objects.requireNonNull(decompressorRegistry, "decompressorRegistry");
        this.compressorRegistry = Objects.requireNonNull(compressorRegistry, "compressorRegistry");
        this.supportedSerializationFormats = supportedSerializationFormats;
        this.useClientTimeoutHeader = useClientTimeoutHeader;
        this.jsonMarshallers = FramedGrpcService.getJsonMarshallers(registry, supportedSerializationFormats, jsonMarshallerFactory);
        this.protoReflectionServiceInterceptor = protoReflectionServiceInterceptor;
        this.statusFunction = statusFunction;
        this.maxRequestMessageLength = maxRequestMessageLength;
        this.maxResponseMessageLength = maxResponseMessageLength;
        this.useBlockingTaskExecutor = useBlockingTaskExecutor;
        this.unsafeWrapRequestBuffers = unsafeWrapRequestBuffers;
        this.lookupMethodFromAttribute = lookupMethodFromAttribute;
        this.advertisedEncodingsHeader = String.join((CharSequence)",", decompressorRegistry.getAdvertisedMessageEncodings());
        this.defaultHeaders = (Map)supportedSerializationFormats.stream().map(format -> {
            ResponseHeadersBuilder builder = ResponseHeaders.builder((HttpStatus)HttpStatus.OK).contentType(format.mediaType()).add((CharSequence)GrpcHeaderNames.GRPC_ENCODING, Codec.Identity.NONE.getMessageEncoding());
            if (!this.advertisedEncodingsHeader.isEmpty()) {
                builder.add((CharSequence)GrpcHeaderNames.GRPC_ACCEPT_ENCODING, this.advertisedEncodingsHeader);
            }
            return new AbstractMap.SimpleImmutableEntry<SerializationFormat, ResponseHeaders>((SerializationFormat)format, builder.build());
        }).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
        this.grpcHealthCheckService = grpcHealthCheckService;
    }

    public ExchangeType exchangeType(RequestHeaders headers, Route route) {
        return (ExchangeType)MoreObjects.firstNonNull((Object)this.exchangeTypes.get(headers.path()), (Object)ExchangeType.BIDI_STREAMING);
    }

    private static ExchangeType toExchangeType(ServerMethodDefinition<?, ?> methodDefinition) {
        switch (methodDefinition.getMethodDescriptor().getType()) {
            case UNARY: {
                return ExchangeType.UNARY;
            }
            case CLIENT_STREAMING: {
                return ExchangeType.REQUEST_STREAMING;
            }
            case SERVER_STREAMING: {
                return ExchangeType.RESPONSE_STREAMING;
            }
        }
        return ExchangeType.BIDI_STREAMING;
    }

    protected HttpResponse doPost(ServiceRequestContext ctx, HttpRequest req) throws Exception {
        String timeoutHeader;
        ServerMethodDefinition<?, ?> method;
        MediaType contentType = req.contentType();
        SerializationFormat serializationFormat = this.findSerializationFormat(contentType);
        if (serializationFormat == null) {
            return HttpResponse.of((HttpStatus)HttpStatus.UNSUPPORTED_MEDIA_TYPE, (MediaType)MediaType.PLAIN_TEXT_UTF_8, (String)"Missing or invalid Content-Type header.");
        }
        ctx.logBuilder().serializationFormat(serializationFormat);
        ServerMethodDefinition<?, ?> serverMethodDefinition = method = this.lookupMethodFromAttribute ? (ServerMethodDefinition<?, ?>)ctx.attr(RESOLVED_GRPC_METHOD) : null;
        if (method == null) {
            String methodName = GrpcRequestUtil.determineMethod(ctx);
            if (methodName == null) {
                return HttpResponse.of((HttpStatus)HttpStatus.BAD_REQUEST, (MediaType)MediaType.PLAIN_TEXT_UTF_8, (String)"Invalid path.");
            }
            method = this.registry.lookupMethod(methodName);
            if (method == null) {
                return HttpResponse.of((ResponseHeaders)((ResponseHeaders)ArmeriaServerCall.statusToTrailers(ctx, (HttpHeadersBuilder)this.defaultHeaders.get(serializationFormat).toBuilder(), Status.UNIMPLEMENTED.withDescription("Method not found: " + methodName), new Metadata())));
            }
        }
        if (this.useClientTimeoutHeader && (timeoutHeader = req.headers().get((CharSequence)GrpcHeaderNames.GRPC_TIMEOUT)) != null) {
            try {
                long timeout = TimeoutHeaderUtil.fromHeaderValue(timeoutHeader);
                if (timeout == 0L) {
                    ctx.clearRequestTimeout();
                } else {
                    ctx.setRequestTimeout(TimeoutMode.SET_FROM_NOW, Duration.ofNanos(timeout));
                }
            }
            catch (IllegalArgumentException e) {
                Metadata metadata = new Metadata();
                return HttpResponse.of((ResponseHeaders)((ResponseHeaders)ArmeriaServerCall.statusToTrailers(ctx, (HttpHeadersBuilder)this.defaultHeaders.get(serializationFormat).toBuilder(), GrpcStatus.fromThrowable(this.statusFunction, (RequestContext)ctx, e, metadata), metadata)));
            }
        }
        ctx.logBuilder().defer(new RequestLogProperty[]{RequestLogProperty.REQUEST_CONTENT, RequestLogProperty.RESPONSE_CONTENT});
        HttpResponseWriter res = HttpResponse.streaming();
        ArmeriaServerCall<?, ?> call = this.startCall(this.registry.simpleMethodName(method.getMethodDescriptor()), method, ctx, req, res, serializationFormat);
        if (call != null) {
            ctx.whenRequestCancelling().handle((cancellationCause, unused) -> {
                Status status = Status.CANCELLED.withCause(cancellationCause);
                if (cancellationCause instanceof RequestTimeoutException) {
                    status = status.withDescription("Request timed out");
                }
                call.close(status, new Metadata());
                return null;
            });
            call.startDeframing();
        }
        return res;
    }

    @Nullable
    private <I, O> ArmeriaServerCall<I, O> startCall(String simpleMethodName, ServerMethodDefinition<I, O> methodDef, ServiceRequestContext ctx, HttpRequest req, HttpResponseWriter res, SerializationFormat serializationFormat) {
        ServerCall.Listener listener;
        MethodDescriptor methodDescriptor = methodDef.getMethodDescriptor();
        ArmeriaServerCall call = new ArmeriaServerCall(req, methodDescriptor, simpleMethodName, this.compressorRegistry, this.decompressorRegistry, res, this.maxRequestMessageLength, this.maxResponseMessageLength, ctx, serializationFormat, this.jsonMarshallers.get(methodDescriptor.getServiceName()), this.unsafeWrapRequestBuffers, this.useBlockingTaskExecutor, this.defaultHeaders.get(serializationFormat), this.statusFunction);
        try (SafeCloseable ignored = ctx.push();){
            listener = methodDef.getServerCallHandler().startCall(call, MetadataUtil.copyFromHeaders((HttpHeaders)req.headers()));
        }
        catch (Throwable t) {
            call.setListener(new EmptyListener());
            Metadata metadata = new Metadata();
            call.close(GrpcStatus.fromThrowable(this.statusFunction, (RequestContext)ctx, t, metadata), metadata);
            logger.warn("Exception thrown from streaming request stub method before processing any request data - this is likely a bug in the stub implementation.", t);
            return null;
        }
        if (listener == null) {
            throw new NullPointerException("startCall() returned a null listener for method " + methodDescriptor.getFullMethodName());
        }
        call.setListener(listener);
        return call;
    }

    public void serviceAdded(ServiceConfig cfg) {
        if (this.maxRequestMessageLength == -1) {
            this.maxRequestMessageLength = (int)Math.min(cfg.maxRequestLength(), Integer.MAX_VALUE);
        }
        if (this.protoReflectionServiceInterceptor != null) {
            Map grpcServices = (Map)cfg.server().config().virtualHosts().stream().flatMap(host -> host.serviceConfigs().stream()).map(serviceConfig -> (FramedGrpcService)serviceConfig.service().as(FramedGrpcService.class)).filter(Objects::nonNull).flatMap(service -> service.services().stream()).collect(ImmutableMap.toImmutableMap(def -> def.getServiceDescriptor().getName(), Function.identity(), (a, b) -> a));
            this.protoReflectionServiceInterceptor.setServer(FramedGrpcService.newDummyServer(grpcServices));
        }
        if (this.grpcHealthCheckService != null) {
            this.grpcHealthCheckService.serviceAdded(cfg);
        }
    }

    private static Server newDummyServer(final Map<String, ServerServiceDefinition> grpcServices) {
        return new Server(){

            public Server start() {
                throw new UnsupportedOperationException();
            }

            public List<ServerServiceDefinition> getServices() {
                return ImmutableList.copyOf(grpcServices.values());
            }

            public List<ServerServiceDefinition> getImmutableServices() {
                return this.getServices();
            }

            public List<ServerServiceDefinition> getMutableServices() {
                return ImmutableList.of();
            }

            public Server shutdown() {
                throw new UnsupportedOperationException();
            }

            public Server shutdownNow() {
                throw new UnsupportedOperationException();
            }

            public boolean isShutdown() {
                throw new UnsupportedOperationException();
            }

            public boolean isTerminated() {
                throw new UnsupportedOperationException();
            }

            public boolean awaitTermination(long timeout, TimeUnit unit) {
                throw new UnsupportedOperationException();
            }

            public void awaitTermination() {
                throw new UnsupportedOperationException();
            }
        };
    }

    @Override
    public boolean isFramed() {
        return true;
    }

    @Override
    public List<ServerServiceDefinition> services() {
        List<ServerServiceDefinition> services = this.registry.services();
        assert (services instanceof ImmutableList);
        return services;
    }

    @Override
    public Map<String, ServerMethodDefinition<?, ?>> methods() {
        Map<String, ServerMethodDefinition<?, ?>> methods = this.registry.methods();
        assert (methods instanceof ImmutableMap);
        return methods;
    }

    @Override
    public Set<SerializationFormat> supportedSerializationFormats() {
        return this.supportedSerializationFormats;
    }

    @Nullable
    private SerializationFormat findSerializationFormat(@Nullable MediaType contentType) {
        if (contentType == null) {
            return null;
        }
        for (SerializationFormat format : this.supportedSerializationFormats) {
            if (!format.isAccepted(contentType)) continue;
            return format;
        }
        return null;
    }

    public Set<Route> routes() {
        return this.routes;
    }

    private static class EmptyListener<T>
    extends ServerCall.Listener<T> {
        private EmptyListener() {
        }
    }
}

