/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.runners.dataflow;

import org.apache.beam.runners.dataflow.DataflowPTransformMatchers;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.CombineFnBase;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.hamcrest.Matcher;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(value=JUnit4.class)
public class DataflowPTransformMatchersTest {
    @Test
    public void combineValuesWithoutSideInputsSuccessfulMatches() {
        DataflowPTransformMatchers.CombineValuesWithoutSideInputsPTransformMatcher matcher = new DataflowPTransformMatchers.CombineValuesWithoutSideInputsPTransformMatcher();
        AppliedPTransform<?, ?, ?> groupedValues = DataflowPTransformMatchersTest.getCombineGroupedValuesFrom(DataflowPTransformMatchersTest.createCombineGroupedValuesPipeline());
        MatcherAssert.assertThat((Object)matcher.matches(groupedValues), (Matcher)Matchers.is((Object)true));
        groupedValues = DataflowPTransformMatchersTest.getCombineGroupedValuesFrom(DataflowPTransformMatchersTest.createCombinePerKeyPipeline());
        MatcherAssert.assertThat((Object)matcher.matches(groupedValues), (Matcher)Matchers.is((Object)true));
    }

    @Test
    public void combineValuesWithoutSideInputsSkipsNonmatching() {
        DataflowPTransformMatchers.CombineValuesWithoutSideInputsPTransformMatcher matcher = new DataflowPTransformMatchers.CombineValuesWithoutSideInputsPTransformMatcher();
        AppliedPTransform<?, ?, ?> groupedValues = DataflowPTransformMatchersTest.getCombineGroupedValuesFrom(DataflowPTransformMatchersTest.createCombineGroupedValuesWithSideInputsPipeline());
        MatcherAssert.assertThat((Object)matcher.matches(groupedValues), (Matcher)Matchers.is((Object)false));
        groupedValues = DataflowPTransformMatchersTest.getCombineGroupedValuesFrom(DataflowPTransformMatchersTest.createCombinePerKeyWithSideInputsPipeline());
        MatcherAssert.assertThat((Object)matcher.matches(groupedValues), (Matcher)Matchers.is((Object)false));
    }

    @Test
    public void combineValuesWithParentCheckSuccessfulMatches() {
        DataflowPTransformMatchers.CombineValuesWithParentCheckPTransformMatcher matcher = new DataflowPTransformMatchers.CombineValuesWithParentCheckPTransformMatcher();
        AppliedPTransform<?, ?, ?> groupedValues = DataflowPTransformMatchersTest.getCombineGroupedValuesFrom(DataflowPTransformMatchersTest.createCombinePerKeyPipeline());
        MatcherAssert.assertThat((Object)matcher.matches(groupedValues), (Matcher)Matchers.is((Object)true));
    }

    @Test
    public void combineValuesWithParentCheckSkipsNonmatching() {
        DataflowPTransformMatchers.CombineValuesWithParentCheckPTransformMatcher matcher = new DataflowPTransformMatchers.CombineValuesWithParentCheckPTransformMatcher();
        AppliedPTransform<?, ?, ?> groupedValues = DataflowPTransformMatchersTest.getCombineGroupedValuesFrom(DataflowPTransformMatchersTest.createCombineGroupedValuesPipeline());
        MatcherAssert.assertThat((Object)matcher.matches(groupedValues), (Matcher)Matchers.is((Object)false));
        groupedValues = DataflowPTransformMatchersTest.getCombineGroupedValuesFrom(DataflowPTransformMatchersTest.createCombineGroupedValuesWithSideInputsPipeline());
        MatcherAssert.assertThat((Object)matcher.matches(groupedValues), (Matcher)Matchers.is((Object)false));
        groupedValues = DataflowPTransformMatchersTest.getCombineGroupedValuesFrom(DataflowPTransformMatchersTest.createCombinePerKeyWithSideInputsPipeline());
        MatcherAssert.assertThat((Object)matcher.matches(groupedValues), (Matcher)Matchers.is((Object)false));
    }

    private static TestPipeline createCombinePerKeyPipeline() {
        TestPipeline pipeline = TestPipeline.create().enableAbandonedNodeEnforcement(false);
        PCollection input = ((PCollection)pipeline.apply((PTransform)Create.of((Object)KV.of((Object)"key", (Object)1), (Object[])new KV[0]))).setCoder((Coder)KvCoder.of((Coder)StringUtf8Coder.of(), (Coder)VarIntCoder.of()));
        input.apply((PTransform)Combine.perKey((CombineFnBase.GlobalCombineFn)new SumCombineFn()));
        return pipeline;
    }

