/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.fn.harness;

import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.HashMap;
import org.apache.beam.fn.harness.CombineRunners;
import org.apache.beam.fn.harness.MapFnRunners;
import org.apache.beam.fn.harness.PTransformRunnerFactory;
import org.apache.beam.fn.harness.PTransformRunnerFactoryTestContext;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.function.ThrowingRunnable;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.CombineFnBase;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.construction.ModelCoders;
import org.apache.beam.sdk.util.construction.PipelineTranslation;
import org.apache.beam.sdk.util.construction.SdkComponents;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.hamcrest.core.IsEqual;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(value=JUnit4.class)
public class CombineRunnersTest {
    private static final String TEST_COMBINE_ID = "combineId";
    private RunnerApi.PTransform pTransform;
    private String inputPCollectionId;
    private String outputPCollectionId;
    private RunnerApi.Pipeline pProto;

    @Before
    public void createPipeline() throws Exception {
        TestCombineFn combineFn = new TestCombineFn();
        Combine.PerKey combine = Combine.perKey((CombineFnBase.GlobalCombineFn)combineFn);
        Pipeline p = Pipeline.create();
        PCollection inputPCollection = (PCollection)p.apply((PTransform)Create.of((Object)KV.of((Object)"unused", (Object)"0"), (Object[])new KV[0]));
        inputPCollection.setCoder((Coder)KvCoder.of((Coder)StringUtf8Coder.of(), (Coder)StringUtf8Coder.of()));
        PCollection outputPCollection = (PCollection)inputPCollection.apply(TEST_COMBINE_ID, (PTransform)combine);
        outputPCollection.setCoder((Coder)KvCoder.of((Coder)StringUtf8Coder.of(), (Coder)BigEndianIntegerCoder.of()));
        SdkComponents sdkComponents = SdkComponents.create((PipelineOptions)p.getOptions());
        this.pProto = PipelineTranslation.toProto((Pipeline)p, (SdkComponents)sdkComponents);
        this.inputPCollectionId = sdkComponents.registerPCollection(inputPCollection);
        this.outputPCollectionId = sdkComponents.registerPCollection(outputPCollection);
        this.pTransform = this.pProto.getComponents().getTransformsOrThrow(TEST_COMBINE_ID);
    }

