/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.functions;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.NoOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches;
import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
import org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization;
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2dDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss;
import org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss;
import org.nd4j.linalg.api.ops.impl.loss.HingeLoss;
import org.nd4j.linalg.api.ops.impl.loss.HuberLoss;
import org.nd4j.linalg.api.ops.impl.loss.L2Loss;
import org.nd4j.linalg.api.ops.impl.loss.LogLoss;
import org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss;
import org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss;
import org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss;
import org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss;
import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss;
import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss;
import org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits;
import org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss;
import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.CosineDistanceLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.HingeLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.HuberLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.LogLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.LogPoissonLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.MeanPairwiseSquaredErrorLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.MeanSquaredErrorLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SigmoidCrossEntropyLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyWithLogitsLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SparseSoftmaxCrossEntropyLossWithLogitsBp;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.reduce.Moments;
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul;
import org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction;
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
import org.nd4j.linalg.api.ops.impl.reduce.bp.CumProdBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.DotBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.SquaredNormBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp;
import org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul;
import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp;
import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy;
import org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax;
import org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy;
import org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm;
import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero;
import org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.impl.reduce.same.AMax;
import org.nd4j.linalg.api.ops.impl.reduce.same.AMin;
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
import org.nd4j.linalg.api.ops.impl.reduce.same.Min;
import org.nd4j.linalg.api.ops.impl.reduce.same.Prod;
import org.nd4j.linalg.api.ops.impl.reduce.same.Sum;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.reduce3.Dot;
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.scalar.LogX;
import org.nd4j.linalg.api.ops.impl.scalar.Pow;
import org.nd4j.linalg.api.ops.impl.scalar.PowDerivative;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction;
import org.nd4j.linalg.api.ops.impl.scalar.Step;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNotEquals;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterMax;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.shape.Broadcast;
import org.nd4j.linalg.api.ops.impl.shape.Concat;
import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix;
import org.nd4j.linalg.api.ops.impl.shape.Cross;
import org.nd4j.linalg.api.ops.impl.shape.Diag;
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
import org.nd4j.linalg.api.ops.impl.shape.ExpandDims;
import org.nd4j.linalg.api.ops.impl.shape.Gather;
import org.nd4j.linalg.api.ops.impl.shape.GatherNd;
import org.nd4j.linalg.api.ops.impl.shape.Linspace;
import org.nd4j.linalg.api.ops.impl.shape.MergeAvg;
import org.nd4j.linalg.api.ops.impl.shape.MergeMax;
import org.nd4j.linalg.api.ops.impl.shape.MeshGrid;
import org.nd4j.linalg.api.ops.impl.shape.OneHot;
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
import org.nd4j.linalg.api.ops.impl.shape.ParallelStack;
import org.nd4j.linalg.api.ops.impl.shape.Permute;
import org.nd4j.linalg.api.ops.impl.shape.Rank;
import org.nd4j.linalg.api.ops.impl.shape.ReductionShape;
import org.nd4j.linalg.api.ops.impl.shape.Repeat;
import org.nd4j.linalg.api.ops.impl.shape.Reshape;
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
import org.nd4j.linalg.api.ops.impl.shape.Size;
import org.nd4j.linalg.api.ops.impl.shape.SizeAt;
import org.nd4j.linalg.api.ops.impl.shape.Slice;
import org.nd4j.linalg.api.ops.impl.shape.Squeeze;
import org.nd4j.linalg.api.ops.impl.shape.Stack;
import org.nd4j.linalg.api.ops.impl.shape.StridedSlice;
import org.nd4j.linalg.api.ops.impl.shape.Tile;
import org.nd4j.linalg.api.ops.impl.shape.Transpose;
import org.nd4j.linalg.api.ops.impl.shape.Unstack;
import org.nd4j.linalg.api.ops.impl.shape.ZerosLike;
import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp;
import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.Constant;
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
import org.nd4j.linalg.api.ops.impl.transforms.ReluLayer;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite;
import org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf;
import org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm;
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch;
import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttentionBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
import org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Trace;
import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd;
import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum;
import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast;
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.TruncateDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
import org.nd4j.linalg.api.ops.impl.transforms.same.Abs;
import org.nd4j.linalg.api.ops.impl.transforms.same.Ceil;
import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
import org.nd4j.linalg.api.ops.impl.transforms.same.Floor;
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
import org.nd4j.linalg.api.ops.impl.transforms.same.Negative;
import org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal;
import org.nd4j.linalg.api.ops.impl.transforms.same.Round;
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
import org.nd4j.linalg.api.ops.impl.transforms.same.Square;
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax;
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean;
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin;
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd;
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN;
import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMeanBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMinBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentProdBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentSumBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMeanBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACos;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASin;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ATan;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Cos;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Erf;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p;
import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sin;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SwishDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tan;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
import org.nd4j.linalg.api.ops.random.custom.RandomNormal;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution;
import org.nd4j.linalg.api.ops.random.impl.Range;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.api.ops.random.impl.UniformDistribution;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.util.ArrayUtil;

public class DifferentialFunctionFactory {
    protected SameDiff sameDiff;
    private static Map<String, Method> methodNames;

    public DifferentialFunctionFactory(SameDiff sameDiff) {
        if (sameDiff != null) {
            this.sameDiff = sameDiff;
            if (methodNames == null) {
                Method[] methods;
                methodNames = new HashMap<String, Method>();
                for (Method method : methods = this.getClass().getDeclaredMethods()) {
                    methodNames.put(method.getName().toLowerCase(), method);
                }
            }
        } else {
            throw new IllegalArgumentException("Input not null value.");
        }
    }

    public SameDiff sameDiff() {
        return this.sameDiff;
    }

    public SDVariable invoke(String name, Object[] args) {
        try {
            return (SDVariable)methodNames.get(name).invoke((Object)this, args);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public Constant val(SDVariable iX) {
        return new Constant(this.sameDiff(), iX, iX.getShape());
    }

    public ExternalErrorsFunction externalErrors(SDVariable ... inputs) {
        return this.externalErrors((Map<String, INDArray>)null, inputs);
    }

    public ExternalErrorsFunction externalErrors(Map<String, INDArray> externalGradients, SDVariable ... inputs) {
        Preconditions.checkArgument((inputs != null && inputs.length > 0 ? 1 : 0) != 0, (String)"Require at least one SDVariable to be specified when using external errors: got %s", (Object[])inputs);
        ExternalErrorsFunction fn = new ExternalErrorsFunction(this.sameDiff(), Arrays.asList(inputs), externalGradients);
        fn.outputVariable();
        return fn;
    }

    public SDVariable zerosLike(SDVariable input) {
        return this.zerosLike(null, input);
    }

    public SDVariable zerosLike(String name, SDVariable input) {
        this.validateDifferentialFunctionsameDiff(input);
        return new ZerosLike(name, this.sameDiff(), input).outputVariable();
    }

    public SDVariable onesLike(String name, SDVariable input, DataType dataType) {
        this.validateDifferentialFunctionsameDiff(input);
        return new OnesLike(name, this.sameDiff(), input, dataType).outputVariable();
    }

    public SDVariable constant(SDVariable input, long ... shape) {
        return new Constant(this.sameDiff(), input, shape != null && shape.length > 0 ? shape : null).outputVariable();
    }

    public SDVariable linspace(SDVariable lower, SDVariable upper, SDVariable count, DataType dt) {
        return new Linspace(this.sameDiff(), lower, upper, count, dt).outputVariable();
    }

    public SDVariable range(double from, double to, double step, DataType dataType) {
        return new Range(this.sameDiff(), from, to, step, dataType).outputVariable();
    }

    public SDVariable cast(SDVariable toCast, DataType toType) {
        return new Cast(this.sameDiff(), toCast, toType).outputVariable();
    }

    public SDVariable[] meshgrid(boolean cartesian, SDVariable ... inputs) {
        return new MeshGrid(this.sameDiff(), cartesian, inputs).outputVariables();
    }

    public SDVariable randomUniform(double min, double max, SDVariable shape) {
        return new DistributionUniform(this.sameDiff(), shape, min, max).outputVariable();
    }

    public SDVariable randomUniform(double min, double max, long ... shape) {
        return new UniformDistribution(this.sameDiff(), min, max, shape).outputVariable();
    }

    public SDVariable randomNormal(double mean, double std, SDVariable shape) {
        return new RandomNormal(this.sameDiff(), shape, mean, std).outputVariable();
    }

    public SDVariable randomNormal(double mean, double std, long ... shape) {
        return new GaussianDistribution(this.sameDiff(), mean, std, shape).outputVariable();
    }

    public SDVariable randomBernoulli(double p, SDVariable shape) {
        return new RandomBernoulli(this.sameDiff(), shape, p).outputVariable();
    }

    public SDVariable randomBernoulli(double p, long ... shape) {
        return new BernoulliDistribution(this.sameDiff(), p, shape).outputVariable();
    }

    public SDVariable randomBinomial(int nTrials, double p, long ... shape) {
        return new BinomialDistribution(this.sameDiff(), nTrials, p, shape).outputVariable();
    }

    public SDVariable randomLogNormal(double mean, double stdev, long ... shape) {
        return new LogNormalDistribution(this.sameDiff(), mean, stdev, shape).outputVariable();
    }

    public SDVariable randomNormalTruncated(double mean, double stdev, long ... shape) {
        return new TruncatedNormalDistribution(this.sameDiff(), mean, stdev, shape).outputVariable();
    }

    public SDVariable randomExponential(double lambda, SDVariable shape) {
        return new RandomExponential(this.sameDiff(), shape, lambda).outputVariable();
    }

    public SDVariable pad(SDVariable input, SDVariable padding, Pad.Mode mode, double padValue) {
        return new Pad(this.sameDiff(), input, padding, mode, padValue).outputVariable();
    }

    public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) {
        LocalResponseNormalization lrn = LocalResponseNormalization.builder().inputFunctions(new SDVariable[]{input}).sameDiff(this.sameDiff()).config(lrnConfig).build();
        return lrn.outputVariable();
    }

    public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) {
        Conv1D conv1D = Conv1D.builder().inputFunctions(new SDVariable[]{input, weights}).sameDiff(this.sameDiff()).config(conv1DConfig).build();
        return conv1D.outputVariable();
    }

