/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.learning;

import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.config.AMSGrad;
import org.nd4j.linalg.ops.transforms.Transforms;

public class AMSGradUpdater
implements GradientUpdater<AMSGrad> {
    private AMSGrad config;
    private INDArray m;
    private INDArray v;
    private INDArray vHat;
    private char gradientReshapeOrder;

    public AMSGradUpdater(AMSGrad config) {
        this.config = config;
    }

    @Override
    public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) {
        if (!viewArray.isRowVector()) {
            throw new IllegalArgumentException("Invalid input: expect row vector input");
        }
        if (initialize) {
            viewArray.assign(0);
        }
        long n = viewArray.length() / 3L;
        this.m = viewArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(0L, n));
        this.v = viewArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(n, 2L * n));
        this.vHat = viewArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(2L * n, 3L * n));
        this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f');
        this.v = Shape.newShapeNoCopy(this.v, gradientShape, gradientOrder == 'f');
        this.vHat = Shape.newShapeNoCopy(this.vHat, gradientShape, gradientOrder == 'f');
        if (this.m == null || this.v == null || this.vHat == null) {
            throw new IllegalStateException("Could not correctly reshape gradient view arrays");
        }
        this.gradientReshapeOrder = gradientOrder;
    }

    @Override
    public void applyUpdater(INDArray gradient, int iteration, int epoch) {
        if (this.m == null || this.v == null || this.vHat == null) {
            throw new IllegalStateException("Updater has not been initialized with view state");
        }
        double beta1 = this.config.getBeta1();
        double beta2 = this.config.getBeta2();
        double learningRate = this.config.getLearningRate(iteration, epoch);
        double epsilon = this.config.getEpsilon();
        INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1);
        this.m.muli(beta1).addi(oneMinusBeta1Grad);
        INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1.0 - beta2);
        this.v.muli(beta2).addi(oneMinusBeta2GradSquared);
        double beta1t = FastMath.pow((double)beta1, (int)(iteration + 1));
        double beta2t = FastMath.pow((double)beta2, (int)(iteration + 1));
        Transforms.max(this.vHat, this.v, false);
        double alphat = learningRate * FastMath.sqrt((double)(1.0 - beta2t)) / (1.0 - beta1t);
        if (Double.isNaN(alphat) || alphat == 0.0) {
            alphat = epsilon;
        }
        Nd4j.getExecutioner().exec(new Sqrt(this.vHat, gradient)).addi(epsilon);
        gradient.rdivi(this.m).muli(alphat);
    }

    @Override
    public AMSGrad getConfig() {
        return this.config;
    }

    public INDArray getM() {
        return this.m;
    }

    public INDArray getV() {
        return this.v;
    }

    public INDArray getVHat() {
        return this.vHat;
    }

    public char getGradientReshapeOrder() {
        return this.gradientReshapeOrder;
    }

    public void setConfig(AMSGrad config) {
        this.config = config;
    }

    public void setM(INDArray m) {
        this.m = m;
    }

    public void setV(INDArray v) {
        this.v = v;
    }

    public void setVHat(INDArray vHat) {
        this.vHat = vHat;
    }

    public void setGradientReshapeOrder(char gradientReshapeOrder) {
        this.gradientReshapeOrder = gradientReshapeOrder;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof AMSGradUpdater)) {
            return false;
        }
        AMSGradUpdater other = (AMSGradUpdater)o;
        if (!other.canEqual(this)) {
            return false;
        }
        AMSGrad this$config = this.getConfig();
        AMSGrad other$config = other.getConfig();
        if (this$config == null ? other$config != null : !((Object)this$config).equals(other$config)) {
            return false;
        }
        INDArray this$m = this.getM();
        INDArray other$m = other.getM();
        if (this$m == null ? other$m != null : !this$m.equals(other$m)) {
            return false;
        }
        INDArray this$v = this.getV();
        INDArray other$v = other.getV();
        if (this$v == null ? other$v != null : !this$v.equals(other$v)) {
            return false;
        }
        INDArray this$vHat = this.getVHat();
        INDArray other$vHat = other.getVHat();
        if (this$vHat == null ? other$vHat != null : !this$vHat.equals(other$vHat)) {
            return false;
        }
        return this.getGradientReshapeOrder() == other.getGradientReshapeOrder();
    }

    protected boolean canEqual(Object other) {
        return other instanceof AMSGradUpdater;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        AMSGrad $config = this.getConfig();
        result = result * 59 + ($config == null ? 43 : ((Object)$config).hashCode());
        INDArray $m = this.getM();
        result = result * 59 + ($m == null ? 43 : $m.hashCode());
        INDArray $v = this.getV();
        result = result * 59 + ($v == null ? 43 : $v.hashCode());
        INDArray $vHat = this.getVHat();
        result = result * 59 + ($vHat == null ? 43 : $vHat.hashCode());
        result = result * 59 + this.getGradientReshapeOrder();
        return result;
    }

    public String toString() {
        return "AMSGradUpdater(config=" + this.getConfig() + ", m=" + this.getM() + ", v=" + this.getV() + ", vHat=" + this.getVHat() + ", gradientReshapeOrder=" + this.getGradientReshapeOrder() + ")";
    }
}

