/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.schemas.transforms;

import com.google.auto.value.AutoValue;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.schemas.SchemaUtils;
import org.apache.beam.sdk.schemas.transforms.AutoValue_CoGroup_By;
import org.apache.beam.sdk.schemas.transforms.AutoValue_CoGroup_Result;
import org.apache.beam.sdk.schemas.utils.RowSelector;
import org.apache.beam.sdk.schemas.utils.SelectHelpers;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.join.CoGbkResult;
import org.apache.beam.sdk.transforms.join.CoGroupByKey;
import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.checkerframework.checker.nullness.qual.Nullable;

@Experimental(value=Experimental.Kind.SCHEMAS)
public class CoGroup {
    private static final List NULL_LIST = Lists.newArrayList();

    public static Impl join(By clause) {
        return new Impl(new JoinArguments(clause));
    }

    public static Impl join(String tag, By clause) {
        return new Impl(new JoinArguments(ImmutableMap.of(tag, clause)));
    }

    static void verify(PCollectionTuple input, JoinArguments joinArgs) {
        Set joinTags;
        Set inputTags;
        if (joinArgs.allInputsJoinArgs == null && !(inputTags = input.getAll().keySet().stream().map(TupleTag::getId).collect(Collectors.toSet())).equals(joinTags = joinArgs.joinArgsMap.keySet())) {
            throw new IllegalArgumentException("The input PCollectionTuple has tags: " + inputTags + " and the join was specified for tags " + joinTags + ". These do not match.");
        }
    }

    static {
        NULL_LIST.add(null);
    }

    public static class ExpandCrossProduct
    extends PTransform<PCollectionTuple, PCollection<Row>> {
        private final JoinArguments joinArgs;

        ExpandCrossProduct(JoinArguments joinArgs) {
            this.joinArgs = joinArgs;
        }

        public ExpandCrossProduct join(String tag, By clause) {
            if (this.joinArgs.allInputsJoinArgs != null) {
                throw new IllegalStateException("Cannot set both a global and per-tag fields.");
            }
            return new ExpandCrossProduct(this.joinArgs.with(tag, clause));
        }

        @Override
        public PCollection<Row> expand(PCollectionTuple input) {
            PCollection expanded;
            CoGroup.verify(input, this.joinArgs);
            JoinInformation joinInformation = JoinInformation.from(input, x$0 -> this.joinArgs.getFieldAccessDescriptor(x$0), x$0 -> this.joinArgs.getSideInputSource(x$0));
            Result.verifyExpandedArgs(joinInformation, this.joinArgs);
            Schema outputSchema = Result.getExpandedOutputSchema(joinInformation, this.joinArgs);
            Collection views = joinInformation.sideInputs.values();
            if (joinInformation.keyedPCollectionTuple.getKeyedCollections().size() > 1) {
                expanded = (PCollection)((Object)((PCollection)joinInformation.keyedPCollectionTuple.apply("CoGroupByKey", CoGroupByKey.create())).apply(ParDo.of(new ConvertCoGbkResult(joinInformation, this.joinArgs, ConvertCoGbkResult.ConvertType.EXPANDED, outputSchema)).withSideInputs(views)));
            } else {
                KeyedPCollectionTuple.TaggedKeyedPCollection tpc = Iterables.getOnlyElement(joinInformation.keyedPCollectionTuple.getKeyedCollections());
                expanded = (PCollection)((Object)tpc.getCollection().apply(ParDo.of(new ExpandRowResult(joinInformation, this.joinArgs, outputSchema)).withSideInputs(views)));
            }
            return expanded.setRowSchema(outputSchema);
        }
    }

