/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.activations.impl;

import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;

public class ActivationThresholdedReLU
extends BaseActivationFunction {
    public static final double DEFAULT_THETA = 1.0;
    private double theta;

    public ActivationThresholdedReLU() {
        this(1.0);
    }

    public ActivationThresholdedReLU(double theta) {
        this.theta = theta;
    }

    @Override
    public INDArray getActivation(INDArray in, boolean training) {
        DynamicCustomOp threshRelu = DynamicCustomOp.builder("thresholdedrelu").addOutputs(in).addInputs(in).addFloatingPointArguments(this.theta).build();
        Nd4j.getExecutioner().execAndReturn(threshRelu);
        return in;
    }

    @Override
    public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
        this.assertShape(in, epsilon);
        DynamicCustomOp threshReluBp = DynamicCustomOp.builder("thresholdedrelu_bp").addInputs(in, epsilon).addOutputs(in).addFloatingPointArguments(this.theta).build();
        Nd4j.getExecutioner().execAndReturn(threshReluBp);
        return new Pair<INDArray, Object>(in, null);
    }

    public String toString() {
        return "thresholdedrelu(theta=" + this.theta + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ActivationThresholdedReLU)) {
            return false;
        }
        ActivationThresholdedReLU other = (ActivationThresholdedReLU)o;
        if (!other.canEqual(this)) {
            return false;
        }
        return Double.compare(this.getTheta(), other.getTheta()) == 0;
    }

    protected boolean canEqual(Object other) {
        return other instanceof ActivationThresholdedReLU;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $theta = Double.doubleToLongBits(this.getTheta());
        result = result * 59 + (int)($theta >>> 32 ^ $theta);
        return result;
    }

    public double getTheta() {
        return this.theta;
    }
}

