/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.serde;

import com.google.flatbuffers.FlatBufferBuilder;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.serde.LegacyOpMapper;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatProperties;
import org.nd4j.graph.IntPair;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.primitives.Ints;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FlatBuffersMapper {
    private static final Logger log = LoggerFactory.getLogger(FlatBuffersMapper.class);
    private static final boolean[] EMPTY_BOOLEAN = new boolean[0];
    private static final int[] EMPTY_INT = new int[0];
    private static final long[] EMPTY_LONG = new long[0];
    private static final double[] EMPTY_DOUBLE = new double[0];

    private FlatBuffersMapper() {
    }

    public static byte getDataTypeAsByte(@NonNull DataType type) {
        if (type == null) {
            throw new NullPointerException("type is marked non-null but is null");
        }
        switch (type) {
            case FLOAT: {
                return 5;
            }
            case DOUBLE: {
                return 6;
            }
            case HALF: {
                return 3;
            }
            case INT: {
                return 9;
            }
            case LONG: {
                return 10;
            }
            case BOOL: {
                return 1;
            }
            case SHORT: {
                return 8;
            }
            case BYTE: {
                return 7;
            }
            case UBYTE: {
                return 11;
            }
            case UTF8: {
                return 50;
            }
            case UINT16: {
                return 12;
            }
            case UINT32: {
                return 13;
            }
            case UINT64: {
                return 14;
            }
            case BFLOAT16: {
                return 17;
            }
        }
        throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + (Object)((Object)type) + "]");
    }

    public static DataType getDataTypeFromByte(byte val) {
        if (val == 5) {
            return DataType.FLOAT;
        }
        if (val == 6) {
            return DataType.DOUBLE;
        }
        if (val == 3) {
            return DataType.HALF;
        }
        if (val == 9) {
            return DataType.INT;
        }
        if (val == 10) {
            return DataType.LONG;
        }
        if (val == 7) {
            return DataType.BYTE;
        }
        if (val == 1) {
            return DataType.BOOL;
        }
        if (val == 11) {
            return DataType.UBYTE;
        }
        if (val == 8) {
            return DataType.SHORT;
        }
        if (val == 50) {
            return DataType.UTF8;
        }
        if (val == 12) {
            return DataType.UINT16;
        }
        if (val == 13) {
            return DataType.UINT32;
        }
        if (val == 14) {
            return DataType.UINT64;
        }
        if (val == 17) {
            return DataType.BFLOAT16;
        }
        throw new RuntimeException("Unknown datatype: " + val);
    }

    public static long getOpNum(String name, Op.Type type) {
        if (type == Op.Type.LOOP) {
            return 0L;
        }
        if (type == Op.Type.RETURN) {
            return 40L;
        }
        if (type == Op.Type.CONDITIONAL) {
            return 10L;
        }
        if (type == Op.Type.LOOP_COND) {
            return 70L;
        }
        if (type == Op.Type.LOGIC) {
            switch (name) {
                case "enter": {
                    return 100L;
                }
                case "exit": {
                    return 90L;
                }
                case "next_iteration": {
                    return 80L;
                }
                case "merge": {
                    return 60L;
                }
                case "switch": {
                    return 30L;
                }
                case "ExternalErrorsFn": {
                    return 0L;
                }
            }
            throw new IllegalStateException("Unknown LOGIC op with name: " + name);
        }
        if (type == Op.Type.CUSTOM) {
            CustomOpDescriptor name2 = Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase());
            if (name2 == null) {
                CustomOpDescriptor name3 = Nd4j.getExecutioner().getCustomOperations().get(name);
                if (name3 == null) {
                    return 0L;
                }
                return name3.getHash();
            }
            return name2.getHash();
        }
        try {
            DifferentialFunction op = DifferentialFunctionClassHolder.getInstance().getInstance(name);
            return op.opNum();
        }
        catch (Exception e) {
            throw new RuntimeException("Could not find op number for operation: [" + name + "]", e);
        }
    }

    public static Op.Type getTypeFromByte(byte type) {
        switch (type) {
            case 10: {
                return Op.Type.SCALAR;
            }
            case 11: {
                return Op.Type.SCALAR_BOOL;
            }
            case 12: {
                return Op.Type.BROADCAST;
            }
            case 13: {
                return Op.Type.BROADCAST_BOOL;
            }
            case 2: {
                return Op.Type.TRANSFORM_BOOL;
            }
            case 0: {
                return Op.Type.TRANSFORM_FLOAT;
            }
            case 1: {
                return Op.Type.TRANSFORM_SAME;
            }
            case 4: {
                return Op.Type.TRANSFORM_ANY;
            }
            case 3: {
                return Op.Type.TRANSFORM_STRICT;
            }
            case 8: {
                return Op.Type.REDUCE_BOOL;
            }
            case 7: {
                return Op.Type.REDUCE_LONG;
            }
            case 5: {
                return Op.Type.REDUCE_FLOAT;
            }
            case 6: {
                return Op.Type.REDUCE_SAME;
            }
            case 16: {
                return Op.Type.REDUCE3;
            }
            case 9: {
                return Op.Type.INDEXREDUCE;
            }
            case 20: {
                return Op.Type.RANDOM;
            }
            case 119: {
                return Op.Type.LOGIC;
            }
            case 21: {
                return Op.Type.CUSTOM;
            }
            case 14: {
                return Op.Type.PAIRWISE;
            }
            case 15: {
                return Op.Type.PAIRWISE_BOOL;
            }
            case 17: {
                return Op.Type.SUMMARYSTATS;
            }
        }
        throw new UnsupportedOperationException("Unknown op type passed in: " + type);
    }

    public static byte getFlatOpType(Op.Type type) {
        switch (type) {
            case SCALAR: {
                return 10;
            }
            case SCALAR_BOOL: {
                return 11;
            }
            case BROADCAST: {
                return 12;
            }
            case BROADCAST_BOOL: {
                return 13;
            }
            case TRANSFORM_BOOL: {
                return 2;
            }
            case TRANSFORM_FLOAT: {
                return 0;
            }
            case TRANSFORM_SAME: {
                return 1;
            }
            case TRANSFORM_ANY: {
                return 4;
            }
            case TRANSFORM_STRICT: {
                return 3;
            }
            case SPECIAL: {
                return 3;
            }
            case REDUCE_FLOAT: {
                return 5;
            }
            case REDUCE_BOOL: {
                return 8;
            }
            case REDUCE_SAME: {
                return 6;
            }
            case REDUCE_LONG: {
                return 7;
            }
            case REDUCE3: {
                return 16;
            }
            case INDEXREDUCE: {
                return 9;
            }
            case RANDOM: {
                return 20;
            }
            case CONDITIONAL: 
            case LOOP: 
            case RETURN: 
            case LOOP_COND: 
            case LOGIC: {
                return 119;
            }
            case CUSTOM: {
                return 21;
            }
            case PAIRWISE: {
                return 14;
            }
            case PAIRWISE_BOOL: {
                return 15;
            }
            case SUMMARYSTATS: 
            case VARIANCE: {
                return 17;
            }
        }
        throw new UnsupportedOperationException("Unknown op type passed in: " + (Object)((Object)type));
    }

    public static ByteOrder getOrderFromByte(byte val) {
        if (val == 0) {
            return ByteOrder.LITTLE_ENDIAN;
        }
        return ByteOrder.BIG_ENDIAN;
    }

    public static byte getOrderAsByte() {
        if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) {
            return 1;
        }
        return 0;
    }

    public static DifferentialFunction fromFlatNode(FlatNode fn) {
        Op op;
        int id = fn.id();
        String name = fn.name();
        Op.Type opType = FlatBuffersMapper.getTypeFromByte(fn.opType());
        long opNum = fn.opNum();
        int[] input = new int[fn.inputLength()];
        for (int i = 0; i < input.length; ++i) {
            input[i] = fn.input(i);
        }
        IntPair[] inputPaired = new IntPair[fn.inputPairedLength()];
        for (int i = 0; i < inputPaired.length; ++i) {
            inputPaired[i] = fn.inputPaired(i);
        }
        int[] output = new int[fn.outputLength()];
        for (int i = 0; i < output.length; ++i) {
            output[i] = fn.output(i);
        }
        double[] extraParams = new double[fn.extraParamsLength()];
        for (int i = 0; i < extraParams.length; ++i) {
            extraParams[i] = fn.extraParams(i);
        }
        long[] extraInteger = new long[fn.extraIntegerLength()];
        for (int i = 0; i < extraInteger.length; ++i) {
            extraInteger[i] = fn.extraInteger(i);
        }
        boolean[] extraBools = new boolean[fn.extraBoolsLength()];
        for (int i = 0; i < extraBools.length; ++i) {
            extraBools[i] = fn.extraBools(i);
        }
        DataType[] extraDTypes = new DataType[fn.extraTypesLength()];
        for (int i = 0; i < extraDTypes.length; ++i) {
            extraDTypes[i] = DataType.fromInt(fn.extraTypes(i));
        }
        int[] dimensions = new int[fn.dimensionsLength()];
        for (int i = 0; i < dimensions.length; ++i) {
            dimensions[i] = fn.dimensions(i);
        }
        FlatArray fa = fn.scalar();
        INDArray scalar = null;
        if (fa != null) {
            scalar = Nd4j.createFromFlatArray(fa);
        }
        FlatProperties[] flatProperties = new FlatProperties[fn.propertiesLength()];
        for (int i = 0; i < flatProperties.length; ++i) {
            flatProperties[i] = fn.properties(i);
        }
        Map<String, Object> props = FlatBuffersMapper.mapFlatPropertiesToFunctionProperties(Arrays.asList(flatProperties));
        if (opType == Op.Type.CUSTOM || opType == Op.Type.LOGIC) {
            DifferentialFunction op2;
            String opName = fn.opName();
            Class<?> c = DifferentialFunctionClassHolder.getInstance().customOpClassForHashAndName(opNum, opName);
            Preconditions.checkNotNull(c, (String)"Could not find class for hash %s", (long)opNum);
            try {
                op2 = (DifferentialFunction)c.newInstance();
            }
            catch (IllegalAccessException | InstantiationException e) {
                throw new RuntimeException("Error creating differential function instance of type " + c);
            }
            op2.setOwnName(name);
            ((CustomOp)((Object)op2)).addIArgument(extraInteger);
            ((CustomOp)((Object)op2)).addTArgument(extraParams);
            ((CustomOp)((Object)op2)).addBArgument(extraBools);
            ((CustomOp)((Object)op2)).addDArgument(extraDTypes);
            op2.setPropertiesForFunction(props);
            return op2;
        }
        Class<?> c = LegacyOpMapper.getLegacyOpClassForId(opType, (int)opNum);
        try {
            op = (Op)c.newInstance();
        }
        catch (IllegalAccessException | InstantiationException e) {
            throw new RuntimeException("Error creating differential function (Op) instance of type " + c);
        }
        if (extraParams.length > 0) {
            Object[] extraParamsObj = new Object[extraParams.length];
            for (int i = 0; i < extraParams.length; ++i) {
                extraParamsObj[i] = extraParams[i];
            }
            op.setExtraArgs(extraParamsObj);
        }
        if (opType == Op.Type.SCALAR || opType == Op.Type.SCALAR_BOOL) {
            ScalarOp sOp = (ScalarOp)op;
            sOp.setScalar(scalar);
        } else if (opType == Op.Type.REDUCE_FLOAT || opType == Op.Type.REDUCE3 || opType == Op.Type.SUMMARYSTATS || opType == Op.Type.VARIANCE || opType == Op.Type.REDUCE_BOOL || opType == Op.Type.REDUCE_LONG || opType == Op.Type.REDUCE_SAME) {
            BaseReduceOp ba = (BaseReduceOp)op;
            ba.setDimensions(dimensions);
            ba.setDimensionz(Shape.ndArrayDimFromInt(dimensions));
        } else if (opType == Op.Type.INDEXREDUCE) {
            BaseIndexAccumulation bia = (BaseIndexAccumulation)op;
            bia.setDimensions(dimensions);
            bia.setDimensionz(Shape.ndArrayDimFromInt(dimensions));
        }
        ((DifferentialFunction)((Object)op)).setPropertiesForFunction(props);
        return (DifferentialFunction)((Object)op);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static int[] mapFunctionPropertiesToFlatProperties(FlatBufferBuilder fbb, Map<String, Object> fnProps) {
        int[] outIdxs = new int[fnProps.size()];
        int count = 0;
        for (Map.Entry<String, Object> e : fnProps.entrySet()) {
            Object v = e.getValue();
            int iname = fbb.createString((CharSequence)e.getKey());
            int[] i = null;
            long[] l = null;
            double[] d = null;
            int[] aIdx = null;
            boolean[] b = null;
            int[] sIdx = null;
            int[] shape = null;
            if (v != null) {
                int strOffset;
                if (v instanceof Boolean) {
                    b = new boolean[]{(Boolean)v};
                } else if (v instanceof Character) {
                    i = new int[]{((Character)v).charValue()};
                } else if (v instanceof Number) {
                    if (v instanceof Double) {
                        d = new double[]{(Double)v};
                    } else if (v instanceof Float) {
                        d = new double[]{((Float)v).floatValue()};
                    } else if (v instanceof Integer) {
                        i = new int[]{(Integer)v};
                    } else {
                        if (!(v instanceof Long)) throw new UnsupportedOperationException("Unable to map property \"" + e.getKey() + "\" of type " + v.getClass());
                        l = new long[]{(Long)v};
                    }
                } else if (v instanceof String) {
                    String str = (String)v;
                    strOffset = fbb.createString((CharSequence)str);
                    sIdx = new int[]{strOffset};
                } else if (v instanceof DataType) {
                    String str = v.toString();
                    strOffset = fbb.createString((CharSequence)str);
                    sIdx = new int[]{strOffset};
                } else if (v instanceof Enum) {
                    String str = v.toString();
                    strOffset = fbb.createString((CharSequence)str);
                    sIdx = new int[]{strOffset};
                } else if (v instanceof INDArray) {
                    INDArray arr = (INDArray)v;
                    aIdx = new int[]{arr.toFlatArray(fbb)};
                } else if (v.getClass().isArray()) {
                    int j;
                    if (v.getClass().getComponentType().isPrimitive()) {
                        if (v instanceof boolean[]) {
                            b = (boolean[])v;
                            shape = new int[]{b.length};
                        } else if (v instanceof double[]) {
                            d = (double[])v;
                            shape = new int[]{d.length};
                        } else if (v instanceof int[]) {
                            i = (int[])v;
                            shape = new int[]{i.length};
                        } else {
                            if (!(v instanceof long[])) throw new UnsupportedOperationException("Unable to map property \"" + e.getKey() + "\" of type " + v.getClass());
                            l = (long[])v;
                            shape = new int[]{l.length};
                        }
                    } else if (v instanceof String[]) {
                        String[] strArr = (String[])v;
                        sIdx = new int[strArr.length];
                        for (j = 0; j < strArr.length; ++j) {
                            sIdx[j] = fbb.createString((CharSequence)strArr[j]);
                        }
                        shape = new int[]{strArr.length};
                    } else if (v instanceof INDArray[]) {
                        INDArray[] arrArr = (INDArray[])v;
                        aIdx = new int[arrArr.length];
                        for (j = 0; j < arrArr.length; ++j) {
                            aIdx[j] = arrArr[j].toFlatArray(fbb);
                        }
                    } else if (v.getClass().getComponentType().isArray()) {
                        shape = ArrayUtil.arrayShape((Object)v, (boolean)true);
                        if (v instanceof boolean[][]) {
                            b = ArrayUtil.flatten((boolean[][])((boolean[][])v));
                        } else if (v instanceof boolean[][][]) {
                            b = ArrayUtil.flatten((boolean[][][])((boolean[][][])v));
                        } else if (v instanceof double[][]) {
                            d = ArrayUtil.flatten((double[][])((double[][])v));
                        } else if (v instanceof double[][][]) {
                            d = ArrayUtil.flatten((double[][][])((double[][][])v));
                        } else if (v instanceof int[][]) {
                            i = ArrayUtil.flatten((int[][])((int[][])v));
                        } else if (v instanceof int[][][]) {
                            i = ArrayUtil.flatten((int[][][])((int[][][])v));
                        } else if (v instanceof long[][]) {
                            l = ArrayUtil.flatten((long[][])((long[][])v));
                        } else {
                            if (!(v instanceof long[][][])) throw new UnsupportedOperationException("Unable to map multidimensional array property \"" + e.getKey() + "\" of type " + v.getClass());
                            l = ArrayUtil.flatten((long[][][])((long[][][])v));
                        }
                    }
                }
            }
            int idxD = FlatProperties.createDVector(fbb, d != null ? d : EMPTY_DOUBLE);
            int idxI = FlatProperties.createIVector(fbb, i != null ? i : EMPTY_INT);
            int idxL = FlatProperties.createLVector(fbb, l != null ? l : EMPTY_LONG);
            int idxA = FlatProperties.createAVector(fbb, aIdx != null ? aIdx : EMPTY_INT);
            int idxB = FlatProperties.createBVector(fbb, b != null ? b : EMPTY_BOOLEAN);
            int idxS = FlatProperties.createSVector(fbb, sIdx != null ? sIdx : EMPTY_INT);
            int idxShape = FlatProperties.createShapeVector(fbb, shape != null ? shape : EMPTY_INT);
            outIdxs[count++] = FlatProperties.createFlatProperties(fbb, iname, idxI, idxL, idxD, idxA, idxB, idxS, idxShape);
        }
        return outIdxs;
    }

    public static Map<String, Object> mapFlatPropertiesToFunctionProperties(Iterable<FlatProperties> list) {
        HashMap<String, Object> out = new HashMap<String, Object>();
        for (FlatProperties p : list) {
            String name = p.name();
            if (p.shapeLength() > 0) {
                int i;
                int[] shape = new int[p.shapeLength()];
                for (int i2 = 0; i2 < shape.length; ++i2) {
                    shape[i2] = p.shape(i2);
                }
                if (p.iLength() > 0) {
                    int[] iArr = new int[p.iLength()];
                    for (i = 0; i < iArr.length; ++i) {
                        iArr[i] = p.i(i);
                    }
                    if (shape.length == 0 || shape.length == 1) {
                        out.put(name, iArr);
                        continue;
                    }
                    if (shape.length == 2) {
                        out.put(name, ArrayUtil.reshapeInt((int[])iArr, (int)shape[0], (int)shape[1]));
                        continue;
                    }
                    if (shape.length != 3) continue;
                    out.put(name, ArrayUtil.reshapeInt((int[])iArr, (int)shape[0], (int)shape[1], (int)shape[2]));
                    continue;
                }
                if (p.dLength() > 0) {
                    double[] dArr = new double[p.dLength()];
                    for (i = 0; i < dArr.length; ++i) {
                        dArr[i] = p.d(i);
                    }
                    if (shape.length == 0 || shape.length == 1) {
                        out.put(name, dArr);
                        continue;
                    }
                    if (shape.length == 2) {
                        out.put(name, ArrayUtil.reshapeDouble((double[])dArr, (int)shape[0], (int)shape[1]));
                        continue;
                    }
                    if (shape.length != 3) continue;
                    out.put(name, ArrayUtil.reshapeDouble((double[])dArr, (int)shape[0], (int)shape[1], (int)shape[2]));
                    continue;
                }
                if (p.lLength() > 0) {
                    long[] lArr = new long[p.lLength()];
                    for (i = 0; i < lArr.length; ++i) {
                        lArr[i] = p.l(i);
                    }
                    if (shape.length == 0 || shape.length == 1) {
                        out.put(name, lArr);
                        continue;
                    }
                    if (shape.length == 2) {
                        out.put(name, ArrayUtil.reshapeLong((long[])lArr, (int)shape[0], (int)shape[1]));
                        continue;
                    }
                    if (shape.length != 3) continue;
                    out.put(name, ArrayUtil.reshapeLong((long[])lArr, (int)shape[0], (int)shape[1], (int)shape[2]));
                    continue;
                }
                if (p.bLength() > 0) {
                    boolean[] bArr = new boolean[p.bLength()];
                    for (i = 0; i < bArr.length; ++i) {
                        bArr[i] = p.b(i);
                    }
                    if (shape.length == 0 || shape.length == 1) {
                        out.put(name, bArr);
                        continue;
                    }
                    if (shape.length == 2) {
                        out.put(name, ArrayUtil.reshapeBoolean((boolean[])bArr, (int)shape[0], (int)shape[1]));
                        continue;
                    }
                    if (shape.length != 3) continue;
                    out.put(name, ArrayUtil.reshapeBoolean((boolean[])bArr, (int)shape[0], (int)shape[1], (int)shape[2]));
                    continue;
                }
                if (p.sLength() > 0) {
                    Object[] sArr = new String[p.sLength()];
                    for (i = 0; i < sArr.length; ++i) {
                        sArr[i] = p.s(i);
                    }
                    if (shape.length == 0 || shape.length == 1) {
                        out.put(name, sArr);
                        continue;
                    }
                    if (shape.length == 2) {
                        out.put(name, ArrayUtil.reshapeObject((Object[])sArr, (int)shape[0], (int)shape[1]));
                        continue;
                    }
                    if (shape.length != 3) continue;
                    out.put(name, ArrayUtil.reshapeObject((Object[])sArr, (int)shape[0], (int)shape[1], (int)shape[2]));
                    continue;
                }
                if (p.aLength() > 0) {
                    Object[] iArr = new INDArray[p.aLength()];
                    for (i = 0; i < iArr.length; ++i) {
                        FlatArray fa = p.a(0);
                        iArr[i] = Nd4j.createFromFlatArray(fa);
                    }
                    if (shape.length == 0 || shape.length == 1) {
                        out.put(name, iArr);
                        continue;
                    }
                    if (shape.length == 2) {
                        out.put(name, ArrayUtil.reshapeObject((Object[])iArr, (int)shape[0], (int)shape[1]));
                        continue;
                    }
                    if (shape.length != 3) continue;
                    out.put(name, ArrayUtil.reshapeObject((Object[])iArr, (int)shape[0], (int)shape[1], (int)shape[2]));
                    continue;
                }
                out.put(name, null);
                continue;
            }
            if (p.bLength() > 0) {
                out.put(name, p.b(0));
                continue;
            }
            if (p.iLength() > 0) {
                out.put(name, p.i(0));
                continue;
            }
            if (p.lLength() > 0) {
                out.put(name, p.l(0));
                continue;
            }
            if (p.dLength() > 0) {
                out.put(name, p.d(0));
                continue;
            }
            if (p.sLength() > 0) {
                out.put(name, p.s(0));
                continue;
            }
            if (p.aLength() > 0) {
                FlatArray fa = p.a(0);
                out.put(name, Nd4j.createFromFlatArray(fa));
                continue;
            }
            out.put(name, null);
        }
        return out;
    }

    public static int asFlatNode(@NonNull SameDiff sameDiff, @NonNull DifferentialFunction node, @NonNull FlatBufferBuilder bufferBuilder, List<SDVariable> variables, Map<String, Integer> reverseMap, Map<String, Integer> forwardMap, Map<String, Integer> framesMap, AtomicInteger idCounter, Integer id) {
        List<String> outVarNames;
        ScalarOp sOp;
        INDArray s;
        int[] dims;
        String[] outNames;
        SDVariable[] inputs;
        Op op;
        double[] extras;
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        if (node == null) {
            throw new NullPointerException("node is marked non-null but is null");
        }
        if (bufferBuilder == null) {
            throw new NullPointerException("bufferBuilder is marked non-null but is null");
        }
        String opName = node.opName();
        long hash = FlatBuffersMapper.getOpNum(node.opName(), node.opType());
        if (node.opType() == Op.Type.CUSTOM) {
            CustomOp op2 = (CustomOp)((Object)node);
            extras = op2.tArgs();
        } else {
            Object[] eArgs = node.getExtraArgs();
            extras = eArgs != null ? new double[eArgs.length] : new double[]{};
            for (int e = 0; e < extras.length; ++e) {
                extras[e] = ((Number)eArgs[e]).doubleValue();
            }
        }
        boolean[] boolArgs = null;
        byte[] dtypeArgs = null;
        long[] extraBits = null;
        if (node.opType() == Op.Type.CUSTOM) {
            DynamicCustomOp dynamicCustomOp = (DynamicCustomOp)node;
            extraBits = dynamicCustomOp.iArgs();
            boolArgs = dynamicCustomOp.bArgs();
            if (dynamicCustomOp.numDArguments() > 0) {
                dtypeArgs = new byte[dynamicCustomOp.numDArguments()];
                DataType[] d = dynamicCustomOp.dArgs();
                for (int e = 0; e < dtypeArgs.length; ++e) {
                    dtypeArgs[e] = (byte)d[e].toInt();
                }
            }
        } else if (node instanceof Enter) {
            String frameName = ((Enter)node).getFrameName();
            if (!framesMap.containsKey(frameName)) {
                framesMap.put(frameName, idCounter.incrementAndGet());
            }
            extraBits = new long[]{framesMap.get(frameName).intValue()};
        } else {
            extraBits = new long[]{};
        }
        if (node.opType() == Op.Type.REDUCE_BOOL || node.opType() == Op.Type.REDUCE_SAME || node.opType() == Op.Type.REDUCE_FLOAT || node.opType() == Op.Type.REDUCE_LONG) {
            op = (ReduceOp)((Object)node);
            boolArgs = new boolean[]{op.isKeepDims(), true};
        } else if (node.opType() == Op.Type.INDEXREDUCE) {
            op = (IndexAccumulation)((Object)node);
            boolArgs = new boolean[]{op.isKeepDims(), true};
        }
        ArrayList<Integer> inPaired = new ArrayList<Integer>();
        int[] outputIds = null;
        SDVariable[] outputVertexId = null;
        try {
            outputVertexId = node.outputVariables();
            outputIds = new int[outputVertexId.length];
            for (int i = 0; i < outputIds.length; ++i) {
                outputIds[i] = variables.indexOf(outputVertexId[i]);
            }
        }
        catch (ND4UnresolvedOutputVariables e) {
            outputIds = new int[]{};
            outputVertexId = null;
        }
        catch (Exception e) {
            throw new ND4JIllegalStateException(e);
        }
        for (SDVariable input : inputs = node.args()) {
            int outIdx;
            String varName = input.name();
            if (sameDiff.getVariables().get(varName).getOutputOfOp() != null) {
                DifferentialFunction df = sameDiff.getOps().get(sameDiff.getVariables().get(varName).getOutputOfOp()).getOp();
                outIdx = sameDiff.getOps().get(df.getOwnName()).getOutputsOfOp().indexOf(varName);
            } else {
                outIdx = 0;
            }
            if (!reverseMap.containsKey(varName)) {
                if (varName.contains("NextIteration")) {
                    int fwdNodeId = idCounter.incrementAndGet();
                    forwardMap.put(varName, fwdNodeId);
                    reverseMap.put(varName, fwdNodeId);
                } else {
                    throw new ND4JIllegalStateException("Unknown variable used in input: [" + varName + "]");
                }
            }
            int nodeId = reverseMap.get(varName);
            inPaired.add(IntPair.createIntPair(bufferBuilder, nodeId, outIdx));
        }
        log.trace("Own Name: {}", (Object)node.getOwnName());
        int ownId = id != null ? id.intValue() : idCounter.incrementAndGet();
        for (String s2 : outNames = node.outputVariablesNames()) {
            if (reverseMap.containsKey(s2)) continue;
            reverseMap.put(s2, ownId);
        }
        Op.Type t = node.opType();
        if (t == Op.Type.REDUCE_FLOAT || t == Op.Type.REDUCE_SAME || t == Op.Type.REDUCE_BOOL || t == Op.Type.REDUCE_LONG || t == Op.Type.INDEXREDUCE || t == Op.Type.REDUCE3 || t == Op.Type.VARIANCE || t == Op.Type.SUMMARYSTATS) {
            dims = node.getDimensions();
            if (dims == null) {
                dims = new int[]{};
            }
        } else {
            dims = new int[]{};
        }
        Map<String, Object> fnProps = node.propertiesForFunction();
        int[] flatProperties = FlatBuffersMapper.mapFunctionPropertiesToFlatProperties(bufferBuilder, fnProps);
        int propIdx = FlatNode.createPropertiesVector(bufferBuilder, flatProperties);
        int nodesIn = FlatNode.createInputVector(bufferBuilder, new int[0]);
        int nodesInPaired = FlatNode.createInputPairedVector(bufferBuilder, Ints.toArray(inPaired));
        int nodesOut = FlatNode.createOutputVector(bufferBuilder, outputIds);
        int extraz = FlatNode.createExtraParamsVector(bufferBuilder, extras);
        int integerArgs = FlatNode.createExtraIntegerVector(bufferBuilder, extraBits);
        int bArgs = FlatNode.createExtraBoolsVector(bufferBuilder, boolArgs != null ? boolArgs : new boolean[]{});
        int dArgs = FlatNode.createOutputTypesVector(bufferBuilder, dtypeArgs != null ? dtypeArgs : new byte[]{});
        int dimensions = FlatNode.createDimensionsVector(bufferBuilder, dims);
        int fname = bufferBuilder.createString((CharSequence)node.getOwnName());
        int scopeName = bufferBuilder.createString((CharSequence)"");
        int scalar = 0;
        if (node instanceof ScalarOp && (s = (sOp = (ScalarOp)((Object)node)).scalar()) != null) {
            scalar = s.toFlatArray(bufferBuilder);
        }
        if (node.opType() == null) {
            log.warn("Null-op node: {}", (Object)node);
        }
        int[] outVarNamesStringsOffsets = new int[(outVarNames = node.getSameDiff().getOps().get(node.getOwnName()).getOutputsOfOp()) == null ? 0 : outVarNames.size()];
        for (int i = 0; i < outVarNamesStringsOffsets.length; ++i) {
            outVarNamesStringsOffsets[i] = bufferBuilder.createString((CharSequence)outVarNames.get(i));
        }
        int outVarNamesOffset = FlatNode.createOutputNamesVector(bufferBuilder, outVarNamesStringsOffsets);
        int opNameOffset = bufferBuilder.createString((CharSequence)opName);
        byte[] outTypes = new byte[outVarNames.size()];
        int i = 0;
        for (String s3 : outVarNames) {
            SDVariable v = sameDiff.getVariable(s3);
            outTypes[i++] = FlatBuffersMapper.getDataTypeAsByte(v.dataType());
        }
        int outTypesOffset = FlatNode.createOutputTypesVector(bufferBuilder, outTypes);
        SameDiffOp sdo = sameDiff.getOps().get(node.getOwnName());
        int opCds = 0;
        int[] opCdsArr = FlatBuffersMapper.mapOrNull(sdo.getControlDeps(), bufferBuilder);
        if (opCdsArr != null) {
            opCds = FlatNode.createControlDepsVector(bufferBuilder, opCdsArr);
        }
        int varCds = 0;
        int[] varCdsArr = FlatBuffersMapper.mapOrNull(sdo.getVarControlDeps(), bufferBuilder);
        if (varCdsArr != null) {
            varCds = FlatNode.createVarControlDepsVector(bufferBuilder, varCdsArr);
        }
        int cdsFor = 0;
        int[] cdsForArr = FlatBuffersMapper.mapOrNull(sdo.getControlDepFor(), bufferBuilder);
        if (cdsForArr != null) {
            cdsFor = FlatNode.createControlDepForVector(bufferBuilder, cdsForArr);
        }
        int flatNode = FlatNode.createFlatNode(bufferBuilder, ownId, fname, FlatBuffersMapper.getFlatOpType(node.opType()), hash, propIdx, nodesIn, nodesInPaired, nodesOut, extraz, integerArgs, bArgs, dimensions, -1, 0, scopeName, outVarNamesOffset, opNameOffset, outTypesOffset, scalar, opCds, varCds, cdsFor, dArgs);
        return flatNode;
    }

    public static int[] mapOrNull(List<String> list, FlatBufferBuilder fbb) {
        if (list == null) {
            return null;
        }
        int[] out = new int[list.size()];
        int i = 0;
        for (String s : list) {
            out[i++] = fbb.createString((CharSequence)s);
        }
        return out;
    }

    public static DifferentialFunction cloneViaSerialize(SameDiff sd, DifferentialFunction df) {
        HashMap<String, Integer> nameToIdxMap = new HashMap<String, Integer>();
        int count = 0;
        for (Variable v : sd.getVariables().values()) {
            nameToIdxMap.put(v.getName(), count++);
        }
        return FlatBuffersMapper.cloneViaSerialize(sd, df, nameToIdxMap);
    }

    public static DifferentialFunction cloneViaSerialize(SameDiff sd, DifferentialFunction df, Map<String, Integer> nameToIdxMap) {
        HashMap<String, Integer> temp2 = new HashMap<String, Integer>();
        HashMap<String, Integer> temp3 = new HashMap<String, Integer>();
        AtomicInteger temp4 = new AtomicInteger();
        FlatBufferBuilder bufferBuilder = new FlatBufferBuilder(1024);
        int fn = FlatBuffersMapper.asFlatNode(sd, df, bufferBuilder, sd.variables(), nameToIdxMap, temp2, temp3, temp4, 0);
        bufferBuilder.finish(fn);
        FlatNode flatNode = FlatNode.getRootAsFlatNode(bufferBuilder.dataBuffer());
        DifferentialFunction clone = FlatBuffersMapper.fromFlatNode(flatNode);
        return clone;
    }

    public static byte toVarType(VariableType variableType) {
        switch (variableType) {
            case VARIABLE: {
                return 0;
            }
            case CONSTANT: {
                return 1;
            }
            case ARRAY: {
                return 2;
            }
            case PLACEHOLDER: {
                return 3;
            }
        }
        throw new RuntimeException("Unknown variable type: " + (Object)((Object)variableType));
    }

    public static VariableType fromVarType(byte varType) {
        switch (varType) {
            case 0: {
                return VariableType.VARIABLE;
            }
            case 1: {
                return VariableType.CONSTANT;
            }
            case 2: {
                return VariableType.ARRAY;
            }
            case 3: {
                return VariableType.PLACEHOLDER;
            }
        }
        throw new IllegalStateException("Unknown VarType byte value:" + varType);
    }
}

