/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.listeners.debugging;

import java.util.Arrays;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;

public class ExecDebuggingListener
extends BaseListener {
    private final PrintMode printMode;
    private final int maxIterations;
    private final boolean logIter;
    private long printIterations = 0L;
    private int lastIter = -1;
    private int stepThisIter = 0;

    public ExecDebuggingListener(PrintMode printMode, int maxIterations, boolean logIter) {
        this.printMode = printMode;
        this.maxIterations = maxIterations;
        this.logIter = logIter;
    }

    @Override
    public boolean isActive(Operation operation) {
        return true;
    }

    @Override
    public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
        Op lOp;
        if (this.lastIter != at.iteration()) {
            this.lastIter = at.iteration();
            this.stepThisIter = 0;
            ++this.printIterations;
        }
        if (this.maxIterations > 0 && this.printIterations > (long)this.maxIterations) {
            return;
        }
        StringBuilder sb = new StringBuilder();
        if (this.logIter) {
            sb.append("(iter=").append(at.iteration()).append(",epoch=").append(at.epoch()).append(",");
        }
        sb.append("op=").append(this.stepThisIter++).append(this.logIter ? ") " : " - ");
        DifferentialFunction df = op.getOp();
        sb.append(op.getOp().getClass().getName());
        CustomOp co = df instanceof CustomOp ? (CustomOp)((Object)df) : null;
        Op op2 = lOp = df instanceof Op ? (Op)((Object)df) : null;
        if (this.printMode == PrintMode.OPS_ONLY) {
            sb.append("\n");
        } else if (this.printMode == PrintMode.SHAPES_ONLY) {
            if (co != null) {
                int i;
                if (co.iArgs() != null && co.iArgs().length > 0) {
                    sb.append("\n\tiArgs=").append(Arrays.toString(co.iArgs()));
                }
                if (co.bArgs() != null && co.bArgs().length > 0) {
                    sb.append("\n\tbArgs=").append(Arrays.toString(co.bArgs()));
                }
                if (co.tArgs() != null && co.tArgs().length > 0) {
                    sb.append("\n\ttArgs=").append(Arrays.toString(co.tArgs()));
                }
                INDArray[] inputs = co.inputArguments();
                INDArray[] outputs = co.outputArguments();
                if (inputs != null) {
                    for (i = 0; i < inputs.length; ++i) {
                        sb.append("\n\tInput[").append(i).append("]=").append(inputs[i].shapeInfoToString());
                    }
                }
                if (outputs != null) {
                    for (i = 0; i < outputs.length; ++i) {
                        sb.append("\n\tOutputs[").append(i).append("]=").append(outputs[i].shapeInfoToString());
                    }
                }
            } else {
                INDArray scalar;
                if (lOp.x() != null) {
                    sb.append("\n\tx: ").append(lOp.x().shapeInfoToString());
                }
                if (lOp.y() != null) {
                    sb.append("\n\ty: ").append(lOp.y().shapeInfoToString());
                }
                if (lOp.z() != null) {
                    sb.append("\n\tz: ").append(lOp.z().shapeInfoToString());
                }
                if (lOp instanceof ScalarOp && (scalar = ((ScalarOp)lOp).scalar()) != null) {
                    sb.append("\n\tscalar: ").append(scalar.shapeInfoToString());
                }
            }
            sb.append("\n");
        } else if (this.printMode == PrintMode.REPRODUCE) {
            sb.append("\n");
            if (co != null) {
                int i;
                sb.append("DynamicCustomOp op = new ").append(co.getClass().getName()).append("();\n");
                if (co.iArgs() != null && co.iArgs().length > 0) {
                    sb.append("op.addIArgument(").append(Arrays.toString(co.iArgs()).replaceAll("[\\[\\]]", "")).append(");\n");
                }
                if (co.bArgs() != null && co.bArgs().length > 0) {
                    sb.append("op.addBArgument(").append(Arrays.toString(co.bArgs()).replaceAll("[\\[\\]]", "")).append(");\n");
                }
                if (co.tArgs() != null && co.tArgs().length > 0) {
                    sb.append("op.addTArgument(").append(Arrays.toString(co.tArgs()).replaceAll("[\\[\\]]", "")).append(");\n");
                }
                INDArray[] inputs = co.inputArguments();
                INDArray[] outputs = co.outputArguments();
                if (inputs != null) {
                    sb.append("INDArray[] inputs = new INDArray[").append(inputs.length).append("];\n");
                    for (i = 0; i < inputs.length; ++i) {
                        sb.append("inputs[").append(i).append("] = ");
                        sb.append(ExecDebuggingListener.createString(inputs[i])).append(";\n");
                    }
                    sb.append("op.addInputArgument(inputs);\n");
                }
                if (outputs != null) {
                    sb.append("INDArray[] outputs = new INDArray[").append(outputs.length).append("];\n");
                    for (i = 0; i < outputs.length; ++i) {
                        sb.append("outputs[").append(i).append("] = ");
                        sb.append(ExecDebuggingListener.createString(outputs[i])).append(";\n");
                    }
                    sb.append("op.addOutputArgument(outputs);\n");
                }
            } else {
                INDArray scalar;
                sb.append("Op op = new ").append(op.getClass().getName()).append("();\n");
                if (lOp.x() != null) {
                    sb.append("op.setX(").append(ExecDebuggingListener.createString(lOp.x())).append(");\n");
                }
                if (lOp.y() != null) {
                    sb.append("op.setY(").append(ExecDebuggingListener.createString(lOp.y())).append(");\n");
                }
                if (lOp.z() != null) {
                    sb.append("op.setZ").append(ExecDebuggingListener.createString(lOp.z())).append(");\n");
                }
                if (lOp instanceof ScalarOp && (scalar = ((ScalarOp)lOp).scalar()) != null) {
                    sb.append("((ScalarOp)op).setScalar(").append(ExecDebuggingListener.createString(scalar)).append(");\n");
                }
            }
            sb.append("Nd4j.exec(op);\n");
        }
        System.out.print(sb.toString());
    }

    private static String createString(INDArray arr) {
        StringBuilder sb = new StringBuilder();
        if (arr.isEmpty()) {
            sb.append("Nd4j.empty(DataType.").append(arr.dataType()).append(");");
        } else {
            sb.append("Nd4j.createFromArray(");
            DataType dt = arr.dataType();
            switch (dt) {
                case DOUBLE: {
                    double[] dArr = arr.dup().data().asDouble();
                    sb.append(Arrays.toString(dArr).replaceAll("[\\[\\]]", ""));
                    break;
                }
                case FLOAT: 
                case HALF: 
                case BFLOAT16: {
                    float[] fArr = arr.dup().data().asFloat();
                    sb.append(Arrays.toString(fArr).replaceAll(",", "f,").replaceAll("]", "f").replaceAll("[\\[\\]]", ""));
                    break;
                }
                case LONG: 
                case UINT32: 
                case UINT64: {
                    long[] lArr = arr.dup().data().asLong();
                    sb.append(Arrays.toString(lArr).replaceAll(",", "L,").replaceAll("]", "L").replaceAll("[\\[\\]]", ""));
                    break;
                }
                case INT: 
                case SHORT: 
                case UBYTE: 
                case BYTE: 
                case UINT16: 
                case BOOL: {
                    int[] iArr = arr.dup().data().asInt();
                    sb.append(Arrays.toString(iArr).replaceAll("[\\[\\]]", ""));
                    break;
                }
                case UTF8: {
                    break;
                }
            }
            sb.append(").reshape(").append(Arrays.toString(arr.shape()).replaceAll("[\\[\\]]", "")).append(")");
            if (dt == DataType.HALF || dt == DataType.BFLOAT16 || dt == DataType.UINT32 || dt == DataType.UINT64 || dt == DataType.SHORT || dt == DataType.UBYTE || dt == DataType.BYTE || dt == DataType.UINT16 || dt == DataType.BOOL) {
                sb.append(".cast(DataType.").append(arr.dataType()).append(")");
            }
        }
        return sb.toString();
    }

    public static enum PrintMode {
        OPS_ONLY,
        SHAPES_ONLY,
        REPRODUCE;

    }
}

