/*
 * Decompiled with CFR 0.152.
 */
package org.wiremock.grpc.internal;

import com.github.tomakehurst.wiremock.common.Exceptions;
import com.github.tomakehurst.wiremock.common.Pair;
import com.github.tomakehurst.wiremock.http.StubRequestHandler;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.wiremock.grpc.internal.ClientStreamingServerCallHandler;
import org.wiremock.grpc.internal.HeaderCopyingServerInterceptor;
import org.wiremock.grpc.internal.JsonMessageConverter;
import org.wiremock.grpc.internal.UnaryServerCallHandler;
import wiremock.com.google.protobuf.Descriptors;
import wiremock.com.google.protobuf.DynamicMessage;
import wiremock.com.google.protobuf.TypeRegistry;
import wiremock.grpc.io.grpc.BindableService;
import wiremock.grpc.io.grpc.MethodDescriptor;
import wiremock.grpc.io.grpc.ServerCallHandler;
import wiremock.grpc.io.grpc.ServerInterceptor;
import wiremock.grpc.io.grpc.ServerInterceptors;
import wiremock.grpc.io.grpc.ServerServiceDefinition;
import wiremock.grpc.io.grpc.ServiceDescriptor;
import wiremock.grpc.io.grpc.protobuf.ProtoServiceDescriptorSupplier;
import wiremock.grpc.io.grpc.protobuf.ProtoUtils;
import wiremock.grpc.io.grpc.protobuf.services.ProtoReflectionServiceV1;
import wiremock.grpc.io.grpc.servlet.jakarta.ServletAdapter;
import wiremock.grpc.io.grpc.servlet.jakarta.ServletServerBuilder;
import wiremock.grpc.io.grpc.stub.ServerCalls;
import wiremock.jakarta.servlet.FilterChain;
import wiremock.jakarta.servlet.ServletException;
import wiremock.jakarta.servlet.http.HttpFilter;
import wiremock.jakarta.servlet.http.HttpServletRequest;
import wiremock.jakarta.servlet.http.HttpServletResponse;

