/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.learning;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.ops.transforms.Transforms;

public class RmsPropUpdater
implements GradientUpdater<RmsProp> {
    private final RmsProp config;
    private INDArray lastGradient;
    private char gradientReshapeOrder;

    public RmsPropUpdater(RmsProp config) {
        this.config = config;
    }

    @Override
    public void setStateViewArray(INDArray viewArray, int[] gradientShape, char gradientOrder, boolean initialize) {
        if (!viewArray.isRowVector()) {
            throw new IllegalArgumentException("Invalid input: expect row vector input");
        }
        if (initialize) {
            viewArray.assign(this.config.getEpsilon());
        }
        this.lastGradient = viewArray;
        this.lastGradient = Shape.newShapeNoCopy(this.lastGradient, gradientShape, gradientOrder == 'f');
        if (this.lastGradient == null) {
            throw new IllegalStateException("Could not correctly reshape gradient view array");
        }
        this.gradientReshapeOrder = gradientOrder;
    }

    @Override
    public void applyUpdater(INDArray gradient, int iteration) {
        if (this.lastGradient == null) {
            throw new IllegalStateException("Updater has not been initialized with view state");
        }
        double learningRate = this.config.getLearningRate();
        double rmsDecay = this.config.getRmsDecay();
        double epsilon = this.config.getEpsilon();
        this.lastGradient.muli(rmsDecay).addi(gradient.mul(gradient).muli(1.0 - rmsDecay));
        gradient.muli(learningRate).divi(Transforms.sqrt(this.lastGradient.dup(this.gradientReshapeOrder), false).addi(epsilon));
    }

    @Override
    public RmsProp getConfig() {
        return this.config;
    }

    public INDArray getLastGradient() {
        return this.lastGradient;
    }

    public char getGradientReshapeOrder() {
        return this.gradientReshapeOrder;
    }

    public void setLastGradient(INDArray lastGradient) {
        this.lastGradient = lastGradient;
    }

    public void setGradientReshapeOrder(char gradientReshapeOrder) {
        this.gradientReshapeOrder = gradientReshapeOrder;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof RmsPropUpdater)) {
            return false;
        }
        RmsPropUpdater other = (RmsPropUpdater)o;
        if (!other.canEqual(this)) {
            return false;
        }
        RmsProp this$config = this.getConfig();
        RmsProp other$config = other.getConfig();
        if (this$config == null ? other$config != null : !((Object)this$config).equals(other$config)) {
            return false;
        }
        INDArray this$lastGradient = this.getLastGradient();
        INDArray other$lastGradient = other.getLastGradient();
        if (this$lastGradient == null ? other$lastGradient != null : !this$lastGradient.equals(other$lastGradient)) {
            return false;
        }
        return this.getGradientReshapeOrder() == other.getGradientReshapeOrder();
    }

    protected boolean canEqual(Object other) {
        return other instanceof RmsPropUpdater;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        RmsProp $config = this.getConfig();
        result = result * 59 + ($config == null ? 43 : ((Object)$config).hashCode());
        INDArray $lastGradient = this.getLastGradient();
        result = result * 59 + ($lastGradient == null ? 43 : $lastGradient.hashCode());
        result = result * 59 + this.getGradientReshapeOrder();
        return result;
    }

    public String toString() {
        return "RmsPropUpdater(config=" + this.getConfig() + ", lastGradient=" + this.getLastGradient() + ", gradientReshapeOrder=" + this.getGradientReshapeOrder() + ")";
    }
}

