/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.layers.convolutional;

import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class KerasConvolution
extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasConvolution.class);
    protected int numTrainableParams;
    protected boolean hasBias;

    public KerasConvolution(Integer kerasVersion) throws UnsupportedKerasConfigurationException {
        super(kerasVersion);
    }

    public KerasConvolution(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(layerConfig, true);
    }

    public KerasConvolution(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(layerConfig, enforceTrainingConfig);
    }

    @Override
    public int getNumParams() {
        return this.numTrainableParams;
    }

    @Override
    public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
        this.weights = new HashMap();
        if (!weights.containsKey(this.conf.getKERAS_PARAM_NAME_W())) {
            throw new InvalidKerasConfigurationException("Parameter " + this.conf.getKERAS_PARAM_NAME_W() + " does not exist in weights");
        }
        INDArray kerasParamValue = weights.get(this.conf.getKERAS_PARAM_NAME_W());
        INDArray paramValue = this.getConvParameterValues(kerasParamValue);
        this.weights.put("W", paramValue);
        if (this.hasBias) {
            if (weights.containsKey(this.conf.getKERAS_PARAM_NAME_B())) {
                this.weights.put("b", weights.get(this.conf.getKERAS_PARAM_NAME_B()));
            } else {
                throw new InvalidKerasConfigurationException("Parameter " + this.conf.getKERAS_PARAM_NAME_B() + " does not exist in weights");
            }
        }
        KerasLayerUtils.removeDefaultWeights(weights, this.conf);
    }

    public INDArray getConvParameterValues(INDArray kerasParamValue) throws InvalidKerasConfigurationException {
        INDArray paramValue;
        switch (this.getDimOrder()) {
            case TENSORFLOW: {
                if (kerasParamValue.rank() == 5) {
                    paramValue = kerasParamValue.permute(new int[]{4, 3, 0, 1, 2});
                    break;
                }
                paramValue = kerasParamValue.permute(new int[]{3, 2, 0, 1});
                break;
            }
            case THEANO: {
                paramValue = kerasParamValue.dup();
                int i = 0;
                while ((long)i < paramValue.tensorsAlongDimension(new int[]{2, 3})) {
                    INDArray copyFilter = paramValue.tensorAlongDimension((long)i, new int[]{2, 3}).dup();
                    double[] flattenedFilter = copyFilter.ravel().data().asDouble();
                    ArrayUtils.reverse((double[])flattenedFilter);
                    INDArray newFilter = Nd4j.create((double[])flattenedFilter, (long[])copyFilter.shape());
                    INDArray inPlaceFilter = paramValue.tensorAlongDimension((long)i, new int[]{2, 3});
                    inPlaceFilter.muli((Number)0).addi(newFilter);
                    ++i;
                }
                break;
            }
            default: {
                throw new InvalidKerasConfigurationException("Unknown keras backend " + (Object)((Object)this.getDimOrder()));
            }
        }
        return paramValue;
    }

    public int getNumTrainableParams() {
        return this.numTrainableParams;
    }

    public boolean isHasBias() {
        return this.hasBias;
    }

    public void setNumTrainableParams(int numTrainableParams) {
        this.numTrainableParams = numTrainableParams;
    }

    public void setHasBias(boolean hasBias) {
        this.hasBias = hasBias;
    }

    public String toString() {
        return "KerasConvolution(numTrainableParams=" + this.getNumTrainableParams() + ", hasBias=" + this.isHasBias() + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof KerasConvolution)) {
            return false;
        }
        KerasConvolution other = (KerasConvolution)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getNumTrainableParams() != other.getNumTrainableParams()) {
            return false;
        }
        return this.isHasBias() == other.isHasBias();
    }

    protected boolean canEqual(Object other) {
        return other instanceof KerasConvolution;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getNumTrainableParams();
        result = result * 59 + (this.isHasBias() ? 79 : 97);
        return result;
    }
}

