/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.optimize.optimizations;

import java.util.List;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.optimize.OptimizationHelper;
import org.nd4j.autodiff.samediff.optimize.Optimizer;
import org.nd4j.autodiff.samediff.optimize.optimizations.BaseOptimizerSet;
import org.nd4j.autodiff.samediff.optimize.optimizations.OptimizationUtils;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
import org.nd4j.linalg.factory.Nd4j;

public class CuDNNFunctionOptimizations
extends BaseOptimizerSet {
    protected static final boolean isCudaBackend;

    static {
        String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
        isCudaBackend = "CUDA".equalsIgnoreCase(backend);
    }

    public static class CudnnConv2dNCHWtoNHWCConversion
    implements Optimizer {
        @Override
        public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) {
            if (!(op.getOp() instanceof Conv2D)) {
                return false;
            }
            Conv2D c2d = (Conv2D)op.getOp();
            boolean weightsCorrect = false;
            boolean activationsCorrect = c2d.getConfig().isNHWC();
            if (activationsCorrect && weightsCorrect) {
                return false;
            }
            List<String> inputs = op.getInputsToOp();
            String wArgName = inputs.get(1);
            if (!activationsCorrect) {
                String inArgName = inputs.get(0);
                SDVariable in = sd.getVariable(inArgName);
                String newName = in.name() + "_cudnn_nchw_to_nhwc";
                OptimizationUtils.replaceOpInputsWith(sd, in.name(), newName);
                SDVariable nhwc = in.permute(0, 2, 3, 1).rename(newName);
                SDVariable outNhwc = sd.getVariable(op.getOutputsOfOp().get(0));
                String newName2 = outNhwc.name() + "_cudnn_nhwc_to_nchw";
                SDVariable outNchw = outNhwc.permute(0, 3, 1, 2).rename(newName2);
                OptimizationUtils.replaceOpInputsWith(sd, outNhwc.name(), outNchw.name());
                c2d.getConfig().isNHWC(true);
            }
            if (!weightsCorrect) {
                SDVariable w = sd.getVariable(wArgName);
                String newWname = w.name() + "_cudnn_yxio_to_oyxi";
                OptimizationUtils.replaceOpInputsWith(sd, w.name(), newWname);
                SDVariable sDVariable = w.permute(3, 0, 1, 2).rename(newWname);
            }
            return true;
        }
    }
}