    private static TestPipeline createCombinePerKeyWithSideInputsPipeline() {
        TestPipeline pipeline = TestPipeline.create().enableAbandonedNodeEnforcement(false);
        PCollection input = ((PCollection)pipeline.apply((PTransform)Create.of((Object)KV.of((Object)"key", (Object)1), (Object[])new KV[0]))).setCoder((Coder)KvCoder.of((Coder)StringUtf8Coder.of(), (Coder)VarIntCoder.of()));
        PCollection sideInput = (PCollection)pipeline.apply((PTransform)Create.of((Object)"side input", (Object[])new String[0]));
        PCollectionView sideInputView = (PCollectionView)sideInput.apply((PTransform)View.asSingleton());
        input.apply((PTransform)Combine.perKey((CombineFnBase.GlobalCombineFn)new SumCombineFnWithContext()).withSideInputs(new PCollectionView[]{sideInputView}));
        return pipeline;
    }

    private static TestPipeline createCombineGroupedValuesPipeline() {
        TestPipeline pipeline = TestPipeline.create().enableAbandonedNodeEnforcement(false);
        PCollection input = ((PCollection)pipeline.apply((PTransform)Create.of((Object)KV.of((Object)"key", (Object)1), (Object[])new KV[0]))).setCoder((Coder)KvCoder.of((Coder)StringUtf8Coder.of(), (Coder)VarIntCoder.of()));
        ((PCollection)input.apply((PTransform)GroupByKey.create())).apply((PTransform)Combine.groupedValues((CombineFnBase.GlobalCombineFn)new SumCombineFn()));
        return pipeline;
    }

    private static TestPipeline createCombineGroupedValuesWithSideInputsPipeline() {
        TestPipeline pipeline = TestPipeline.create().enableAbandonedNodeEnforcement(false);
        PCollection input = ((PCollection)pipeline.apply((PTransform)Create.of((Object)KV.of((Object)"key", (Object)1), (Object[])new KV[0]))).setCoder((Coder)KvCoder.of((Coder)StringUtf8Coder.of(), (Coder)VarIntCoder.of()));
        PCollection sideInput = (PCollection)pipeline.apply((PTransform)Create.of((Object)"side input", (Object[])new String[0]));
        PCollectionView sideInputView = (PCollectionView)sideInput.apply((PTransform)View.asSingleton());
        ((PCollection)input.apply((PTransform)GroupByKey.create())).apply((PTransform)Combine.groupedValues((CombineFnBase.GlobalCombineFn)new SumCombineFnWithContext()).withSideInputs(new PCollectionView[]{sideInputView}));
        return pipeline;
    }

    private static AppliedPTransform<?, ?, ?> getCombineGroupedValuesFrom(TestPipeline pipeline) {
        final AppliedPTransform[] transform = new AppliedPTransform[1];
        pipeline.traverseTopologically((Pipeline.PipelineVisitor)new Pipeline.PipelineVisitor.Defaults(){

            public Pipeline.PipelineVisitor.CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) {
                if (!node.isRootNode() && node.toAppliedPTransform(this.getPipeline()).getTransform().getClass().equals(Combine.GroupedValues.class)) {
                    transform[0] = node.toAppliedPTransform(this.getPipeline());
                    return Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
                }
                return Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
            }
        });
        return transform[0];
    }

    private static class SumCombineFnWithContext
    extends CombineWithContext.CombineFnWithContext<Integer, Integer, Integer> {
        SumCombineFn delegate;

        private SumCombineFnWithContext() {
        }

        public Integer createAccumulator(CombineWithContext.Context c) {
            return this.delegate.createAccumulator();
        }

        public Integer addInput(Integer accum, Integer input, CombineWithContext.Context c) {
            return this.delegate.addInput(accum, input);
        }

        public Integer mergeAccumulators(Iterable<Integer> accumulators, CombineWithContext.Context c) {
            return this.delegate.mergeAccumulators((Iterable)accumulators);
        }

        public Integer extractOutput(Integer accumulator, CombineWithContext.Context c) {
            return this.delegate.extractOutput(accumulator);
        }
    }

    private static class SumCombineFn
    extends Combine.CombineFn<Integer, Integer, Integer> {
        private SumCombineFn() {
        }

        public Integer createAccumulator() {
            return 0;
        }

        public Integer addInput(Integer accum, Integer input) {
            return accum + input;
        }

        public Integer mergeAccumulators(Iterable<Integer> accumulators) {
            Integer sum = 0;
            for (Integer accum : accumulators) {
                sum = sum + accum;
            }
            return sum;
        }

        public Integer extractOutput(Integer accumulator) {
            return accumulator;
        }
    }
}

