/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.expansion.service;

import com.google.auto.service.AutoService;
import java.util.ArrayList;
import java.util.List;
import org.apache.beam.model.expansion.v1.ExpansionApi;
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.expansion.service.ExpansionService;
import org.apache.beam.sdk.expansion.service.ExpansionServiceSchemaTransformProvider;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.schemas.JavaFieldSchema;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaCreate;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Impulse;
import org.apache.beam.sdk.transforms.InferableFunction;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.construction.BeamUrns;
import org.apache.beam.sdk.util.construction.ParDoTranslation;
import org.apache.beam.sdk.util.construction.PipelineTranslation;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.InvalidProtocolBufferException;
import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ProtocolMessageEnum;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.junit.Assert;
import org.junit.Test;

public class ExpansionServiceSchemaTransformProviderTest {
    private static final String TEST_NAME = "TestName";
    private static final String TEST_NAMESPACE = "namespace";
    private static final Schema TEST_SCHEMATRANSFORM_CONFIG_SCHEMA = Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"int1", (Schema.FieldType)Schema.FieldType.INT32), Schema.Field.of((String)"int2", (Schema.FieldType)Schema.FieldType.INT32), Schema.Field.of((String)"str1", (Schema.FieldType)Schema.FieldType.STRING), Schema.Field.of((String)"str2", (Schema.FieldType)Schema.FieldType.STRING)});
    private static final Schema TEST_SCHEMATRANSFORM_EQUIVALENT_CONFIG_SCHEMA = Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"str2", (Schema.FieldType)Schema.FieldType.STRING), Schema.Field.of((String)"str1", (Schema.FieldType)Schema.FieldType.STRING), Schema.Field.of((String)"int2", (Schema.FieldType)Schema.FieldType.INT32), Schema.Field.of((String)"int1", (Schema.FieldType)Schema.FieldType.INT32)});
    private ExpansionService expansionService = new ExpansionService();

    @Test
    public void testSchemaTransformDiscovery() {
        ExpansionApi.DiscoverSchemaTransformRequest discoverRequest = ExpansionApi.DiscoverSchemaTransformRequest.newBuilder().build();
        ExpansionApi.DiscoverSchemaTransformResponse response = this.expansionService.discover(discoverRequest);
        Assert.assertTrue((response.getSchemaTransformConfigsCount() >= 2 ? 1 : 0) != 0);
    }

    private void verifyLeafTransforms(ExpansionApi.ExpansionResponse response, int count) {
        int leafTransformCount = 0;
        for (RunnerApi.PTransform transform : response.getComponents().getTransformsMap().values()) {
            if (!transform.getSpec().getUrn().equals("beam:transform:pardo:v1")) continue;
            try {
                RunnerApi.ParDoPayload parDoPayload = RunnerApi.ParDoPayload.parseFrom((ByteString)transform.getSpec().getPayload());
                DoFn doFn = ParDoTranslation.getDoFn((RunnerApi.ParDoPayload)parDoPayload);
                if (!(doFn instanceof TestDoFn)) continue;
                TestDoFn testDoFn = (TestDoFn)doFn;
                Assert.assertEquals((Object)"aaa", (Object)testDoFn.str1);
                Assert.assertEquals((Object)"bbb", (Object)testDoFn.str2);
                Assert.assertEquals((long)111L, (long)testDoFn.int1);
                Assert.assertEquals((long)222L, (long)testDoFn.int2);
                ++leafTransformCount;
            }
            catch (InvalidProtocolBufferException exc) {
                throw new RuntimeException(exc);
            }
        }
        Assert.assertEquals((long)count, (long)leafTransformCount);
    }

    @Test
    public void testSchemaTransformExpansion() {
        Pipeline p = Pipeline.create();
        p.apply((PTransform)Impulse.create());
        RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto((Pipeline)p);
        String inputPcollId = (String)Iterables.getOnlyElement(((RunnerApi.PTransform)Iterables.getOnlyElement(pipelineProto.getComponents().getTransformsMap().values())).getOutputsMap().values());
        Row configRow = Row.withSchema((Schema)TEST_SCHEMATRANSFORM_CONFIG_SCHEMA).withFieldValue("int1", (Object)111).withFieldValue("int2", (Object)222).withFieldValue("str1", (Object)"aaa").withFieldValue("str2", (Object)"bbb").build();
        ExpansionApi.ExpansionRequest request = ExpansionApi.ExpansionRequest.newBuilder().setComponents(pipelineProto.getComponents()).setTransform(RunnerApi.PTransform.newBuilder().setUniqueName(TEST_NAME).setSpec(this.createSpec("dummy_id", configRow)).putInputs("input1", inputPcollId)).setNamespace(TEST_NAMESPACE).build();
        ExpansionApi.ExpansionResponse response = this.expansionService.expand(request);
        RunnerApi.PTransform expandedTransform = response.getTransform();
        Assert.assertEquals((long)3L, (long)expandedTransform.getSubtransformsCount());
        Assert.assertEquals((long)1L, (long)expandedTransform.getInputsCount());
        Assert.assertEquals((long)1L, (long)expandedTransform.getOutputsCount());
        this.verifyLeafTransforms(response, 1);
    }

    @Test
    public void testSchemaTransformExpansionMultiInputMultiOutput() {
        Pipeline p = Pipeline.create();
        p.apply((PTransform)Impulse.create());
        p.apply((PTransform)Impulse.create());
        RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto((Pipeline)p);
        ArrayList<String> inputPcollIds = new ArrayList<String>();
        for (RunnerApi.PTransform transform : pipelineProto.getComponents().getTransformsMap().values()) {
            inputPcollIds.add((String)Iterables.getOnlyElement(transform.getOutputsMap().values()));
        }
        Assert.assertEquals((long)2L, (long)inputPcollIds.size());
        Row configRow = Row.withSchema((Schema)TEST_SCHEMATRANSFORM_CONFIG_SCHEMA).withFieldValue("int1", (Object)111).withFieldValue("int2", (Object)222).withFieldValue("str1", (Object)"aaa").withFieldValue("str2", (Object)"bbb").build();
        ExpansionApi.ExpansionRequest request = ExpansionApi.ExpansionRequest.newBuilder().setComponents(pipelineProto.getComponents()).setTransform(RunnerApi.PTransform.newBuilder().setUniqueName(TEST_NAME).setSpec(this.createSpec("dummy_id_multi_input_multi_output", configRow)).putInputs("input1", (String)inputPcollIds.get(0)).putInputs("input2", (String)inputPcollIds.get(1))).setNamespace(TEST_NAMESPACE).build();
        ExpansionApi.ExpansionResponse response = this.expansionService.expand(request);
        RunnerApi.PTransform expandedTransform = response.getTransform();
        Assert.assertEquals((long)6L, (long)expandedTransform.getSubtransformsCount());
        Assert.assertEquals((long)2L, (long)expandedTransform.getInputsCount());
        Assert.assertEquals((long)2L, (long)expandedTransform.getOutputsCount());
        this.verifyLeafTransforms(response, 2);
    }

    @Test
    public void testSchematransformEquivalentConfigSchema() throws CoderException {
        Row configRow = Row.withSchema((Schema)TEST_SCHEMATRANSFORM_CONFIG_SCHEMA).withFieldValue("int1", (Object)111).withFieldValue("int2", (Object)222).withFieldValue("str1", (Object)"aaa").withFieldValue("str2", (Object)"bbb").build();
        RunnerApi.FunctionSpec spec = this.createSpec("dummy_id", configRow);
        Row equivalentConfigRow = Row.withSchema((Schema)TEST_SCHEMATRANSFORM_EQUIVALENT_CONFIG_SCHEMA).withFieldValue("int1", (Object)111).withFieldValue("int2", (Object)222).withFieldValue("str1", (Object)"aaa").withFieldValue("str2", (Object)"bbb").build();
        RunnerApi.FunctionSpec equivalentSpec = this.createSpec("dummy_id", equivalentConfigRow);
        Assert.assertNotEquals((Object)spec.getPayload(), (Object)equivalentSpec.getPayload());
        TestSchemaTransform transform = (TestSchemaTransform)ExpansionServiceSchemaTransformProvider.of().getTransform(spec, PipelineOptionsFactory.create());
        TestSchemaTransform equivalentTransform = (TestSchemaTransform)ExpansionServiceSchemaTransformProvider.of().getTransform(equivalentSpec, PipelineOptionsFactory.create());
        Assert.assertEquals((Object)transform.int1, (Object)equivalentTransform.int1);
        Assert.assertEquals((Object)transform.int2, (Object)equivalentTransform.int2);
        Assert.assertEquals((Object)transform.str1, (Object)equivalentTransform.str1);
        Assert.assertEquals((Object)transform.str2, (Object)equivalentTransform.str2);
    }

    private RunnerApi.FunctionSpec createSpec(String identifier, Row configRow) {
        byte[] encodedRow;
        try {
            encodedRow = CoderUtils.encodeToByteArray((Coder)SchemaCoder.of((Schema)configRow.getSchema()), (Object)configRow);
        }
        catch (CoderException e) {
            throw new RuntimeException(e);
        }
        ExternalTransforms.SchemaTransformPayload payload = ExternalTransforms.SchemaTransformPayload.newBuilder().setIdentifier(identifier).setConfigurationRow(ByteString.copyFrom((byte[])encodedRow)).setConfigurationSchema(SchemaTranslation.schemaToProto((Schema)configRow.getSchema(), (boolean)true)).build();
        return RunnerApi.FunctionSpec.newBuilder().setUrn(BeamUrns.getUrn((ProtocolMessageEnum)ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM)).setPayload(payload.toByteString()).build();
    }

    public static class TestSchemaTransformMultiInputOutput
    extends SchemaTransform {
        private String str1;
        private String str2;
        private Integer int1;
        private Integer int2;

        public TestSchemaTransformMultiInputOutput(String str1, String str2, Integer int1, Integer int2) {
            this.str1 = str1;
            this.str2 = str2;
            this.int1 = int1;
            this.int2 = int2;
        }

        public PCollectionRowTuple expand(PCollectionRowTuple input) {
            PCollection outputPC1 = ((PCollection)((PCollection)((PCollection)input.get("input1").apply((PTransform)MapElements.via((InferableFunction)new InferableFunction<Row, String>(){

                public String apply(Row input) throws Exception {
                    return input.getString("in_str");
                }
            }))).apply((PTransform)ParDo.of((DoFn)new TestDoFn(this.str1, this.str2, this.int1, this.int2)))).apply((PTransform)MapElements.via((InferableFunction)new InferableFunction<String, Row>(){

                public Row apply(String input) throws Exception {
                    return Row.withSchema((Schema)Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"out_str", (Schema.FieldType)Schema.FieldType.STRING)})).withFieldValue("out_str", (Object)input).build();
                }
            }))).setRowSchema(Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"out_str", (Schema.FieldType)Schema.FieldType.STRING)}));
            PCollection outputPC2 = ((PCollection)((PCollection)((PCollection)input.get("input2").apply((PTransform)MapElements.via((InferableFunction)new InferableFunction<Row, String>(){

                public String apply(Row input) throws Exception {
                    return input.getString("in_str");
                }
            }))).apply((PTransform)ParDo.of((DoFn)new TestDoFn(this.str1, this.str2, this.int1, this.int2)))).apply((PTransform)MapElements.via((InferableFunction)new InferableFunction<String, Row>(){

                public Row apply(String input) throws Exception {
                    return Row.withSchema((Schema)Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"out_str", (Schema.FieldType)Schema.FieldType.STRING)})).withFieldValue("out_str", (Object)input).build();
                }
            }))).setRowSchema(Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"out_str", (Schema.FieldType)Schema.FieldType.STRING)}));
            return PCollectionRowTuple.of((String)"output1", (PCollection)outputPC1, (String)"output2", (PCollection)outputPC2);
        }
    }

    @AutoService(value={SchemaTransformProvider.class})
    public static class TestSchemaTransformProviderMultiInputMultiOutput
    extends TypedSchemaTransformProvider<TestSchemaTransformConfiguration> {
        protected Class<TestSchemaTransformConfiguration> configurationClass() {
            return TestSchemaTransformConfiguration.class;
        }

        protected SchemaTransform from(TestSchemaTransformConfiguration configuration) {
            return new TestSchemaTransformMultiInputOutput(configuration.str1, configuration.str2, configuration.int1, configuration.int2);
        }

        public String identifier() {
            return "dummy_id_multi_input_multi_output";
        }

        public List<String> inputCollectionNames() {
            return ImmutableList.of((Object)"input1", (Object)"input2");
        }

        public List<String> outputCollectionNames() {
            return ImmutableList.of((Object)"output1", (Object)"output2");
        }
    }

    public static class TestSchemaTransform
    extends SchemaTransform {
        private String str1;
        private String str2;
        private Integer int1;
        private Integer int2;

        public TestSchemaTransform(String str1, String str2, Integer int1, Integer int2) {
            this.str1 = str1;
            this.str2 = str2;
            this.int1 = int1;
            this.int2 = int2;
        }

        public PCollectionRowTuple expand(PCollectionRowTuple input) {
            PCollection outputPC = ((PCollection)((PCollection)((PCollection)((PCollection)input.getAll().values().iterator().next()).apply((PTransform)MapElements.via((InferableFunction)new InferableFunction<Row, String>(){

                public String apply(Row input) throws Exception {
                    return input.getString("in_str");
                }
            }))).apply((PTransform)ParDo.of((DoFn)new TestDoFn(this.str1, this.str2, this.int1, this.int2)))).apply((PTransform)MapElements.via((InferableFunction)new InferableFunction<String, Row>(){

                public Row apply(String input) throws Exception {
                    return Row.withSchema((Schema)Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"out_str", (Schema.FieldType)Schema.FieldType.STRING)})).withFieldValue("out_str", (Object)input).build();
                }
            }))).setRowSchema(Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"out_str", (Schema.FieldType)Schema.FieldType.STRING)}));
            return PCollectionRowTuple.of((String)"output1", (PCollection)outputPC);
        }
    }

    public static class TestDoFn
    extends DoFn<String, String> {
        public String str1;
        public String str2;
        public int int1;
        public int int2;

        public TestDoFn(String str1, String str2, Integer int1, Integer int2) {
            this.str1 = str1;
            this.str2 = str2;
            this.int1 = int1;
            this.int2 = int2;
        }

        @DoFn.ProcessElement
        public void processElement(@DoFn.Element String element, DoFn.OutputReceiver<String> receiver) {
            receiver.output((Object)element);
        }
    }

    @AutoService(value={SchemaTransformProvider.class})
    public static class TestSchemaTransformProvider
    extends TypedSchemaTransformProvider<TestSchemaTransformConfiguration> {
        protected Class<TestSchemaTransformConfiguration> configurationClass() {
            return TestSchemaTransformConfiguration.class;
        }

        protected SchemaTransform from(TestSchemaTransformConfiguration configuration) {
            return new TestSchemaTransform(configuration.str1, configuration.str2, configuration.int1, configuration.int2);
        }

        public String identifier() {
            return "dummy_id";
        }

        public List<String> inputCollectionNames() {
            return ImmutableList.of((Object)"input1");
        }

        public List<String> outputCollectionNames() {
            return ImmutableList.of((Object)"output1");
        }
    }

    @DefaultSchema(value=JavaFieldSchema.class)
    public static class TestSchemaTransformConfiguration {
        public final String str1;
        public final String str2;
        public final Integer int1;
        public final Integer int2;

        @SchemaCreate
        public TestSchemaTransformConfiguration(String str1, String str2, Integer int1, Integer int2) {
            this.str1 = str1;
            this.str2 = str2;
            this.int1 = int1;
            this.int2 = int2;
        }
    }
}

