/*
 * Decompiled with CFR 0.152.
 */
package com.spotify.zoltar.tf;

import com.google.protobuf.AbstractMessageLite;
import com.spotify.futures.CompletableFutures;
import com.spotify.zoltar.PredictFns;
import com.spotify.zoltar.Prediction;
import com.spotify.zoltar.Vector;
import com.spotify.zoltar.tf.TensorFlowExtras;
import com.spotify.zoltar.tf.TensorFlowModel;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.proto.example.Example;
import org.tensorflow.types.TString;

@FunctionalInterface
public interface TensorFlowPredictFn<InputT, VectorT, ValueT>
extends PredictFns.AsyncPredictFn<TensorFlowModel, InputT, VectorT, ValueT> {
    public static <InputT, ValueT> TensorFlowPredictFn<InputT, Example, ValueT> example(Function<Map<String, Tensor<?>>, List<ValueT>> outTensorExtractor, String ... fetchOps) {
        return (model, vectors) -> CompletableFuture.supplyAsync(() -> {
            byte[][] bytes = (byte[][])vectors.stream().map(Vector::value).map(AbstractMessageLite::toByteArray).toArray(x$0 -> new byte[x$0][]);
            NdArray examplesNdArray = NdArrays.vectorOfObjects((Object[])bytes);
            try (Tensor t = TString.tensorOfBytes((NdArray)examplesNdArray);){
                Session.Runner runner = model.instance().session().runner().feed("input_example_tensor", t);
                Map<String, Tensor<?>> result = TensorFlowExtras.runAndExtract(runner, fetchOps);
                Iterator vectorIterator = vectors.iterator();
                Iterator valueTIterator = ((List)outTensorExtractor.apply(result)).iterator();
                ArrayList<Prediction> predictions = new ArrayList<Prediction>();
                while (vectorIterator.hasNext() && valueTIterator.hasNext()) {
                    predictions.add(Prediction.create((Object)((Vector)vectorIterator.next()).input(), valueTIterator.next()));
                }
                ArrayList<Prediction> arrayList = predictions;
                return arrayList;
            }
        });
    }

    @Deprecated
    public static <InputT, ValueT> TensorFlowPredictFn<InputT, List<Example>, ValueT> exampleBatch(Function<Map<String, Tensor<?>>, ValueT> outTensorExtractor, String ... fetchOps) {
        BiFunction<TensorFlowModel, List, Object> predictFn = (model, examples) -> {
            byte[][] bytes = (byte[][])examples.stream().map(AbstractMessageLite::toByteArray).toArray(x$0 -> new byte[x$0][]);
            NdArray examplesNdArray = NdArrays.vectorOfObjects((Object[])bytes);
            try (Tensor t = TString.tensorOfBytes((NdArray)examplesNdArray);){
                Session.Runner runner = model.instance().session().runner().feed("input_example_tensor", t);
                Map<String, Tensor<?>> result = TensorFlowExtras.runAndExtract(runner, fetchOps);
                Object r = outTensorExtractor.apply(result);
                return r;
            }
        };
        return (model, vectors) -> {
            List predictions = vectors.stream().map(vector -> CompletableFuture.supplyAsync(() -> predictFn.apply((TensorFlowModel)model, (List)vector.value())).thenApply(v -> Prediction.create((Object)vector.input(), (Object)v))).collect(Collectors.toList());
            return CompletableFutures.allAsList(predictions);
        };
    }
}

