/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.serverless.workflow.rpc;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.NullNode;
import com.google.protobuf.DescriptorProtos;
import com.google.protobuf.Descriptors;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.Message;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.protobuf.ProtoUtils;
import io.grpc.stub.ClientCalls;
import io.grpc.stub.StreamObserver;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
import org.kie.kogito.internal.process.workitem.KogitoWorkItem;
import org.kie.kogito.jackson.utils.JsonObjectUtils;
import org.kie.kogito.serverless.workflow.WorkflowWorkItemHandler;
import org.kie.kogito.serverless.workflow.rpc.DefaultEnumRpcDecorator;
import org.kie.kogito.serverless.workflow.rpc.FileDescriptorHolder;
import org.kie.kogito.serverless.workflow.rpc.RPCConverterFactory;
import org.kie.kogito.serverless.workflow.rpc.RPCDecorator;

public abstract class RPCWorkItemHandler
extends WorkflowWorkItemHandler {
    public static final String SERVICE_PROP = "serviceName";
    public static final String FILE_PROP = "fileName";
    public static final String METHOD_PROP = "methodName";
    public static final String GRPC_ENUM_DEFAULT_PROPERTY = "kogito.grpc.enum.includeDefault";
    public static final String GRPC_STREAM_TIMEOUT_PROPERTY = "kogito.grpc.stream.timeout";
    public static final boolean GRPC_ENUM_DEFAULT_VALUE = false;
    public static final int GRPC_STREAM_TIMEOUT_VALUE = 20;
    private final Collection<RPCDecorator> decorators = new ArrayList<RPCDecorator>();
    private final int streamTimeout;
    private Map<String, Descriptors.FileDescriptor> fileDescriptors = new ConcurrentHashMap<String, Descriptors.FileDescriptor>();

    public RPCWorkItemHandler() {
        this(false, 20);
    }

    public RPCWorkItemHandler(boolean enumDefault, int streamTimeout) {
        this.streamTimeout = streamTimeout;
        if (enumDefault) {
            this.decorators.add(new DefaultEnumRpcDecorator());
        }
    }

    protected Object internalExecute(KogitoWorkItem workItem, Map<String, Object> parameters) {
        Map metadata = workItem.getNodeInstance().getNode().getMetaData();
        String file = (String)metadata.get(FILE_PROP);
        String service = (String)metadata.get(SERVICE_PROP);
        String method = (String)metadata.get(METHOD_PROP);
        return this.doCall(FileDescriptorHolder.get().descriptor().orElseThrow(() -> new IllegalStateException("Descriptor protobuf/descriptor-sets/output.protobin is not present")), parameters, this.getChannel(file, service), file, service, method);
    }

    protected abstract Channel getChannel(String var1, String var2);

    private JsonNode doCall(DescriptorProtos.FileDescriptorSet fdSet, Map<String, Object> parameters, Channel channel, String fileName, String serviceName, String methodName) {
        Descriptors.FileDescriptor descriptor = this.buildFileDescriptor(fdSet, fileName);
        Descriptors.ServiceDescriptor serviceDesc = Objects.requireNonNull(descriptor.findServiceByName(serviceName), "Cannot find service name " + serviceName);
        Descriptors.MethodDescriptor methodDesc = Objects.requireNonNull(serviceDesc.findMethodByName(methodName), "Cannot find method name " + methodName);
        MethodDescriptor.MethodType methodType = RPCWorkItemHandler.getMethodType(methodDesc);
        ClientCall call = channel.newCall(MethodDescriptor.newBuilder().setType(methodType).setFullMethodName(MethodDescriptor.generateFullMethodName((String)serviceDesc.getFullName(), (String)methodDesc.getName())).setRequestMarshaller(ProtoUtils.marshaller((Message)DynamicMessage.newBuilder((Descriptors.Descriptor)methodDesc.getInputType()).buildPartial())).setResponseMarshaller(ProtoUtils.marshaller((Message)DynamicMessage.newBuilder((Descriptors.Descriptor)methodDesc.getOutputType()).buildPartial())).build(), CallOptions.DEFAULT.withWaitForReady());
        if (methodType == MethodDescriptor.MethodType.CLIENT_STREAMING) {
            return this.asyncStreamingCall(parameters, methodDesc, responseObserver -> ClientCalls.asyncClientStreamingCall((ClientCall)call, (StreamObserver)responseObserver), nodes -> nodes.isEmpty() ? NullNode.instance : (JsonNode)nodes.get(0));
        }
        if (methodType == MethodDescriptor.MethodType.BIDI_STREAMING) {
            return this.asyncStreamingCall(parameters, methodDesc, responseObserver -> ClientCalls.asyncBidiStreamingCall((ClientCall)call, (StreamObserver)responseObserver), JsonObjectUtils::fromValue);
        }
        if (methodType == MethodDescriptor.MethodType.SERVER_STREAMING) {
            ArrayList nodes2 = new ArrayList();
            ClientCalls.blockingServerStreamingCall((ClientCall)call, (Object)RPCConverterFactory.get().buildMessage(parameters, (Message.Builder)DynamicMessage.newBuilder((Descriptors.Descriptor)methodDesc.getInputType())).build()).forEachRemaining(m -> nodes2.add(this.convert((Message)m, methodDesc)));
            return JsonObjectUtils.fromValue(nodes2);
        }
        return this.convert((Message)ClientCalls.blockingUnaryCall((ClientCall)call, (Object)RPCConverterFactory.get().buildMessage(parameters, (Message.Builder)DynamicMessage.newBuilder((Descriptors.Descriptor)methodDesc.getInputType())).build()), methodDesc);
    }

    private Descriptors.FileDescriptor buildFileDescriptor(DescriptorProtos.FileDescriptorSet fdSet, String fileName) {
        return this.fileDescriptors.computeIfAbsent(fileName, name -> {
            DescriptorProtos.FileDescriptorProto fdProto = fdSet.getFileList().stream().filter(f -> f.getName().equals(name)).findFirst().orElseThrow(() -> new IllegalArgumentException("Cannot find file name " + fileName));
            try {
                return Descriptors.FileDescriptor.buildFrom((DescriptorProtos.FileDescriptorProto)fdProto, (Descriptors.FileDescriptor[])((Descriptors.FileDescriptor[])fdProto.getDependencyList().stream().map(fdName -> this.buildFileDescriptor(fdSet, (String)fdName)).toArray(Descriptors.FileDescriptor[]::new)));
            }
            catch (Descriptors.DescriptorValidationException e) {
                throw new IllegalStateException(e);
            }
        });
    }

    private JsonNode convert(Message m, Descriptors.MethodDescriptor descriptor) {
        JsonNode node = RPCConverterFactory.get().getJsonNode(m);
        for (RPCDecorator decorator : this.decorators) {
            node = decorator.decorate(node, descriptor.getOutputType());
        }
        return node;
    }

    private JsonNode asyncStreamingCall(Map<String, Object> parameters, Descriptors.MethodDescriptor methodDesc, UnaryOperator<StreamObserver<Message>> streamObserverFunction, Function<List<JsonNode>, JsonNode> nodesFunction) {
        WaitingStreamObserver responseObserver = new WaitingStreamObserver(this.streamTimeout);
        StreamObserver requestObserver = (StreamObserver)streamObserverFunction.apply(responseObserver);
        for (Object messageParam : Objects.requireNonNull((List)parameters.get("ContentData"), "Missing streaming call parameter")) {
            try {
                Message message = RPCConverterFactory.get().buildMessage(messageParam, (Message.Builder)DynamicMessage.newBuilder((Descriptors.Descriptor)methodDesc.getInputType())).build();
                requestObserver.onNext((Object)message);
            }
            catch (Exception e) {
                requestObserver.onError((Throwable)e);
                throw e;
            }
            responseObserver.checkForServerStreamErrors();
        }
        requestObserver.onCompleted();
        return nodesFunction.apply(responseObserver.get().stream().map(m -> this.convert((Message)m, methodDesc)).collect(Collectors.toList()));
    }

    private static MethodDescriptor.MethodType getMethodType(Descriptors.MethodDescriptor methodDesc) {
        DescriptorProtos.MethodDescriptorProto methodDescProto = methodDesc.toProto();
        if (methodDescProto.getClientStreaming()) {
            if (methodDescProto.getServerStreaming()) {
                return MethodDescriptor.MethodType.BIDI_STREAMING;
            }
            return MethodDescriptor.MethodType.CLIENT_STREAMING;
        }
        if (methodDescProto.getServerStreaming()) {
            return MethodDescriptor.MethodType.SERVER_STREAMING;
        }
        return MethodDescriptor.MethodType.UNARY;
    }

    private static class WaitingStreamObserver
    implements StreamObserver<Message> {
        List<Message> responses = new ArrayList<Message>();
        CompletableFuture<List<Message>> responsesFuture = new CompletableFuture();
        private final int timeout;

        public WaitingStreamObserver(int timeout) {
            this.timeout = timeout;
        }

        public void onNext(Message messageReply) {
            this.responses.add(messageReply);
        }

        public void onError(Throwable throwable) {
            this.responsesFuture.completeExceptionally(throwable);
        }

        public void onCompleted() {
            this.responsesFuture.complete(this.responses);
        }

        public List<Message> get() {
            try {
                return this.responsesFuture.get(this.timeout, TimeUnit.SECONDS);
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new IllegalStateException(e);
            }
            catch (TimeoutException e) {
                throw new IllegalStateException(String.format("gRPC call timed out after %d seconds", this.timeout), e);
            }
            catch (ExecutionException e) {
                throw new IllegalStateException(this.getServerStreamErrorMessage(e.getCause()), e.getCause());
            }
        }

        public void checkForServerStreamErrors() {
            if (this.responsesFuture.isCompletedExceptionally()) {
                try {
                    this.responsesFuture.join();
                }
                catch (CompletionException e) {
                    throw new IllegalStateException(this.getServerStreamErrorMessage(e.getCause()), e.getCause());
                }
            }
        }

        private String getServerStreamErrorMessage(Throwable throwable) {
            return String.format("Received an error through gRPC server stream with status: %s", Status.fromThrowable((Throwable)throwable));
        }
    }
}

