/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.Closeable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.OffsetDateTime;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.proto.framework.GraphDef;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.Prediction;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.interop.ExternalDatasetProvenance;
import org.tribuo.interop.ExternalModel;
import org.tribuo.interop.ExternalTrainerProvenance;
import org.tribuo.interop.tensorflow.FeatureConverter;
import org.tribuo.interop.tensorflow.OutputConverter;
import org.tribuo.interop.tensorflow.TensorMap;
import org.tribuo.interop.tensorflow.protos.FeatureConverterProto;
import org.tribuo.interop.tensorflow.protos.OutputConverterProto;
import org.tribuo.interop.tensorflow.protos.TensorFlowFrozenExternalModelProto;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelDataProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.util.Util;

public final class TensorFlowFrozenExternalModel<T extends Output<T>>
extends ExternalModel<T, TensorMap, Tensor>
implements Closeable {
    private static final long serialVersionUID = 200L;
    public static final int CURRENT_VERSION = 0;
    private transient Graph model;
    private transient Session session;
    private final FeatureConverter featureConverter;
    private final OutputConverter<T> outputConverter;
    @Deprecated
    private final String inputName = "";
    private final String outputName;

    private TensorFlowFrozenExternalModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, Map<String, Integer> featureMapping, Graph model, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
        super(name, provenance, featureIDMap, outputIDInfo, outputConverter.generatesProbabilities(), featureMapping);
        this.model = model;
        this.session = new Session(model);
        this.outputName = outputName;
        this.featureConverter = featureConverter;
        this.outputConverter = outputConverter;
    }

    private TensorFlowFrozenExternalModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, int[] featureForwardMapping, int[] featureBackwardMapping, Graph model, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
        super(name, provenance, featureIDMap, outputIDInfo, featureForwardMapping, featureBackwardMapping, outputConverter.generatesProbabilities());
        this.model = model;
        this.session = new Session(model);
        this.outputName = outputName;
        this.featureConverter = featureConverter;
        this.outputConverter = outputConverter;
    }

    public static TensorFlowFrozenExternalModel<?> deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
        int[] featureBackwardMapping;
        if (version < 0 || version > 0) {
            throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + 0);
        }
        TensorFlowFrozenExternalModelProto proto = (TensorFlowFrozenExternalModelProto)message.unpack(TensorFlowFrozenExternalModelProto.class);
        OutputConverter outputConverter = (OutputConverter)ProtoUtil.deserialize((Message)proto.getOutputConverter());
        FeatureConverter featureConverter = (FeatureConverter)ProtoUtil.deserialize((Message)proto.getFeatureConverter());
        ModelDataCarrier carrier = ModelDataCarrier.deserialize((ModelDataProto)proto.getMetadata());
        if (!carrier.outputDomain().getOutput(0).getClass().equals(outputConverter.getTypeWitness())) {
            throw new IllegalStateException("Invalid protobuf, output domain does not match converter, found " + carrier.outputDomain().getClass() + " and " + outputConverter.getTypeWitness());
        }
        int[] featureForwardMapping = Util.toPrimitiveInt(proto.getForwardFeatureMappingList());
        if (!TensorFlowFrozenExternalModel.validateFeatureMapping((int[])featureForwardMapping, (int[])(featureBackwardMapping = Util.toPrimitiveInt(proto.getBackwardFeatureMappingList())), (ImmutableFeatureMap)carrier.featureDomain())) {
            throw new IllegalStateException("Invalid protobuf, external<->Tribuo feature mapping does not form a bijection");
        }
        Graph graph = new Graph();
        graph.importGraphDef(GraphDef.parseFrom((ByteString)proto.getModelDef()));
        return new TensorFlowFrozenExternalModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), carrier.outputDomain(), featureForwardMapping, featureBackwardMapping, graph, proto.getOutputName(), featureConverter, outputConverter);
    }

    protected TensorMap convertFeatures(SparseVector input) {
        return this.featureConverter.convert((SGDVector)input);
    }

    protected TensorMap convertFeaturesList(List<SparseVector> input) {
        return this.featureConverter.convert(input);
    }

    protected Tensor externalPrediction(TensorMap input) {
        Tensor output = (Tensor)input.feedInto(this.session.runner()).fetch(this.outputName).run().get(0);
        input.close();
        return output;
    }

    protected Prediction<T> convertOutput(Tensor output, int numValidFeatures, Example<T> example) {
        Prediction<T> pred = this.outputConverter.convertToPrediction(output, this.outputIDInfo, numValidFeatures, example);
        output.close();
        return pred;
    }

    protected List<Prediction<T>> convertOutput(Tensor output, int[] numValidFeatures, List<Example<T>> examples) {
        List<Prediction<T>> predictions = this.outputConverter.convertToBatchPrediction(output, this.outputIDInfo, numValidFeatures, examples);
        output.close();
        return predictions;
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
        return Collections.emptyMap();
    }

    protected Model<T> copy(String newName, ModelProvenance newProvenance) {
        GraphDef modelBytes = this.model.toGraphDef();
        Graph newGraph = new Graph();
        newGraph.importGraphDef(modelBytes);
        return new TensorFlowFrozenExternalModel<T>(newName, newProvenance, this.featureIDMap, this.outputIDInfo, this.featureForwardMapping, this.featureBackwardMapping, newGraph, this.outputName, this.featureConverter, this.outputConverter);
    }

    @Override
    public void close() {
        if (this.session != null) {
            this.session.close();
        }
        if (this.model != null) {
            this.model.close();
        }
    }

    public ModelProto serialize() {
        ModelDataCarrier carrier = this.createDataCarrier();
        TensorFlowFrozenExternalModelProto.Builder modelBuilder = TensorFlowFrozenExternalModelProto.newBuilder();
        modelBuilder.setMetadata(carrier.serialize());
        modelBuilder.setModelDef(ByteString.copyFrom((byte[])this.model.toGraphDef().toByteArray()));
        modelBuilder.setOutputName(this.outputName);
        modelBuilder.addAllForwardFeatureMapping(Arrays.stream(this.featureForwardMapping).boxed().collect(Collectors.toList()));
        modelBuilder.addAllBackwardFeatureMapping(Arrays.stream(this.featureBackwardMapping).boxed().collect(Collectors.toList()));
        modelBuilder.setOutputConverter((OutputConverterProto)this.outputConverter.serialize());
        modelBuilder.setFeatureConverter((FeatureConverterProto)this.featureConverter.serialize());
        ModelProto.Builder builder = ModelProto.newBuilder();
        builder.setSerializedData(Any.pack((Message)modelBuilder.build()));
        builder.setClassName(TensorFlowFrozenExternalModel.class.getName());
        builder.setVersion(0);
        return builder.build();
    }

    public static <T extends Output<T>> TensorFlowFrozenExternalModel<T> createTensorflowModel(OutputFactory<T> factory, Map<String, Integer> featureMapping, Map<T, Integer> outputMapping, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter, String filename) {
        try {
            Path path = Paths.get(filename, new String[0]);
            byte[] model = Files.readAllBytes(path);
            Graph graph = new Graph();
            graph.importGraphDef(GraphDef.parseFrom((byte[])model));
            URL provenanceLocation = path.toUri().toURL();
            ImmutableFeatureMap featureMap = ExternalModel.createFeatureMap(featureMapping.keySet());
            ImmutableOutputInfo outputInfo = ExternalModel.createOutputInfo(factory, outputMapping);
            OffsetDateTime now = OffsetDateTime.now();
            ExternalTrainerProvenance trainerProvenance = new ExternalTrainerProvenance(provenanceLocation);
            ExternalDatasetProvenance datasetProvenance = new ExternalDatasetProvenance("unknown-external-data", factory, false, featureMapping.size(), outputMapping.size());
            ModelProvenance provenance = new ModelProvenance(TensorFlowFrozenExternalModel.class.getName(), now, (DatasetProvenance)datasetProvenance, (TrainerProvenance)trainerProvenance);
            return new TensorFlowFrozenExternalModel<T>("tf-frozen-graph", provenance, featureMap, outputInfo, featureMapping, graph, outputName, featureConverter, outputConverter);
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Unable to load model from path " + filename, e);
        }
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.defaultWriteObject();
        GraphDef modelBytes = this.model.toGraphDef();
        out.writeObject(modelBytes.toByteArray());
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        byte[] modelBytes = (byte[])in.readObject();
        this.model = new Graph();
        this.model.importGraphDef(GraphDef.parseFrom((byte[])modelBytes));
        this.session = new Session(this.model);
    }
}