    public static class Impl
    extends PTransform<PCollectionTuple, PCollection<Row>> {
        private final JoinArguments joinArgs;
        private final String keyFieldName;

        private Impl() {
            this(new JoinArguments(Collections.emptyMap()));
        }

        private Impl(JoinArguments joinArgs) {
            this(joinArgs, "key");
        }

        private Impl(JoinArguments joinArgs, String keyFieldName) {
            this.joinArgs = joinArgs;
            this.keyFieldName = keyFieldName;
        }

        public Impl withKeyField(String keyFieldName) {
            return new Impl(this.joinArgs, keyFieldName);
        }

        public Impl join(String tag, By clause) {
            if (this.joinArgs.allInputsJoinArgs != null) {
                throw new IllegalStateException("Cannot set both a global and per-tag fields.");
            }
            return new Impl(this.joinArgs.with(tag, clause), this.keyFieldName);
        }

        public ExpandCrossProduct crossProductJoin() {
            return new ExpandCrossProduct(this.joinArgs);
        }

        @Override
        public PCollection<Row> expand(PCollectionTuple input) {
            CoGroup.verify(input, this.joinArgs);
            JoinInformation joinInformation = JoinInformation.from(input, x$0 -> this.joinArgs.getFieldAccessDescriptor(x$0), x$0 -> this.joinArgs.getSideInputSource(x$0));
            Collection views = joinInformation.sideInputs.values();
            Schema outputSchema = Result.getUnexandedOutputSchema(this.keyFieldName, joinInformation);
            return ((PCollection)((Object)((PCollection)joinInformation.keyedPCollectionTuple.apply("CoGroupByKey", CoGroupByKey.create())).apply(ParDo.of(new ConvertCoGbkResult(joinInformation, this.joinArgs, ConvertCoGbkResult.ConvertType.UNEXPANDED, outputSchema)).withSideInputs(views)))).setRowSchema(outputSchema);
        }
    }

    static class ExpandRowResult
    extends DoFn<KV<Row, Row>, Row> {
        private final JoinInformation joinInformation;
        private final JoinArguments joinArgs;
        private final Schema outputSchema;

        public ExpandRowResult(JoinInformation joinInformation, JoinArguments joinArgs, Schema outputSchema) {
            this.joinInformation = joinInformation;
            this.joinArgs = joinArgs;
            this.outputSchema = outputSchema;
        }

        @DoFn.ProcessElement
        public void process(@DoFn.Element KV<Row, Row> element, DoFn.ProcessContext c, DoFn.OutputReceiver<Row> o) {
            Result result = Result.from(this.joinInformation, this.joinArgs, element.getKey(), this.outputSchema, element.getValue(), c);
            result.outputExpandedRows(o);
        }
    }

    static class ConvertCoGbkResult
    extends DoFn<KV<Row, CoGbkResult>, Row> {
        private final JoinInformation joinInformation;
        private final JoinArguments joinArgs;
        private final Schema outputSchema;
        private ConvertType convertType;

        public ConvertCoGbkResult(JoinInformation joinInformation, JoinArguments joinArgs, ConvertType convertType, Schema outputSchema) {
            this.joinInformation = joinInformation;
            this.joinArgs = joinArgs;
            this.outputSchema = outputSchema;
            this.convertType = convertType;
        }

        @DoFn.ProcessElement
        public void process(@DoFn.Element KV<Row, CoGbkResult> element, DoFn.ProcessContext c, DoFn.OutputReceiver<Row> o) {
            Result result = Result.from(this.joinInformation, this.joinArgs, element.getKey(), this.outputSchema, element.getValue(), c);
            if (this.convertType == ConvertType.UNEXPANDED) {
                result.outputUnexpandedRow(this.outputSchema, o);
            } else {
                result.outputExpandedRows(o);
            }
        }

        static enum ConvertType {
            UNEXPANDED,
            EXPANDED;

        }
    }

    @AutoValue
    public static abstract class Result {
        abstract Row getKey();

        abstract List<Iterable<Row>> getIterables();

        abstract List<String> getTags();

        abstract JoinArguments getJoinArguments();

        abstract Schema getOutputSchema();

        static Result from(JoinInformation joinInformation, JoinArguments joinArgs, Row key, Schema outputSchema, CoGbkResult coGbkResult, DoFn.ProcessContext processContext) {
            return Result.from(joinInformation, joinArgs, key, outputSchema, coGbkResult::getAll, processContext);
        }

        static Result from(JoinInformation joinInformation, JoinArguments joinArgs, Row key, Schema outputSchema, Row leftRow, DoFn.ProcessContext processContext) {
            return Result.from(joinInformation, joinArgs, key, outputSchema, (String t) -> Lists.newArrayList(leftRow), processContext);
        }

