/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow.example;

import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.framework.initializers.Glorot;
import org.tensorflow.framework.initializers.VarianceScaling;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.linalg.MatMul;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.nn.Relu;
import org.tensorflow.proto.framework.GraphDef;
import org.tensorflow.types.TFloat32;
import org.tribuo.interop.tensorflow.example.GraphDefTuple;

public abstract class MLPExamples {
    private MLPExamples() {
    }

    public static GraphDefTuple buildMLPGraph(String inputName, int numFeatures, int[] hiddenSizes, int numOutputs) {
        Placeholder input;
        if (numFeatures < 1) {
            throw new IllegalArgumentException("Must have a positive number of features, found " + numFeatures);
        }
        if (numOutputs < 1) {
            throw new IllegalArgumentException("Must have a positive number of outputs, found " + numOutputs);
        }
        if (hiddenSizes.length < 1) {
            throw new IllegalArgumentException("Must supply a hidden layer dimension.");
        }
        for (int i = 0; i < hiddenSizes.length; ++i) {
            if (hiddenSizes[i] >= 1) continue;
            throw new IllegalArgumentException("Hidden dimensions must be positive, found " + hiddenSizes[i]);
        }
        Graph graph = new Graph();
        Ops tf = Ops.create((ExecutionEnvironment)graph);
        Glorot initializer = new Glorot(VarianceScaling.Distribution.TRUNCATED_NORMAL, 12345L);
        Placeholder prevOutput = input = tf.withName(inputName).placeholder(TFloat32.class, new Placeholder.Options[]{Placeholder.shape((Shape)Shape.of((long[])new long[]{-1L, numFeatures}))});
        long prevLayerSize = numFeatures;
        for (int i = 0; i < hiddenSizes.length; ++i) {
            Variable fcWeights = tf.variable(initializer.call(tf, (Operand)tf.array(new long[]{prevLayerSize, hiddenSizes[i]}), TFloat32.class), new Variable.Options[0]);
            Variable fcBiases = tf.variable((Operand)tf.fill((Operand)tf.array(new int[]{hiddenSizes[i]}), (Operand)tf.constant(0.1f)), new Variable.Options[0]);
            Relu relu = tf.nn.relu((Operand)tf.math.add((Operand)tf.linalg.matMul((Operand)prevOutput, (Operand)fcWeights, new MatMul.Options[0]), (Operand)fcBiases));
            prevLayerSize = hiddenSizes[i];
            prevOutput = relu;
        }
        Variable outputWeights = tf.variable(initializer.call(tf, (Operand)tf.array(new long[]{prevLayerSize, numOutputs}), TFloat32.class), new Variable.Options[0]);
        Variable outputBiases = tf.variable((Operand)tf.fill((Operand)tf.array(new int[]{numOutputs}), (Operand)tf.constant(0.1f)), new Variable.Options[0]);
        Add output = tf.math.add((Operand)tf.linalg.matMul((Operand)prevOutput, (Operand)outputWeights, new MatMul.Options[0]), (Operand)outputBiases);
        GraphDef graphDef = graph.toGraphDef();
        String outputName = output.op().name();
        graph.close();
        return new GraphDefTuple(graphDef, inputName, outputName);
    }
}

