/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.examples.multilanguage;

import java.util.ArrayList;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.extensions.python.transforms.RunInference;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.Validation;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Splitter;
import org.checkerframework.checker.initialization.qual.Initialized;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.UnknownKeyFor;

public class SklearnMnistClassification {
    private @UnknownKeyFor @NonNull @Initialized String getModelLoaderScript() {
        String s = "from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy\n";
        s = s + "from apache_beam.ml.inference.base import KeyedModelHandler\n";
        s = s + "def get_model_handler(model_uri):\n";
        s = s + "  return KeyedModelHandler(SklearnModelHandlerNumpy(model_uri))\n";
        return s;
    }

    void runExample(@UnknownKeyFor @NonNull @Initialized SklearnMnistClassificationOptions options, @UnknownKeyFor @NonNull @Initialized String expansionService) {
        Schema schema = Schema.of((Schema.Field[])new Schema.Field[]{Schema.Field.of((String)"example", (Schema.FieldType)Schema.FieldType.array((Schema.FieldType)Schema.FieldType.INT64)), Schema.Field.of((String)"inference", (Schema.FieldType)Schema.FieldType.STRING)});
        Pipeline pipeline = Pipeline.create((PipelineOptions)options);
        PCollection col = (PCollection)((PCollection)((PCollection)pipeline.apply((PTransform)TextIO.read().from(options.getInput()))).apply((PTransform)Filter.by((SerializableFunction)new FilterNonRecordsFn()))).apply((PTransform)MapElements.via((SimpleFunction)new RecordsToLabeledPixelsFn()));
        ((PCollection)((PCollection)col.apply((PTransform)RunInference.ofKVs((String)this.getModelLoaderScript(), (Schema)schema, (Coder)VarLongCoder.of()).withKwarg("model_uri", (Object)options.getModelPath()).withExpansionService(expansionService))).apply((PTransform)MapElements.via((SimpleFunction)new FormatOutput()))).apply((PTransform)TextIO.write().to(options.getOutput()));
        pipeline.run().waitUntilFinish();
    }

    public static void main(@UnknownKeyFor @NonNull @Initialized String @UnknownKeyFor @NonNull @Initialized [] args) {
        SklearnMnistClassificationOptions options = (SklearnMnistClassificationOptions)PipelineOptionsFactory.fromArgs((String[])args).as(SklearnMnistClassificationOptions.class);
        SklearnMnistClassification example = new SklearnMnistClassification();
        example.runExample(options, options.getExpansionService());
    }

    public static interface SklearnMnistClassificationOptions
    extends PipelineOptions {
        @Description(value="Path to an input file that contains labels and pixels to feed into the model")
        @Default.String(value="gs://apache-beam-samples/multi-language/mnist/example_input.csv")
        public @UnknownKeyFor @NonNull @Initialized String getInput();

        public void setInput(@UnknownKeyFor @NonNull @Initialized String var1);

        @Description(value="Path for storing the output")
        @Validation.Required
        public @UnknownKeyFor @NonNull @Initialized String getOutput();

        public void setOutput(@UnknownKeyFor @NonNull @Initialized String var1);

        @Description(value="Path to a model file that contains the pickled file of a scikit-learn model trained on MNIST data")
        @Default.String(value="gs://apache-beam-samples/multi-language/mnist/example_model")
        public @UnknownKeyFor @NonNull @Initialized String getModelPath();

        public void setModelPath(@UnknownKeyFor @NonNull @Initialized String var1);

        @Description(value="URL of Python expansion service")
        @Default.String(value="")
        public @UnknownKeyFor @NonNull @Initialized String getExpansionService();

        public void setExpansionService(@UnknownKeyFor @NonNull @Initialized String var1);
    }

    static class FormatOutput
    extends SimpleFunction<KV<Long, Row>, String> {
        FormatOutput() {
        }

        public @UnknownKeyFor @NonNull @Initialized String apply(@UnknownKeyFor @NonNull @Initialized KV<@UnknownKeyFor @NonNull @Initialized Long, @UnknownKeyFor @NonNull @Initialized Row> input) {
            return input.getKey() + "," + ((Row)input.getValue()).getString("inference");
        }
    }

    static class RecordsToLabeledPixelsFn
    extends SimpleFunction<String, KV<Long, Iterable<Long>>> {
        RecordsToLabeledPixelsFn() {
        }

        public @UnknownKeyFor @NonNull @Initialized KV<@UnknownKeyFor @NonNull @Initialized Long, @UnknownKeyFor @NonNull @Initialized Iterable<@UnknownKeyFor @NonNull @Initialized Long>> apply(@UnknownKeyFor @NonNull @Initialized String input) {
            String[] data = Splitter.on((char)',').splitToList((CharSequence)input).toArray(new String[0]);
            Long label = Long.valueOf(data[0]);
            ArrayList<Long> pixels = new ArrayList<Long>();
            for (int i = 1; i < data.length; ++i) {
                pixels.add(Long.valueOf(data[i]));
            }
            return KV.of((Object)label, pixels);
        }
    }

    static class FilterNonRecordsFn
    implements SerializableFunction<String, Boolean> {
        FilterNonRecordsFn() {
        }

        public @UnknownKeyFor @NonNull @Initialized Boolean apply(@UnknownKeyFor @NonNull @Initialized String input) {
            return !input.startsWith("label");
        }
    }
}

