/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.transform;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.transform.SubGraph;
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
import org.nd4j.common.base.Preconditions;

public class GraphTransformUtil {
    private GraphTransformUtil() {
    }

    public static SameDiff replaceSubgraphsMatching(@NonNull SameDiff sd, @NonNull SubGraphPredicate p, @NonNull SubGraphProcessor processor) {
        if (sd == null) {
            throw new NullPointerException("sd is marked non-null but is null");
        }
        if (p == null) {
            throw new NullPointerException("p is marked non-null but is null");
        }
        if (processor == null) {
            throw new NullPointerException("processor is marked non-null but is null");
        }
        sd = sd.dup();
        List<SubGraph> subgraphs = GraphTransformUtil.getSubgraphsMatching(sd, p);
        for (SubGraph sg : subgraphs) {
            List<SDVariable> newOutputs = processor.processSubgraph(sd, sg);
            List<SDVariable> oldOutputs = sg.outputs();
            Preconditions.checkState(oldOutputs.size() == newOutputs.size(), "Error applying subgraph processor: different number of outputs for subgraph (%s) vs. returned by preprocessor (%s)", oldOutputs.size(), newOutputs.size());
            List<DifferentialFunction> allSubGraphFns = sg.allFunctionsInSubgraph();
            for (int i = 0; i < oldOutputs.size(); ++i) {
                String newOutVarName;
                String oldOutVarName = oldOutputs.get(i).name();
                Preconditions.checkState(!oldOutVarName.equals(newOutVarName = newOutputs.get(i).name()), "Reusing old variables not yet implemented");
                List<String> oldInputsForOps = sd.getVariables().get(oldOutVarName).getInputsForOp();
                if (oldInputsForOps != null) {
                    Iterator<SameDiffOp> newInputsForOps = new ArrayList();
                    for (String s : oldInputsForOps) {
                        DifferentialFunction df = sd.getOpById(s);
                        if (allSubGraphFns.contains(df)) continue;
                        newInputsForOps.add((SameDiffOp)((Object)s));
                    }
                    sd.getVariables().get(newOutVarName).setInputsForOp((List<String>)((Object)newInputsForOps));
                }
                for (Variable variable : sd.getVariables().values()) {
                    int idx;
                    List<String> cds;
                    if (variable.getControlDepsForVar() != null) {
                        cds = variable.getControlDepsForVar();
                        while ((idx = cds.indexOf(oldOutVarName)) > 0) {
                            cds.set(idx, newOutVarName);
                        }
                    }
                    if (variable.getControlDeps() == null) continue;
                    cds = variable.getControlDeps();
                    while ((idx = cds.indexOf(oldOutVarName)) > 0) {
                        cds.set(idx, newOutVarName);
                    }
                }
                for (SameDiffOp sameDiffOp : sd.getOps().values()) {
                    int idx;
                    List<String> controlDeps;
                    List<String> inputsToOp = sameDiffOp.getInputsToOp();
                    if (inputsToOp != null) {
                        int idx2;
                        while ((idx2 = inputsToOp.indexOf(oldOutVarName)) >= 0) {
                            inputsToOp.set(idx2, newOutVarName);
                        }
                    }
                    if ((controlDeps = sameDiffOp.getControlDeps()) == null) continue;
                    while ((idx = controlDeps.indexOf(oldOutVarName)) >= 0) {
                        controlDeps.set(idx, newOutVarName);
                    }
                }
            }
            List<SDVariable> inputs = sg.inputs();
            for (SDVariable v : inputs) {
                Variable var = sd.getVariables().get(v.name());
                if (var.getInputsForOp() == null) continue;
                ArrayList<String> newInputsForOp = new ArrayList<String>(var.getInputsForOp());
                for (String opName : var.getInputsForOp()) {
                    DifferentialFunction df = sd.getOpById(opName);
                    if (!allSubGraphFns.contains(df)) continue;
                    newInputsForOp.remove(opName);
                }
                var.setInputsForOp(newInputsForOp);
            }
            Map<String, SameDiffOp> ops = sd.getOps();
            Map<String, Variable> vars = sd.getVariables();
            for (DifferentialFunction df : sg.allFunctionsInSubgraph()) {
                ops.remove(df.getOwnName());
                SDVariable[] sDVariableArray = df.outputVariables();
                if (sDVariableArray == null) continue;
                for (SDVariable v : sDVariableArray) {
                    vars.remove(v.name());
                }
            }
        }
        return sd;
    }

    public static List<SubGraph> getSubgraphsMatching(SameDiff sd, SubGraphPredicate p) {
        ArrayList<SubGraph> out = new ArrayList<SubGraph>();
        for (DifferentialFunction df : sd.ops()) {
            if (!p.matches(sd, df)) continue;
            SubGraph sg = p.getSubGraph(sd, df);
            out.add(sg);
        }
        return out;
    }
}

