/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow.sequence;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.proto.framework.GraphDef;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.interop.tensorflow.TensorFlowUtil;
import org.tribuo.interop.tensorflow.TensorMap;
import org.tribuo.interop.tensorflow.protos.SequenceFeatureConverterProto;
import org.tribuo.interop.tensorflow.protos.SequenceOutputConverterProto;
import org.tribuo.interop.tensorflow.protos.TensorFlowSequenceModelProto;
import org.tribuo.interop.tensorflow.protos.TensorTupleProto;
import org.tribuo.interop.tensorflow.sequence.SequenceFeatureConverter;
import org.tribuo.interop.tensorflow.sequence.SequenceOutputConverter;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.SequenceModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.sequence.SequenceExample;
import org.tribuo.sequence.SequenceModel;

public class TensorFlowSequenceModel<T extends Output<T>>
extends SequenceModel<T>
implements AutoCloseable {
    private static final long serialVersionUID = 200L;
    public static final int CURRENT_VERSION = 0;
    private transient Graph modelGraph = null;
    private transient Session session = null;
    protected final SequenceFeatureConverter featureConverter;
    protected final SequenceOutputConverter<T> outputConverter;
    protected final String predictOp;

    TensorFlowSequenceModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap, GraphDef graphDef, SequenceFeatureConverter featureConverter, SequenceOutputConverter<T> outputConverter, String predictOp, Map<String, TensorFlowUtil.TensorTuple> tensorMap) {
        super(name, description, featureIDMap, outputIDMap);
        this.featureConverter = featureConverter;
        this.outputConverter = outputConverter;
        this.predictOp = predictOp;
        this.modelGraph = new Graph();
        this.modelGraph.importGraphDef(graphDef);
        this.session = new Session(this.modelGraph);
        TensorFlowUtil.restoreMarshalledVariables(this.session, tensorMap);
    }

    public static TensorFlowSequenceModel<?> deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        TensorFlowSequenceModelProto proto = (TensorFlowSequenceModelProto)message.unpack(TensorFlowSequenceModelProto.class);
        SequenceOutputConverter outputConverter = (SequenceOutputConverter)ProtoUtil.deserialize((Message)proto.getOutputConverter());
        SequenceFeatureConverter featureConverter = (SequenceFeatureConverter)ProtoUtil.deserialize((Message)proto.getFeatureConverter());
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(outputConverter.getTypeWitness())) {
            throw new IllegalStateException("Invalid protobuf, output domain does not match converter, found " + carrier.outputDomain().getClass() + " and " + outputConverter.getTypeWitness());
        }
        GraphDef graphDef = GraphDef.parseFrom((ByteString)proto.getModelDef());
        HashMap<String, TensorFlowUtil.TensorTuple> tensorMap = new HashMap<String, TensorFlowUtil.TensorTuple>();
        for (Map.Entry<String, TensorTupleProto> e : proto.getTensorsMap().entrySet()) {
            tensorMap.put(e.getKey(), new TensorFlowUtil.TensorTuple(e.getValue()));
        }
        return new TensorFlowSequenceModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), carrier.outputDomain(), graphDef, featureConverter, outputConverter, proto.getPredictOp(), tensorMap);
    }

    public List<Prediction<T>> predict(SequenceExample<T> example) {
        try (TensorMap feed = this.featureConverter.encode(example, this.featureIDMap);){
            List<Prediction<T>> list;
            block12: {
                Session.Runner runner = this.session.runner();
                runner = feed.feedInto(runner);
                Tensor outputTensor = (Tensor)runner.fetch(this.predictOp).run().get(0);
                try {
                    list = this.outputConverter.decode(outputTensor, example, this.outputIDMap);
                    if (outputTensor == null) break block12;
                }
                catch (Throwable throwable) {
                    if (outputTensor != null) {
                        try {
                            outputTensor.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                outputTensor.close();
            }
            return list;
        }
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    @Override
    public void close() {
        if (this.session != null) {
            this.session.close();
        }
        if (this.modelGraph != null) {
            this.modelGraph.close();
        }
    }

    public SequenceModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        HashMap<String, TensorTupleProto> tensors = new HashMap<String, TensorTupleProto>();
        for (Map.Entry<String, TensorFlowUtil.TensorTuple> e : TensorFlowUtil.extractMarshalledVariables(this.modelGraph, this.session).entrySet()) {
            tensors.put(e.getKey(), e.getValue().serialize());
        }
        TensorFlowSequenceModelProto.Builder modelBuilder = TensorFlowSequenceModelProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.setModelDef(ByteString.copyFrom((byte[])this.modelGraph.toGraphDef().toByteArray()));
        modelBuilder.putAllTensors(tensors);
        modelBuilder.setPredictOp(this.predictOp);
        modelBuilder.setOutputConverter((SequenceOutputConverterProto)this.outputConverter.serialize());
        modelBuilder.setFeatureConverter((SequenceFeatureConverterProto)this.featureConverter.serialize());
        SequenceModelProto.Builder builder = SequenceModelProto.newBuilder();
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        builder.setClassName(TensorFlowSequenceModel.class.getName());
        builder.setVersion(0);
        return builder.build();
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.defaultWriteObject();
        byte[] modelBytes = this.modelGraph.toGraphDef().toByteArray();
        out.writeObject(modelBytes);
        Map<String, TensorFlowUtil.TensorTuple> tensorMap = TensorFlowUtil.extractMarshalledVariables(this.modelGraph, this.session);
        out.writeObject(tensorMap);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        byte[] modelBytes = (byte[])in.readObject();
        Map tensorMap = (Map)in.readObject();
        this.modelGraph = new Graph();
        this.modelGraph.importGraphDef(GraphDef.parseFrom((byte[])modelBytes));
        this.session = new Session(this.modelGraph);
        TensorFlowUtil.restoreMarshalledVariables(this.session, tensorMap);
    }
}

