/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow.contrib.android;

import android.content.res.AssetManager;
import android.os.Build;
import android.os.Trace;
import android.text.TextUtils;
import android.util.Log;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.Tensors;
import org.tensorflow.contrib.android.RunStats;
import org.tensorflow.types.UInt8;

public class TensorFlowInferenceInterface {
    private static final String TAG = "TensorFlowInferenceInterface";
    private static final String ASSET_FILE_PREFIX = "file:///android_asset/";
    private final String modelName;
    private final Graph g;
    private final Session sess;
    private Session.Runner runner;
    private List<String> feedNames = new ArrayList<String>();
    private List<Tensor<?>> feedTensors = new ArrayList();
    private List<String> fetchNames = new ArrayList<String>();
    private List<Tensor<?>> fetchTensors = new ArrayList();
    private RunStats runStats;

    public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
        this.prepareNativeRuntime();
        this.modelName = model;
        this.g = new Graph();
        this.sess = new Session(this.g);
        this.runner = this.sess.runner();
        boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX);
        InputStream is = null;
        try {
            String aname = hasAssetPrefix ? model.split(ASSET_FILE_PREFIX)[1] : model;
            is = assetManager.open(aname);
        }
        catch (IOException e) {
            if (hasAssetPrefix) {
                throw new RuntimeException("Failed to load model from '" + model + "'", e);
            }
            try {
                is = new FileInputStream(model);
            }
            catch (IOException e2) {
                throw new RuntimeException("Failed to load model from '" + model + "'", e);
            }
        }
        try {
            byte[] graphDef;
            int numBytesRead;
            if (Build.VERSION.SDK_INT >= 18) {
                Trace.beginSection((String)"initializeTensorFlow");
                Trace.beginSection((String)"readGraphDef");
            }
            if ((numBytesRead = is.read(graphDef = new byte[is.available()])) != graphDef.length) {
                throw new IOException("read error: read only " + numBytesRead + " of the graph, expected to read " + graphDef.length);
            }
            if (Build.VERSION.SDK_INT >= 18) {
                Trace.endSection();
            }
            this.loadGraph(graphDef, this.g);
            is.close();
            Log.i((String)TAG, (String)("Successfully loaded model from '" + model + "'"));
            if (Build.VERSION.SDK_INT >= 18) {
                Trace.endSection();
            }
        }
        catch (IOException e) {
            throw new RuntimeException("Failed to load model from '" + model + "'", e);
        }
    }

    public TensorFlowInferenceInterface(InputStream is) {
        this.prepareNativeRuntime();
        this.modelName = "";
        this.g = new Graph();
        this.sess = new Session(this.g);
        this.runner = this.sess.runner();
        try {
            int numBytesRead;
            if (Build.VERSION.SDK_INT >= 18) {
                Trace.beginSection((String)"initializeTensorFlow");
                Trace.beginSection((String)"readGraphDef");
            }
            int baosInitSize = is.available() > 16384 ? is.available() : 16384;
            ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
            byte[] buf = new byte[16384];
            while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) {
                baos.write(buf, 0, numBytesRead);
            }
            byte[] graphDef = baos.toByteArray();
            if (Build.VERSION.SDK_INT >= 18) {
                Trace.endSection();
            }
            this.loadGraph(graphDef, this.g);
            Log.i((String)TAG, (String)"Successfully loaded model from the input stream");
            if (Build.VERSION.SDK_INT >= 18) {
                Trace.endSection();
            }
        }
        catch (IOException e) {
            throw new RuntimeException("Failed to load model from the input stream", e);
        }
    }

    public TensorFlowInferenceInterface(Graph g) {
        this.prepareNativeRuntime();
        this.modelName = "";
        this.g = g;
        this.sess = new Session(g);
        this.runner = this.sess.runner();
    }

    public void run(String[] outputNames) {
        this.run(outputNames, false);
    }

    public void run(String[] outputNames, boolean enableStats) {
        this.run(outputNames, enableStats, new String[0]);
    }

    public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) {
        this.closeFetches();
        for (String o : outputNames) {
            this.fetchNames.add(o);
            TensorId tid = TensorId.parse(o);
            this.runner.fetch(tid.name, tid.outputIndex);
        }
        for (String t : targetNodeNames) {
            this.runner.addTarget(t);
        }
        try {
            if (enableStats) {
                Session.Run r = this.runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
                this.fetchTensors = r.outputs;
                if (this.runStats == null) {
                    this.runStats = new RunStats();
                }
                this.runStats.add(r.metadata);
            } else {
                this.fetchTensors = this.runner.run();
            }
        }
        catch (RuntimeException e) {
            Log.e((String)TAG, (String)("Failed to run TensorFlow inference with inputs:[" + TextUtils.join((CharSequence)", ", this.feedNames) + "], outputs:[" + TextUtils.join((CharSequence)", ", this.fetchNames) + "]"));
            throw e;
        }
        finally {
            this.closeFeeds();
            this.runner = this.sess.runner();
        }
    }

    public Graph graph() {
        return this.g;
    }

    public Operation graphOperation(String operationName) {
        Operation operation = this.g.operation(operationName);
        if (operation == null) {
            throw new RuntimeException("Node '" + operationName + "' does not exist in model '" + this.modelName + "'");
        }
        return operation;
    }

    public String getStatString() {
        return this.runStats == null ? "" : this.runStats.summary();
    }

    public void close() {
        this.closeFeeds();
        this.closeFetches();
        this.sess.close();
        this.g.close();
        if (this.runStats != null) {
            this.runStats.close();
        }
        this.runStats = null;
    }

    protected void finalize() throws Throwable {
        try {
            this.close();
        }
        finally {
            super.finalize();
        }
    }

    public void feed(String inputName, boolean[] src, long ... dims) {
        byte[] b = new byte[src.length];
        for (int i = 0; i < src.length; ++i) {
            b[i] = src[i] ? (byte)1 : 0;
        }
        this.addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
    }

    public void feed(String inputName, float[] src, long ... dims) {
        this.addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src)));
    }

    public void feed(String inputName, int[] src, long ... dims) {
        this.addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src)));
    }

    public void feed(String inputName, long[] src, long ... dims) {
        this.addFeed(inputName, Tensor.create(dims, LongBuffer.wrap(src)));
    }

    public void feed(String inputName, double[] src, long ... dims) {
        this.addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src)));
    }

    public void feed(String inputName, byte[] src, long ... dims) {
        this.addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src)));
    }

    public void feedString(String inputName, byte[] src) {
        this.addFeed(inputName, Tensors.create(src));
    }

    public void feedString(String inputName, byte[][] src) {
        this.addFeed(inputName, Tensors.create(src));
    }

    public void feed(String inputName, FloatBuffer src, long ... dims) {
        this.addFeed(inputName, Tensor.create(dims, src));
    }

    public void feed(String inputName, IntBuffer src, long ... dims) {
        this.addFeed(inputName, Tensor.create(dims, src));
    }

    public void feed(String inputName, LongBuffer src, long ... dims) {
        this.addFeed(inputName, Tensor.create(dims, src));
    }

    public void feed(String inputName, DoubleBuffer src, long ... dims) {
        this.addFeed(inputName, Tensor.create(dims, src));
    }

    public void feed(String inputName, ByteBuffer src, long ... dims) {
        this.addFeed(inputName, Tensor.create(UInt8.class, dims, src));
    }

    public void fetch(String outputName, float[] dst) {
        this.fetch(outputName, FloatBuffer.wrap(dst));
    }

    public void fetch(String outputName, int[] dst) {
        this.fetch(outputName, IntBuffer.wrap(dst));
    }

    public void fetch(String outputName, long[] dst) {
        this.fetch(outputName, LongBuffer.wrap(dst));
    }

    public void fetch(String outputName, double[] dst) {
        this.fetch(outputName, DoubleBuffer.wrap(dst));
    }

    public void fetch(String outputName, byte[] dst) {
        this.fetch(outputName, ByteBuffer.wrap(dst));
    }

    public void fetch(String outputName, FloatBuffer dst) {
        this.getTensor(outputName).writeTo(dst);
    }

    public void fetch(String outputName, IntBuffer dst) {
        this.getTensor(outputName).writeTo(dst);
    }

    public void fetch(String outputName, LongBuffer dst) {
        this.getTensor(outputName).writeTo(dst);
    }

    public void fetch(String outputName, DoubleBuffer dst) {
        this.getTensor(outputName).writeTo(dst);
    }

    public void fetch(String outputName, ByteBuffer dst) {
        this.getTensor(outputName).writeTo(dst);
    }

    private void prepareNativeRuntime() {
        Log.i((String)TAG, (String)"Checking to see if TensorFlow native methods are already loaded");
        try {
            new RunStats();
            Log.i((String)TAG, (String)"TensorFlow native methods already loaded");
        }
        catch (UnsatisfiedLinkError e1) {
            Log.i((String)TAG, (String)"TensorFlow native methods not found, attempting to load via tensorflow_inference");
            try {
                System.loadLibrary("tensorflow_inference");
                Log.i((String)TAG, (String)"Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
            }
            catch (UnsatisfiedLinkError e2) {
                throw new RuntimeException("Native TF methods not found; check that the correct native libraries are present in the APK.");
            }
        }
    }

    private void loadGraph(byte[] graphDef, Graph g) throws IOException {
        long startMs = System.currentTimeMillis();
        if (Build.VERSION.SDK_INT >= 18) {
            Trace.beginSection((String)"importGraphDef");
        }
        try {
            g.importGraphDef(graphDef);
        }
        catch (IllegalArgumentException e) {
            throw new IOException("Not a valid TensorFlow Graph serialization: " + e.getMessage());
        }
        if (Build.VERSION.SDK_INT >= 18) {
            Trace.endSection();
        }
        long endMs = System.currentTimeMillis();
        Log.i((String)TAG, (String)("Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version()));
    }

    private void addFeed(String inputName, Tensor<?> t) {
        TensorId tid = TensorId.parse(inputName);
        this.runner.feed(tid.name, tid.outputIndex, t);
        this.feedNames.add(inputName);
        this.feedTensors.add(t);
    }

    private Tensor<?> getTensor(String outputName) {
        int i = 0;
        for (String n : this.fetchNames) {
            if (n.equals(outputName)) {
                return this.fetchTensors.get(i);
            }
            ++i;
        }
        throw new RuntimeException("Node '" + outputName + "' was not provided to run(), so it cannot be read");
    }

    private void closeFeeds() {
        for (Tensor<?> t : this.feedTensors) {
            t.close();
        }
        this.feedTensors.clear();
        this.feedNames.clear();
    }

    private void closeFetches() {
        for (Tensor<?> t : this.fetchTensors) {
            t.close();
        }
        this.fetchTensors.clear();
        this.fetchNames.clear();
    }

    private static class TensorId {
        String name;
        int outputIndex;

        private TensorId() {
        }

        public static TensorId parse(String name) {
            TensorId tid = new TensorId();
            int colonIndex = name.lastIndexOf(58);
            if (colonIndex < 0) {
                tid.outputIndex = 0;
                tid.name = name;
                return tid;
            }
            try {
                tid.outputIndex = Integer.parseInt(name.substring(colonIndex + 1));
                tid.name = name.substring(0, colonIndex);
            }
            catch (NumberFormatException e) {
                tid.outputIndex = 0;
                tid.name = name;
            }
            return tid;
        }
    }
}

