/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold;

import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithmReducer;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AdaptiveThresholdAlgorithm
implements ThresholdAlgorithm {
    private static final Logger log = LoggerFactory.getLogger(AdaptiveThresholdAlgorithm.class);
    public static final double DEFAULT_INITIAL_THRESHOLD = 1.0E-4;
    public static final double DEFAULT_MIN_SPARSITY_TARGET = 1.0E-4;
    public static final double DEFAULT_MAX_SPARSITY_TARGET = 0.01;
    public static final double DEFAULT_DECAY_RATE = Math.pow(0.5, 0.05);
    private final double initialThreshold;
    private final double minTargetSparsity;
    private final double maxTargetSparsity;
    private final double decayRate;
    private double lastThreshold = Double.NaN;
    private double lastSparsity = Double.NaN;

    public AdaptiveThresholdAlgorithm() {
        this(1.0E-4, 1.0E-4, 0.01, DEFAULT_DECAY_RATE);
    }

    public AdaptiveThresholdAlgorithm(double initialThreshold) {
        this(initialThreshold, 1.0E-4, 0.01, DEFAULT_DECAY_RATE);
    }

    public AdaptiveThresholdAlgorithm(double initialThreshold, double minTargetSparsity, double maxTargetSparsity, double decayRate) {
        Preconditions.checkArgument((initialThreshold > 0.0 ? 1 : 0) != 0, (String)"Initial threshold must be positive. Got: %s", (double)initialThreshold);
        Preconditions.checkArgument((minTargetSparsity > 0.0 && maxTargetSparsity > 0.0 ? 1 : 0) != 0, (String)"Minimum and maximum target sparsities must be > 0. Got minTargetSparsity=%s, maxTargetSparsity=%s", (double)minTargetSparsity, (double)maxTargetSparsity);
        Preconditions.checkArgument((minTargetSparsity <= maxTargetSparsity ? 1 : 0) != 0, (String)"Min target sparsity must be less than or equal to max target sparsity. Got minTargetSparsity=%s, maxTargetSparsity=%s", (double)minTargetSparsity, (double)maxTargetSparsity);
        Preconditions.checkArgument((decayRate >= 0.5 && decayRate < 1.0 ? 1 : 0) != 0, (String)"Decay rate must be a number in range 0.5 (inclusive) to 1.0 (exclusive). Usually decay rate is in range 0.95 to 0.999. Got decay rate: %s", (double)decayRate);
        this.initialThreshold = initialThreshold;
        this.minTargetSparsity = minTargetSparsity;
        this.maxTargetSparsity = maxTargetSparsity;
        this.decayRate = decayRate;
    }

    @Override
    public double calculateThreshold(int iteration, int epoch, Double lastThreshold, Boolean lastWasDense, Double lastSparsityRatio, INDArray updatesPlusResidual) {
        double prevSparsity;
        double adaptFromThreshold;
        if (lastThreshold == null && Double.isNaN(this.lastThreshold)) {
            this.lastThreshold = this.initialThreshold;
            return this.initialThreshold;
        }
        double d = adaptFromThreshold = lastThreshold != null ? lastThreshold : this.lastThreshold;
        if (lastSparsityRatio != null) {
            prevSparsity = lastSparsityRatio;
        } else if (lastWasDense != null && lastWasDense.booleanValue()) {
            prevSparsity = 0.0625;
        } else if (!Double.isNaN(this.lastSparsity)) {
            prevSparsity = this.lastSparsity;
        } else {
            throw new IllegalStateException("Unexpected state: not first iteration but no last sparsity value is available: iteration=" + iteration + ", epoch=" + epoch + ", lastThreshold=" + lastThreshold + ", lastWasDense=" + lastWasDense + ", lastSparsityRatio=" + lastSparsityRatio + ", this.lastSparsity=" + this.lastSparsity);
        }
        this.lastSparsity = prevSparsity;
        if (prevSparsity >= this.minTargetSparsity && prevSparsity <= this.maxTargetSparsity) {
            if (log.isDebugEnabled()) {
                log.debug("AdaptiveThresholdAlgorithm: iter {} epoch {}: prev sparsity {}, keeping existing threshold of {}", new Object[]{iteration, epoch, prevSparsity, adaptFromThreshold});
            }
            return adaptFromThreshold;
        }
        if (prevSparsity < this.minTargetSparsity) {
            double retThreshold;
            this.lastThreshold = retThreshold = this.decayRate * adaptFromThreshold;
            if (log.isDebugEnabled()) {
                log.debug("AdaptiveThresholdAlgorithm: iter {} epoch {}: prev sparsity {} < min sparsity {}, reducing threshold from {} to  {}", new Object[]{iteration, epoch, prevSparsity, this.minTargetSparsity, adaptFromThreshold, retThreshold});
            }
            return retThreshold;
        }
        if (prevSparsity > this.maxTargetSparsity) {
            double retThreshold;
            this.lastThreshold = retThreshold = 1.0 / this.decayRate * adaptFromThreshold;
            if (log.isDebugEnabled()) {
                log.debug("AdaptiveThresholdAlgorithm: iter {} epoch {}: prev sparsity {} > max sparsity {}, increasing threshold from {} to  {}", new Object[]{iteration, epoch, prevSparsity, this.maxTargetSparsity, adaptFromThreshold, retThreshold});
            }
            return retThreshold;
        }
        throw new IllegalStateException("Invalid previous sparsity value: " + prevSparsity);
    }

    @Override
    public ThresholdAlgorithmReducer newReducer() {
        return new Reducer(this.initialThreshold, this.minTargetSparsity, this.maxTargetSparsity, this.decayRate);
    }

    @Override
    public AdaptiveThresholdAlgorithm clone() {
        AdaptiveThresholdAlgorithm ret = new AdaptiveThresholdAlgorithm(this.initialThreshold, this.minTargetSparsity, this.maxTargetSparsity, this.decayRate);
        ret.lastThreshold = this.lastThreshold;
        ret.lastSparsity = this.lastSparsity;
        return ret;
    }

    public String toString() {
        String s = "AdaptiveThresholdAlgorithm(initialThreshold=" + this.initialThreshold + ",minTargetSparsity=" + this.minTargetSparsity + ",maxTargetSparsity=" + this.maxTargetSparsity + ",decayRate=" + this.decayRate;
        if (Double.isNaN(this.lastThreshold)) {
            return s + ")";
        }
        return s + ",lastThreshold=" + this.lastThreshold + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof AdaptiveThresholdAlgorithm)) {
            return false;
        }
        AdaptiveThresholdAlgorithm other = (AdaptiveThresholdAlgorithm)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (Double.compare(this.initialThreshold, other.initialThreshold) != 0) {
            return false;
        }
        if (Double.compare(this.minTargetSparsity, other.minTargetSparsity) != 0) {
            return false;
        }
        if (Double.compare(this.maxTargetSparsity, other.maxTargetSparsity) != 0) {
            return false;
        }
        return Double.compare(this.decayRate, other.decayRate) == 0;
    }

    protected boolean canEqual(Object other) {
        return other instanceof AdaptiveThresholdAlgorithm;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $initialThreshold = Double.doubleToLongBits(this.initialThreshold);
        result = result * 59 + (int)($initialThreshold >>> 32 ^ $initialThreshold);
        long $minTargetSparsity = Double.doubleToLongBits(this.minTargetSparsity);
        result = result * 59 + (int)($minTargetSparsity >>> 32 ^ $minTargetSparsity);
        long $maxTargetSparsity = Double.doubleToLongBits(this.maxTargetSparsity);
        result = result * 59 + (int)($maxTargetSparsity >>> 32 ^ $maxTargetSparsity);
        long $decayRate = Double.doubleToLongBits(this.decayRate);
        result = result * 59 + (int)($decayRate >>> 32 ^ $decayRate);
        return result;
    }

    public double getLastThreshold() {
        return this.lastThreshold;
    }

    public double getLastSparsity() {
        return this.lastSparsity;
    }

    private static class Reducer
    implements ThresholdAlgorithmReducer {
        private final double initialThreshold;
        private final double minTargetSparsity;
        private final double maxTargetSparsity;
        private final double decayRate;
        private double lastThresholdSum;
        private double lastSparsitySum;
        private int count;

        private Reducer(double initialThreshold, double minTargetSparsity, double maxTargetSparsity, double decayRate) {
            this.initialThreshold = initialThreshold;
            this.minTargetSparsity = minTargetSparsity;
            this.maxTargetSparsity = maxTargetSparsity;
            this.decayRate = decayRate;
        }

        @Override
        public void add(ThresholdAlgorithm instance) {
            AdaptiveThresholdAlgorithm a = (AdaptiveThresholdAlgorithm)instance;
            if (a == null || Double.isNaN(a.lastThreshold)) {
                return;
            }
            this.lastThresholdSum += a.lastThreshold;
            this.lastSparsitySum += a.lastSparsity;
            ++this.count;
        }

        @Override
        public ThresholdAlgorithmReducer merge(ThresholdAlgorithmReducer other) {
            Reducer r = (Reducer)other;
            this.lastThresholdSum += r.lastThresholdSum;
            this.lastSparsitySum += r.lastSparsitySum;
            this.count += r.count;
            return this;
        }

        @Override
        public ThresholdAlgorithm getFinalResult() {
            AdaptiveThresholdAlgorithm ret = new AdaptiveThresholdAlgorithm(this.initialThreshold, this.minTargetSparsity, this.maxTargetSparsity, this.decayRate);
            if (this.count > 0) {
                ret.lastThreshold = this.lastThresholdSum / (double)this.count;
                ret.lastSparsity = this.lastSparsitySum / (double)this.count;
            }
            return ret;
        }
    }
}

