/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.fn.harness.control;

import java.io.IOException;
import java.io.OutputStream;
import java.io.Serializable;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.BeamFnDataReadRunner;
import org.apache.beam.fn.harness.Cache;
import org.apache.beam.fn.harness.Caches;
import org.apache.beam.fn.harness.PTransformRunnerFactory;
import org.apache.beam.fn.harness.control.BundleProgressReporter;
import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.control.ExecutionStateSampler;
import org.apache.beam.fn.harness.control.FinalizeBundleHandler;
import org.apache.beam.fn.harness.control.ProcessBundleHandler;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry;
import org.apache.beam.fn.harness.data.PTransformFunctionRegistry;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.metrics.ShortIdMap;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.BeamFnDataOutboundAggregator;
import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver;
import org.apache.beam.sdk.fn.data.DataEndpoint;
import org.apache.beam.sdk.fn.data.TimerEndpoint;
import org.apache.beam.sdk.fn.test.TestExecutors;
import org.apache.beam.sdk.function.ThrowingRunnable;
import org.apache.beam.sdk.metrics.MetricsEnvironment;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.TimerMap;
import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.state.TimerSpecs;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.ByteStringOutputStream;
import org.apache.beam.sdk.util.DoFnWithExecutionInformation;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.util.construction.BeamUrns;
import org.apache.beam.sdk.util.construction.CoderTranslation;
import org.apache.beam.sdk.util.construction.ModelCoders;
import org.apache.beam.sdk.util.construction.Timer;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ProtocolMessageEnum;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.hamcrest.collection.IsEmptyCollection;
import org.joda.time.Instant;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;

