/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow;

import com.google.protobuf.ByteString;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;
import org.tensorflow.GraphOperationBuilder;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.ByteDataBuffer;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.op.Scope;
import org.tensorflow.types.family.TType;
import org.tribuo.interop.tensorflow.protos.TensorTupleProto;
import org.tribuo.util.Util;

public abstract class TensorFlowUtil {
    private static final Logger logger = Logger.getLogger(TensorFlowUtil.class.getName());
    public static final String VARIABLE_V2 = "VariableV2";
    public static final String ASSIGN_OP = "Assign";
    public static final String ASSIGN_PLACEHOLDER = "Assign_from_Placeholder";
    public static final String PLACEHOLDER = "Placeholder";
    public static final String DTYPE = "dtype";

    private TensorFlowUtil() {
    }

    public static void closeTensorCollection(Collection<Tensor> tensors) {
        for (Tensor t : tensors) {
            t.close();
        }
    }

    public static void annotateGraph(Graph graph, Session session) {
        ArrayList<String> variableNames = new ArrayList<String>();
        HashMap<String, GraphOperation> opMap = new HashMap<String, GraphOperation>();
        Iterator opItr = graph.operations();
        while (opItr.hasNext()) {
            GraphOperation op = (GraphOperation)opItr.next();
            if (!op.type().equals(VARIABLE_V2)) continue;
            variableNames.add(op.name());
            opMap.put(op.name(), op);
        }
        Session.Runner runner = session.runner();
        for (String s : variableNames) {
            runner.fetch(s);
        }
        List output = runner.run();
        if (output.size() != variableNames.size()) {
            TensorFlowUtil.closeTensorCollection(output);
            throw new IllegalStateException("Failed to annotate all requested variables. Requested " + variableNames.size() + ", found " + output.size());
        }
        Scope scope = graph.baseScope();
        for (int i = 0; i < output.size(); ++i) {
            GraphOperationBuilder builder = graph.opBuilder(PLACEHOLDER, TensorFlowUtil.generatePlaceholderName((String)variableNames.get(i)), scope);
            builder.setAttr(DTYPE, ((Tensor)output.get(i)).dataType());
            GraphOperation o = builder.build();
            builder = graph.opBuilder(ASSIGN_OP, (String)variableNames.get(i) + "/" + ASSIGN_PLACEHOLDER, scope);
            builder.addInput(((GraphOperation)opMap.get(variableNames.get(i))).output(0));
            builder.addInput(o.output(0));
            builder.build();
        }
        TensorFlowUtil.closeTensorCollection(output);
    }

    public static String generatePlaceholderName(String variableName) {
        return variableName + "-tribuo-" + PLACEHOLDER;
    }

    public static Map<String, TensorTuple> extractMarshalledVariables(Graph graph, Session session) {
        ArrayList<String> variableNames = new ArrayList<String>();
        Iterator opItr = graph.operations();
        while (opItr.hasNext()) {
            GraphOperation op = (GraphOperation)opItr.next();
            if (!op.type().equals(VARIABLE_V2)) continue;
            variableNames.add(op.name());
        }
        Session.Runner runner = session.runner();
        for (String s : variableNames) {
            runner.fetch(s);
        }
        List output = runner.run();
        if (output.size() != variableNames.size()) {
            TensorFlowUtil.closeTensorCollection(output);
            throw new IllegalStateException("Failed to serialise all requested variables. Requested " + variableNames.size() + ", found " + output.size());
        }
        HashMap<String, TensorTuple> tensorMap = new HashMap<String, TensorTuple>();
        for (int i = 0; i < variableNames.size(); ++i) {
            String name = (String)variableNames.get(i);
            Tensor tensor = (Tensor)output.get(i);
            tensorMap.put(name, TensorTuple.of((TType)tensor));
        }
        TensorFlowUtil.closeTensorCollection(output);
        return tensorMap;
    }

    public static void restoreMarshalledVariables(Session session, Map<String, TensorTuple> tensorMap) {
        Session.Runner runner = session.runner();
        ArrayList<Tensor> tensors = new ArrayList<Tensor>();
        for (Map.Entry<String, TensorTuple> e : tensorMap.entrySet()) {
            logger.log(Level.FINEST, "Loading " + e.getKey() + " of type " + e.getValue().getClass().getName());
            Tensor tensor = e.getValue().rebuildTensor();
            runner.feed(TensorFlowUtil.generatePlaceholderName(e.getKey()), tensor);
            runner.addTarget(e.getKey() + "/" + ASSIGN_PLACEHOLDER);
            tensors.add(tensor);
        }
        runner.run();
        TensorFlowUtil.closeTensorCollection(tensors);
    }

    public static final class TensorTuple
    implements Serializable {
        private static final long serialVersionUID = 1L;
        public final String className;
        public final long[] shape;
        public final byte[] data;

        public TensorTuple(String className, long[] shape, byte[] data) {
            this.className = className;
            this.shape = shape;
            this.data = data;
        }

        public TensorTuple(TensorTupleProto proto) {
            this.className = proto.getClassName();
            this.shape = Util.toPrimitiveLong(proto.getShapeList());
            this.data = proto.getData().toByteArray();
        }

        public Tensor rebuildTensor() {
            try {
                Class<?> clazz = Class.forName(this.className);
                if (TType.class.isAssignableFrom(clazz)) {
                    Class<?> tensorClass = clazz;
                    Shape shapeObj = Shape.of((long[])this.shape);
                    ByteDataBuffer buf = DataBuffers.of((byte[])this.data);
                    return Tensor.of(tensorClass, (Shape)shapeObj, (ByteDataBuffer)buf);
                }
                throw new IllegalStateException("Unexpected Tensor type, found " + this.className);
            }
            catch (ClassNotFoundException e) {
                throw new IllegalStateException("Failed to instantiate Tensor class", e);
            }
        }

        public TensorTupleProto serialize() {
            TensorTupleProto.Builder builder = TensorTupleProto.newBuilder();
            builder.setClassName(this.className);
            builder.addAllShape(Arrays.stream(this.shape).boxed().collect(Collectors.toList()));
            builder.setData(ByteString.copyFrom((byte[])this.data));
            return builder.build();
        }

        public static TensorTuple of(TType tensor) {
            ByteDataBuffer buffer = tensor.asRawTensor().data();
            long size = buffer.size();
            if (size > Integer.MAX_VALUE) {
                throw new IllegalArgumentException("Cannot serialize Tensors bigger than Integer.MAX_VALUE, found " + size);
            }
            String className = tensor.type().getName();
            long[] shape = tensor.shape().asArray();
            byte[] data = new byte[(int)size];
            buffer.read(data);
            return new TensorTuple(className, shape, data);
        }
    }
}