public class GrpcFilter
extends HttpFilter {
    private ServletAdapter servletAdapter;
    private final StubRequestHandler stubRequestHandler;

    public GrpcFilter(StubRequestHandler stubRequestHandler) {
        this.stubRequestHandler = stubRequestHandler;
    }

    public void loadFileDescriptors(List<Descriptors.FileDescriptor> fileDescriptors) {
        this.loadFileDescriptors(fileDescriptors, Collections.emptyList());
    }

    public void loadFileDescriptors(List<Descriptors.FileDescriptor> fileDescriptors, List<ServerInterceptor> interceptors) {
        List<BindableService> services = this.buildServices(fileDescriptors);
        this.servletAdapter = GrpcFilter.loadServices(services, interceptors);
    }

    private static ServletAdapter loadServices(List<? extends BindableService> bindableServices, List<ServerInterceptor> interceptors) {
        HeaderCopyingServerInterceptor headerCopyingServerInterceptor = new HeaderCopyingServerInterceptor();
        ServletServerBuilder serverBuilder = new ServletServerBuilder();
        bindableServices.forEach(service -> serverBuilder.addService(ServerInterceptors.intercept(ServerInterceptors.intercept(service, headerCopyingServerInterceptor), (List<? extends ServerInterceptor>)interceptors)));
        return serverBuilder.buildServletAdapter();
    }

    private List<BindableService> buildServices(List<Descriptors.FileDescriptor> fileDescriptors) {
        TypeRegistry.Builder typeRegistryBuilder = TypeRegistry.newBuilder();
        fileDescriptors.forEach(fileDescriptor -> fileDescriptor.getMessageTypes().forEach(typeRegistryBuilder::add));
        TypeRegistry typeRegistry = typeRegistryBuilder.build();
        JsonMessageConverter jsonMessageConverter = new JsonMessageConverter(typeRegistry);
        Stream<BindableService> servicesFromDescriptors = fileDescriptors.stream().flatMap(fileDescriptor -> fileDescriptor.getServices().stream().map(service -> Pair.pair(fileDescriptor, service))).map(fileAndServiceDescriptor -> () -> {
            final Descriptors.FileDescriptor fileDescriptor = (Descriptors.FileDescriptor)fileAndServiceDescriptor.a;
            final Descriptors.ServiceDescriptor serviceDescriptor = (Descriptors.ServiceDescriptor)fileAndServiceDescriptor.b;
            ServiceDescriptor.Builder serviceDescriptorBuilder = ServiceDescriptor.newBuilder(serviceDescriptor.getFullName()).setSchemaDescriptor(new ProtoServiceDescriptorSupplier(){

                @Override
                public Descriptors.FileDescriptor getFileDescriptor() {
                    return fileDescriptor;
                }

                @Override
                public Descriptors.ServiceDescriptor getServiceDescriptor() {
                    return serviceDescriptor;
                }
            });
            List<Pair> methodDescriptorHandlerPairs = serviceDescriptor.getMethods().stream().map(methodDescriptor -> Pair.pair(GrpcFilter.buildMessageDescriptorInstance(serviceDescriptor, methodDescriptor), this.buildHandler(serviceDescriptor, (Descriptors.MethodDescriptor)methodDescriptor, jsonMessageConverter))).collect(Collectors.toList());
            methodDescriptorHandlerPairs.stream().map(pair -> (MethodDescriptor)pair.a).forEach(serviceDescriptorBuilder::addMethod);
            ServerServiceDefinition.Builder builder = ServerServiceDefinition.builder(serviceDescriptorBuilder.build());
            methodDescriptorHandlerPairs.forEach(pair -> builder.addMethod((MethodDescriptor)pair.a, (ServerCallHandler)pair.b));
            return builder.build();
        });
        BindableService reflectionService = ProtoReflectionServiceV1.newInstance();
        return Stream.concat(servicesFromDescriptors, Stream.of(reflectionService)).collect(Collectors.toUnmodifiableList());
    }

    private ServerCallHandler<DynamicMessage, DynamicMessage> buildHandler(Descriptors.ServiceDescriptor serviceDescriptor, Descriptors.MethodDescriptor methodDescriptor, JsonMessageConverter jsonMessageConverter) {
        return methodDescriptor.isClientStreaming() ? ServerCalls.asyncClientStreamingCall(new ClientStreamingServerCallHandler(this.stubRequestHandler, serviceDescriptor, methodDescriptor, jsonMessageConverter)) : ServerCalls.asyncUnaryCall(new UnaryServerCallHandler(this.stubRequestHandler, serviceDescriptor, methodDescriptor, jsonMessageConverter));
    }

    private static MethodDescriptor<DynamicMessage, DynamicMessage> buildMessageDescriptorInstance(Descriptors.ServiceDescriptor serviceDescriptor, Descriptors.MethodDescriptor methodDescriptor) {
        return MethodDescriptor.newBuilder().setType(GrpcFilter.getMethodTypeFromDesc(methodDescriptor)).setFullMethodName(MethodDescriptor.generateFullMethodName(serviceDescriptor.getFullName(), methodDescriptor.getName())).setRequestMarshaller(ProtoUtils.marshaller(DynamicMessage.getDefaultInstance(methodDescriptor.getInputType()))).setResponseMarshaller(ProtoUtils.marshaller(DynamicMessage.getDefaultInstance(methodDescriptor.getOutputType()))).build();
    }

    private static MethodDescriptor.MethodType getMethodTypeFromDesc(Descriptors.MethodDescriptor methodDesc) {
        if (!methodDesc.isServerStreaming() && !methodDesc.isClientStreaming()) {
            return MethodDescriptor.MethodType.UNARY;
        }
        if (methodDesc.isServerStreaming() && !methodDesc.isClientStreaming()) {
            return MethodDescriptor.MethodType.SERVER_STREAMING;
        }
        if (!methodDesc.isServerStreaming()) {
            return MethodDescriptor.MethodType.CLIENT_STREAMING;
        }
        return MethodDescriptor.MethodType.BIDI_STREAMING;
    }

    @Override
    protected void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain) throws IOException, ServletException {
        if (!ServletAdapter.isGrpc(request)) {
            chain.doFilter(request, response);
            return;
        }
        ServerAddress.set(request.getScheme(), request.getLocalAddr(), request.getLocalPort());
        String method = request.getMethod();
        if (GrpcFilter.isPost(method)) {
            this.servletAdapter.doPost(request, response);
        } else if (GrpcFilter.isGet(method)) {
            this.servletAdapter.doGet(request, response);
        }
    }

    @Override
    public void destroy() {
        this.servletAdapter.destroy();
    }

    private static boolean isGet(String method) {
        return method.equalsIgnoreCase("GET");
    }

    private static boolean isPost(String method) {
        return method.equalsIgnoreCase("POST");
    }

    public static class ServerAddress {
        private static final CompletableFuture<ServerAddress> instance = new CompletableFuture();
        final String scheme;
        final String hostname;
        final int port;

        public static void set(String scheme, String hostname, int port) {
            instance.complete(new ServerAddress(scheme, hostname, port));
        }

        public static ServerAddress get() {
            return Exceptions.uncheck(() -> instance.get(5L, TimeUnit.SECONDS), ServerAddress.class);
        }

        public ServerAddress(String scheme, String hostname, int port) {
            this.scheme = scheme;
            this.hostname = hostname;
            this.port = port;
        }
    }
}