@RunWith(value=JUnit4.class)
public class ProcessBundleHandlerTest {
    private static final String DATA_INPUT_URN = "beam:runner:source:v1";
    private static final String DATA_OUTPUT_URN = "beam:runner:sink:v1";
    @Rule
    public TestExecutors.TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool);
    @Mock
    private BeamFnDataClient beamFnDataClient;
    private ExecutionStateSampler executionStateSampler;

    @Before
    public void setUp() {
        MockitoAnnotations.initMocks(this);
        TestBundleProcessor.resetCnt = 0;
        this.executionStateSampler = new ExecutionStateSampler(PipelineOptionsFactory.create(), System::currentTimeMillis);
    }

    @After
    public void tearDown() {
        this.executionStateSampler.stop();
    }

    @Test
    public void testTrySplitBeforeBundleDoesNotFail() {
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), null, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of(), Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
        BeamFnApi.InstructionResponse response = handler.trySplit(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("999L").setProcessBundleSplit(BeamFnApi.ProcessBundleSplitRequest.newBuilder().setInstructionId("unknown-id")).build()).build();
        Assert.assertNotNull(response.getProcessBundleSplit());
        Assert.assertEquals(0L, response.getProcessBundleSplit().getChannelSplitsCount());
    }

    @Test
    public void testProgressBeforeBundleDoesNotFail() throws Exception {
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), null, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of(), Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
        handler.progress(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("999L").setProcessBundleProgress(BeamFnApi.ProcessBundleProgressRequest.newBuilder().setInstructionId("unknown-id")).build());
        BeamFnApi.InstructionResponse response = handler.trySplit(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("999L").setProcessBundleSplit(BeamFnApi.ProcessBundleSplitRequest.newBuilder().setInstructionId("unknown-id")).build()).build();
        Assert.assertNotNull(response.getProcessBundleProgress());
        Assert.assertEquals(0L, response.getProcessBundleProgress().getMonitoringInfosCount());
    }

    @Test
    public void testOrderOfStartAndFinishCalls() throws Exception {
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).putOutputs("2L-output", "2L-output-pc").build()).putTransforms("3L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN).build()).putInputs("3L-input", "2L-output-pc").build()).putPcollections("2L-output-pc", RunnerApi.PCollection.getDefaultInstance()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        ArrayList transformsProcessed = new ArrayList();
        ArrayList orderOfOperations = new ArrayList();
        PTransformRunnerFactory startFinishRecorder = context -> {
            String pTransformId = context.getPTransformId();
            transformsProcessed.add(context.getPTransform());
            Supplier processBundleInstructionId = context.getProcessBundleInstructionIdSupplier();
            context.addStartBundleFunction(() -> ProcessBundleHandlerTest.lambda$testOrderOfStartAndFinishCalls$0((Supplier)processBundleInstructionId, orderOfOperations, pTransformId));
            context.addFinishBundleFunction(() -> ProcessBundleHandlerTest.lambda$testOrderOfStartAndFinishCalls$1((Supplier)processBundleInstructionId, orderOfOperations, pTransformId));
            return null;
        };
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), ((Map)fnApiRegistry)::get, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of((Object)DATA_INPUT_URN, (Object)startFinishRecorder, (Object)DATA_OUTPUT_URN, (Object)startFinishRecorder), Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
        handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("999L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build());
        MatcherAssert.assertThat(transformsProcessed, Matchers.contains((RunnerApi.PTransform)processBundleDescriptor.getTransformsMap().get("3L"), (RunnerApi.PTransform)processBundleDescriptor.getTransformsMap().get("2L")));
        MatcherAssert.assertThat(orderOfOperations, Matchers.contains("Start3L", "Start2L", "Finish2L", "Finish3L"));
    }

    @Test
    public void testOrderOfSetupTeardownCalls() throws Exception {
        DoFnWithExecutionInformation doFnWithExecutionInformation = DoFnWithExecutionInformation.of((DoFn)new TestDoFn(), (TupleTag)TestDoFn.mainOutput, Collections.emptyMap(), (DoFnSchemaInformation)DoFnSchemaInformation.create());
        RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder().setUrn("beam:dofn:javasdk:0.1").setPayload(ByteString.copyFrom((byte[])SerializableUtils.serializeToByteArray((Serializable)doFnWithExecutionInformation))).build();
        RunnerApi.ParDoPayload parDoPayload = RunnerApi.ParDoPayload.newBuilder().setDoFn(functionSpec).build();
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).putOutputs("2L-output", "2L-output-pc").build()).putTransforms("3L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("beam:transform:pardo:v1").setPayload(parDoPayload.toByteString())).putInputs("3L-input", "2L-output-pc").build()).putPcollections("2L-output-pc", RunnerApi.PCollection.newBuilder().setWindowingStrategyId("window-strategy").setCoderId("2L-output-coder").setIsBounded(RunnerApi.IsBounded.Enum.BOUNDED).build()).putWindowingStrategies("window-strategy", RunnerApi.WindowingStrategy.newBuilder().setWindowCoderId("window-strategy-coder").setWindowFn(RunnerApi.FunctionSpec.newBuilder().setUrn("beam:window_fn:global_windows:v1")).setOutputTime(RunnerApi.OutputTime.Enum.END_OF_WINDOW).setAccumulationMode(RunnerApi.AccumulationMode.Enum.ACCUMULATING).setTrigger(RunnerApi.Trigger.newBuilder().setAlways(RunnerApi.Trigger.Always.getDefaultInstance())).setClosingBehavior(RunnerApi.ClosingBehavior.Enum.EMIT_ALWAYS).setOnTimeBehavior(RunnerApi.OnTimeBehavior.Enum.FIRE_ALWAYS).build()).putCoders("2L-output-coder", CoderTranslation.toProto((Coder)StringUtf8Coder.of()).getCoder()).putCoders("window-strategy-coder", RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(ModelCoders.GLOBAL_WINDOW_CODER_URN).build()).build()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        HashMap urnToPTransformRunnerFactoryMap = Maps.newHashMap((Map)ProcessBundleHandler.REGISTERED_RUNNER_FACTORIES);
        urnToPTransformRunnerFactoryMap.put(DATA_INPUT_URN, context -> null);
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), ((Map)fnApiRegistry)::get, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)urnToPTransformRunnerFactoryMap, Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
        handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("998L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build());
        handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("999L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build());
        handler.shutdown();
        MatcherAssert.assertThat(TestDoFn.orderOfOperations, Matchers.contains("setUp", "startBundle", "finishBundle", "startBundle", "finishBundle", "tearDown"));
    }

    @Test
    public void testBundleProcessorIsResetWhenAddedBackToCache() throws Exception {
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).build()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), ((Map)fnApiRegistry)::get, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of((Object)DATA_INPUT_URN, context -> null), Caches.noop(), (ProcessBundleHandler.BundleProcessorCache)new TestBundleProcessorCache(), null);
        MatcherAssert.assertThat(TestBundleProcessor.resetCnt, Matchers.equalTo(0));
        handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("998L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build());
        MatcherAssert.assertThat(TestBundleProcessor.resetCnt, Matchers.equalTo(1));
        MatcherAssert.assertThat(handler.bundleProcessorCache.getCachedBundleProcessors().size(), Matchers.equalTo(1));
        MatcherAssert.assertThat(((ConcurrentLinkedQueue)handler.bundleProcessorCache.getCachedBundleProcessors().get("1L")).size(), Matchers.equalTo(1));
        ((ProcessBundleHandler.BundleProcessor)Iterables.getOnlyElement((Iterable)((Iterable)handler.bundleProcessorCache.getCachedBundleProcessors().get("1L")))).getResetFunctions().add(() -> {
            throw new IllegalStateException("ResetFailed");
        });
        handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("999L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build());
        MatcherAssert.assertThat(((ConcurrentLinkedQueue)handler.bundleProcessorCache.getCachedBundleProcessors().get("1L")).size(), Matchers.equalTo(0));
    }

    private static BeamFnApi.InstructionRequest processBundleRequestFor(String instructionId, String bundleDescriptorId, BeamFnApi.ProcessBundleRequest.CacheToken ... cacheTokens) {
        return BeamFnApi.InstructionRequest.newBuilder().setInstructionId(instructionId).setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId(bundleDescriptorId).addAllCacheTokens(Arrays.asList(cacheTokens))).build();
    }

    @Test
    public void testBundleProcessorIsFoundWhenActive() {
        ProcessBundleHandler.BundleProcessor bundleProcessor = Mockito.mock(ProcessBundleHandler.BundleProcessor.class);
        Mockito.when(bundleProcessor.getInstructionId()).thenReturn("known");
        ProcessBundleHandler.BundleProcessorCache cache = new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO);
        Assert.assertNull(cache.find("unknown"));
        cache.get(ProcessBundleHandlerTest.processBundleRequestFor("known", "descriptorId", new BeamFnApi.ProcessBundleRequest.CacheToken[0]), () -> bundleProcessor);
        Assert.assertSame(bundleProcessor, cache.find("known"));
        cache.release("descriptorId", bundleProcessor);
        Assert.assertNull(cache.find("known"));
        cache.get(ProcessBundleHandlerTest.processBundleRequestFor("known", "descriptorId", new BeamFnApi.ProcessBundleRequest.CacheToken[0]), () -> bundleProcessor);
        Assert.assertSame(bundleProcessor, cache.find("known"));
        cache.discard(bundleProcessor);
        Mockito.verify(bundleProcessor).discard();
        Assert.assertNull(cache.find("known"));
    }

    @Test
    public void testBundleProcessorReset() throws Exception {
        PTransformFunctionRegistry startFunctionRegistry = Mockito.mock(PTransformFunctionRegistry.class);
        PTransformFunctionRegistry finishFunctionRegistry = Mockito.mock(PTransformFunctionRegistry.class);
        BundleSplitListener.InMemory splitListener = Mockito.mock(BundleSplitListener.InMemory.class);
        Collection bundleFinalizationCallbacks = Mockito.mock(Collection.class);
        PCollectionConsumerRegistry pCollectionConsumerRegistry = Mockito.mock(PCollectionConsumerRegistry.class);
        ExecutionStateSampler.ExecutionStateTracker stateTracker = Mockito.mock(ExecutionStateSampler.ExecutionStateTracker.class);
        ProcessBundleHandler.HandleStateCallsForBundle beamFnStateClient = Mockito.mock(ProcessBundleHandler.HandleStateCallsForBundle.class);
        ThrowingRunnable resetFunction = Mockito.mock(ThrowingRunnable.class);
        Cache processWideCache = Caches.eternal();
        ProcessBundleHandler.BundleProcessor bundleProcessor = ProcessBundleHandler.BundleProcessor.create((Cache)processWideCache, (BundleProgressReporter.InMemory)new BundleProgressReporter.InMemory(), (BeamFnApi.ProcessBundleDescriptor)BeamFnApi.ProcessBundleDescriptor.getDefaultInstance(), (PTransformFunctionRegistry)startFunctionRegistry, (PTransformFunctionRegistry)finishFunctionRegistry, Collections.singletonList(resetFunction), new ArrayList(), (BundleSplitListener.InMemory)splitListener, (PCollectionConsumerRegistry)pCollectionConsumerRegistry, (ProcessBundleHandler.MetricsEnvironmentStateForBundle)new ProcessBundleHandler.MetricsEnvironmentStateForBundle(), (ExecutionStateSampler.ExecutionStateTracker)stateTracker, (ProcessBundleHandler.HandleStateCallsForBundle)beamFnStateClient, (Collection)bundleFinalizationCallbacks, new HashSet());
        bundleProcessor.finish();
        BeamFnApi.ProcessBundleRequest.CacheToken cacheToken = BeamFnApi.ProcessBundleRequest.CacheToken.newBuilder().setSideInput(BeamFnApi.ProcessBundleRequest.CacheToken.SideInput.newBuilder().setTransformId("transformId")).build();
        bundleProcessor.setupForProcessBundleRequest(ProcessBundleHandlerTest.processBundleRequestFor("instructionId", "descriptorId", cacheToken));
        Assert.assertEquals("instructionId", bundleProcessor.getInstructionId());
        MatcherAssert.assertThat(bundleProcessor.getCacheTokens(), Matchers.containsInAnyOrder(cacheToken));
        Cache bundleCache = bundleProcessor.getBundleCache();
        bundleCache.put((Object)"A", (Object)"B");
        Assert.assertEquals("B", bundleCache.peek((Object)"A"));
        Assert.assertTrue(bundleProcessor.getProgressRequestLock().tryLock());
        bundleProcessor.reset();
        Assert.assertNull(bundleProcessor.getInstructionId());
        Assert.assertNull(bundleProcessor.getCacheTokens());
        Assert.assertNull(bundleCache.peek((Object)"A"));
        Mockito.verify(splitListener, Mockito.times(1)).clear();
        Mockito.verify(stateTracker, Mockito.times(1)).reset();
        Mockito.verify(bundleFinalizationCallbacks, Mockito.times(1)).clear();
        Mockito.verify(resetFunction, Mockito.times(1)).run();
        Assert.assertNull(MetricsEnvironment.getCurrentContainer());
        bundleProcessor.setupForProcessBundleRequest(ProcessBundleHandlerTest.processBundleRequestFor("instructionId2", "descriptorId2", new BeamFnApi.ProcessBundleRequest.CacheToken[0]));
        Assert.assertNotSame(bundleCache, bundleProcessor.getBundleCache());
        Assert.assertEquals("instructionId2", bundleProcessor.getInstructionId());
        MatcherAssert.assertThat(bundleProcessor.getCacheTokens(), Matchers.is(Matchers.emptyIterable()));
    }

    @Test
    public void testCreatingPTransformExceptionsArePropagated() throws Exception {
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).build()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), ((Map)fnApiRegistry)::get, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of((Object)DATA_INPUT_URN, context -> {
            throw new IllegalStateException("TestException");
        }), Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
        Assert.assertThrows("TestException", IllegalStateException.class, () -> handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build()));
    }

    @Test
    public void testBundleFinalizationIsPropagated() throws Exception {
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).build()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        FinalizeBundleHandler mockFinalizeBundleHandler = Mockito.mock(FinalizeBundleHandler.class);
        DoFn.BundleFinalizer.Callback mockCallback = Mockito.mock(DoFn.BundleFinalizer.Callback.class);
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), ((Map)fnApiRegistry)::get, this.beamFnDataClient, null, mockFinalizeBundleHandler, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of((Object)DATA_INPUT_URN, context -> {
            DoFn.BundleFinalizer bundleFinalizer = context.getBundleFinalizer();
            context.addStartBundleFunction(() -> bundleFinalizer.afterBundleCommit(Instant.ofEpochMilli((long)42L), mockCallback));
            return null;
        }), Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
        BeamFnApi.InstructionResponse.Builder response = handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("2L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build());
        Assert.assertTrue(response.getProcessBundle().getRequiresFinalization());
        Mockito.verify(mockFinalizeBundleHandler).registerCallbacks(Mockito.eq("2L"), Mockito.argThat(arg -> {
            FinalizeBundleHandler.CallbackRegistration registration = (FinalizeBundleHandler.CallbackRegistration)Iterables.getOnlyElement((Iterable)arg);
            Assert.assertEquals(Instant.ofEpochMilli((long)42L), registration.getExpiryTime());
            Assert.assertSame(mockCallback, registration.getCallback());
            return true;
        }));
    }

    @Test
    public void testPTransformStartExceptionsArePropagated() {
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).build()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), ((Map)fnApiRegistry)::get, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of((Object)DATA_INPUT_URN, context -> {
            context.addStartBundleFunction(ProcessBundleHandlerTest::throwException);
            return null;
        }), Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
        Assert.assertThrows("TestException", IllegalStateException.class, () -> handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build()));
        MatcherAssert.assertThat((ConcurrentLinkedQueue)handler.bundleProcessorCache.getCachedBundleProcessors().get("1L"), IsEmptyCollection.empty());
    }

    private ProcessBundleHandler setupProcessBundleHandlerForSimpleRecordingDoFn(List<String> dataOutput, final List<BeamFnApi.Elements.Timers> timerOutput, boolean enableOutputEmbedding) throws Exception {
        DoFnWithExecutionInformation doFnWithExecutionInformation = DoFnWithExecutionInformation.of((DoFn)new SimpleDoFn(), (TupleTag)SimpleDoFn.MAIN_OUTPUT_TAG, Collections.emptyMap(), (DoFnSchemaInformation)DoFnSchemaInformation.create());
        RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder().setUrn("beam:dofn:javasdk:0.1").setPayload(ByteString.copyFrom((byte[])SerializableUtils.serializeToByteArray((Serializable)doFnWithExecutionInformation))).build();
        RunnerApi.ParDoPayload parDoPayload = RunnerApi.ParDoPayload.newBuilder().setDoFn(functionSpec).putTimerFamilySpecs("tfs-timer_family", RunnerApi.TimerFamilySpec.newBuilder().setTimeDomain(RunnerApi.TimeDomain.Enum.EVENT_TIME).setTimerFamilyCoderId("timer-coder").build()).build();
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).putOutputs("2L-output", "2L-output-pc").build()).putTransforms("3L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("beam:transform:pardo:v1").setPayload(parDoPayload.toByteString())).putInputs("3L-input", "2L-output-pc").build()).putPcollections("2L-output-pc", RunnerApi.PCollection.newBuilder().setWindowingStrategyId("window-strategy").setCoderId("2L-output-coder").setIsBounded(RunnerApi.IsBounded.Enum.BOUNDED).build()).putWindowingStrategies("window-strategy", RunnerApi.WindowingStrategy.newBuilder().setWindowCoderId("window-strategy-coder").setWindowFn(RunnerApi.FunctionSpec.newBuilder().setUrn("beam:window_fn:global_windows:v1")).setOutputTime(RunnerApi.OutputTime.Enum.END_OF_WINDOW).setAccumulationMode(RunnerApi.AccumulationMode.Enum.ACCUMULATING).setTrigger(RunnerApi.Trigger.newBuilder().setAlways(RunnerApi.Trigger.Always.getDefaultInstance())).setClosingBehavior(RunnerApi.ClosingBehavior.Enum.EMIT_ALWAYS).setOnTimeBehavior(RunnerApi.OnTimeBehavior.Enum.FIRE_ALWAYS).build()).setTimerApiServiceDescriptor(Endpoints.ApiServiceDescriptor.newBuilder().setUrl("url").build()).putCoders("string_coder", CoderTranslation.toProto((Coder)StringUtf8Coder.of()).getCoder()).putCoders("2L-output-coder", RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(ModelCoders.KV_CODER_URN).build()).addComponentCoderIds("string_coder").addComponentCoderIds("string_coder").build()).putCoders("window-strategy-coder", RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(ModelCoders.GLOBAL_WINDOW_CODER_URN).build()).build()).putCoders("timer-coder", RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(ModelCoders.TIMER_CODER_URN)).addComponentCoderIds("string_coder").addComponentCoderIds("window-strategy-coder").build()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        HashMap urnToPTransformRunnerFactoryMap = Maps.newHashMap((Map)ProcessBundleHandler.REGISTERED_RUNNER_FACTORIES);
        urnToPTransformRunnerFactoryMap.put(DATA_INPUT_URN, context -> {
            context.addIncomingDataEndpoint(Endpoints.ApiServiceDescriptor.getDefaultInstance(), (Coder)KvCoder.of((Coder)StringUtf8Coder.of(), (Coder)StringUtf8Coder.of()), input -> dataOutput.add((String)input.getValue()));
            return null;
        });
        Mockito.doAnswer(invocation -> new BeamFnDataOutboundAggregator(PipelineOptionsFactory.create(), (Supplier)invocation.getArgument(1), (StreamObserver)new StreamObserver<BeamFnApi.Elements>(){

            public void onNext(BeamFnApi.Elements elements) {
                for (BeamFnApi.Elements.Timers timer : elements.getTimersList()) {
                    timerOutput.addAll(elements.getTimersList());
                }
            }

            public void onError(Throwable throwable) {
            }

            public void onCompleted() {
            }
        }, ((Boolean)invocation.getArgument(2)).booleanValue())).when(this.beamFnDataClient).createOutboundAggregator((Endpoints.ApiServiceDescriptor)ArgumentMatchers.any(), (Supplier)ArgumentMatchers.any(), ArgumentMatchers.anyBoolean());
        return new ProcessBundleHandler(PipelineOptionsFactory.create(), enableOutputEmbedding ? Collections.singleton(BeamUrns.getUrn((ProtocolMessageEnum)RunnerApi.StandardRunnerProtocols.Enum.CONTROL_RESPONSE_ELEMENTS_EMBEDDING)) : Collections.emptySet(), ((Map)fnApiRegistry)::get, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)urnToPTransformRunnerFactoryMap, Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
    }

    @Test
    public void testInstructionEmbeddedElementsAreProcessed() throws Exception {
        ArrayList<String> dataOutput = new ArrayList<String>();
        ArrayList<BeamFnApi.Elements.Timers> timerOutput = new ArrayList<BeamFnApi.Elements.Timers>();
        ProcessBundleHandler handler = this.setupProcessBundleHandlerForSimpleRecordingDoFn(dataOutput, timerOutput, false);
        ByteStringOutputStream encodedData = new ByteStringOutputStream();
        KvCoder.of((Coder)StringUtf8Coder.of(), (Coder)StringUtf8Coder.of()).encode(KV.of((Object)"", (Object)"data"), (OutputStream)encodedData);
        ByteStringOutputStream encodedTimer = new ByteStringOutputStream();
        Timer.Coder.of((Coder)StringUtf8Coder.of(), (Coder)GlobalWindow.Coder.INSTANCE).encode(Timer.of((Object)"", (String)"timer_id", Collections.singletonList(GlobalWindow.INSTANCE), (Instant)Instant.ofEpochMilli((long)1L), (Instant)Instant.ofEpochMilli((long)1L), (PaneInfo)PaneInfo.ON_TIME_AND_ONLY_FIRING), (OutputStream)encodedTimer);
        BeamFnApi.Elements elements = BeamFnApi.Elements.newBuilder().addData(BeamFnApi.Elements.Data.newBuilder().setInstructionId("998L").setTransformId("2L").setData(encodedData.toByteString()).build()).addData(BeamFnApi.Elements.Data.newBuilder().setInstructionId("998L").setTransformId("2L").setIsLast(true).build()).addTimers(BeamFnApi.Elements.Timers.newBuilder().setInstructionId("998L").setTransformId("3L").setTimerFamilyId("tfs-timer_family").setTimers(encodedTimer.toByteString()).build()).addTimers(BeamFnApi.Elements.Timers.newBuilder().setInstructionId("998L").setTransformId("3L").setTimerFamilyId("tfs-timer_family").setIsLast(true).build()).build();
        handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("998L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L").setElements(elements)).build());
        handler.shutdown();
        MatcherAssert.assertThat(dataOutput, Matchers.contains("data"));
        Timer timer = Timer.Coder.of((Coder)StringUtf8Coder.of(), (Coder)GlobalWindow.Coder.INSTANCE).decode(((BeamFnApi.Elements.Timers)timerOutput.get(0)).getTimers().newInput());
        Assert.assertEquals("output_timer", timer.getDynamicTimerTag());
    }

    @Test
    public void testInstructionEmbeddedElementsWithMalformedData() throws Exception {
        ArrayList<String> dataOutput = new ArrayList<String>();
        ArrayList<BeamFnApi.Elements.Timers> timerOutput = new ArrayList<BeamFnApi.Elements.Timers>();
        ProcessBundleHandler handler = this.setupProcessBundleHandlerForSimpleRecordingDoFn(dataOutput, timerOutput, false);
        ByteStringOutputStream encodedData = new ByteStringOutputStream();
        KvCoder.of((Coder)StringUtf8Coder.of(), (Coder)StringUtf8Coder.of()).encode(KV.of((Object)"", (Object)"data"), (OutputStream)encodedData);
        Assert.assertThrows("Expect java.lang.IllegalStateException: Unable to find inbound data receiver for instruction 998L and transform 3L.", IllegalStateException.class, () -> handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("998L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L").setElements(BeamFnApi.Elements.newBuilder().addData(BeamFnApi.Elements.Data.newBuilder().setInstructionId("998L").setTransformId("3L").setData(encodedData.toByteString()).build()).build())).build()));
        Assert.assertThrows("Elements embedded in ProcessBundleRequest do not contain stream terminators for all data and timer inputs. Unterminated endpoints: [2L:data, 3L:timers:tfs-timer_family]", RuntimeException.class, () -> handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("998L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L").setElements(BeamFnApi.Elements.newBuilder().addData(BeamFnApi.Elements.Data.newBuilder().setInstructionId("998L").setTransformId("2L").setData(encodedData.toByteString()).build()).build())).build()));
        handler.shutdown();
    }

    @Test
    public void testInstructionEmbeddedElementsWithMalformedTimers() throws Exception {
        ArrayList<String> dataOutput = new ArrayList<String>();
        ArrayList<BeamFnApi.Elements.Timers> timerOutput = new ArrayList<BeamFnApi.Elements.Timers>();
        ProcessBundleHandler handler = this.setupProcessBundleHandlerForSimpleRecordingDoFn(dataOutput, timerOutput, false);
        ByteStringOutputStream encodedTimer = new ByteStringOutputStream();
        Timer.Coder.of((Coder)StringUtf8Coder.of(), (Coder)GlobalWindow.Coder.INSTANCE).encode(Timer.of((Object)"", (String)"timer_id", Collections.singletonList(GlobalWindow.INSTANCE), (Instant)Instant.ofEpochMilli((long)1L), (Instant)Instant.ofEpochMilli((long)1L), (PaneInfo)PaneInfo.ON_TIME_AND_ONLY_FIRING), (OutputStream)encodedTimer);
        Assert.assertThrows("Expect java.lang.IllegalStateException: Unable to find inbound timer receiver for instruction 998L, transform 4L, and timer family tfs-timer_family.", IllegalStateException.class, () -> handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("998L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L").setElements(BeamFnApi.Elements.newBuilder().addTimers(BeamFnApi.Elements.Timers.newBuilder().setInstructionId("998L").setTransformId("4L").setTimerFamilyId("tfs-timer_family").setTimers(encodedTimer.toByteString()).build()).build())).build()));
        Assert.assertThrows("Expect java.lang.IllegalStateException: Unable to find inbound timer receiver for instruction 998L, transform 3L, and timer family tfs-not_declared_id.", IllegalStateException.class, () -> handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("998L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L").setElements(BeamFnApi.Elements.newBuilder().addTimers(BeamFnApi.Elements.Timers.newBuilder().setInstructionId("998L").setTransformId("3L").setTimerFamilyId("tfs-not_declared_id").setTimers(encodedTimer.toByteString()).build()).build())).build()));
        Assert.assertThrows("Elements embedded in ProcessBundleRequest do not contain stream terminators for all data and timer inputs. Unterminated endpoints: [2L:data, 3L:timers:tfs-timer_family]", RuntimeException.class, () -> handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("998L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L").setElements(BeamFnApi.Elements.newBuilder().addTimers(BeamFnApi.Elements.Timers.newBuilder().setInstructionId("998L").setTransformId("3L").setTimerFamilyId("tfs-timer_family").setTimers(encodedTimer.toByteString()).build()).build())).build()));
        handler.shutdown();
    }

    @Test
    public void testOutputEmbeddedElementsAreProcessed() throws Exception {
        ArrayList<String> dataOutput = new ArrayList<String>();
        ArrayList<BeamFnApi.Elements.Timers> timerOutput = new ArrayList<BeamFnApi.Elements.Timers>();
        ProcessBundleHandler handler = this.setupProcessBundleHandlerForSimpleRecordingDoFn(dataOutput, timerOutput, true);
        ByteStringOutputStream encodedTimer = new ByteStringOutputStream();
        Timer.Coder.of((Coder)StringUtf8Coder.of(), (Coder)GlobalWindow.Coder.INSTANCE).encode(Timer.of((Object)"", (String)"timer_id", Collections.singletonList(GlobalWindow.INSTANCE), (Instant)Instant.ofEpochMilli((long)1L), (Instant)Instant.ofEpochMilli((long)1L), (PaneInfo)PaneInfo.ON_TIME_AND_ONLY_FIRING), (OutputStream)encodedTimer);
        BeamFnApi.InstructionResponse.Builder builder = handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("998L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L").setElements(BeamFnApi.Elements.newBuilder().addData(BeamFnApi.Elements.Data.newBuilder().setInstructionId("998L").setTransformId("2L").setIsLast(true).build()).addTimers(BeamFnApi.Elements.Timers.newBuilder().setInstructionId("998L").setTransformId("3L").setTimerFamilyId("tfs-timer_family").setTimers(encodedTimer.toByteString()).build()).addTimers(BeamFnApi.Elements.Timers.newBuilder().setInstructionId("998L").setTransformId("3L").setTimerFamilyId("tfs-timer_family").setIsLast(true).build()).build())).build());
        handler.shutdown();
        MatcherAssert.assertThat(timerOutput, IsEmptyCollection.empty());
        Assert.assertEquals(2L, builder.build().getProcessBundle().getElements().getTimersCount());
    }

    @Test
    public void testInstructionIsUnregisteredFromBeamFnDataClientOnSuccess() throws Exception {
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).build()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        Mockito.doAnswer(invocation -> {
            String instructionId = invocation.getArgument(0, String.class);
            CloseableFnDataReceiver data = invocation.getArgument(2, CloseableFnDataReceiver.class);
            data.accept((Object)BeamFnApi.Elements.newBuilder().addData(BeamFnApi.Elements.Data.newBuilder().setInstructionId(instructionId).setTransformId("2L").setIsLast(true)).build());
            return null;
        }).when(this.beamFnDataClient).registerReceiver((String)ArgumentMatchers.any(), (List)ArgumentMatchers.any(), (CloseableFnDataReceiver)ArgumentMatchers.any());
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), ((Map)fnApiRegistry)::get, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of((Object)DATA_INPUT_URN, context -> {
            context.addIncomingDataEndpoint(Endpoints.ApiServiceDescriptor.getDefaultInstance(), (Coder)StringUtf8Coder.of(), input -> {});
            return null;
        }), Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
        handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("instructionId").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build());
        Mockito.verify(this.beamFnDataClient).registerReceiver(Mockito.eq("instructionId"), (List)ArgumentMatchers.any(), (CloseableFnDataReceiver)ArgumentMatchers.any());
        Mockito.verify(this.beamFnDataClient).unregisterReceiver(Mockito.eq("instructionId"), (List)ArgumentMatchers.any());
        Mockito.verifyNoMoreInteractions(this.beamFnDataClient);
    }

    @Test
    public void testDataProcessingExceptionsArePropagated() throws Exception {
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).build()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        Mockito.doAnswer(invocation -> {
            ByteStringOutputStream encodedData = new ByteStringOutputStream();
            StringUtf8Coder.of().encode("A", (OutputStream)encodedData);
            String instructionId = invocation.getArgument(0, String.class);
            CloseableFnDataReceiver data = invocation.getArgument(2, CloseableFnDataReceiver.class);
            data.accept((Object)BeamFnApi.Elements.newBuilder().addData(BeamFnApi.Elements.Data.newBuilder().setInstructionId(instructionId).setTransformId("2L").setData(encodedData.toByteString()).setIsLast(true)).build());
            return null;
        }).when(this.beamFnDataClient).registerReceiver((String)ArgumentMatchers.any(), (List)ArgumentMatchers.any(), (CloseableFnDataReceiver)ArgumentMatchers.any());
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), ((Map)fnApiRegistry)::get, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of((Object)DATA_INPUT_URN, context -> {
            context.addIncomingDataEndpoint(Endpoints.ApiServiceDescriptor.getDefaultInstance(), (Coder)StringUtf8Coder.of(), input -> {
                throw new IllegalStateException("TestException");
            });
            return null;
        }), Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
        Assert.assertThrows("TestException", IllegalStateException.class, () -> handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("instructionId").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build()));
        Mockito.verify(this.beamFnDataClient).registerReceiver(Mockito.eq("instructionId"), (List)ArgumentMatchers.any(), (CloseableFnDataReceiver)ArgumentMatchers.any());
        Mockito.verify(this.beamFnDataClient).poisonInstructionId(Mockito.eq("instructionId"));
        Mockito.verifyNoMoreInteractions(this.beamFnDataClient);
    }

    @Test
    public void testPTransformFinishExceptionsArePropagated() throws Exception {
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).build()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), ((Map)fnApiRegistry)::get, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of((Object)DATA_INPUT_URN, context -> {
            context.addFinishBundleFunction(ProcessBundleHandlerTest::throwException);
            return null;
        }), Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
        Assert.assertThrows("TestException", IllegalStateException.class, () -> handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build()));
        MatcherAssert.assertThat((ConcurrentLinkedQueue)handler.bundleProcessorCache.getCachedBundleProcessors().get("1L"), IsEmptyCollection.empty());
    }

    @Test
    public void testPendingStateCallsBlockTillCompletion() throws Exception {
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).build()).setStateApiServiceDescriptor(Endpoints.ApiServiceDescriptor.getDefaultInstance()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        final CompletableFuture[] successfulResponse = new CompletableFuture[1];
        final CompletableFuture[] unsuccessfulResponse = new CompletableFuture[1];
        BeamFnStateGrpcClientCache mockBeamFnStateGrpcClient = Mockito.mock(BeamFnStateGrpcClientCache.class);
        BeamFnStateClient mockBeamFnStateClient = Mockito.mock(BeamFnStateClient.class);
        Mockito.when(mockBeamFnStateGrpcClient.forApiServiceDescriptor((Endpoints.ApiServiceDescriptor)ArgumentMatchers.any())).thenReturn(mockBeamFnStateClient);
        Mockito.doAnswer(invocation -> {
            BeamFnApi.StateRequest.Builder stateRequestBuilder = (BeamFnApi.StateRequest.Builder)invocation.getArguments()[0];
            CompletableFuture completableFuture = new CompletableFuture();
            new Thread(() -> {
                Uninterruptibles.sleepUninterruptibly((long)500L, (TimeUnit)TimeUnit.MILLISECONDS);
                switch (stateRequestBuilder.getInstructionId()) {
                    case "SUCCESS": {
                        completableFuture.complete(BeamFnApi.StateResponse.getDefaultInstance());
                        break;
                    }
                    case "FAIL": {
                        completableFuture.completeExceptionally(new RuntimeException("TEST ERROR"));
                    }
                }
            }).start();
            return completableFuture;
        }).when(mockBeamFnStateClient).handle((BeamFnApi.StateRequest.Builder)ArgumentMatchers.any());
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), ((Map)fnApiRegistry)::get, this.beamFnDataClient, mockBeamFnStateGrpcClient, null, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of((Object)DATA_INPUT_URN, (Object)new PTransformRunnerFactory<Object>(){

            public Object createRunnerForPTransform(PTransformRunnerFactory.Context context) throws IOException {
                BeamFnStateClient beamFnStateClient = context.getBeamFnStateClient();
                context.addStartBundleFunction(() -> this.doStateCalls(beamFnStateClient));
                return null;
            }

            private void doStateCalls(BeamFnStateClient beamFnStateClient) {
                successfulResponse[0] = beamFnStateClient.handle(BeamFnApi.StateRequest.newBuilder().setInstructionId("SUCCESS"));
                unsuccessfulResponse[0] = beamFnStateClient.handle(BeamFnApi.StateRequest.newBuilder().setInstructionId("FAIL"));
            }
        }), Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
        handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build());
        Assert.assertTrue(successfulResponse[0].isDone());
        Assert.assertTrue(unsuccessfulResponse[0].isDone());
    }

    @Test
    public void testStateCallsFailIfNoStateApiServiceDescriptorSpecified() throws Exception {
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).build()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), ((Map)fnApiRegistry)::get, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of((Object)DATA_INPUT_URN, (Object)new PTransformRunnerFactory<Object>(){

            public Object createRunnerForPTransform(PTransformRunnerFactory.Context context) throws IOException {
                BeamFnStateClient beamFnStateClient = context.getBeamFnStateClient();
                context.addStartBundleFunction(() -> this.doStateCalls(beamFnStateClient));
                return null;
            }

            private void doStateCalls(BeamFnStateClient beamFnStateClient) {
                beamFnStateClient.handle(BeamFnApi.StateRequest.newBuilder().setInstructionId("SUCCESS"));
            }
        }), Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
        Assert.assertThrows("State API calls are unsupported", IllegalStateException.class, () -> handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build()));
    }

    @Test
    public void testProgressReportingIsExecutedSerially() throws Exception {
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).putOutputs("2L-output", "2L-output-pc").build()).putPcollections("2L-output-pc", RunnerApi.PCollection.getDefaultInstance()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        CountDownLatch startLatch = new CountDownLatch(1);
        CountDownLatch finishLatch = new CountDownLatch(1);
        final AtomicReference<ProcessBundleHandler.BundleProcessor> bundleProcessor = new AtomicReference<ProcessBundleHandler.BundleProcessor>();
        final AtomicReference mainBundleProcessingThread = new AtomicReference();
        final AtomicInteger counter = new AtomicInteger();
        final AtomicBoolean finalWasCalled = new AtomicBoolean();
        final AtomicBoolean resetWasCalled = new AtomicBoolean();
        BundleProgressReporter testReporter = new BundleProgressReporter(){

            public void updateIntermediateMonitoringData(Map<String, ByteString> monitoringData) {
                Assert.assertTrue(((ReentrantLock)((ProcessBundleHandler.BundleProcessor)bundleProcessor.get()).getProgressRequestLock()).isHeldByCurrentThread());
                Assert.assertNotEquals(Thread.currentThread(), mainBundleProcessingThread.get());
                Assert.assertFalse(finalWasCalled.get());
                Assert.assertFalse(resetWasCalled.get());
                monitoringData.put("testId", ByteString.copyFromUtf8((String)Long.toString(counter.getAndIncrement())));
            }

            public void updateFinalMonitoringData(Map<String, ByteString> monitoringData) {
                Assert.assertTrue(((ReentrantLock)((ProcessBundleHandler.BundleProcessor)bundleProcessor.get()).getProgressRequestLock()).isHeldByCurrentThread());
                Assert.assertEquals(Thread.currentThread(), mainBundleProcessingThread.get());
                Assert.assertFalse(finalWasCalled.getAndSet(true));
                Assert.assertFalse(resetWasCalled.get());
                monitoringData.put("testId", ByteString.copyFromUtf8((String)Long.toString(counter.get())));
            }

            public void reset() {
                Assert.assertTrue(((ReentrantLock)((ProcessBundleHandler.BundleProcessor)bundleProcessor.get()).getProgressRequestLock()).isHeldByCurrentThread());
                Assert.assertEquals(Thread.currentThread(), mainBundleProcessingThread.get());
                Assert.assertTrue(finalWasCalled.get());
                Assert.assertFalse(resetWasCalled.getAndSet(true));
            }
        };
        PTransformRunnerFactory startFinishGuard = context -> {
            String pTransformId = context.getPTransformId();
            Supplier processBundleInstructionId = context.getProcessBundleInstructionIdSupplier();
            context.addBundleProgressReporter(testReporter);
            context.addStartBundleFunction(() -> startLatch.countDown());
            context.addFinishBundleFunction(() -> finishLatch.await());
            return null;
        };
        ProcessBundleHandler.BundleProcessorCache bundleProcessorCache = new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO);
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.singleton(BeamUrns.getUrn((ProtocolMessageEnum)RunnerApi.StandardRunnerProtocols.Enum.MONITORING_INFO_SHORT_IDS)), ((Map)fnApiRegistry)::get, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of((Object)DATA_INPUT_URN, (Object)startFinishGuard), Caches.noop(), bundleProcessorCache, null);
        AtomicBoolean progressShouldExit = new AtomicBoolean();
        Future bundleProcessorTask = this.executor.submit(() -> {
            mainBundleProcessingThread.set(Thread.currentThread());
            BeamFnApi.InstructionResponse response = handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("999L").setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build()).build();
            progressShouldExit.set(true);
            return response;
        });
        startLatch.await();
        bundleProcessor.set(bundleProcessorCache.find("999L"));
        int minNumResults = 5;
        CountDownLatch progressLatch = new CountDownLatch(1);
        CountDownLatch someProgressIsDone = new CountDownLatch(5);
        ArrayList<Future> progressReportingTasks = new ArrayList<Future>();
        int i = 0;
        while (i < 20) {
            int threadId = i++;
            progressReportingTasks.add(this.executor.submit(() -> {
                BeamFnApi.InstructionResponse.Builder response;
                progressLatch.await();
                int requestCount = 0;
                try {
                    while (!(response = handler.progress(BeamFnApi.InstructionRequest.newBuilder().setInstructionId("thread-" + threadId + "-" + ++requestCount).setProcessBundleProgress(BeamFnApi.ProcessBundleProgressRequest.newBuilder().setInstructionId("999L").build()).build())).getProcessBundleProgress().getMonitoringDataMap().containsKey("testId") && !progressShouldExit.get()) {
                    }
                }
                finally {
                    someProgressIsDone.countDown();
                }
                return response.build();
            }));
        }
        progressLatch.countDown();
        someProgressIsDone.await();
        finishLatch.countDown();
        ArrayList<ByteString> progressReportingResults = new ArrayList<ByteString>();
        for (Future progressReportingTask : progressReportingTasks) {
            ByteString result = ((BeamFnApi.InstructionResponse)progressReportingTask.get()).getProcessBundleProgress().getMonitoringDataOrDefault("testId", null);
            if (result == null) continue;
            progressReportingResults.add(result);
        }
        Assert.assertTrue(progressReportingResults.size() >= 5);
        ArrayList<ByteString> expectedIntermediateResults = new ArrayList<ByteString>();
        for (int i2 = 0; i2 < counter.get(); ++i2) {
            expectedIntermediateResults.add(ByteString.copyFromUtf8((String)Long.toString(i2)));
        }
        MatcherAssert.assertThat(progressReportingResults, Matchers.containsInAnyOrder(expectedIntermediateResults.toArray()));
        Assert.assertEquals(ByteString.copyFromUtf8((String)Long.toString(counter.get())), ((BeamFnApi.InstructionResponse)bundleProcessorTask.get()).getProcessBundle().getMonitoringDataOrThrow("testId"));
        Assert.assertTrue(finalWasCalled.get());
        Assert.assertTrue(resetWasCalled.get());
    }

    @Test
    public void testTimerRegistrationsFailIfNoTimerApiServiceDescriptorSpecified() throws Exception {
        BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = BeamFnApi.ProcessBundleDescriptor.newBuilder().putTransforms("2L", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).build()).build();
        ImmutableMap fnApiRegistry = ImmutableMap.of((Object)"1L", (Object)processBundleDescriptor);
        ProcessBundleHandler handler = new ProcessBundleHandler(PipelineOptionsFactory.create(), Collections.emptySet(), ((Map)fnApiRegistry)::get, this.beamFnDataClient, null, null, new ShortIdMap(), this.executionStateSampler, (Map)ImmutableMap.of((Object)DATA_INPUT_URN, (Object)new PTransformRunnerFactory<Object>(){

            public Object createRunnerForPTransform(PTransformRunnerFactory.Context context) throws IOException {
                context.addOutgoingTimersEndpoint("timer", (Coder)Timer.Coder.of((Coder)StringUtf8Coder.of(), (Coder)GlobalWindow.Coder.INSTANCE));
                return null;
            }
        }), Caches.noop(), new ProcessBundleHandler.BundleProcessorCache(Duration.ZERO), null);
        Assert.assertThrows("Timers are unsupported", IllegalStateException.class, () -> handler.processBundle(BeamFnApi.InstructionRequest.newBuilder().setProcessBundle(BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")).build()));
    }

    private static void throwException() {
        throw new IllegalStateException("TestException");
    }

    private static /* synthetic */ void lambda$testOrderOfStartAndFinishCalls$1(Supplier processBundleInstructionId, List orderOfOperations, String pTransformId) throws Exception {
        MatcherAssert.assertThat((String)processBundleInstructionId.get(), Matchers.equalTo("999L"));
        orderOfOperations.add("Finish" + pTransformId);
    }

    private static /* synthetic */ void lambda$testOrderOfStartAndFinishCalls$0(Supplier processBundleInstructionId, List orderOfOperations, String pTransformId) throws Exception {
        MatcherAssert.assertThat((String)processBundleInstructionId.get(), Matchers.equalTo("999L"));
        orderOfOperations.add("Start" + pTransformId);
    }

    private static final class SimpleDoFn
    extends DoFn<KV<String, String>, String> {
        private static final TupleTag<String> MAIN_OUTPUT_TAG = new TupleTag("mainOutput");
        private static final String TIMER_FAMILY_ID = "timer_family";
        @DoFn.TimerFamily(value="timer_family")
        private final TimerSpec timer = TimerSpecs.timerMap((TimeDomain)TimeDomain.EVENT_TIME);

        private SimpleDoFn() {
        }

        @DoFn.ProcessElement
        public void processElement(DoFn.ProcessContext context, BoundedWindow window) {
        }

        @DoFn.OnTimerFamily(value="timer_family")
        public void onTimer(@DoFn.TimerFamily(value="timer_family") TimerMap timerFamily) {
            timerFamily.get("output_timer").withOutputTimestamp(Instant.ofEpochMilli((long)100L)).set(Instant.ofEpochMilli((long)100L));
        }
    }

    private static class TestBundleProcessorCache
    extends ProcessBundleHandler.BundleProcessorCache {
        TestBundleProcessorCache() {
            super(Duration.ZERO);
        }

        ProcessBundleHandler.BundleProcessor get(BeamFnApi.InstructionRequest processBundleRequest, Supplier<ProcessBundleHandler.BundleProcessor> bundleProcessorSupplier) {
            return new TestBundleProcessor(super.get(processBundleRequest, bundleProcessorSupplier));
        }
    }

    private static class TestBundleProcessor
    extends ProcessBundleHandler.BundleProcessor {
        static int resetCnt = 0;
        private ProcessBundleHandler.BundleProcessor wrappedBundleProcessor;

        TestBundleProcessor(ProcessBundleHandler.BundleProcessor wrappedBundleProcessor) {
            this.wrappedBundleProcessor = wrappedBundleProcessor;
        }

        Cache<?, ?> getProcessWideCache() {
            return this.wrappedBundleProcessor.getProcessWideCache();
        }

        BundleProgressReporter.InMemory getBundleProgressReporterAndRegistrar() {
            return this.wrappedBundleProcessor.getBundleProgressReporterAndRegistrar();
        }

        BeamFnApi.ProcessBundleDescriptor getProcessBundleDescriptor() {
            return this.wrappedBundleProcessor.getProcessBundleDescriptor();
        }

        PTransformFunctionRegistry getStartFunctionRegistry() {
            return this.wrappedBundleProcessor.getStartFunctionRegistry();
        }

        PTransformFunctionRegistry getFinishFunctionRegistry() {
            return this.wrappedBundleProcessor.getFinishFunctionRegistry();
        }

        List<ThrowingRunnable> getResetFunctions() {
            return this.wrappedBundleProcessor.getResetFunctions();
        }

        List<ThrowingRunnable> getTearDownFunctions() {
            return this.wrappedBundleProcessor.getTearDownFunctions();
        }

        BundleSplitListener.InMemory getSplitListener() {
            return this.wrappedBundleProcessor.getSplitListener();
        }

        PCollectionConsumerRegistry getpCollectionConsumerRegistry() {
            return this.wrappedBundleProcessor.getpCollectionConsumerRegistry();
        }

        ProcessBundleHandler.MetricsEnvironmentStateForBundle getMetricsEnvironmentStateForBundle() {
            return this.wrappedBundleProcessor.getMetricsEnvironmentStateForBundle();
        }

        public ExecutionStateSampler.ExecutionStateTracker getStateTracker() {
            return this.wrappedBundleProcessor.getStateTracker();
        }

        ProcessBundleHandler.HandleStateCallsForBundle getBeamFnStateClient() {
            return this.wrappedBundleProcessor.getBeamFnStateClient();
        }

        List<Endpoints.ApiServiceDescriptor> getInboundEndpointApiServiceDescriptors() {
            return this.wrappedBundleProcessor.getInboundEndpointApiServiceDescriptors();
        }

        List<DataEndpoint<?>> getInboundDataEndpoints() {
            return this.wrappedBundleProcessor.getInboundDataEndpoints();
        }

        List<TimerEndpoint<?>> getTimerEndpoints() {
            return this.wrappedBundleProcessor.getTimerEndpoints();
        }

        Collection<FinalizeBundleHandler.CallbackRegistration> getBundleFinalizationCallbackRegistrations() {
            return this.wrappedBundleProcessor.getBundleFinalizationCallbackRegistrations();
        }

        Collection<BeamFnDataReadRunner> getChannelRoots() {
            return this.wrappedBundleProcessor.getChannelRoots();
        }

        Map<Endpoints.ApiServiceDescriptor, BeamFnDataOutboundAggregator> getOutboundAggregators() {
            return this.wrappedBundleProcessor.getOutboundAggregators();
        }

        Set<String> getRunnerCapabilities() {
            return this.wrappedBundleProcessor.getRunnerCapabilities();
        }

        Lock getProgressRequestLock() {
            return this.wrappedBundleProcessor.getProgressRequestLock();
        }

        void reset() throws Exception {
            ++resetCnt;
            this.wrappedBundleProcessor.reset();
        }
    }

    private static class TestDoFn
    extends DoFn<String, String> {
        private static final TupleTag<String> mainOutput = new TupleTag("mainOutput");
        static List<String> orderOfOperations = new ArrayList<String>();
        private State state = State.NOT_SET_UP;

        private TestDoFn() {
        }

        @DoFn.Setup
        public void setUp() {
            Preconditions.checkState((boolean)State.NOT_SET_UP.equals((Object)this.state), (String)"Unexpected state: %s", (Object)((Object)this.state));
            this.state = State.SET_UP;
            orderOfOperations.add("setUp");
        }

        @DoFn.Teardown
        public void tearDown() {
            Preconditions.checkState((!State.TEAR_DOWN.equals((Object)this.state) ? 1 : 0) != 0, (String)"Unexpected state: %s", (Object)((Object)this.state));
            this.state = State.TEAR_DOWN;
            orderOfOperations.add("tearDown");
        }

        @DoFn.StartBundle
        public void startBundle() {
            this.state = State.START_BUNDLE;
            orderOfOperations.add("startBundle");
        }

        @DoFn.ProcessElement
        public void processElement(DoFn.ProcessContext context, BoundedWindow window) {
            Preconditions.checkState((boolean)State.START_BUNDLE.equals((Object)this.state), (String)"Unexpected state: %s", (Object)((Object)this.state));
        }

        @DoFn.FinishBundle
        public void finishBundle(DoFn.FinishBundleContext context) {
            Preconditions.checkState((boolean)State.START_BUNDLE.equals((Object)this.state), (String)"Unexpected state: %s", (Object)((Object)this.state));
            this.state = State.FINISH_BUNDLE;
            orderOfOperations.add("finishBundle");
        }

        private static enum State {
            NOT_SET_UP,
            SET_UP,
            START_BUNDLE,
            FINISH_BUNDLE,
            TEAR_DOWN;

        }
    }
}

