/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Logger;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;
import org.tensorflow.Operand;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.SessionFunction;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.proto.framework.GraphDef;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.interop.tensorflow.FeatureConverter;
import org.tribuo.interop.tensorflow.OutputConverter;
import org.tribuo.interop.tensorflow.TensorMap;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;

public abstract class TensorFlowModel<T extends Output<T>>
extends Model<T>
implements AutoCloseable {
    private static final Logger logger = Logger.getLogger(TensorFlowModel.class.getName());
    private static final long serialVersionUID = 200L;
    protected int batchSize;
    protected final String outputName;
    protected final FeatureConverter featureConverter;
    protected final OutputConverter<T> outputConverter;
    protected transient Graph modelGraph = new Graph();
    protected transient Session session = null;
    protected transient boolean closed = false;

    protected TensorFlowModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, GraphDef trainedGraphDef, int batchSize, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
        super(name, provenance, featureIDMap, outputIDInfo, outputConverter.generatesProbabilities());
        this.modelGraph.importGraphDef(trainedGraphDef);
        this.session = new Session(this.modelGraph);
        this.batchSize = batchSize;
        this.outputName = outputName;
        this.featureConverter = featureConverter;
        this.outputConverter = outputConverter;
    }

    public Prediction<T> predict(Example<T> example) {
        if (this.closed) {
            throw new IllegalStateException("Can't use a closed model, the state has gone.");
        }
        SparseVector vec = SparseVector.createSparseVector(example, (ImmutableFeatureMap)this.featureIDMap, (boolean)false);
        try (TensorMap transformedInput = this.featureConverter.convert((SGDVector)vec);){
            Prediction<T> prediction;
            block13: {
                Tensor outputTensor = (Tensor)transformedInput.feedInto(this.session.runner()).fetch(this.outputName).run().get(0);
                try {
                    prediction = this.outputConverter.convertToPrediction(outputTensor, this.outputIDInfo, vec.numActiveElements(), example);
                    if (outputTensor == null) break block13;
                }
                catch (Throwable throwable) {
                    if (outputTensor != null) {
                        try {
                            outputTensor.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                outputTensor.close();
            }
            return prediction;
        }
    }

    protected List<Prediction<T>> innerPredict(Iterable<Example<T>> examples) {
        ArrayList<Prediction<T>> predictions = new ArrayList<Prediction<T>>();
        ArrayList<Example<T>> batchExamples = new ArrayList<Example<T>>();
        for (Example<T> example : examples) {
            batchExamples.add(example);
            if (batchExamples.size() != this.batchSize) continue;
            predictions.addAll(this.predictBatch(batchExamples));
            batchExamples.clear();
        }
        if (!batchExamples.isEmpty()) {
            predictions.addAll(this.predictBatch(batchExamples));
        }
        return predictions;
    }

    private List<Prediction<T>> predictBatch(List<Example<T>> batchExamples) {
        if (this.closed) {
            throw new IllegalStateException("Can't use a closed model, the state has gone.");
        }
        ArrayList<SparseVector> vectors = new ArrayList<SparseVector>(batchExamples.size());
        int[] numActiveElements = new int[batchExamples.size()];
        for (int i = 0; i < batchExamples.size(); ++i) {
            SparseVector vec = SparseVector.createSparseVector(batchExamples.get(i), (ImmutableFeatureMap)this.featureIDMap, (boolean)false);
            numActiveElements[i] = vec.numActiveElements();
            vectors.add(vec);
        }
        try (TensorMap transformedInput = this.featureConverter.convert(vectors);){
            List<Prediction<T>> list;
            block14: {
                Tensor outputTensor = (Tensor)transformedInput.feedInto(this.session.runner()).fetch(this.outputName).run().get(0);
                try {
                    list = this.outputConverter.convertToBatchPrediction(outputTensor, this.outputIDInfo, numActiveElements, batchExamples);
                    if (outputTensor == null) break block14;
                }
                catch (Throwable throwable) {
                    if (outputTensor != null) {
                        try {
                            outputTensor.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                outputTensor.close();
            }
            return list;
        }
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setBatchSize(int batchSize) {
        if (batchSize <= 0) {
            throw new IllegalArgumentException("Batch size must be positive, found " + batchSize);
        }
        this.batchSize = batchSize;
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        return Collections.emptyMap();
    }

    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        return Optional.empty();
    }

    public String getOutputName() {
        return this.outputName;
    }

    public void exportModel(String path) throws IOException {
        if (this.closed) {
            throw new IllegalStateException("Can't serialize a closed model, the state has gone.");
        }
        Signature.Builder sigBuilder = Signature.builder();
        Set<String> inputs = this.featureConverter.inputNamesSet();
        for (String s : inputs) {
            GraphOperation inputOp = this.modelGraph.operation(s);
            sigBuilder.input(s, (Operand)inputOp.output(0));
        }
        GraphOperation outputOp = this.modelGraph.operation(this.outputName);
        Signature modelSig = sigBuilder.output(this.outputName, (Operand)outputOp.output(0)).build();
        SessionFunction concFunc = SessionFunction.create((Signature)modelSig, (Session)this.session);
        SavedModelBundle.exporter((String)path).withFunction(concFunc).export();
    }

    @Override
    public void close() {
        if (this.session != null) {
            this.session.close();
            this.session = null;
        }
        if (this.modelGraph != null) {
            this.modelGraph.close();
            this.modelGraph = null;
        }
        this.closed = true;
    }
}