    @Test
    public void testPrecombine() throws Exception {
        PTransformRunnerFactoryTestContext context = PTransformRunnerFactoryTestContext.builder(TEST_COMBINE_ID, this.pTransform).components(RunnerApi.Components.newBuilder().putAllPcollections(this.pProto.getComponents().getPcollectionsMap()).putAllCoders(this.pProto.getComponents().getCodersMap()).putAllWindowingStrategies(this.pProto.getComponents().getWindowingStrategiesMap()).build()).build();
        ArrayDeque<WindowedValue> mainOutputValues = new ArrayDeque<WindowedValue>();
        context.addPCollectionConsumer((String)Iterables.getOnlyElement(this.pTransform.getOutputsMap().values()), mainOutputValues::add);
        new CombineRunners.PrecombineFactory().createRunnerForPTransform((PTransformRunnerFactory.Context)context);
        ((ThrowingRunnable)Iterables.getOnlyElement(context.getStartBundleFunctions())).run();
        mainOutputValues.clear();
        MatcherAssert.assertThat(context.getPCollectionConsumers().keySet(), Matchers.containsInAnyOrder(this.inputPCollectionId, this.outputPCollectionId));
        FnDataReceiver input = context.getPCollectionConsumer(this.inputPCollectionId);
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"A", (Object)"1")));
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"A", (Object)"2")));
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"A", (Object)"6")));
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"B", (Object)"2")));
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"C", (Object)"3")));
        ((ThrowingRunnable)Iterables.getOnlyElement(context.getFinishBundleFunctions())).run();
        Integer sum = 0;
        for (WindowedValue outputValue : mainOutputValues) {
            if (!"A".equals(((KV)outputValue.getValue()).getKey())) continue;
            sum = sum + (Integer)((KV)outputValue.getValue()).getValue();
        }
        MatcherAssert.assertThat(sum, IsEqual.equalTo(9));
        mainOutputValues.removeIf(elem -> "A".equals(((KV)elem.getValue()).getKey()));
        MatcherAssert.assertThat(mainOutputValues, Matchers.containsInAnyOrder(WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"B", (Object)2)), WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"C", (Object)3))));
    }

    @Test
    public void testMergeAccumulators() throws Exception {
        RunnerApi.PCollection pCollection = RunnerApi.PCollection.newBuilder().setUniqueName(this.inputPCollectionId).setCoderId("coder-id").build();
        HashMap<String, RunnerApi.PCollection> pCollectionMap = new HashMap<String, RunnerApi.PCollection>(this.pProto.getComponents().getPcollectionsMap());
        pCollectionMap.put(this.inputPCollectionId, pCollection);
        HashMap<String, RunnerApi.Coder> coderMap = new HashMap<String, RunnerApi.Coder>(this.pProto.getComponents().getCodersMap());
        coderMap.put("coder-id", RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(ModelCoders.KV_CODER_URN).build()).addComponentCoderIds("StringUtf8Coder").addComponentCoderIds("coder-id-iterable").build());
        coderMap.put("coder-id-iterable", RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(ModelCoders.ITERABLE_CODER_URN).build()).addComponentCoderIds("BigEndianIntegerCoder").build());
        PTransformRunnerFactoryTestContext context = PTransformRunnerFactoryTestContext.builder(TEST_COMBINE_ID, this.pTransform).components(RunnerApi.Components.newBuilder().putAllPcollections(pCollectionMap).putAllCoders(coderMap).build()).build();
        ArrayDeque mainOutputValues = new ArrayDeque();
        context.addPCollectionConsumer((String)Iterables.getOnlyElement(this.pTransform.getOutputsMap().values()), mainOutputValues::add);
        MapFnRunners.forValueMapFnFactory(CombineRunners::createMergeAccumulatorsMapFunction).createRunnerForPTransform((PTransformRunnerFactory.Context)context);
        MatcherAssert.assertThat(context.getStartBundleFunctions(), Matchers.empty());
        MatcherAssert.assertThat(context.getFinishBundleFunctions(), Matchers.empty());
        mainOutputValues.clear();
        MatcherAssert.assertThat(context.getPCollectionConsumers().keySet(), Matchers.containsInAnyOrder(this.inputPCollectionId, this.outputPCollectionId));
        FnDataReceiver input = context.getPCollectionConsumer(this.inputPCollectionId);
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"A", Arrays.asList(1, 2, 6))));
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"B", Arrays.asList(2, 3))));
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"C", Arrays.asList(5, 2))));
        MatcherAssert.assertThat(mainOutputValues, Matchers.contains(WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"A", (Object)9)), WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"B", (Object)5)), WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"C", (Object)7))));
    }

    @Test
    public void testExtractOutputs() throws Exception {
        RunnerApi.PCollection pCollection = RunnerApi.PCollection.newBuilder().setUniqueName(this.inputPCollectionId).setCoderId("coder-id").build();
        HashMap<String, RunnerApi.PCollection> pCollectionMap = new HashMap<String, RunnerApi.PCollection>(this.pProto.getComponents().getPcollectionsMap());
        pCollectionMap.put(this.inputPCollectionId, pCollection);
        HashMap<String, RunnerApi.Coder> coderMap = new HashMap<String, RunnerApi.Coder>(this.pProto.getComponents().getCodersMap());
        coderMap.put("coder-id", RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(ModelCoders.KV_CODER_URN).build()).addComponentCoderIds("StringUtf8Coder").addComponentCoderIds("BigEndianIntegerCoder").build());
        PTransformRunnerFactoryTestContext context = PTransformRunnerFactoryTestContext.builder(TEST_COMBINE_ID, this.pTransform).components(RunnerApi.Components.newBuilder().putAllPcollections(pCollectionMap).putAllCoders(coderMap).build()).build();
        ArrayDeque mainOutputValues = new ArrayDeque();
        context.addPCollectionConsumer((String)Iterables.getOnlyElement(this.pTransform.getOutputsMap().values()), mainOutputValues::add);
        MapFnRunners.forValueMapFnFactory(CombineRunners::createExtractOutputsMapFunction).createRunnerForPTransform((PTransformRunnerFactory.Context)context);
        MatcherAssert.assertThat(context.getStartBundleFunctions(), Matchers.empty());
        MatcherAssert.assertThat(context.getFinishBundleFunctions(), Matchers.empty());
        mainOutputValues.clear();
        MatcherAssert.assertThat(context.getPCollectionConsumers().keySet(), Matchers.containsInAnyOrder(this.inputPCollectionId, this.outputPCollectionId));
        FnDataReceiver input = context.getPCollectionConsumer(this.inputPCollectionId);
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"A", (Object)9)));
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"B", (Object)5)));
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"C", (Object)7)));
        MatcherAssert.assertThat(mainOutputValues, Matchers.contains(WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"A", (Object)-9)), WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"B", (Object)-5)), WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"C", (Object)-7))));
    }

    @Test
    public void testConvertToAccumulators() throws Exception {
        PTransformRunnerFactoryTestContext context = PTransformRunnerFactoryTestContext.builder(TEST_COMBINE_ID, this.pTransform).components(RunnerApi.Components.newBuilder().putAllPcollections(this.pProto.getComponents().getPcollectionsMap()).putAllCoders(this.pProto.getComponents().getCodersMap()).build()).build();
        ArrayDeque mainOutputValues = new ArrayDeque();
        context.addPCollectionConsumer((String)Iterables.getOnlyElement(this.pTransform.getOutputsMap().values()), mainOutputValues::add);
        MapFnRunners.forValueMapFnFactory(CombineRunners::createConvertToAccumulatorsMapFunction).createRunnerForPTransform((PTransformRunnerFactory.Context)context);
        MatcherAssert.assertThat(context.getStartBundleFunctions(), Matchers.empty());
        MatcherAssert.assertThat(context.getFinishBundleFunctions(), Matchers.empty());
        mainOutputValues.clear();
        MatcherAssert.assertThat(context.getPCollectionConsumers().keySet(), Matchers.containsInAnyOrder(this.inputPCollectionId, this.outputPCollectionId));
        FnDataReceiver input = context.getPCollectionConsumer(this.inputPCollectionId);
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"A", (Object)"9")));
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"B", (Object)"5")));
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"C", (Object)"7")));
        MatcherAssert.assertThat(mainOutputValues, Matchers.contains(WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"A", (Object)9)), WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"B", (Object)5)), WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"C", (Object)7))));
    }

    @Test
    public void testCombineGroupedValues() throws Exception {
        RunnerApi.PCollection pCollection = RunnerApi.PCollection.newBuilder().setUniqueName(this.inputPCollectionId).setCoderId("coder-id").build();
        HashMap<String, RunnerApi.PCollection> pCollectionMap = new HashMap<String, RunnerApi.PCollection>(this.pProto.getComponents().getPcollectionsMap());
        pCollectionMap.put(this.inputPCollectionId, pCollection);
        HashMap<String, RunnerApi.Coder> coderMap = new HashMap<String, RunnerApi.Coder>(this.pProto.getComponents().getCodersMap());
        coderMap.put("coder-id", RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(ModelCoders.KV_CODER_URN).build()).addComponentCoderIds("StringUtf8Coder").addComponentCoderIds("coder-id-iterable").build());
        coderMap.put("coder-id-iterable", RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(ModelCoders.ITERABLE_CODER_URN).build()).addComponentCoderIds("StringUtf8Coder").build());
        PTransformRunnerFactoryTestContext context = PTransformRunnerFactoryTestContext.builder(TEST_COMBINE_ID, this.pTransform).components(RunnerApi.Components.newBuilder().putAllPcollections(pCollectionMap).putAllCoders(coderMap).build()).build();
        ArrayDeque mainOutputValues = new ArrayDeque();
        context.addPCollectionConsumer((String)Iterables.getOnlyElement(this.pTransform.getOutputsMap().values()), mainOutputValues::add);
        MapFnRunners.forValueMapFnFactory(CombineRunners::createCombineGroupedValuesMapFunction).createRunnerForPTransform((PTransformRunnerFactory.Context)context);
        MatcherAssert.assertThat(context.getStartBundleFunctions(), Matchers.empty());
        MatcherAssert.assertThat(context.getFinishBundleFunctions(), Matchers.empty());
        mainOutputValues.clear();
        MatcherAssert.assertThat(context.getPCollectionConsumers().keySet(), Matchers.containsInAnyOrder(this.inputPCollectionId, this.outputPCollectionId));
        FnDataReceiver input = context.getPCollectionConsumer(this.inputPCollectionId);
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"A", Arrays.asList("1", "2", "6"))));
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"B", Arrays.asList("2", "3"))));
        input.accept((Object)WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"C", Arrays.asList("5", "2"))));
        MatcherAssert.assertThat(mainOutputValues, Matchers.contains(WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"A", (Object)-9)), WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"B", (Object)-5)), WindowedValue.valueInGlobalWindow((Object)KV.of((Object)"C", (Object)-7))));
    }

    private static class TestCombineFn
    extends Combine.CombineFn<String, Integer, Integer> {
        private TestCombineFn() {
        }

        public Integer createAccumulator() {
            return 0;
        }

        public Integer addInput(Integer accum, String input) {
            accum = accum + Integer.parseInt(input);
            return accum;
        }

        public Integer mergeAccumulators(Iterable<Integer> accums) {
            Integer merged = 0;
            for (Integer accum : accums) {
                merged = merged + accum;
            }
            return merged;
        }

        public Integer extractOutput(Integer accum) {
            return -accum.intValue();
        }
    }
}

