/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.fn.harness.data;

import java.util.Arrays;
import java.util.Collections;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.beam.fn.harness.data.BeamFnDataGrpcClient;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnDataGrpc;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.LengthPrefixCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.BeamFnDataInboundObserver;
import org.apache.beam.sdk.fn.data.BeamFnDataOutboundAggregator;
import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver;
import org.apache.beam.sdk.fn.data.DataEndpoint;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.data.LogicalEndpoint;
import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
import org.apache.beam.sdk.fn.test.TestStreams;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.BindableService;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.ManagedChannel;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.inprocess.InProcessServerBuilder;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.CallStreamObserver;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(value=JUnit4.class)
public class BeamFnDataGrpcClientTest {
    private static final Coder<WindowedValue<String>> CODER = LengthPrefixCoder.of((Coder)WindowedValue.getFullCoder((Coder)StringUtf8Coder.of(), (Coder)GlobalWindow.Coder.INSTANCE));
    private static final String INSTRUCTION_ID_A = "12L";
    private static final String INSTRUCTION_ID_B = "56L";
    private static final String TRANSFORM_ID_A = "34L";
    private static final String TRANSFORM_ID_B = "78L";
    private static final LogicalEndpoint ENDPOINT_A = LogicalEndpoint.data((String)"12L", (String)"34L");
    private static final LogicalEndpoint ENDPOINT_B = LogicalEndpoint.data((String)"56L", (String)"78L");
    private static final BeamFnApi.Elements ELEMENTS_A_1;
    private static final BeamFnApi.Elements ELEMENTS_A_2;
    private static final BeamFnApi.Elements ELEMENTS_B_1;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testForInboundConsumer() throws Exception {
        final CountDownLatch waitForClientToConnect = new CountDownLatch(1);
        ConcurrentLinkedQueue inboundValuesA = new ConcurrentLinkedQueue();
        ConcurrentLinkedQueue inboundValuesB = new ConcurrentLinkedQueue();
        ConcurrentLinkedQueue inboundServerValues = new ConcurrentLinkedQueue();
        final AtomicReference outboundServerObserver = new AtomicReference();
        final CallStreamObserver inboundServerObserver = TestStreams.withOnNext(inboundServerValues::add).build();
        Endpoints.ApiServiceDescriptor apiServiceDescriptor = Endpoints.ApiServiceDescriptor.newBuilder().setUrl(this.getClass().getName() + "-" + UUID.randomUUID()).build();
        Server server = ((InProcessServerBuilder)InProcessServerBuilder.forName((String)apiServiceDescriptor.getUrl()).addService((BindableService)new BeamFnDataGrpc.BeamFnDataImplBase(){

            public StreamObserver<BeamFnApi.Elements> data(StreamObserver<BeamFnApi.Elements> outboundObserver) {
                outboundServerObserver.set(outboundObserver);
                waitForClientToConnect.countDown();
                return inboundServerObserver;
            }
        })).build();
        server.start();
        try {
            ManagedChannel channel = InProcessChannelBuilder.forName((String)apiServiceDescriptor.getUrl()).build();
            BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient(PipelineOptionsFactory.create(), descriptor -> channel, OutboundObserverFactory.trivial());
            DataEndpoint[] dataEndpointArray = new DataEndpoint[1];
            dataEndpointArray[0] = DataEndpoint.create((String)TRANSFORM_ID_A, CODER, inboundValuesA::add);
            BeamFnDataInboundObserver observerA = BeamFnDataInboundObserver.forConsumers(Arrays.asList(dataEndpointArray), Collections.emptyList());
            DataEndpoint[] dataEndpointArray2 = new DataEndpoint[1];
            dataEndpointArray2[0] = DataEndpoint.create((String)TRANSFORM_ID_B, CODER, inboundValuesB::add);
            BeamFnDataInboundObserver observerB = BeamFnDataInboundObserver.forConsumers(Arrays.asList(dataEndpointArray2), Collections.emptyList());
            clientFactory.registerReceiver(INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), (CloseableFnDataReceiver)observerA);
            waitForClientToConnect.await();
            ((StreamObserver)outboundServerObserver.get()).onNext((Object)ELEMENTS_A_1);
            ((StreamObserver)outboundServerObserver.get()).onNext((Object)ELEMENTS_B_1);
            Thread.sleep(100L);
            clientFactory.registerReceiver(INSTRUCTION_ID_B, Arrays.asList(apiServiceDescriptor), (CloseableFnDataReceiver)observerB);
            observerB.awaitCompletion();
            MatcherAssert.assertThat(inboundValuesB, Matchers.contains(WindowedValue.valueInGlobalWindow((Object)"JKL"), WindowedValue.valueInGlobalWindow((Object)"MNO")));
            ((StreamObserver)outboundServerObserver.get()).onNext((Object)ELEMENTS_A_2);
            observerA.awaitCompletion();
            MatcherAssert.assertThat(inboundValuesA, Matchers.contains(WindowedValue.valueInGlobalWindow((Object)"ABC"), WindowedValue.valueInGlobalWindow((Object)"DEF"), WindowedValue.valueInGlobalWindow((Object)"GHI")));
        }
        finally {
            server.shutdownNow();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testForInboundConsumerThatThrows() throws Exception {
        final CountDownLatch waitForClientToConnect = new CountDownLatch(1);
        AtomicInteger consumerInvoked = new AtomicInteger();
        ConcurrentLinkedQueue inboundServerValues = new ConcurrentLinkedQueue();
        final AtomicReference outboundServerObserver = new AtomicReference();
        final CallStreamObserver inboundServerObserver = TestStreams.withOnNext(inboundServerValues::add).build();
        Endpoints.ApiServiceDescriptor apiServiceDescriptor = Endpoints.ApiServiceDescriptor.newBuilder().setUrl(this.getClass().getName() + "-" + UUID.randomUUID()).build();
        Server server = ((InProcessServerBuilder)InProcessServerBuilder.forName((String)apiServiceDescriptor.getUrl()).addService((BindableService)new BeamFnDataGrpc.BeamFnDataImplBase(){

            public StreamObserver<BeamFnApi.Elements> data(StreamObserver<BeamFnApi.Elements> outboundObserver) {
                outboundServerObserver.set(outboundObserver);
                waitForClientToConnect.countDown();
                return inboundServerObserver;
            }
        })).build();
        server.start();
        RuntimeException exceptionToThrow = new RuntimeException("TestFailure");
        try {
            ManagedChannel channel = InProcessChannelBuilder.forName((String)apiServiceDescriptor.getUrl()).build();
            BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient(PipelineOptionsFactory.create(), descriptor -> channel, OutboundObserverFactory.trivial());
            BeamFnDataInboundObserver observer = BeamFnDataInboundObserver.forConsumers(Arrays.asList(DataEndpoint.create((String)TRANSFORM_ID_A, CODER, t -> {
                consumerInvoked.incrementAndGet();
                throw exceptionToThrow;
            })), Collections.emptyList());
            clientFactory.registerReceiver(INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), (CloseableFnDataReceiver)observer);
            waitForClientToConnect.await();
            ((StreamObserver)outboundServerObserver.get()).onNext((Object)ELEMENTS_A_1);
            ((StreamObserver)outboundServerObserver.get()).onNext((Object)ELEMENTS_A_2);
            try {
                observer.awaitCompletion();
                Assert.fail("Expected channel to fail");
            }
            catch (Exception e) {
                Assert.assertEquals(exceptionToThrow, e);
            }
            MatcherAssert.assertThat(inboundServerValues, Matchers.empty());
            Assert.assertEquals(1L, consumerInvoked.get());
        }
        finally {
            server.shutdownNow();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testForInboundConsumerThatIsPoisoned() throws Exception {
        final CountDownLatch waitForClientToConnect = new CountDownLatch(1);
        CountDownLatch receivedAElement = new CountDownLatch(1);
        ConcurrentLinkedQueue inboundValuesA = new ConcurrentLinkedQueue();
        ConcurrentLinkedQueue inboundServerValues = new ConcurrentLinkedQueue();
        final AtomicReference outboundServerObserver = new AtomicReference();
        final CallStreamObserver inboundServerObserver = TestStreams.withOnNext(inboundServerValues::add).build();
        Endpoints.ApiServiceDescriptor apiServiceDescriptor = Endpoints.ApiServiceDescriptor.newBuilder().setUrl(this.getClass().getName() + "-" + UUID.randomUUID()).build();
        Server server = ((InProcessServerBuilder)InProcessServerBuilder.forName((String)apiServiceDescriptor.getUrl()).addService((BindableService)new BeamFnDataGrpc.BeamFnDataImplBase(){

            public StreamObserver<BeamFnApi.Elements> data(StreamObserver<BeamFnApi.Elements> outboundObserver) {
                outboundServerObserver.set(outboundObserver);
                waitForClientToConnect.countDown();
                return inboundServerObserver;
            }
        })).build();
        server.start();
        try {
            ManagedChannel channel = InProcessChannelBuilder.forName((String)apiServiceDescriptor.getUrl()).build();
            BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient(PipelineOptionsFactory.create(), descriptor -> channel, OutboundObserverFactory.trivial());
            BeamFnDataInboundObserver observerA = BeamFnDataInboundObserver.forConsumers(Arrays.asList(DataEndpoint.create((String)TRANSFORM_ID_A, CODER, elem -> {
                receivedAElement.countDown();
                inboundValuesA.add(elem);
            })), Collections.emptyList());
            CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
                try {
                    observerA.awaitCompletion();
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            });
            clientFactory.registerReceiver(INSTRUCTION_ID_A, Arrays.asList(apiServiceDescriptor), (CloseableFnDataReceiver)observerA);
            waitForClientToConnect.await();
            ((StreamObserver)outboundServerObserver.get()).onNext((Object)ELEMENTS_B_1);
            clientFactory.poisonInstructionId(INSTRUCTION_ID_B);
            ((StreamObserver)outboundServerObserver.get()).onNext((Object)ELEMENTS_B_1);
            ((StreamObserver)outboundServerObserver.get()).onNext((Object)ELEMENTS_A_1);
            Assert.assertTrue(receivedAElement.await(5L, TimeUnit.SECONDS));
            clientFactory.poisonInstructionId(INSTRUCTION_ID_A);
            try {
                future.get();
                Assert.fail();
            }
            catch (Exception exception) {
                // empty catch block
            }
            ((StreamObserver)outboundServerObserver.get()).onNext((Object)ELEMENTS_A_2);
            MatcherAssert.assertThat(inboundValuesA, Matchers.contains(WindowedValue.valueInGlobalWindow((Object)"ABC"), WindowedValue.valueInGlobalWindow((Object)"DEF")));
        }
        finally {
            server.shutdownNow();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testForOutboundConsumer() throws Exception {
        CountDownLatch waitForInboundServerValuesCompletion = new CountDownLatch(2);
        ConcurrentLinkedQueue inboundServerValues = new ConcurrentLinkedQueue();
        final CallStreamObserver inboundServerObserver = TestStreams.withOnNext(t -> {
            inboundServerValues.add(t);
            waitForInboundServerValuesCompletion.countDown();
        }).build();
        Endpoints.ApiServiceDescriptor apiServiceDescriptor = Endpoints.ApiServiceDescriptor.newBuilder().setUrl(this.getClass().getName() + "-" + UUID.randomUUID()).build();
        Server server = ((InProcessServerBuilder)InProcessServerBuilder.forName((String)apiServiceDescriptor.getUrl()).addService((BindableService)new BeamFnDataGrpc.BeamFnDataImplBase(){

            public StreamObserver<BeamFnApi.Elements> data(StreamObserver<BeamFnApi.Elements> outboundObserver) {
                return inboundServerObserver;
            }
        })).build();
        server.start();
        try {
            ManagedChannel channel = InProcessChannelBuilder.forName((String)apiServiceDescriptor.getUrl()).build();
            BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient(PipelineOptionsFactory.fromArgs((String[])new String[]{"--experiments=data_buffer_size_limit=20"}).create(), descriptor -> channel, OutboundObserverFactory.trivial());
            BeamFnDataOutboundAggregator aggregator = clientFactory.createOutboundAggregator(apiServiceDescriptor, () -> INSTRUCTION_ID_A, false);
            FnDataReceiver fnDataReceiver = aggregator.registerOutputDataLocation(TRANSFORM_ID_A, CODER);
            fnDataReceiver.accept((Object)WindowedValue.valueInGlobalWindow((Object)"ABC"));
            fnDataReceiver.accept((Object)WindowedValue.valueInGlobalWindow((Object)"DEF"));
            fnDataReceiver.accept((Object)WindowedValue.valueInGlobalWindow((Object)"GHI"));
            aggregator.sendOrCollectBufferedDataAndFinishOutboundStreams();
            waitForInboundServerValuesCompletion.await();
            MatcherAssert.assertThat(inboundServerValues, Matchers.contains(ELEMENTS_A_1, ELEMENTS_A_2));
        }
        finally {
            server.shutdownNow();
        }
    }

    static {
        try {
            ELEMENTS_A_1 = BeamFnApi.Elements.newBuilder().addData(BeamFnApi.Elements.Data.newBuilder().setInstructionId(ENDPOINT_A.getInstructionId()).setTransformId(ENDPOINT_A.getTransformId()).setData(ByteString.copyFrom((byte[])CoderUtils.encodeToByteArray(CODER, (Object)WindowedValue.valueInGlobalWindow((Object)"ABC"))).concat(ByteString.copyFrom((byte[])CoderUtils.encodeToByteArray(CODER, (Object)WindowedValue.valueInGlobalWindow((Object)"DEF")))))).build();
            ELEMENTS_A_2 = BeamFnApi.Elements.newBuilder().addData(BeamFnApi.Elements.Data.newBuilder().setInstructionId(ENDPOINT_A.getInstructionId()).setTransformId(ENDPOINT_A.getTransformId()).setData(ByteString.copyFrom((byte[])CoderUtils.encodeToByteArray(CODER, (Object)WindowedValue.valueInGlobalWindow((Object)"GHI"))))).addData(BeamFnApi.Elements.Data.newBuilder().setInstructionId(ENDPOINT_A.getInstructionId()).setTransformId(ENDPOINT_A.getTransformId()).setIsLast(true)).build();
            ELEMENTS_B_1 = BeamFnApi.Elements.newBuilder().addData(BeamFnApi.Elements.Data.newBuilder().setInstructionId(ENDPOINT_B.getInstructionId()).setTransformId(ENDPOINT_B.getTransformId()).setData(ByteString.copyFrom((byte[])CoderUtils.encodeToByteArray(CODER, (Object)WindowedValue.valueInGlobalWindow((Object)"JKL"))).concat(ByteString.copyFrom((byte[])CoderUtils.encodeToByteArray(CODER, (Object)WindowedValue.valueInGlobalWindow((Object)"MNO")))))).addData(BeamFnApi.Elements.Data.newBuilder().setInstructionId(ENDPOINT_B.getInstructionId()).setTransformId(ENDPOINT_B.getTransformId()).setIsLast(true)).build();
        }
        catch (Exception e) {
            throw new ExceptionInInitializerError(e);
        }
    }
}

