/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.io;

import com.google.auto.service.AutoService;
import java.util.Arrays;
import java.util.List;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.io.Compression;
import org.apache.beam.sdk.io.TFRecordIO;
import org.apache.beam.sdk.io.TFRecordReadSchemaTransformProvider;
import org.apache.beam.sdk.io.TFRecordWriteSchemaTransformConfiguration;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.checkerframework.checker.initialization.qual.Initialized;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@AutoService(value={SchemaTransformProvider.class})
public class TFRecordWriteSchemaTransformProvider
extends TypedSchemaTransformProvider<TFRecordWriteSchemaTransformConfiguration> {
    private static final @UnknownKeyFor @NonNull @Initialized String IDENTIFIER = "beam:schematransform:org.apache.beam:tfrecord_write:v1";
    private static final @UnknownKeyFor @NonNull @Initialized String INPUT = "input";
    private static final @UnknownKeyFor @NonNull @Initialized String OUTPUT = "output";
    private static final @UnknownKeyFor @NonNull @Initialized String ERROR = "errors";
    public static final @UnknownKeyFor @NonNull @Initialized TupleTag<@UnknownKeyFor @NonNull @Initialized byte @UnknownKeyFor @NonNull @Initialized []> OUTPUT_TAG = new TupleTag<byte[]>(){};
    public static final @UnknownKeyFor @NonNull @Initialized TupleTag<@UnknownKeyFor @NonNull @Initialized Row> ERROR_TAG = new TupleTag<Row>(){};
    private static final @UnknownKeyFor @NonNull @Initialized Logger LOG = LoggerFactory.getLogger(TFRecordWriteSchemaTransformProvider.class);

    @Override
    protected @UnknownKeyFor @NonNull @Initialized SchemaTransform from(@UnknownKeyFor @NonNull @Initialized TFRecordWriteSchemaTransformConfiguration configuration) {
        return new TFRecordWriteSchemaTransform(configuration);
    }

    @Override
    public @UnknownKeyFor @NonNull @Initialized String identifier() {
        return IDENTIFIER;
    }

    @Override
    public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> inputCollectionNames() {
        return Arrays.asList(INPUT);
    }

    @Override
    public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> outputCollectionNames() {
        return Arrays.asList(OUTPUT, ERROR);
    }

    public static @UnknownKeyFor @NonNull @Initialized SerializableFunction<@UnknownKeyFor @NonNull @Initialized Row, @UnknownKeyFor @NonNull @Initialized byte @UnknownKeyFor @NonNull @Initialized []> getRowToBytesFn(final @UnknownKeyFor @NonNull @Initialized String rowFieldName) {
        return new SimpleFunction<Row, byte[]>(){

            @Override
            public @UnknownKeyFor @NonNull @Initialized byte @UnknownKeyFor @NonNull @Initialized [] apply(@UnknownKeyFor @NonNull @Initialized Row input) {
                byte[] rawBytes = input.getBytes(rowFieldName);
                if (rawBytes == null) {
                    throw new NullPointerException();
                }
                return rawBytes;
            }
        };
    }

    public static class ErrorFn
    extends DoFn<Row, byte[]> {
        private final @UnknownKeyFor @NonNull @Initialized SerializableFunction<@UnknownKeyFor @NonNull @Initialized Row, @UnknownKeyFor @NonNull @Initialized byte @UnknownKeyFor @NonNull @Initialized []> toBytesFn;
        private final @UnknownKeyFor @NonNull @Initialized Counter errorCounter;
        private @UnknownKeyFor @NonNull @Initialized Long errorsInBundle = 0L;
        private final @UnknownKeyFor @NonNull @Initialized boolean handleErrors;
        private final @UnknownKeyFor @NonNull @Initialized Schema errorSchema;

        public ErrorFn(@UnknownKeyFor @NonNull @Initialized String name, @UnknownKeyFor @NonNull @Initialized SerializableFunction<@UnknownKeyFor @NonNull @Initialized Row, @UnknownKeyFor @NonNull @Initialized byte @UnknownKeyFor @NonNull @Initialized []> toBytesFn, @UnknownKeyFor @NonNull @Initialized Schema errorSchema, @UnknownKeyFor @NonNull @Initialized boolean handleErrors) {
            this.toBytesFn = toBytesFn;
            this.errorCounter = Metrics.counter(TFRecordReadSchemaTransformProvider.class, name);
            this.handleErrors = handleErrors;
            this.errorSchema = errorSchema;
        }

        @DoFn.ProcessElement
        public void process(@DoFn.Element @UnknownKeyFor @NonNull @Initialized Row row, @UnknownKeyFor @NonNull @Initialized DoFn.MultiOutputReceiver receiver) {
            byte[] output = null;
            try {
                output = this.toBytesFn.apply(row);
            }
            catch (Exception e) {
                if (!this.handleErrors) {
                    throw new RuntimeException(e);
                }
                this.errorsInBundle = this.errorsInBundle + 1L;
                LOG.warn("Error while processing the element", (Throwable)e);
                receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(this.errorSchema, row, (Throwable)e));
            }
            if (output != null) {
                receiver.get(OUTPUT_TAG).output(output);
            }
        }

        @DoFn.FinishBundle
        public void finish() {
            this.errorCounter.inc(this.errorsInBundle);
            this.errorsInBundle = 0L;
        }
    }

    static class TFRecordWriteSchemaTransform
    extends SchemaTransform {
        private final @UnknownKeyFor @NonNull @Initialized TFRecordWriteSchemaTransformConfiguration configuration;

        TFRecordWriteSchemaTransform(@UnknownKeyFor @NonNull @Initialized TFRecordWriteSchemaTransformConfiguration configuration) {
            this.configuration = configuration;
        }

        public @UnknownKeyFor @NonNull @Initialized Row getConfigurationRow() {
            try {
                return SchemaRegistry.createDefault().getToRowFunction(TFRecordWriteSchemaTransformConfiguration.class).apply(this.configuration).sorted().toSnakeCase();
            }
            catch (NoSuchSchemaException e) {
                throw new RuntimeException(e);
            }
        }

        @Override
        public @UnknownKeyFor @NonNull @Initialized PCollectionRowTuple expand(@UnknownKeyFor @NonNull @Initialized PCollectionRowTuple input) {
            Schema inputSchema;
            int numFields;
            String filenameSuffix;
            this.configuration.validate();
            TFRecordIO.Write writeTransform = TFRecordIO.write().withCompression(Compression.valueOf(this.configuration.getCompression()));
            writeTransform = writeTransform.to(this.configuration.getOutputPrefix());
            String shardTemplate = this.configuration.getShardTemplate();
            if (shardTemplate != null) {
                writeTransform = writeTransform.withShardNameTemplate(shardTemplate);
            }
            if ((filenameSuffix = this.configuration.getFilenameSuffix()) != null) {
                writeTransform = writeTransform.withSuffix(filenameSuffix);
            }
            writeTransform = this.configuration.getNumShards() > 0 ? writeTransform.withNumShards(this.configuration.getNumShards()) : writeTransform.withoutSharding();
            if (Boolean.TRUE.equals(this.configuration.getNoSpilling())) {
                writeTransform = writeTransform.withNoSpilling();
            }
            if ((numFields = (inputSchema = input.get(TFRecordWriteSchemaTransformProvider.INPUT).getSchema()).getFields().size()) != 1) {
                throw new IllegalArgumentException("Expecting exactly one field, found " + numFields);
            }
            if (!inputSchema.getField(0).getType().equals(Schema.FieldType.BYTES)) {
                throw new IllegalArgumentException("The input schema must have exactly one field of type byte.");
            }
            String schemaField = inputSchema.getField(0).getName() != null ? inputSchema.getField(0).getName() : "record";
            PCollection<Row> inputRows = input.get(TFRecordWriteSchemaTransformProvider.INPUT);
            SerializableFunction<Row, byte[]> rowToBytesFn = TFRecordWriteSchemaTransformProvider.getRowToBytesFn(schemaField);
            Schema errorSchema = ErrorHandling.errorSchema(inputSchema);
            boolean handleErrors = ErrorHandling.hasOutput(this.configuration.getErrorHandling());
            PCollectionTuple byteArrays = (PCollectionTuple)inputRows.apply(ParDo.of(new ErrorFn("TFRecord-write-error-counter", rowToBytesFn, errorSchema, handleErrors)).withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));
            byteArrays.get(OUTPUT_TAG).setCoder(ByteArrayCoder.of()).apply(writeTransform);
            String output = "";
            ErrorHandling errorHandler = this.configuration.getErrorHandling();
            if (errorHandler != null) {
                String outputHandler = errorHandler.getOutput();
                output = outputHandler != null ? outputHandler : "";
            }
            PCollection<Row> errorOutput = byteArrays.get(ERROR_TAG).setRowSchema(ErrorHandling.errorSchema(errorSchema));
            return PCollectionRowTuple.of(handleErrors ? output : TFRecordWriteSchemaTransformProvider.ERROR, errorOutput);
        }
    }
}