        private static Result from(JoinInformation joinInformation, JoinArguments joinArgs, Row key, Schema outputSchema, Function<String, Iterable<Row>> leftSideSupplier, DoFn.ProcessContext processContext) {
            ArrayList<Iterable<Row>> fields = Lists.newArrayListWithCapacity(joinInformation.sortedTags.size());
            ArrayList<String> tags = Lists.newArrayListWithCapacity(joinInformation.sortedTags.size());
            for (int i = 0; i < joinInformation.sortedTags.size(); ++i) {
                Iterable rows;
                String tupleTag = (String)joinInformation.tagToKeyedTag.get(i);
                PCollectionView sideView = (PCollectionView)joinInformation.sideInputs.get(tupleTag);
                Iterable iterable = rows = sideView != null ? (Iterable)((Map)processContext.sideInput(sideView)).get(key) : leftSideSupplier.apply(tupleTag);
                if (rows == null) {
                    rows = Collections::emptyIterator;
                }
                fields.add(rows);
                tags.add((String)joinInformation.sortedTags.get(i));
            }
            return new AutoValue_CoGroup_Result(key, fields, tags, joinArgs, outputSchema);
        }

        static Schema getUnexandedOutputSchema(String keyFieldName, JoinInformation joinInformation) {
            Schema.Builder schemaBuilder = Schema.builder().addRowField(keyFieldName, joinInformation.keySchema);
            for (Map.Entry entry : joinInformation.componentSchemas.entrySet()) {
                schemaBuilder.addIterableField((String)entry.getKey(), Schema.FieldType.row((Schema)entry.getValue()));
            }
            return schemaBuilder.build();
        }

        void outputUnexpandedRow(Schema outputSchema, DoFn.OutputReceiver<Row> o) {
            ArrayList<Object> fields = Lists.newArrayListWithCapacity(this.getIterables().size() + 1);
            fields.add(this.getKey());
            fields.addAll(this.getIterables());
            o.output(Row.withSchema(outputSchema).attachValues(fields));
        }

        static void verifyExpandedArgs(JoinInformation joinInformation, JoinArguments joinArgs) {
            boolean hasSideInput = false;
            boolean allMainInputsOptional = true;
            for (int i = 0; i < joinInformation.sortedTags.size(); ++i) {
                String tupleTag = (String)joinInformation.tagToKeyedTag.get(i);
                if (joinInformation.sideInputs.get(tupleTag) != null) {
                    hasSideInput = true;
                    continue;
                }
                if (joinArgs.getOptionalParticipation((String)joinInformation.sortedTags.get(i))) continue;
                allMainInputsOptional = false;
            }
            Preconditions.checkArgument(!hasSideInput || !allMainInputsOptional, "Cannot perform join when all main inputs are optional and there is a side input.  consider removing the side input.");
        }

        static Schema getExpandedOutputSchema(JoinInformation joinInformation, JoinArguments joinArgs) {
            Schema.Builder joinedSchemaBuilder = Schema.builder();
            for (Map.Entry entry : joinInformation.componentSchemas.entrySet()) {
                Schema.FieldType fieldType = Schema.FieldType.row((Schema)entry.getValue());
                if (joinArgs.getOptionalParticipation((String)entry.getKey())) {
                    fieldType = fieldType.withNullable(true);
                }
                joinedSchemaBuilder.addField((String)entry.getKey(), fieldType);
            }
            return joinedSchemaBuilder.build();
        }

        void outputExpandedRows(DoFn.OutputReceiver<Row> o) {
            List<Iterable<Row>> allIterables = this.extractIterables();
            ArrayList<Row> accumulatedRows = Lists.newArrayListWithCapacity(this.getIterables().size());
            this.crossProduct(0, accumulatedRows, allIterables, o);
        }

        private List<Iterable<Row>> extractIterables() {
            ArrayList<Iterable<Row>> iterables = Lists.newArrayListWithCapacity(this.getIterables().size());
            for (int i = 0; i < this.getIterables().size(); ++i) {
                Iterable<Row> items = this.getIterables().get(i);
                String tag = this.getTags().get(i);
                if (!items.iterator().hasNext() && this.getJoinArguments().getOptionalParticipation(tag)) {
                    items = () -> NULL_LIST.iterator();
                }
                iterables.add(items);
            }
            return iterables;
        }

