/*
 * Decompiled with CFR 0.152.
 */
package ai.konduit.serving.models.nd4j.tensorflow.step;

import ai.konduit.serving.annotation.runner.CanRun;
import ai.konduit.serving.models.nd4j.tensorflow.step.Nd4jTensorFlowStep;
import ai.konduit.serving.pipeline.api.context.Context;
import ai.konduit.serving.pipeline.api.data.Data;
import ai.konduit.serving.pipeline.api.data.NDArray;
import ai.konduit.serving.pipeline.api.protocol.URIResolver;
import ai.konduit.serving.pipeline.api.step.PipelineStep;
import ai.konduit.serving.pipeline.api.step.PipelineStepRunner;
import java.io.File;
import java.util.LinkedHashMap;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@CanRun(value={Nd4jTensorFlowStep.class})
public class Nd4jTensorFlowRunner
implements PipelineStepRunner {
    private static final Logger log = LoggerFactory.getLogger(Nd4jTensorFlowRunner.class);
    private final Nd4jTensorFlowStep step;
    private GraphRunner sess;

    public Nd4jTensorFlowRunner(@NonNull Nd4jTensorFlowStep step) {
        if (step == null) {
            throw new NullPointerException("step is marked non-null but is null");
        }
        this.step = step;
        File origFile = URIResolver.getFile((String)step.modelUri());
        Preconditions.checkState((boolean)origFile.exists(), (String)("Model file does not exist: " + step.modelUri()));
        this.sess = GraphRunner.builder().inputNames(step.inputNames()).graphPath(origFile).outputNames(step.outputNames()).build();
    }

    public void close() {
        this.sess.close();
    }

    public PipelineStep getPipelineStep() {
        return this.step;
    }

    public Data exec(Context ctx, Data data) {
        Preconditions.checkState((this.step.inputNames() != null ? 1 : 0) != 0, (String)"TensorFlowStep input array names are not set (null)");
        LinkedHashMap<String, INDArray> inputData = new LinkedHashMap<String, INDArray>();
        for (String key : data.keys()) {
            NDArray ndArray = data.getNDArray(key);
            INDArray arr = (INDArray)ndArray.getAs(INDArray.class);
            inputData.put(key, arr);
        }
        Map graphOutput = this.sess.run(inputData);
        Data out = Data.empty();
        for (Map.Entry entry : graphOutput.entrySet()) {
            out.put((String)entry.getKey(), NDArray.create(entry.getValue()));
        }
        return out;
    }
}