    public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
        Conv2D conv2D = Conv2D.builder().inputFunctions(inputs).sameDiff(this.sameDiff()).config(conv2DConfig).build();
        return conv2D.outputVariable();
    }

    public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) {
        return new Upsampling2d(this.sameDiff(), input, nchw, scaleH, scaleW).outputVariable();
    }

    public SDVariable upsampling2dBp(SDVariable input, SDVariable gradient, boolean nchw, int scaleH, int scaleW) {
        return new Upsampling2dDerivative(this.sameDiff(), input, gradient, nchw, scaleH, scaleW).outputVariable();
    }

    public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
        AvgPooling2D avgPooling2D = AvgPooling2D.builder().input(input).sameDiff(this.sameDiff()).config(pooling2DConfig).build();
        return avgPooling2D.outputVariable();
    }

    public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) {
        MaxPooling2D maxPooling2D = MaxPooling2D.builder().input(input).sameDiff(this.sameDiff()).config(pooling2DConfig).build();
        return maxPooling2D.outputVariable();
    }

    public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) {
        pooling3DConfig.setType(Pooling3D.Pooling3DType.AVG);
        return this.pooling3d(input, pooling3DConfig);
    }

    public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) {
        pooling3DConfig.setType(Pooling3D.Pooling3DType.MAX);
        return this.pooling3d(input, pooling3DConfig);
    }

    public SDVariable pooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) {
        Pooling3D pool3d = Pooling3D.builder().inputs(new SDVariable[]{input}).sameDiff(this.sameDiff()).pooling3DConfig(pooling3DConfig).type(pooling3DConfig.getType()).build();
        return pool3d.outputVariable();
    }

    public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) {
        SConv2D sconv2D = SConv2D.sBuilder().inputFunctions(inputs).sameDiff(this.sameDiff()).conv2DConfig(conv2DConfig).build();
        return sconv2D.outputVariable();
    }

    public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) {
        SConv2D depthWiseConv2D = SConv2D.sBuilder().inputFunctions(inputs).sameDiff(this.sameDiff()).conv2DConfig(depthConv2DConfig).build();
        return depthWiseConv2D.outputVariable();
    }

    public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) {
        DeConv2D deconv2D = DeConv2D.builder().inputs(inputs).sameDiff(this.sameDiff()).config(deconv2DConfig).build();
        return deconv2D.outputVariable();
    }

    public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) {
        DeConv3D d = new DeConv3D(this.sameDiff(), input, weights, bias, config);
        return d.outputVariable();
    }

    public SDVariable[] deconv3dDerivative(SDVariable input, SDVariable weights, SDVariable bias, SDVariable grad, DeConv3DConfig config) {
        DeConv3DDerivative d = new DeConv3DDerivative(this.sameDiff(), input, weights, bias, grad, config);
        return d.outputVariables();
    }

    public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) {
        Conv3D conv3D = Conv3D.builder().inputFunctions(inputs).conv3DConfig(conv3DConfig).sameDiff(this.sameDiff()).build();
        SDVariable[] outputVars = conv3D.outputVariables();
        return outputVars[0];
    }

    public SDVariable batchNorm(SDVariable input, SDVariable mean, SDVariable variance, SDVariable gamma, SDVariable beta, boolean applyGamma, boolean applyBeta, double epsilon, int ... axis) {
        BatchNorm batchNorm = BatchNorm.builder().inputFunctions(new SDVariable[]{input, mean, variance, gamma, beta}).applyGamma(applyGamma).applyBeta(applyBeta).epsilon(epsilon).sameDiff(this.sameDiff()).axis(axis).build();
        SDVariable[] outputVars = batchNorm.outputVariables();
        return outputVars[0];
    }

    public SDVariable im2Col(SDVariable input, Conv2DConfig config) {
        return new Im2col(this.sameDiff(), input, config).outputVariable();
    }

    public SDVariable im2ColBp(SDVariable im2colInput, SDVariable gradientAtOutput, Conv2DConfig config) {
        return new Im2colBp(this.sameDiff(), im2colInput, gradientAtOutput, config).outputVariable();
    }

    public SDVariable col2Im(SDVariable input, Conv2DConfig config) {
        return new Col2Im(this.sameDiff(), input, config).outputVariable();
    }

    public SDVariable extractImagePatches(SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
        return new ExtractImagePatches(this.sameDiff(), input, new int[]{kH, kW}, new int[]{sH, sW}, new int[]{rH, rW}, sameMode).outputVariable();
    }

    public SDVariable[] moments(SDVariable input, int ... axes) {
        return new Moments(this.sameDiff(), input, axes).outputVariables();
    }

    public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, double shift) {
        return new NormalizeMoments(this.sameDiff(), counts, means, variances, shift).outputVariables();
    }

    public SDVariable tile(@NonNull SDVariable iX, @NonNull int[] repeat) {
        if (iX == null) {
            throw new NullPointerException("iX is marked @NonNull but is null");
        }
        if (repeat == null) {
            throw new NullPointerException("repeat is marked @NonNull but is null");
        }
        return new Tile(this.sameDiff(), iX, repeat).outputVariable();
    }

    public SDVariable tileBp(@NonNull SDVariable in, @NonNull SDVariable grad, @NonNull int[] repeat) {
        if (in == null) {
            throw new NullPointerException("in is marked @NonNull but is null");
        }
        if (grad == null) {
            throw new NullPointerException("grad is marked @NonNull but is null");
        }
        if (repeat == null) {
            throw new NullPointerException("repeat is marked @NonNull but is null");
        }
        return new TileBp(this.sameDiff, in, grad, repeat).outputVariable();
    }

    public SDVariable dropout(SDVariable input, double p) {
        return new DropOutInverted(this.sameDiff(), input, p).outputVariable();
    }

    public SDVariable sum(SDVariable i_x, boolean keepDims, int ... dimensions) {
        return new Sum(this.sameDiff(), i_x, keepDims, dimensions).outputVariable();
    }

    public SDVariable sumBp(SDVariable i_x, SDVariable grad, boolean keepDims, int ... dimensions) {
        return new SumBp(this.sameDiff(), i_x, grad, keepDims, dimensions).outputVariable();
    }

    public SDVariable prod(SDVariable i_x, boolean keepDims, int ... dimensions) {
        return new Prod(this.sameDiff(), i_x, keepDims, dimensions).outputVariable();
    }

    public SDVariable prodBp(SDVariable preReduceInput, SDVariable grad, boolean keepDims, int ... dimensions) {
        return new ProdBp(this.sameDiff(), preReduceInput, grad, keepDims, dimensions).outputVariable();
    }

    public SDVariable mean(SDVariable in, boolean keepDims, int ... dimensions) {
        return new Mean(this.sameDiff(), in, keepDims, dimensions).outputVariable();
    }

    public SDVariable meanBp(SDVariable in, SDVariable grad, boolean keepDims, int ... dimensions) {
        return new MeanBp(this.sameDiff(), in, grad, keepDims, dimensions).outputVariable();
    }

    public SDVariable std(SDVariable i_x, boolean biasCorrected, boolean keepDims, int ... dimensions) {
        return new StandardDeviation(this.sameDiff(), i_x, biasCorrected, keepDims, dimensions).outputVariable();
    }

    public SDVariable stdBp(SDVariable stdInput, SDVariable gradient, boolean biasCorrected, boolean keepDims, int ... dimensions) {
        return new StandardDeviationBp(this.sameDiff(), stdInput, gradient, biasCorrected, keepDims, dimensions).outputVariable();
    }

    public SDVariable variance(SDVariable i_x, boolean biasCorrected, boolean keepDims, int ... dimensions) {
        return new Variance(this.sameDiff(), i_x, biasCorrected, keepDims, dimensions).outputVariable();
    }

    public SDVariable varianceBp(SDVariable stdInput, SDVariable gradient, boolean biasCorrected, boolean keepDims, int ... dimensions) {
        return new VarianceBp(this.sameDiff(), stdInput, gradient, biasCorrected, keepDims, dimensions).outputVariable();
    }

    public SDVariable standardize(SDVariable i_x, int ... dimensions) {
        return new Standardize(this.sameDiff(), i_x, dimensions).outputVariable();
    }

    public SDVariable standardizeBp(SDVariable stdInput, SDVariable gradient, int ... dimensions) {
        return new StandardizeBp(this.sameDiff(), stdInput, gradient, dimensions).outputVariable();
    }

    public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, int ... dimensions) {
        return new LayerNorm(this.sameDiff(), input, gain, bias, dimensions).outputVariable();
    }

    public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, int ... dimensions) {
        return new LayerNormBp(this.sameDiff(), input, gain, bias, gradient, dimensions).outputVariables();
    }

    public SDVariable layerNorm(SDVariable input, SDVariable gain, int ... dimensions) {
        return new LayerNorm(this.sameDiff(), input, gain, dimensions).outputVariable();
    }

    public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, int ... dimensions) {
        return new LayerNormBp(this.sameDiff(), input, gain, gradient, dimensions).outputVariables();
    }

    public SDVariable squaredNorm(SDVariable input, boolean keepDims, int ... dimensions) {
        return new SquaredNorm(this.sameDiff(), input, keepDims, dimensions).outputVariable();
    }

    public SDVariable squaredNormBp(SDVariable preReduceInput, SDVariable gradient, boolean keepDims, int ... dimensions) {
        return new SquaredNormBp(this.sameDiff(), preReduceInput, gradient, keepDims, dimensions).outputVariable();
    }

    public SDVariable entropy(SDVariable in, int ... dimensions) {
        return new Entropy(this.sameDiff(), in, dimensions).outputVariable();
    }

    public SDVariable logEntropy(SDVariable in, int ... dimensions) {
        return new LogEntropy(this.sameDiff(), in, dimensions).outputVariable();
    }

    public SDVariable shannonEntropy(SDVariable in, int ... dimensions) {
        return new ShannonEntropy(this.sameDiff(), in, dimensions).outputVariable();
    }

    public SDVariable countNonZero(SDVariable input, int ... dimensions) {
        return new CountNonZero(this.sameDiff(), input, dimensions).outputVariable();
    }

    public SDVariable countZero(SDVariable input, int ... dimensions) {
        return new CountZero(this.sameDiff(), input, dimensions).outputVariable();
    }

    public SDVariable zeroFraction(SDVariable input) {
        return new ZeroFraction(this.sameDiff(), input).outputVariable();
    }

    public SDVariable scalarMax(SDVariable in, Number num) {
        return new ScalarMax(this.sameDiff(), in, num).outputVariable();
    }

    public SDVariable scalarMin(SDVariable in, Number num) {
        return new ScalarMin(this.sameDiff(), in, num).outputVariable();
    }

    public SDVariable scalarSet(SDVariable in, Number num) {
        return new ScalarSet(this.sameDiff(), in, num).outputVariable();
    }

    public SDVariable scalarFloorMod(SDVariable in, Number num) {
        return new ScalarFMod(this.sameDiff(), in, num).outputVariable();
    }

    public SDVariable max(SDVariable i_x, boolean keepDims, int ... dimensions) {
        return new org.nd4j.linalg.api.ops.impl.reduce.same.Max(this.sameDiff(), i_x, keepDims, dimensions).outputVariable();
    }

    public SDVariable max(SDVariable first, SDVariable second) {
        return new Max(this.sameDiff(), first, second).outputVariable();
    }

    public SDVariable maxBp(SDVariable i_x, SDVariable grad, boolean keepDims, int ... dimensions) {
        return new MaxBp(this.sameDiff(), i_x, grad, keepDims, dimensions).outputVariable();
    }

    public SDVariable min(SDVariable i_x, boolean keepDims, int ... dimensions) {
        return new Min(this.sameDiff(), i_x, keepDims, dimensions).outputVariable();
    }

    public SDVariable minBp(SDVariable i_x, SDVariable grad, boolean keepDims, int ... dimensions) {
        return new MinBp(this.sameDiff(), i_x, grad, keepDims, dimensions).outputVariable();
    }

    public SDVariable min(SDVariable first, SDVariable second) {
        return new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(this.sameDiff(), first, second).outputVariable();
    }

    public SDVariable amax(SDVariable in, int ... dimensions) {
        return new AMax(this.sameDiff(), in, dimensions).outputVariable();
    }

    public SDVariable amin(SDVariable in, int ... dimensions) {
        return new AMin(this.sameDiff(), in, dimensions).outputVariable();
    }

    public SDVariable amean(SDVariable in, int ... dimensions) {
        return new AMean(this.sameDiff(), in, dimensions).outputVariable();
    }

    public SDVariable asum(SDVariable in, int ... dimensions) {
        return new ASum(this.sameDiff(), in, dimensions).outputVariable();
    }

    public SDVariable argmax(SDVariable in, boolean keepDims, int ... dimensions) {
        return new IMax(this.sameDiff(), in, keepDims, dimensions).outputVariable();
    }

    public SDVariable argmin(SDVariable in, boolean keepDims, int ... dimensions) {
        return new IMin(this.sameDiff(), in, keepDims, dimensions).outputVariable();
    }

    public SDVariable iamax(SDVariable in, boolean keepDims, int ... dimensions) {
        return new IAMax(this.sameDiff(), in, keepDims, dimensions).outputVariable();
    }

    public SDVariable iamin(SDVariable in, boolean keepDims, int ... dimensions) {
        return new IAMin(this.sameDiff(), in, keepDims, dimensions).outputVariable();
    }

    public SDVariable firstIndex(SDVariable in, Condition condition, boolean keepDims, int ... dimensions) {
        return new FirstIndex(this.sameDiff(), in, condition, keepDims, dimensions).outputVariable();
    }

    public SDVariable lastIndex(SDVariable in, Condition condition, boolean keepDims, int ... dimensions) {
        return new LastIndex(this.sameDiff(), in, condition, keepDims, dimensions).outputVariable();
    }

    public SDVariable matchConditionCount(SDVariable in, Condition condition, boolean keepDims, int ... dimensions) {
        return new MatchCondition(this.sameDiff(), in, condition, keepDims, dimensions).outputVariable();
    }

    public SDVariable matchCondition(SDVariable in, Condition condition) {
        return new MatchConditionTransform(this.sameDiff(), in, condition).outputVariable();
    }

    public SDVariable cumsum(SDVariable in, boolean exclusive, boolean reverse, int ... axis) {
        return new CumSum(this.sameDiff(), in, exclusive, reverse, axis).outputVariable();
    }

    public SDVariable cumsumBp(SDVariable in, SDVariable grad, boolean exclusive, boolean reverse, int ... axis) {
        return new CumSumBp(this.sameDiff(), in, grad, exclusive, reverse, axis).outputVariable();
    }

    public SDVariable cumprod(SDVariable in, boolean exclusive, boolean reverse, int ... axis) {
        return new CumProd(this.sameDiff(), in, exclusive, reverse, axis).outputVariable();
    }

    public SDVariable cumprodBp(SDVariable in, SDVariable grad, boolean exclusive, boolean reverse, int ... axis) {
        return new CumProdBp(this.sameDiff(), in, grad, exclusive, reverse, axis).outputVariable();
    }

    public SDVariable biasAdd(SDVariable input, SDVariable bias) {
        return new BiasAdd(this.sameDiff(), input, bias).outputVariable();
    }

    public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) {
        return new BiasAddGrad(this.sameDiff(), input, bias, grad).outputVariables();
    }

    public SDVariable norm1(SDVariable i_x, boolean keepDims, int ... dimensions) {
        return new Norm1(this.sameDiff(), i_x, keepDims, dimensions).outputVariable();
    }

    public SDVariable norm1Bp(SDVariable preReduceIn, SDVariable grad, boolean keepDims, int ... dimensions) {
        return new Norm1Bp(this.sameDiff(), preReduceIn, grad, keepDims, dimensions).outputVariable();
    }

    public SDVariable norm2(SDVariable i_x, boolean keepDims, int ... dimensions) {
        return new Norm2(this.sameDiff(), i_x, keepDims, dimensions).outputVariable();
    }

    public SDVariable norm2Bp(SDVariable preReduceIn, SDVariable grad, boolean keepDims, int ... dimensions) {
        return new Norm2Bp(this.sameDiff(), preReduceIn, grad, keepDims, dimensions).outputVariable();
    }

    public SDVariable normmax(SDVariable i_x, boolean keepDims, int ... dimensions) {
        return new NormMax(this.sameDiff(), i_x, keepDims, dimensions).outputVariable();
    }

    public SDVariable normmaxBp(SDVariable preReduceIn, SDVariable grad, boolean keepDims, int ... dimensions) {
        return new NormMaxBp(this.sameDiff(), preReduceIn, grad, keepDims, dimensions).outputVariable();
    }

    public SDVariable reductionShape(SDVariable shape, SDVariable axis, boolean keepDim) {
        return new ReductionShape(this.sameDiff(), shape, axis, keepDim).outputVariable();
    }

    public SDVariable reductionBroadcastableWithOrigShape(int origRank, int[] reduceDims, SDVariable toExpand) {
        if (Shape.isWholeArray(origRank, reduceDims)) {
            return toExpand;
        }
        if (origRank == 2 && reduceDims.length == 1) {
            return toExpand;
        }
        for (int d : reduceDims) {
            toExpand = this.sameDiff().expandDims(toExpand, d);
        }
        return toExpand;
    }

    public SDVariable reductionBroadcastableWithOrigShape(SDVariable origInput, SDVariable axis, SDVariable toExpand) {
        SDVariable shape = origInput.shape();
        SDVariable reduceShape = this.reductionShape(shape, axis, true);
        SDVariable reshaped = toExpand.reshape(reduceShape);
        return reshaped;
    }

    public SDVariable gradientBackwardsMarker(SDVariable iX) {
        return new GradientBackwardsMarker(this.sameDiff(), iX, this.sameDiff.scalar(iX.getVarName() + "-pairgrad", 1.0)).outputVariable();
    }

    public SDVariable abs(SDVariable iX) {
        return new Abs(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable neg(SDVariable iX) {
        return new Negative(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable cos(SDVariable iX) {
        return new Cos(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable sin(SDVariable iX) {
        return new Sin(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable tan(SDVariable iX) {
        return new Tan(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable permute(SDVariable iX, int ... dimensions) {
        return new Permute(this.sameDiff(), iX, dimensions).outputVariable();
    }

    public SDVariable noop(SDVariable input) {
        return new NoOp(this.sameDiff(), input).outputVariable();
    }

    public SDVariable identity(SDVariable input) {
        return new Identity(this.sameDiff(), input).outputVariable();
    }

    public SDVariable all(SDVariable input, int ... dimensions) {
        return new All(this.sameDiff(), input, dimensions).outputVariable();
    }

    public SDVariable any(SDVariable input, int ... dimensions) {
        return new Any(this.sameDiff(), input, dimensions).outputVariable();
    }

    public SDVariable invertPermutation(SDVariable input, boolean inPlace) {
        return new InvertPermutation(this.sameDiff(), input, inPlace).outputVariable();
    }

    public SDVariable transpose(SDVariable iX) {
        return new Transpose(this.sameDiff(), iX).outputVariable();
    }

    public SDVariable acos(SDVariable iX) {
        return new ACos(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable asin(SDVariable iX) {
        return new ASin(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable atan(SDVariable iX) {
        return new ATan(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable atan2(SDVariable y, SDVariable x) {
        return new ATan2(this.sameDiff(), y, x).outputVariable();
    }

    public SDVariable cosh(SDVariable iX) {
        return new Cosh(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable sinh(SDVariable iX) {
        return new Sinh(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable tanh(SDVariable iX) {
        return new Tanh(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable tanhRational(SDVariable in) {
        return new RationalTanh(this.sameDiff(), in, false).outputVariable();
    }

    public SDVariable tanhRectified(SDVariable in) {
        return new RectifiedTanh(this.sameDiff(), in, false).outputVariable();
    }

    public SDVariable tanhDerivative(SDVariable iX, SDVariable wrt) {
        return new TanhDerivative(this.sameDiff(), iX, wrt).outputVariable();
    }

    public SDVariable tanhRationalDerivative(SDVariable in) {
        return new RationalTanhDerivative(this.sameDiff(), in, false).outputVariable();
    }

    public SDVariable tanhRectifiedDerivative(SDVariable in) {
        return new RectifiedTanhDerivative(this.sameDiff(), in, false).outputVariable();
    }

    public SDVariable step(SDVariable in, double cutoff) {
        return new Step(this.sameDiff(), in, false, cutoff).outputVariable();
    }

    public SDVariable acosh(SDVariable iX) {
        return new ACosh(this.sameDiff(), iX).outputVariable();
    }

    public SDVariable asinh(SDVariable iX) {
        return new ASinh(this.sameDiff(), iX).outputVariable();
    }

    public SDVariable atanh(SDVariable iX) {
        return new ATanh(this.sameDiff(), iX).outputVariable();
    }

    public SDVariable exp(SDVariable iX) {
        return new Exp(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable expm1(SDVariable iX) {
        return new Expm1(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable rsqrt(SDVariable iX) {
        return new RSqrt(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable log(SDVariable iX) {
        return new Log(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable log(SDVariable in, double base) {
        return new LogX(this.sameDiff(), in, base).outputVariable();
    }

    public SDVariable log1p(SDVariable iX) {
        return new Log1p(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable isFinite(SDVariable ix) {
        return new IsFinite(this.sameDiff(), ix, false).outputVariable();
    }

    public SDVariable isInfinite(SDVariable ix) {
        return new IsInf(this.sameDiff(), ix, false).outputVariable();
    }

    public SDVariable isNaN(SDVariable ix) {
        return new IsNaN(this.sameDiff(), ix, false).outputVariable();
    }

    public SDVariable isMax(SDVariable ix) {
        return new IsMax(this.sameDiff(), ix, false).outputVariable();
    }

    public SDVariable replaceWhere(SDVariable to, SDVariable from, Condition condition) {
        return new CompareAndReplace(this.sameDiff(), to, from, condition).outputVariable();
    }

    public SDVariable replaceWhere(SDVariable to, Number set, Condition condition) {
        return new CompareAndSet(this.sameDiff(), to, set, condition).outputVariable();
    }

    public SDVariable round(SDVariable ix) {
        return new Round(this.sameDiff(), ix, false).outputVariable();
    }

    public SDVariable or(SDVariable iX, SDVariable i_y) {
        return new Or(this.sameDiff(), iX, i_y).outputVariable();
    }

    public SDVariable and(SDVariable ix, SDVariable iy) {
        return new And(this.sameDiff(), ix, iy).outputVariable();
    }

    public SDVariable xor(SDVariable ix, SDVariable iy) {
        return new Xor(this.sameDiff(), ix, iy).outputVariable();
    }

    public SDVariable eq(SDVariable iX, SDVariable i_y) {
        return new EqualTo(this.sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
    }

    public SDVariable neq(SDVariable iX, double i_y) {
        return new ScalarNotEquals(this.sameDiff(), iX, (Number)i_y).outputVariable();
    }

    public SDVariable neqi(SDVariable iX, double i_y) {
        return new ScalarNotEquals(this.sameDiff(), iX, (Number)i_y, true).outputVariable();
    }

    public SDVariable neqi(SDVariable iX, SDVariable i_y) {
        return new NotEqualTo(this.sameDiff(), new SDVariable[]{iX, i_y}, true).outputVariable();
    }

    public SDVariable neq(SDVariable iX, SDVariable i_y) {
        return new NotEqualTo(this.sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable();
    }

    public SDVariable pow(SDVariable iX, double i_y) {
        return new Pow(this.sameDiff(), iX, false, i_y).outputVariable();
    }

    public SDVariable pow(SDVariable x, SDVariable y) {
        return new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(this.sameDiff(), x, y).outputVariable();
    }

    public SDVariable sqrt(SDVariable iX) {
        return new Sqrt(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable square(SDVariable iX) {
        return new Square(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable cube(SDVariable iX) {
        return new Cube(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable cubeDerivative(SDVariable iX) {
        return new CubeDerivative(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable floor(SDVariable iX) {
        return new Floor(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable floorDiv(SDVariable x, SDVariable y) {
        return new FloorDivOp(this.sameDiff(), x, y).outputVariable();
    }

    public List<SDVariable> floorDivBp(SDVariable x, SDVariable y, SDVariable grad) {
        return Arrays.asList(new FloorDivBpOp(this.sameDiff(), x, y, grad).outputVariables());
    }

    public SDVariable floorMod(SDVariable x, SDVariable y) {
        return new FloorModOp(this.sameDiff(), x, y).outputVariable();
    }

    public List<SDVariable> floorModBp(SDVariable x, SDVariable y, SDVariable grad) {
        return Arrays.asList(new FloorModBpOp(this.sameDiff(), x, y, grad).outputVariables());
    }

    public SDVariable ceil(SDVariable x) {
        return new Ceil(this.sameDiff(), x).outputVariable();
    }

    public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax) {
        return new ClipByValue(this.sameDiff(), x, clipValueMin, clipValueMax).outputVariable();
    }

    public SDVariable clipByNorm(SDVariable x, double clipValue) {
        return new ClipByNorm(this.sameDiff(), x, clipValue, new int[0]).outputVariable();
    }

    public SDVariable clipByNorm(SDVariable x, double clipValue, int ... dimensions) {
        return new ClipByNorm(this.sameDiff(), x, clipValue, dimensions).outputVariable();
    }

    public SDVariable relu(SDVariable iX, double cutoff) {
        return new RectifiedLinear(this.sameDiff(), iX, false, cutoff).outputVariable();
    }

    public SDVariable relu6(SDVariable iX, double cutoff) {
        return new Relu6(this.sameDiff(), iX, false, cutoff).outputVariable();
    }

    public SDVariable relu6Derivative(SDVariable iX, SDVariable wrt, double cutoff) {
        return new Relu6Derivative(this.sameDiff(), iX, wrt, cutoff).outputVariable();
    }

    public SDVariable softmax(SDVariable iX) {
        return new SoftMax(this.sameDiff(), new SDVariable[]{iX}).outputVariable();
    }

    public SDVariable softmax(SDVariable iX, int dimension) {
        return new SoftMax(this.sameDiff(), new SDVariable[]{iX}, dimension).outputVariable();
    }

    public SDVariable hardTanh(SDVariable iX) {
        return new HardTanh(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable hardTanhDerivative(SDVariable iX) {
        return new HardTanhDerivative(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable hardSigmoid(SDVariable in) {
        return new HardSigmoid(this.sameDiff(), in, false).outputVariable();
    }

    public SDVariable sigmoid(SDVariable iX) {
        return new Sigmoid(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable sigmoidDerivative(SDVariable iX, SDVariable wrt) {
        return new SigmoidDerivative(this.sameDiff(), iX, wrt).outputVariable();
    }

    public SDVariable logSigmoid(SDVariable iX) {
        return new LogSigmoid(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable powDerivative(SDVariable iX, double pow) {
        return new PowDerivative(this.sameDiff(), iX, false, pow).outputVariable();
    }

    public SDVariable swish(SDVariable iX) {
        return new Swish(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable swishDerivative(SDVariable iX) {
        return new SwishDerivative(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable gelu(SDVariable iX, boolean precise) {
        if (precise) {
            return new PreciseGELU(this.sameDiff(), iX, false, precise).outputVariable();
        }
        return new GELU(this.sameDiff(), iX, false, precise).outputVariable();
    }

    public SDVariable geluDerivative(SDVariable iX, boolean precise) {
        if (precise) {
            return new PreciseGELUDerivative(this.sameDiff(), iX, false, precise).outputVariable();
        }
        return new GELUDerivative(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable sign(SDVariable iX) {
        return new Sign(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable expandDims(SDVariable iX, int axis) {
        return new ExpandDims(this.sameDiff(), new SDVariable[]{iX}, axis).outputVariable();
    }

    public SDVariable squeeze(SDVariable iX, int ... axis) {
        return new Squeeze(this.sameDiff(), iX, axis).outputVariable();
    }

    public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, DataType dataType) {
        return new ConfusionMatrix(this.sameDiff(), labels, pred, dataType).outputVariable();
    }

    public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses) {
        return new ConfusionMatrix(this.sameDiff(), labels, pred, numClasses).outputVariable();
    }

    public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights) {
        return new ConfusionMatrix(this.sameDiff(), labels, pred, weights).outputVariable();
    }

    public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) {
        return new ConfusionMatrix(this.sameDiff(), labels, pred, numClasses, weights).outputVariable();
    }

    public SDVariable matrixDeterminant(SDVariable in) {
        return new MatrixDeterminant(this.sameDiff(), in, false).outputVariable();
    }

    public SDVariable matrixInverse(SDVariable in) {
        return new MatrixInverse(this.sameDiff(), in, false).outputVariable();
    }

    public SDVariable broadcast(SDVariable iX, int ... shape) {
        return this.broadcast(iX, ArrayUtil.toLongArray((int[])shape));
    }

    public SDVariable broadcast(SDVariable iX, long ... shape) {
        return new Broadcast(this.sameDiff(), iX, shape).outputVariable();
    }

    public SDVariable onehot(SDVariable indices, int depth, int axis, double on, double off, DataType dataType) {
        return new OneHot(this.sameDiff(), indices, depth, axis, on, off, dataType).outputVariable();
    }

    public SDVariable onehot(SDVariable indices, int depth) {
        return new OneHot(this.sameDiff(), indices, depth).outputVariable();
    }

    public SDVariable reciprocal(SDVariable a) {
        return new Reciprocal(this.sameDiff(), a, false).outputVariable();
    }

    public SDVariable repeat(SDVariable iX, int axis) {
        return new Repeat(this.sameDiff(), new SDVariable[]{iX}, axis).outputVariable();
    }

    public SDVariable stack(SDVariable[] values, int axis) {
        return new Stack(this.sameDiff(), values, axis).outputVariable();
    }

    public SDVariable parallel_stack(SDVariable[] values) {
        return new ParallelStack(this.sameDiff(), values).outputVariable();
    }

    public SDVariable[] unstack(SDVariable value, int axis) {
        return new Unstack(this.sameDiff(), value, axis).outputVariables();
    }

    public SDVariable[] unstack(SDVariable value, int axis, int num) {
        return new Unstack(this.sameDiff(), value, axis, num).outputVariables();
    }

    public SDVariable assign(SDVariable x, SDVariable y) {
        return new Assign(this.sameDiff(), x, y).outputVariable();
    }

    public SDVariable assign(SDVariable x, Number num) {
        return new ScalarSet(this.sameDiff(), x, num).outputVariable();
    }

    public SDVariable softsign(SDVariable iX) {
        return new SoftSign(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable softsignDerivative(SDVariable iX) {
        return new SoftSignDerivative(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable softplus(SDVariable iX) {
        return new SoftPlus(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable elu(SDVariable iX) {
        return new ELU(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable eluDerivative(SDVariable iX) {
        return new ELUDerivative(this.sameDiff(), iX, false).outputVariable();
    }

    public SDVariable leakyRelu(SDVariable iX, double alpha) {
        return new LeakyReLU(this.sameDiff(), iX, false, alpha).outputVariable();
    }

    public SDVariable leakyReluDerivative(SDVariable iX, double cutoff) {
        return new LeakyReLUDerivative(this.sameDiff(), iX, false, cutoff).outputVariable();
    }

    public SDVariable reshape(SDVariable iX, int[] shape) {
        return new Reshape(this.sameDiff(), iX, ArrayUtil.toLongArray((int[])shape)).outputVariable();
    }

    public SDVariable reshape(SDVariable iX, long[] shape) {
        return new Reshape(this.sameDiff(), iX, shape).outputVariable();
    }

    public SDVariable reshape(SDVariable iX, SDVariable shape) {
        return new Reshape(this.sameDiff(), iX, shape).outputVariable();
    }

    public SDVariable reverse(SDVariable x, int ... dimensions) {
        return new Reverse(this.sameDiff(), x, dimensions).outputVariable();
    }

    public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seq_dim, int batch_dim) {
        return new ReverseSequence(this.sameDiff(), x, seq_lengths, seq_dim, batch_dim).outputVariable();
    }

    public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths) {
        return new ReverseSequence(this.sameDiff(), x, seq_lengths).outputVariable();
    }

    public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen, DataType dataType) {
        return new SequenceMask(this.sameDiff(), lengths, maxLen, dataType).outputVariable();
    }

    public SDVariable sequenceMask(SDVariable lengths, int maxLen, DataType dataType) {
        return new SequenceMask(this.sameDiff(), lengths, maxLen, dataType).outputVariable();
    }

    public SDVariable sequenceMask(SDVariable lengths, DataType dataType) {
        return new SequenceMask(this.sameDiff(), lengths, dataType).outputVariable();
    }

    public SDVariable concat(int dimension, SDVariable ... inputs) {
        return new Concat(this.sameDiff(), dimension, inputs).outputVariable();
    }

    public SDVariable fill(SDVariable shape, DataType dataType, double value) {
        return new Fill(this.sameDiff(), shape, dataType, value).outputVariable();
    }

    public SDVariable dot(SDVariable x, SDVariable y, int ... dimensions) {
        return new Dot(this.sameDiff(), x, y, dimensions).outputVariable();
    }

    public SDVariable[] dotBp(SDVariable in1, SDVariable in2, SDVariable grad, boolean keepDims, int ... dimensions) {
        return new DotBp(this.sameDiff(), in1, in2, grad, keepDims, dimensions).outputVariables();
    }

    public SDVariable cosineSimilarity(SDVariable iX, SDVariable i_y, int ... dimensions) {
        return new CosineSimilarity(this.sameDiff(), iX, i_y, dimensions).outputVariable();
    }

    public SDVariable cosineDistance(SDVariable ix, SDVariable iy, int ... dimensions) {
        return new CosineDistance(this.sameDiff(), ix, iy, dimensions).outputVariable();
    }

    public SDVariable euclideanDistance(SDVariable iX, SDVariable i_y, int ... dimensions) {
        return new EuclideanDistance(this.sameDiff(), iX, i_y, dimensions).outputVariable();
    }

    public SDVariable manhattanDistance(SDVariable iX, SDVariable i_y, int ... dimensions) {
        return new ManhattanDistance(this.sameDiff(), iX, i_y, dimensions).outputVariable();
    }

    public SDVariable hammingDistance(SDVariable ix, SDVariable iy, int ... dimensions) {
        return new HammingDistance(this.sameDiff(), ix, iy, dimensions).outputVariable();
    }

    public SDVariable jaccardDistance(SDVariable ix, SDVariable iy, int ... dimensions) {
        return new JaccardDistance(this.sameDiff(), ix, iy, dimensions).outputVariable();
    }

    public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable inputs, SDVariable weights) {
        return new WeightedCrossEntropyLoss(this.sameDiff(), targets, inputs, weights).outputVariable();
    }

    public SDVariable lossL2(SDVariable var) {
        return new L2Loss(this.sameDiff(), var).outputVariable();
    }

    public SDVariable lossAbsoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce) {
        return new AbsoluteDifferenceLoss(this.sameDiff(), lossReduce, predictions, weights, label).outputVariable();
    }

    public SDVariable[] lossAbsoluteDifferenceBP(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce) {
        return new AbsoluteDifferenceLossBp(this.sameDiff(), lossReduce, predictions, weights, label).outputVariables();
    }

    public SDVariable lossCosineDistance(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, int dimension) {
        return new CosineDistanceLoss(this.sameDiff(), lossReduce, predictions, weights, label, dimension).outputVariable();
    }

    public SDVariable[] lossCosineDistanceBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, int dimension) {
        return new CosineDistanceLossBp(this.sameDiff(), lossReduce, predictions, weights, label, dimension).outputVariables();
    }

    public SDVariable lossHinge(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce) {
        return new HingeLoss(this.sameDiff(), lossReduce, predictions, weights, label).outputVariable();
    }

    public SDVariable[] lossHingeBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce) {
        return new HingeLossBp(this.sameDiff(), lossReduce, predictions, weights, label).outputVariables();
    }

    public SDVariable lossHuber(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double delta) {
        return new HuberLoss(this.sameDiff(), lossReduce, predictions, weights, label, delta).outputVariable();
    }

    public SDVariable[] lossHuberBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double delta) {
        return new HuberLossBp(this.sameDiff(), lossReduce, predictions, weights, label, delta).outputVariables();
    }

    public SDVariable lossLog(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double epsilon) {
        return new LogLoss(this.sameDiff(), lossReduce, predictions, weights, label, epsilon).outputVariable();
    }

    public SDVariable[] lossLogBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double epsilon) {
        return new LogLossBp(this.sameDiff(), lossReduce, predictions, weights, label, epsilon).outputVariables();
    }

    public SDVariable lossLogPoisson(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce) {
        return new LogPoissonLoss(this.sameDiff(), lossReduce, predictions, weights, label).outputVariable();
    }

    public SDVariable[] lossLogPoissonBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce) {
        return new LogPoissonLossBp(this.sameDiff(), lossReduce, predictions, weights, label).outputVariables();
    }

    public SDVariable lossLogPoissonFull(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce) {
        return new LogPoissonLoss(this.sameDiff(), lossReduce, predictions, weights, label, true).outputVariable();
    }

    public SDVariable[] lossLogPoissonFullBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce) {
        return new LogPoissonLossBp(this.sameDiff(), lossReduce, predictions, weights, label, true).outputVariables();
    }

    public SDVariable lossMeanPairwiseSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce) {
        return new MeanPairwiseSquaredErrorLoss(this.sameDiff(), lossReduce, predictions, weights, label).outputVariable();
    }

    public SDVariable[] lossMeanPairwiseSquaredErrorBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce) {
        return new MeanPairwiseSquaredErrorLossBp(this.sameDiff(), lossReduce, predictions, weights, label).outputVariables();
    }

    public SDVariable lossMeanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce) {
        return new MeanSquaredErrorLoss(this.sameDiff(), lossReduce, predictions, weights, label).outputVariable();
    }

    public SDVariable[] lossMeanSquaredErrorBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce) {
        return new MeanSquaredErrorLossBp(this.sameDiff(), lossReduce, predictions, weights, label).outputVariables();
    }

    public SDVariable lossSigmoidCrossEntropy(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) {
        return new SigmoidCrossEntropyLoss(this.sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariable();
    }

    public SDVariable[] lossSigmoidCrossEntropyBp(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) {
        return new SigmoidCrossEntropyLossBp(this.sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables();
    }

    public SDVariable lossSoftmaxCrossEntropy(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) {
        return new SoftmaxCrossEntropyLoss(this.sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariable();
    }

    public SDVariable[] lossSoftmaxCrossEntropyBp(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) {
        return new SoftmaxCrossEntropyLossBp(this.sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables();
    }

    public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) {
        return new SoftmaxCrossEntropyWithLogitsLoss(this.sameDiff(), logits, weights, labels, classDim).outputVariable();
    }

    public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) {
        return new SoftmaxCrossEntropyWithLogitsLossBp(this.sameDiff(), logits, weights, labels, classDim).outputVariables();
    }

    public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels) {
        return new SparseSoftmaxCrossEntropyLossWithLogits(this.sameDiff(), logits, labels).outputVariable();
    }

    public SDVariable[] lossSparseSoftmaxCrossEntropyBp(SDVariable logits, SDVariable labels) {
        return new SparseSoftmaxCrossEntropyLossWithLogitsBp(this.sameDiff(), logits, labels).outputVariables();
    }

    public SDVariable xwPlusB(SDVariable input, SDVariable weights, SDVariable bias) {
        return new XwPlusB(this.sameDiff(), input, weights, bias).outputVariable();
    }

    public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) {
        return new ReluLayer(this.sameDiff(), input, weights, bias).outputVariable();
    }

    public SDVariable mmul(SDVariable x, SDVariable y, MMulTranspose mMulTranspose) {
        this.validateDifferentialFunctionsameDiff(x);
        this.validateDifferentialFunctionsameDiff(y);
        return new Mmul(this.sameDiff(), x, y, mMulTranspose).outputVariable();
    }

    public SDVariable mmul(SDVariable x, SDVariable y) {
        return this.mmul(x, y, MMulTranspose.allFalse());
    }

    public List<SDVariable> mmulBp(SDVariable x, SDVariable y, SDVariable eps, MMulTranspose mt) {
        return Arrays.asList(new MmulBp(this.sameDiff(), x, y, eps, mt).outputVariables());
    }

    public SDVariable[] batchMmul(SDVariable[] matricesA, SDVariable[] matricesB) {
        return this.batchMmul(matricesA, matricesB, false, false);
    }

    public SDVariable[] batchMmul(SDVariable[] matricesA, SDVariable[] matricesB, boolean transposeA, boolean transposeB) {
        return this.batchMmul((SDVariable[])ArrayUtils.addAll((Object[])matricesA, (Object[])matricesB), transposeA, transposeB);
    }

    public SDVariable[] batchMmul(SDVariable[] matrices, boolean transposeA, boolean transposeB) {
        return new BatchMmul(this.sameDiff(), matrices, transposeA, transposeB).outputVariables();
    }

    public SDVariable tensorMmul(SDVariable x, SDVariable y, int[][] dimensions) {
        this.validateDifferentialFunctionsameDiff(x);
        this.validateDifferentialFunctionsameDiff(y);
        return new TensorMmul(this.sameDiff(), x, y, dimensions).outputVariable();
    }

    public SDVariable dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled) {
        return new DotProductAttention(this.sameDiff(), queries, keys, values, mask, scaled, false).outputVariable();
    }

    public List<SDVariable> dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled, boolean withWeights) {
        return Arrays.asList(new DotProductAttention(this.sameDiff(), queries, keys, values, mask, scaled, withWeights).outputVariables());
    }

    public List<SDVariable> dotProductAttentionBp(SDVariable queries, SDVariable keys, SDVariable values, SDVariable gradient, SDVariable mask, boolean scaled) {
        return Arrays.asList(new DotProductAttentionBp(this.sameDiff(), queries, keys, values, gradient, mask, scaled).outputVariables());
    }

    public SDVariable multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled) {
        return new MultiHeadDotProductAttention(this.sameDiff(), queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable();
    }

    public List<SDVariable> multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled, boolean withWeights) {
        return Arrays.asList(new MultiHeadDotProductAttention(this.sameDiff(), queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, withWeights).outputVariables());
    }

    public List<SDVariable> multiHeadDotProductAttentionBp(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable gradient, SDVariable mask, boolean scaled) {
        return Arrays.asList(new MultiHeadDotProductAttentionBp(this.sameDiff(), queries, keys, values, Wq, Wk, Wv, Wo, gradient, mask, scaled).outputVariables());
    }

    public SDVariable softmaxDerivative(SDVariable functionInput, SDVariable wrt, Integer dimension) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        return new SoftmaxBp(this.sameDiff(), functionInput, wrt, dimension).outputVariable();
    }

    public SDVariable logSoftmax(SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(i_v);
        return new LogSoftMax(this.sameDiff(), i_v).outputVariable();
    }

    public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt) {
        this.validateDifferentialFunctionsameDiff(arg);
        return new LogSoftMaxDerivative(this.sameDiff(), arg, wrt).outputVariable();
    }

    public SDVariable logSumExp(SDVariable arg, int ... dimension) {
        return new LogSumExp(this.sameDiff(), arg, dimension).outputVariable();
    }

    public SDVariable selu(SDVariable arg) {
        this.validateDifferentialFunctionsameDiff(arg);
        return new SELU(this.sameDiff(), arg, false).outputVariable();
    }

    public SDVariable seluDerivative(SDVariable arg) {
        this.validateDifferentialFunctionsameDiff(arg);
        return new SELUDerivative(this.sameDiff(), arg, false).outputVariable();
    }

    public SDVariable rsub(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new RSubOp(this.sameDiff(), differentialFunction, i_v).outputVariable();
    }

    public List<SDVariable> rsubBp(SDVariable x, SDVariable y, SDVariable grad) {
        return Arrays.asList(new RSubBpOp(this.sameDiff(), x, y, grad).outputVariables());
    }

    public SDVariable rdiv(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new RDivOp(this.sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable();
    }

    public List<SDVariable> rdivBp(SDVariable x, SDVariable y, SDVariable grad) {
        return Arrays.asList(new RDivBpOp(this.sameDiff(), x, y, grad).outputVariables());
    }

    public SDVariable rdivi(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new RDivOp(this.sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable();
    }

    public SDVariable rsubi(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new RSubOp(this.sameDiff(), differentialFunction, i_v).outputVariable();
    }

    public SDVariable add(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new AddOp(this.sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable();
    }

    public SDVariable mergeAdd(SDVariable ... differentialFunctions) {
        for (SDVariable df : differentialFunctions) {
            this.validateDifferentialFunctionsameDiff(df);
        }
        return new MergeAddOp(this.sameDiff(), differentialFunctions, false).outputVariable();
    }

    public SDVariable mergeMax(SDVariable ... differentialFunctions) {
        for (SDVariable df : differentialFunctions) {
            this.validateDifferentialFunctionsameDiff(df);
        }
        return new MergeMax(this.sameDiff(), differentialFunctions).outputVariable();
    }

    public SDVariable mergeAvg(SDVariable ... differentialFunctions) {
        for (SDVariable df : differentialFunctions) {
            this.validateDifferentialFunctionsameDiff(df);
        }
        return new MergeAvg(this.sameDiff(), differentialFunctions).outputVariable();
    }

    public SDVariable diag(SDVariable sdVariable) {
        this.validateDifferentialFunctionsameDiff(sdVariable);
        return new Diag(this.sameDiff(), new SDVariable[]{sdVariable}, false).outputVariable();
    }

    public SDVariable diagPart(SDVariable sdVariable) {
        this.validateDifferentialFunctionsameDiff(sdVariable);
        return new DiagPart(this.sameDiff(), new SDVariable[]{sdVariable}, false).outputVariable();
    }

    public SDVariable setDiag(SDVariable in, SDVariable diag) {
        return new MatrixSetDiag(this.sameDiff(), in, diag, false).outputVariable();
    }

    public SDVariable batchToSpace(SDVariable differentialFunction, int[] blocks, int[][] crops) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new BatchToSpace(this.sameDiff(), new SDVariable[]{differentialFunction}, blocks, crops, false).outputVariable();
    }

    public SDVariable spaceToBatch(SDVariable differentialFunction, int[] blocks, int[][] padding) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new SpaceToBatch(this.sameDiff(), new SDVariable[]{differentialFunction}, blocks, padding, false).outputVariable();
    }

    public SDVariable depthToSpace(SDVariable differentialFunction, int blocksSize, String dataFormat) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new DepthToSpace(this.sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat).outputVariable();
    }

    public SDVariable spaceToDepth(SDVariable differentialFunction, int blocksSize, String dataFormat) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new SpaceToDepth(this.sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat).outputVariable();
    }

    public SDVariable[] dynamicPartition(SDVariable differentialFunction, SDVariable partitions, int numPartitions) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new DynamicPartition(this.sameDiff(), differentialFunction, partitions, numPartitions).outputVariables();
    }

    public SDVariable[] dynamicPartitionBp(SDVariable input, SDVariable partitions, SDVariable[] grads, int numPartitions) {
        return new DynamicPartitionBp(this.sameDiff(), input, partitions, grads, numPartitions).outputVariables();
    }

    public SDVariable dynamicStitch(SDVariable[] indices, SDVariable[] differentialFunctions) {
        for (SDVariable df : differentialFunctions) {
            this.validateDifferentialFunctionsameDiff(df);
        }
        return new DynamicStitch(this.sameDiff(), indices, differentialFunctions).outputVariable();
    }

    public SDVariable segmentMax(SDVariable data, SDVariable segmentIds) {
        return new SegmentMax(this.sameDiff(), data, segmentIds).outputVariable();
    }

    public SDVariable[] segmentMaxBp(SDVariable data, SDVariable segmentIds, SDVariable gradient) {
        return new SegmentMaxBp(this.sameDiff(), data, segmentIds, gradient).outputVariables();
    }

    public SDVariable segmentMin(SDVariable data, SDVariable segmentIds) {
        return new SegmentMin(this.sameDiff(), data, segmentIds).outputVariable();
    }

    public SDVariable[] segmentMinBp(SDVariable data, SDVariable segmentIds, SDVariable gradient) {
        return new SegmentMinBp(this.sameDiff(), data, segmentIds, gradient).outputVariables();
    }

    public SDVariable segmentMean(SDVariable data, SDVariable segmentIds) {
        return new SegmentMean(this.sameDiff(), data, segmentIds).outputVariable();
    }

    public SDVariable[] segmentMeanBp(SDVariable data, SDVariable segmentIds, SDVariable gradient) {
        return new SegmentMeanBp(this.sameDiff(), data, segmentIds, gradient).outputVariables();
    }

    public SDVariable segmentProd(SDVariable data, SDVariable segmentIds) {
        return new SegmentProd(this.sameDiff(), data, segmentIds).outputVariable();
    }

    public SDVariable[] segmentProdBp(SDVariable data, SDVariable segmentIds, SDVariable gradient) {
        return new SegmentProdBp(this.sameDiff(), data, segmentIds, gradient).outputVariables();
    }

    public SDVariable segmentSum(SDVariable data, SDVariable segmentIds) {
        return new SegmentSum(this.sameDiff(), data, segmentIds).outputVariable();
    }

    public SDVariable[] segmentSumBp(SDVariable data, SDVariable segmentIds, SDVariable gradient) {
        return new SegmentSumBp(this.sameDiff(), data, segmentIds, gradient).outputVariables();
    }

    public SDVariable unsortedSegmentMax(SDVariable data, SDVariable segmentIds, int numSegments) {
        return new UnsortedSegmentMax(this.sameDiff(), data, segmentIds, numSegments).outputVariable();
    }

    public SDVariable[] unsortedSegmentMaxBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments) {
        return new UnsortedSegmentMaxBp(this.sameDiff(), data, segmentIds, gradient, numSegments).outputVariables();
    }

    public SDVariable unsortedSegmentMin(SDVariable data, SDVariable segmentIds, int numSegments) {
        return new UnsortedSegmentMin(this.sameDiff(), data, segmentIds, numSegments).outputVariable();
    }

    public SDVariable[] unsortedSegmentMinBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments) {
        return new UnsortedSegmentMinBp(this.sameDiff(), data, segmentIds, gradient, numSegments).outputVariables();
    }

    public SDVariable unsortedSegmentMean(SDVariable data, SDVariable segmentIds, int numSegments) {
        return new UnsortedSegmentMean(this.sameDiff(), data, segmentIds, numSegments).outputVariable();
    }

    public SDVariable[] unsortedSegmentMeanBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments) {
        return new UnsortedSegmentMeanBp(this.sameDiff(), data, segmentIds, gradient, numSegments).outputVariables();
    }

    public SDVariable unsortedSegmentProd(SDVariable data, SDVariable segmentIds, int numSegments) {
        return new UnsortedSegmentProd(this.sameDiff(), data, segmentIds, numSegments).outputVariable();
    }

    public SDVariable[] unsortedSegmentProdBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments) {
        return new UnsortedSegmentProdBp(this.sameDiff(), data, segmentIds, gradient, numSegments).outputVariables();
    }

    public SDVariable unsortedSegmentSum(SDVariable data, SDVariable segmentIds, int numSegments) {
        return new UnsortedSegmentSum(this.sameDiff(), data, segmentIds, numSegments).outputVariable();
    }

    public SDVariable[] unsortedSegmentSumBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments) {
        return new UnsortedSegmentSumBp(this.sameDiff(), data, segmentIds, gradient, numSegments).outputVariables();
    }

    public SDVariable unsortedSegmentSqrtN(SDVariable data, SDVariable segmentIds, int numSegments) {
        return new UnsortedSegmentSqrtN(this.sameDiff(), data, segmentIds, numSegments).outputVariable();
    }

    public SDVariable[] unsortedSegmentSqrtNBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments) {
        return new UnsortedSegmentSqrtNBp(this.sameDiff(), data, segmentIds, gradient, numSegments).outputVariables();
    }

    public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides, int[] rates, boolean isSameMode) {
        this.validateDifferentialFunctionsameDiff(df);
        return new Dilation2D(this.sameDiff(), new SDVariable[]{df, weights}, strides, rates, isSameMode, false).outputVariable();
    }

    public SDVariable shape(SDVariable df) {
        this.validateDifferentialFunctionsameDiff(df);
        return new org.nd4j.linalg.api.ops.impl.shape.Shape(this.sameDiff(), df, false).outputVariable();
    }

    public SDVariable size(SDVariable in) {
        return new Size(this.sameDiff(), in).outputVariable();
    }

    public SDVariable sizeAt(SDVariable in, int dimension) {
        return new SizeAt(this.sameDiff(), in, dimension).outputVariable();
    }

    public SDVariable rank(SDVariable df) {
        return new Rank(this.sameDiff(), df, false).outputVariable();
    }

    public SDVariable gather(SDVariable df, int[] indices, int axis) {
        this.validateDifferentialFunctionsameDiff(df);
        return new Gather(this.sameDiff(), df, indices, axis, false).outputVariable();
    }

    public SDVariable gather(SDVariable df, SDVariable indices, int axis) {
        this.validateDifferentialFunctionsameDiff(df);
        return new Gather(this.sameDiff(), df, indices, axis, false).outputVariable();
    }

    public SDVariable gatherNd(SDVariable df, SDVariable indices) {
        this.validateDifferentialFunctionsameDiff(df);
        return new GatherNd(this.sameDiff(), df, indices, false).outputVariable();
    }

    public SDVariable trace(SDVariable in) {
        return new Trace(this.sameDiff(), in).outputVariable();
    }

    public SDVariable cross(SDVariable a, SDVariable b) {
        this.validateDifferentialFunctionsameDiff(a);
        return new Cross(this.sameDiff(), new SDVariable[]{a, b}).outputVariable();
    }

    public SDVariable erf(SDVariable differentialFunction) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new Erf(this.sameDiff(), differentialFunction, false).outputVariable();
    }

    public SDVariable erfc(SDVariable differentialFunction) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new Erfc(this.sameDiff(), differentialFunction, false).outputVariable();
    }

    public SDVariable addi(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new AddOp(this.sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable();
    }

    public List<SDVariable> addBp(SDVariable x, SDVariable y, SDVariable grad) {
        SDVariable[] ret = new AddBpOp(this.sameDiff(), x, y, grad).outputVariables();
        return Arrays.asList(ret);
    }

    public SDVariable sub(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new SubOp(this.sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable();
    }

    public SDVariable squaredDifference(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new SquaredDifferenceOp(this.sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable();
    }

    public List<SDVariable> subBp(SDVariable x, SDVariable y, SDVariable grad) {
        return Arrays.asList(new SubBpOp(this.sameDiff(), x, y, grad).outputVariables());
    }

    public SDVariable subi(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new SubOp(this.sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable();
    }

    public SDVariable mul(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new MulOp(this.sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable();
    }

    public List<SDVariable> mulBp(SDVariable x, SDVariable y, SDVariable grad) {
        return Arrays.asList(new MulBpOp(this.sameDiff(), x, y, grad).outputVariables());
    }

    public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new MulOp(this.sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable();
    }

    public SDVariable div(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new DivOp(this.sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable();
    }

    public SDVariable truncatedDiv(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new TruncateDivOp(this.sameDiff(), differentialFunction, i_v, false).outputVariable();
    }

    public List<SDVariable> divBp(SDVariable x, SDVariable y, SDVariable grad) {
        return Arrays.asList(new DivBpOp(this.sameDiff(), x, y, grad).outputVariables());
    }

    public SDVariable divi(SDVariable differentialFunction, SDVariable i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new DivOp(this.sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable();
    }

    public SDVariable rsub(SDVariable differentialFunction, double i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new ScalarReverseSubtraction(this.sameDiff(), differentialFunction, (Number)i_v).outputVariable();
    }

    public SDVariable rdiv(SDVariable differentialFunction, double i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new ScalarReverseDivision(this.sameDiff(), differentialFunction, (Number)i_v).outputVariable();
    }

    public SDVariable rdivi(SDVariable differentialFunction, double i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new ScalarReverseDivision(this.sameDiff(), differentialFunction, (Number)i_v, true).outputVariable();
    }

    public SDVariable rsubi(SDVariable differentialFunction, double i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new ScalarReverseSubtraction(this.sameDiff(), differentialFunction, (Number)i_v, true).outputVariable();
    }

    public SDVariable add(SDVariable differentialFunction, double i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new ScalarAdd(this.sameDiff(), differentialFunction, (Number)i_v, false).outputVariable();
    }

    public SDVariable addi(SDVariable differentialFunction, double i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new ScalarAdd(this.sameDiff(), differentialFunction, (Number)i_v, true).outputVariable();
    }

    public SDVariable sub(SDVariable differentialFunction, double i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new ScalarSubtraction(this.sameDiff(), differentialFunction, (Number)i_v).outputVariable();
    }

    public SDVariable subi(SDVariable differentialFunction, double i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new ScalarSubtraction(this.sameDiff(), differentialFunction, (Number)i_v, true).outputVariable();
    }

    public SDVariable mul(SDVariable differentialFunction, double i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new ScalarMultiplication(this.sameDiff(), differentialFunction, (Number)i_v).outputVariable();
    }

    public SDVariable muli(SDVariable differentialFunction, double i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new ScalarMultiplication(this.sameDiff(), differentialFunction, (Number)i_v, true).outputVariable();
    }

    public SDVariable div(SDVariable differentialFunction, double i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new ScalarDivision(this.sameDiff(), differentialFunction, (Number)i_v).outputVariable();
    }

    public SDVariable divi(SDVariable differentialFunction, double i_v) {
        this.validateDifferentialFunctionsameDiff(differentialFunction);
        return new ScalarDivision(this.sameDiff(), differentialFunction, (Number)i_v, true).outputVariable();
    }

    public SDVariable gt(SDVariable functionInput, SDVariable functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        this.validateDifferentialFunctionsameDiff(functionInput1);
        return new GreaterThan(this.sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable();
    }

    public SDVariable lt(SDVariable functionInput, SDVariable functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        this.validateDifferentialFunctionsameDiff(functionInput1);
        return new LessThan(this.sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable();
    }

    public SDVariable gti(SDVariable functionInput, SDVariable functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        this.validateDifferentialFunctionsameDiff(functionInput1);
        return new GreaterThan(this.sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable();
    }

    public SDVariable lti(SDVariable functionInput, SDVariable functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        this.validateDifferentialFunctionsameDiff(functionInput1);
        return new LessThan(this.sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable();
    }

    public SDVariable gte(SDVariable functionInput, SDVariable functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        this.validateDifferentialFunctionsameDiff(functionInput1);
        return new GreaterThanOrEqual(this.sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable();
    }

    public SDVariable lte(SDVariable functionInput, SDVariable functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        this.validateDifferentialFunctionsameDiff(functionInput1);
        return new LessThanOrEqual(this.sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable();
    }

    public SDVariable gtei(SDVariable functionInput, SDVariable functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        this.validateDifferentialFunctionsameDiff(functionInput1);
        return new GreaterThanOrEqual(this.sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable();
    }

    public SDVariable ltOrEqi(SDVariable functionInput, SDVariable functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        this.validateDifferentialFunctionsameDiff(functionInput1);
        return new LessThanOrEqual(this.sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable();
    }

    public SDVariable gt(SDVariable functionInput, double functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        return new ScalarGreaterThan(this.sameDiff(), functionInput, (Number)functionInput1, false).outputVariable();
    }

    public SDVariable lt(SDVariable functionInput, double functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        return new ScalarLessThan(this.sameDiff(), functionInput, (Number)functionInput1, false).outputVariable();
    }

    public SDVariable gti(SDVariable functionInput, double functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        return new ScalarGreaterThan(this.sameDiff(), functionInput, (Number)functionInput1, true).outputVariable();
    }

    public SDVariable lti(SDVariable functionInput, double functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        return new ScalarLessThan(this.sameDiff(), functionInput, (Number)functionInput1, true).outputVariable();
    }

    public SDVariable gte(SDVariable functionInput, double functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        return new ScalarGreaterThanOrEqual(this.sameDiff(), functionInput, (Number)functionInput1, false).outputVariable();
    }

    public SDVariable lte(SDVariable functionInput, double functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        return new ScalarLessThanOrEqual(this.sameDiff(), functionInput, (Number)functionInput1, false).outputVariable();
    }

    public SDVariable gtei(SDVariable functionInput, double functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        return new ScalarGreaterThanOrEqual(this.sameDiff(), functionInput, (Number)functionInput1, true).outputVariable();
    }

    public SDVariable ltei(SDVariable functionInput, double functionInput1) {
        this.validateDifferentialFunctionsameDiff(functionInput);
        return new ScalarLessThanOrEqual(this.sameDiff(), functionInput, (Number)functionInput1, true).outputVariable();
    }

    public SDVariable eq(SDVariable iX, double i_y) {
        return new ScalarEquals(this.sameDiff(), iX, (Number)i_y).outputVariable();
    }

    public SDVariable eqi(SDVariable iX, double i_y) {
        return new ScalarEquals(this.sameDiff(), iX, (Number)i_y, true).outputVariable();
    }

    public SDVariable isNonDecreasing(SDVariable iX) {
        this.validateDifferentialFunctionsameDiff(iX);
        return new IsNonDecreasing(this.sameDiff(), new SDVariable[]{iX}, false).outputVariable();
    }

    public SDVariable isStrictlyIncreasing(SDVariable iX) {
        this.validateDifferentialFunctionsameDiff(iX);
        return new IsStrictlyIncreasing(this.sameDiff(), new SDVariable[]{iX}, false).outputVariable();
    }

    public SDVariable isNumericTensor(SDVariable iX) {
        this.validateDifferentialFunctionsameDiff(iX);
        return new IsNumericTensor(this.sameDiff(), new SDVariable[]{iX}, false).outputVariable();
    }

    public SDVariable slice(SDVariable input, int[] begin, int[] size) {
        return new Slice(this.sameDiff(), input, begin, size).outputVariable();
    }

    public SDVariable slice(SDVariable input, SDVariable begin, SDVariable size) {
        return new Slice(this.sameDiff(), input, begin, size).outputVariable();
    }

    public SDVariable sliceBp(SDVariable input, SDVariable gradient, int[] begin, int[] size) {
        return new SliceBp(this.sameDiff(), input, gradient, begin, size).outputVariable();
    }

    public SDVariable stridedSlice(SDVariable input, int[] begin, int[] end, int[] strides) {
        return new StridedSlice(this.sameDiff(), input, begin, end, strides).outputVariable();
    }

    public SDVariable stridedSlice(SDVariable input, long[] begin, long[] end, long[] strides) {
        return new StridedSlice(this.sameDiff(), input, begin, end, strides).outputVariable();
    }

    public SDVariable stridedSlice(SDVariable in, int[] begin, int[] end, int[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
        return new StridedSlice(this.sameDiff(), in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask).outputVariable();
    }

    public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
        return new StridedSlice(this.sameDiff(), in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask).outputVariable();
    }

    public SDVariable stridedSliceBp(SDVariable in, SDVariable grad, long[] begin, long[] end, long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
        return new StridedSliceBp(this.sameDiff(), in, grad, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask).outputVariable();
    }

    public SDVariable stridedSliceBp(SDVariable in, SDVariable grad, SDVariable begin, SDVariable end, SDVariable strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
        return new StridedSliceBp(this.sameDiff(), in, grad, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask).outputVariable();
    }

    public SDVariable scatterAdd(SDVariable ref, SDVariable indices, SDVariable updates) {
        return new ScatterAdd(this.sameDiff(), ref, indices, updates).outputVariable();
    }

    public SDVariable scatterSub(SDVariable ref, SDVariable indices, SDVariable updates) {
        return new ScatterSub(this.sameDiff(), ref, indices, updates).outputVariable();
    }

    public SDVariable scatterMul(SDVariable ref, SDVariable indices, SDVariable updates) {
        return new ScatterMul(this.sameDiff(), ref, indices, updates).outputVariable();
    }

    public SDVariable scatterDiv(SDVariable ref, SDVariable indices, SDVariable updates) {
        return new ScatterDiv(this.sameDiff(), ref, indices, updates).outputVariable();
    }

    public SDVariable scatterMax(SDVariable ref, SDVariable indices, SDVariable updates) {
        return new ScatterMax(this.sameDiff(), ref, indices, updates).outputVariable();
    }

    public SDVariable scatterMin(SDVariable ref, SDVariable indices, SDVariable updates) {
        return new ScatterMin(this.sameDiff(), ref, indices, updates).outputVariable();
    }

    public SDVariable scatterUpdate(SDVariable ref, SDVariable indices, SDVariable updates) {
        return new ScatterUpdate(this.sameDiff(), ref, indices, updates).outputVariable();
    }

    public SDVariable merge(SDVariable ... inputs) {
        return new Merge(this.sameDiff(), inputs).outputVariable();
    }

    public SDVariable[] switchOp(SDVariable input, SDVariable predicate) {
        return new Switch(this.sameDiff(), input, predicate).outputVariables();
    }

    public void validateDifferentialFunctionsameDiff(SDVariable function) {
        Preconditions.checkState((function != null ? 1 : 0) != 0, (String)"Passed in function was null.");
        Preconditions.checkState((function.getSameDiff() == this.sameDiff ? 1 : 0) != 0);
        Preconditions.checkState((function.getSameDiff() == this.getSameDiff() ? 1 : 0) != 0, (String)"Function applications must be contained in same sameDiff. The left %s must match this function %s", (Object)function, (Object)this);
        Preconditions.checkState((this.sameDiff == this.getSameDiff() ? 1 : 0) != 0, (String)"Function applications must be contained in same sameDiff. The left %s must match this function ", (Object)function, (Object)this);
    }

    public void validateDifferentialFunctionGraph(SDVariable function) {
        Preconditions.checkState((function.getSameDiff() == this.getSameDiff() ? 1 : 0) != 0, (String)"Function applications must be contained in same graph. The left %s must match this function %s", (Object)function, (Object)this);
    }

    public SDVariable doRepeat(SDVariable func, SDVariable input) {
        this.validateDifferentialFunctionsameDiff(func);
        this.validateDifferentialFunctionsameDiff(input);
        return this.tile(func, ArrayUtil.toInts((long[])input.getShape()));
    }

    public String toString() {
        return "DifferentialFunctionFactory{methodNames=" + methodNames + "}";
    }

    public SameDiff getSameDiff() {
        return this.sameDiff;
    }

    public void setSameDiff(SameDiff sameDiff) {
        this.sameDiff = sameDiff;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof DifferentialFunctionFactory)) {
            return false;
        }
        DifferentialFunctionFactory other = (DifferentialFunctionFactory)o;
        if (!other.canEqual(this)) {
            return false;
        }
        SameDiff this$sameDiff = this.getSameDiff();
        SameDiff other$sameDiff = other.getSameDiff();
        return !(this$sameDiff == null ? other$sameDiff != null : !((Object)this$sameDiff).equals(other$sameDiff));
    }

    protected boolean canEqual(Object other) {
        return other instanceof DifferentialFunctionFactory;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        SameDiff $sameDiff = this.getSameDiff();
        result = result * 59 + ($sameDiff == null ? 43 : ((Object)$sameDiff).hashCode());
        return result;
    }
}