        private void crossProduct(int tagIndex, List<Row> accumulatedRows, List<Iterable<Row>> iterables, DoFn.OutputReceiver<Row> o) {
            if (tagIndex >= iterables.size()) {
                return;
            }
            for (Row row : iterables.get(tagIndex)) {
                this.crossProductHelper(tagIndex, accumulatedRows, row, iterables, o);
            }
        }

        private void crossProductHelper(int tagIndex, List<Row> accumulatedRows, Row newRow, List<Iterable<Row>> iterables, DoFn.OutputReceiver<Row> o) {
            boolean atBottom = tagIndex == iterables.size() - 1;
            accumulatedRows.add(newRow);
            if (atBottom) {
                Row row = Row.withSchema(this.getOutputSchema()).attachValues(Lists.newArrayList(accumulatedRows));
                o.output(row);
            } else {
                this.crossProduct(tagIndex + 1, accumulatedRows, iterables, o);
            }
            accumulatedRows.remove(accumulatedRows.size() - 1);
        }
    }

    static class JoinInformation
    implements Serializable {
        private final transient KeyedPCollectionTuple<Row> keyedPCollectionTuple;
        private final Map<String, PCollectionView<Map<Row, Iterable<Row>>>> sideInputs;
        private final Schema keySchema;
        private final Map<String, Schema> componentSchemas;
        private final List<String> sortedTags;
        private final Map<Integer, String> tagToKeyedTag;

        private JoinInformation(KeyedPCollectionTuple<Row> keyedPCollectionTuple, Map<String, PCollectionView<Map<Row, Iterable<Row>>>> sideInputs, Schema keySchema, Map<String, Schema> componentSchemas, List<String> sortedTags, Map<Integer, String> tagToKeyedTag) {
            this.keyedPCollectionTuple = keyedPCollectionTuple;
            this.sideInputs = sideInputs;
            this.keySchema = keySchema;
            this.componentSchemas = componentSchemas;
            this.sortedTags = sortedTags;
            this.tagToKeyedTag = tagToKeyedTag;
        }

        private static JoinInformation from(PCollectionTuple input, Function<String, FieldAccessDescriptor> getFieldAccessDescriptor, Function<String, Boolean> getIsSideInput) {
            KeyedPCollectionTuple<Row> keyedPCollectionTuple = KeyedPCollectionTuple.empty(input.getPipeline());
            List<String> sortedTags = input.getAll().keySet().stream().map(TupleTag::getId).sorted().collect(Collectors.toList());
            TreeMap<String, Schema> componentSchemas = Maps.newTreeMap();
            HashMap<String, PCollectionView<Map<Row, Iterable<Row>>>> sideInputs = Maps.newHashMap();
            HashMap<Integer, String> tagToKeyedTag = Maps.newHashMap();
            Schema keySchema = null;
            for (Map.Entry<TupleTag<?>, PCollection<?>> entry : input.getAll().entrySet()) {
                String tag = entry.getKey().getId();
                int tagIndex = sortedTags.indexOf(tag);
                PCollection<?> pc = entry.getValue();
                Schema schema = pc.getSchema();
                componentSchemas.put(tag, schema);
                FieldAccessDescriptor fieldAccessDescriptor = getFieldAccessDescriptor.apply(tag);
                if (fieldAccessDescriptor == null) {
                    throw new IllegalStateException("No fields were set for input " + tag);
                }
                FieldAccessDescriptor resolved = fieldAccessDescriptor.resolve(schema);
                Schema currentKeySchema = SelectHelpers.getOutputSchema(schema, resolved);
                keySchema = keySchema == null ? currentKeySchema : SchemaUtils.mergeWideningNullable(keySchema, currentKeySchema);
                TupleTag randomTag = new TupleTag();
                String keyedTag = tag + "_" + randomTag;
                tagToKeyedTag.put(tagIndex, keyedTag);
                PCollection keyedPCollection = JoinInformation.extractKey(pc, schema, keySchema, resolved, tag);
                if (getIsSideInput.apply(tag).booleanValue()) {
                    sideInputs.put(keyedTag, (PCollectionView)keyedPCollection.apply("computeSideInputView" + tag, View.asMultimap()));
                    continue;
                }
                keyedPCollectionTuple = keyedPCollectionTuple.and(keyedTag, keyedPCollection);
            }
            return new JoinInformation(keyedPCollectionTuple, sideInputs, keySchema, componentSchemas, sortedTags, tagToKeyedTag);
        }

        private static <T> PCollection<KV<Row, Row>> extractKey(PCollection<T> pCollection, final Schema schema, Schema keySchema, final FieldAccessDescriptor keyFields, String tag) {
            return ((PCollection)pCollection.apply("extractKey" + tag, ParDo.of(new DoFn<T, KV<Row, Row>>(){
                private RowSelector rowSelector;
                {
                    this.rowSelector = new SelectHelpers.RowSelectorContainer(schema, keyFields, true);
                }

                @DoFn.ProcessElement
                public void process(@DoFn.Element Row row, DoFn.OutputReceiver<KV<Row, Row>> o) {
                    o.output(KV.of(this.rowSelector.select(row), row));
                }
            }))).setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(schema)));
        }
    }

    static class JoinArguments
    implements Serializable {
        private final @Nullable By allInputsJoinArgs;
        private final Map<String, By> joinArgsMap;

        JoinArguments(@Nullable By allInputsJoinArgs) {
            this.allInputsJoinArgs = allInputsJoinArgs;
            this.joinArgsMap = Collections.emptyMap();
        }

        JoinArguments(Map<String, By> joinArgsMap) {
            this.allInputsJoinArgs = null;
            this.joinArgsMap = joinArgsMap;
        }

        JoinArguments with(String tag, By clause) {
            return new JoinArguments(new ImmutableMap.Builder<String, By>().putAll(this.joinArgsMap).put(tag, clause).build());
        }

        private @Nullable FieldAccessDescriptor getFieldAccessDescriptor(String tag) {
            return this.allInputsJoinArgs != null ? this.allInputsJoinArgs.getFieldAccessDescriptor() : this.joinArgsMap.get(tag).getFieldAccessDescriptor();
        }

        private boolean getOptionalParticipation(String tag) {
            return this.allInputsJoinArgs != null ? this.allInputsJoinArgs.getOptionalParticipation() : this.joinArgsMap.get(tag).getOptionalParticipation();
        }

        private boolean getSideInputSource(String tag) {
            return this.allInputsJoinArgs != null ? this.allInputsJoinArgs.getSideInput() : this.joinArgsMap.get(tag).getSideInput();
        }
    }

    @AutoValue
    public static abstract class By
    implements Serializable {
        abstract FieldAccessDescriptor getFieldAccessDescriptor();

        abstract boolean getOptionalParticipation();

        abstract boolean getSideInput();

        abstract Builder toBuilder();

        public static By fieldNames(String ... fieldNames) {
            return By.fieldAccessDescriptor(FieldAccessDescriptor.withFieldNames(fieldNames));
        }

        public static By fieldIds(Integer ... fieldIds) {
            return By.fieldAccessDescriptor(FieldAccessDescriptor.withFieldIds(fieldIds));
        }

        public static By fieldAccessDescriptor(FieldAccessDescriptor fieldAccessDescriptor) {
            return new AutoValue_CoGroup_By.Builder().setFieldAccessDescriptor(fieldAccessDescriptor).setOptionalParticipation(false).setSideInput(false).build();
        }

        public By withOptionalParticipation() {
            return this.toBuilder().setOptionalParticipation(true).build();
        }

        public By withSideInput() {
            return this.toBuilder().setSideInput(true).build();
        }

        @AutoValue.Builder
        static abstract class Builder {
            Builder() {
            }

            abstract Builder setFieldAccessDescriptor(FieldAccessDescriptor var1);

            abstract Builder setOptionalParticipation(boolean var1);

            abstract Builder setSideInput(boolean var1);

            abstract By build();
        }
    }
}

