/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow.example;

import java.util.Arrays;
import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.framework.initializers.Glorot;
import org.tensorflow.framework.initializers.VarianceScaling;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Concat;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.linalg.MatMul;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Div;
import org.tensorflow.op.nn.BiasAdd;
import org.tensorflow.op.nn.Conv2d;
import org.tensorflow.op.nn.MaxPool;
import org.tensorflow.op.nn.Relu;
import org.tensorflow.proto.framework.GraphDef;
import org.tensorflow.types.TFloat32;
import org.tribuo.interop.tensorflow.example.GraphDefTuple;

public abstract class CNNExamples {
    private CNNExamples() {
    }

    public static GraphDefTuple buildLeNetGraph(String inputName, int imageSize, int pixelDepth, int numOutputs) {
        if (imageSize < 1) {
            throw new IllegalArgumentException("Must have a positive image size, found " + imageSize);
        }
        if (pixelDepth < 1) {
            throw new IllegalArgumentException("Must have a positive pixel depth, found " + pixelDepth);
        }
        if (numOutputs < 1) {
            throw new IllegalArgumentException("Must have a positive number of outputs, found " + numOutputs);
        }
        String PADDING_TYPE = "SAME";
        Graph graph = new Graph();
        Ops tf = Ops.create((ExecutionEnvironment)graph);
        Glorot initializer = new Glorot(VarianceScaling.Distribution.TRUNCATED_NORMAL, 12345L);
        Placeholder input = tf.withName(inputName).placeholder(TFloat32.class, new Placeholder.Options[]{Placeholder.shape((Shape)Shape.of((long[])new long[]{-1L, imageSize, imageSize, 1L}))});
        Constant centeringFactor = tf.constant((float)pixelDepth / 2.0f);
        Constant scalingFactor = tf.constant((float)pixelDepth);
        Div scaledInput = tf.math.div((Operand)tf.math.sub((Operand)input, (Operand)centeringFactor), (Operand)scalingFactor);
        Variable conv1Weights = tf.variable(initializer.call(tf, (Operand)tf.array(new long[]{5L, 5L, 1L, 32L}), TFloat32.class), new Variable.Options[0]);
        Conv2d conv1 = tf.nn.conv2d((Operand)scaledInput, (Operand)conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), "SAME", new Conv2d.Options[0]);
        Variable conv1Biases = tf.variable((Operand)tf.fill((Operand)tf.array(new int[]{32}), (Operand)tf.constant(0.0f)), new Variable.Options[0]);
        Relu relu1 = tf.nn.relu((Operand)tf.nn.biasAdd((Operand)conv1, (Operand)conv1Biases, new BiasAdd.Options[0]));
        MaxPool pool1 = tf.nn.maxPool((Operand)relu1, (Operand)tf.array(new int[]{1, 2, 2, 1}), (Operand)tf.array(new int[]{1, 2, 2, 1}), "SAME", new MaxPool.Options[0]);
        Variable conv2Weights = tf.variable(initializer.call(tf, (Operand)tf.array(new long[]{5L, 5L, 32L, 64L}), TFloat32.class), new Variable.Options[0]);
        Conv2d conv2 = tf.nn.conv2d((Operand)pool1, (Operand)conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), "SAME", new Conv2d.Options[0]);
        Variable conv2Biases = tf.variable((Operand)tf.fill((Operand)tf.array(new int[]{64}), (Operand)tf.constant(0.1f)), new Variable.Options[0]);
        Relu relu2 = tf.nn.relu((Operand)tf.nn.biasAdd((Operand)conv2, (Operand)conv2Biases, new BiasAdd.Options[0]));
        MaxPool pool2 = tf.nn.maxPool((Operand)relu2, (Operand)tf.array(new int[]{1, 2, 2, 1}), (Operand)tf.array(new int[]{1, 2, 2, 1}), "SAME", new MaxPool.Options[0]);
        long[] poolShape = pool2.shape().subShape(1, 4).asArray();
        long numFlattenedFeatures = poolShape[0] * poolShape[1] * poolShape[2];
        Concat newShape = tf.concat(Arrays.asList(tf.array(new long[]{-1L}), tf.array(new long[]{numFlattenedFeatures})), (Operand)tf.constant(0));
        Reshape flatten = tf.reshape((Operand)pool2, (Operand)newShape);
        Variable fc1Weights = tf.variable(initializer.call(tf, (Operand)tf.concat(Arrays.asList(tf.array(new long[]{numFlattenedFeatures}), tf.array(new long[]{512L})), (Operand)tf.constant(0)), TFloat32.class), new Variable.Options[0]);
        Variable fc1Biases = tf.variable((Operand)tf.fill((Operand)tf.array(new int[]{512}), (Operand)tf.constant(0.1f)), new Variable.Options[0]);
        Relu relu3 = tf.nn.relu((Operand)tf.math.add((Operand)tf.linalg.matMul((Operand)flatten, (Operand)fc1Weights, new MatMul.Options[0]), (Operand)fc1Biases));
        Variable fc2Weights = tf.variable(initializer.call(tf, (Operand)tf.array(new long[]{512L, numOutputs}), TFloat32.class), new Variable.Options[0]);
        Variable fc2Biases = tf.variable((Operand)tf.fill((Operand)tf.array(new int[]{numOutputs}), (Operand)tf.constant(0.1f)), new Variable.Options[0]);
        Add logits = tf.math.add((Operand)tf.linalg.matMul((Operand)relu3, (Operand)fc2Weights, new MatMul.Options[0]), (Operand)fc2Biases);
        GraphDef graphDef = graph.toGraphDef();
        String outputName = logits.op().name();
        graph.close();
        return new GraphDefTuple(graphDef, inputName, outputName);
    }
}

