/*
 * Decompiled with CFR 0.152.
 */
package breeze.optimize;

import breeze.linalg.ImmutableNumericOps;
import breeze.linalg.NumericOps;
import breeze.linalg.norm$;
import breeze.linalg.package$;
import breeze.math.MutableInnerProductModule;
import breeze.optimize.BatchDiffFunction;
import breeze.optimize.FirstOrderMinimizer;
import breeze.optimize.FirstOrderMinimizer$;
import breeze.optimize.StochasticAveragedGradient$;
import breeze.optimize.StochasticAveragedGradient$History$;
import breeze.stats.distributions.Rand$;
import java.io.Serializable;
import scala.Conversion;
import scala.Predef;
import scala.Predef$;
import scala.Product;
import scala.Tuple2;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.Seq;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;
import scala.runtime.Statics;

public class StochasticAveragedGradient<T>
extends FirstOrderMinimizer<T, BatchDiffFunction<T>> {
    private final double initialStepSize;
    private final int tuneStepFrequency;
    private final double l2Regularization;
    private final MutableInnerProductModule<T, Object> vs;
    public final StochasticAveragedGradient$History$ History$lzy1;

    public static <T> int $lessinit$greater$default$1() {
        return StochasticAveragedGradient$.MODULE$.$lessinit$greater$default$1();
    }

    public static <T> double $lessinit$greater$default$2() {
        return StochasticAveragedGradient$.MODULE$.$lessinit$greater$default$2();
    }

    public static <T> int $lessinit$greater$default$3() {
        return StochasticAveragedGradient$.MODULE$.$lessinit$greater$default$3();
    }

    public static <T> double $lessinit$greater$default$4() {
        return StochasticAveragedGradient$.MODULE$.$lessinit$greater$default$4();
    }

    public StochasticAveragedGradient(int maxIter, double initialStepSize, int tuneStepFrequency, double l2Regularization, MutableInnerProductModule<T, Object> vs) {
        this.initialStepSize = initialStepSize;
        this.tuneStepFrequency = tuneStepFrequency;
        this.l2Regularization = l2Regularization;
        this.vs = vs;
        super(maxIter, FirstOrderMinimizer$.MODULE$.$lessinit$greater$default$2(), FirstOrderMinimizer$.MODULE$.$lessinit$greater$default$3(), FirstOrderMinimizer$.MODULE$.$lessinit$greater$default$4(), vs);
        this.History$lzy1 = new StochasticAveragedGradient$History$(this);
    }

    public final StochasticAveragedGradient$History$ History() {
        return this.History$lzy1;
    }

    public History initialHistory(BatchDiffFunction<T> f, T init) {
        Object zero = this.vs.zeroLike().apply(init);
        return this.History().apply(this.initialStepSize, f.fullRange(), this.vs.zeroLike().apply(init), (IndexedSeq)scala.package$.MODULE$.IndexedSeq().fill(f.fullRange().length(), () -> StochasticAveragedGradient.initialHistory$$anonfun$1(zero)), 0);
    }

    @Override
    public T chooseDescentDirection(FirstOrderMinimizer.State<T, Object, History> state, BatchDiffFunction<T> f) {
        return (T)((ImmutableNumericOps)((Conversion)this.vs.hasOps()).apply(state.history().currentSum())).$times(BoxesRunTime.boxToDouble((double)(-1.0 / (double)f.fullRange().size())), this.vs.mulVS_M());
    }

    @Override
    public double determineStepSize(FirstOrderMinimizer.State<T, Object, History> state, BatchDiffFunction<T> f, T direction) {
        return state.history().stepSize();
    }

    public Tuple2<Object, T> calculateObjective(BatchDiffFunction<T> f, T x, History history) {
        return f.calculate(x, (IndexedSeq<Object>)((IndexedSeq)scala.package$.MODULE$.IndexedSeq().apply((Seq)ScalaRunTime$.MODULE$.wrapIntArray(new int[]{history.nextPos()}))));
    }

    @Override
    public Tuple2<Object, T> adjust(T newX, T newGrad, double newVal) {
        double av = newVal + BoxesRunTime.unboxToDouble(((ImmutableNumericOps)((Conversion)this.vs.hasOps()).apply(newX)).dot(newX, this.vs.dotVV())) * this.l2Regularization / 2.0;
        Object ag = ((NumericOps)((Conversion)this.vs.hasOps()).apply(newGrad)).$plus(((ImmutableNumericOps)((Conversion)this.vs.hasOps()).apply(newX)).$times(BoxesRunTime.boxToDouble((double)this.l2Regularization), this.vs.mulVS_M()), this.vs.addVV());
        Double d = (Double)Predef$.MODULE$.ArrowAssoc((Object)BoxesRunTime.boxToDouble((double)av));
        return Predef.ArrowAssoc$.MODULE$.$minus$greater$extension((Object)d, ag);
    }

    @Override
    public T takeStep(FirstOrderMinimizer.State<T, Object, History> state, T dir, double stepSize) {
        Object newx = ((ImmutableNumericOps)((Conversion)this.vs.hasOps()).apply(state.x())).$times(BoxesRunTime.boxToDouble((double)(1.0 - stepSize * this.l2Regularization)), this.vs.mulVS_M());
        package$.MODULE$.axpy(BoxesRunTime.boxToDouble((double)stepSize), dir, newx, this.vs.scaleAddVV());
        return (T)newx;
    }

    public History updateHistory(T newX, T newGrad, double newVal, BatchDiffFunction<T> f, FirstOrderMinimizer.State<T, Object, History> oldState) {
        double d;
        Object d2 = ((ImmutableNumericOps)((Conversion)this.vs.hasOps()).apply(oldState.history().currentSum())).$minus(oldState.history().previousGradients().apply(oldState.history().nextPos()), this.vs.subVV());
        if (this.tuneStepFrequency > 0 && oldState.iter() % this.tuneStepFrequency == 0) {
            Object xdiff = ((ImmutableNumericOps)((Conversion)this.vs.hasOps()).apply(newX)).$minus(oldState.x(), this.vs.subVV());
            d = f.valueAt(newX, (IndexedSeq<Object>)((IndexedSeq)scala.package$.MODULE$.IndexedSeq().apply((Seq)ScalaRunTime$.MODULE$.wrapIntArray(new int[]{oldState.history().nextPos()})))) + this.l2Regularization / (double)2 * BoxesRunTime.unboxToDouble((Object)norm$.MODULE$.apply(newX, this.vs.normImpl())) - oldState.adjustedValue() > BoxesRunTime.unboxToDouble(((ImmutableNumericOps)((Conversion)this.vs.hasOps()).apply(oldState.adjustedGradient())).dot(xdiff, this.vs.dotVV())) + BoxesRunTime.unboxToDouble(((ImmutableNumericOps)((Conversion)this.vs.hasOps()).apply(xdiff)).dot(xdiff, this.vs.dotVV())) / ((double)2 * oldState.history().stepSize()) ? oldState.history().stepSize() / (double)2 : oldState.history().stepSize() * 1.5;
        } else {
            d = oldState.history().stepSize();
        }
        double newStepSize = d;
        ((NumericOps)((Conversion)this.vs.hasOps()).apply(d2)).$plus$eq(newGrad, this.vs.addIntoVV());
        return this.History().apply(newStepSize, oldState.history().range(), d2, (IndexedSeq)oldState.history().previousGradients().updated(oldState.history().nextPos(), newGrad), oldState.iter() < oldState.history().previousGradients().length() - 1 ? oldState.iter() + 1 : BoxesRunTime.unboxToInt((Object)Rand$.MODULE$.choose(oldState.history().range()).draw()));
    }

    private static final Object initialHistory$$anonfun$1(Object zero$1) {
        return zero$1;
    }

    public class History
    implements Product,
    Serializable {
        private final double stepSize;
        private final IndexedSeq range;
        private final Object currentSum;
        private final IndexedSeq previousGradients;
        private final int nextPos;
        private final /* synthetic */ StochasticAveragedGradient $outer;

        public History(StochasticAveragedGradient $outer, double stepSize, IndexedSeq<Object> range, T currentSum, IndexedSeq<T> previousGradients, int nextPos) {
            this.stepSize = stepSize;
            this.range = range;
            this.currentSum = currentSum;
            this.previousGradients = previousGradients;
            this.nextPos = nextPos;
            if ($outer == null) {
                throw new NullPointerException();
            }
            this.$outer = $outer;
        }

        public int hashCode() {
            int n = -889275714;
            n = Statics.mix((int)n, (int)this.productPrefix().hashCode());
            n = Statics.mix((int)n, (int)Statics.doubleHash((double)this.stepSize()));
            n = Statics.mix((int)n, (int)Statics.anyHash(this.range()));
            n = Statics.mix((int)n, (int)Statics.anyHash(this.currentSum()));
            n = Statics.mix((int)n, (int)Statics.anyHash(this.previousGradients()));
            n = Statics.mix((int)n, (int)this.nextPos());
            return Statics.finalizeHash((int)n, (int)5);
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        public boolean equals(Object x$0) {
            if (this == x$0) return true;
            Object object = x$0;
            if (!(object instanceof History)) return false;
            if (((History)object).breeze$optimize$StochasticAveragedGradient$History$$$outer() != this.$outer) return false;
            History history = (History)object;
            if (this.stepSize() != history.stepSize()) return false;
            if (this.nextPos() != history.nextPos()) return false;
            IndexedSeq<Object> indexedSeq = this.range();
            IndexedSeq<Object> indexedSeq2 = history.range();
            if (indexedSeq == null) {
                if (indexedSeq2 != null) {
                    return false;
                }
            } else if (!indexedSeq.equals(indexedSeq2)) return false;
            if (!BoxesRunTime.equals(this.currentSum(), history.currentSum())) return false;
            IndexedSeq indexedSeq3 = this.previousGradients();
            IndexedSeq indexedSeq4 = history.previousGradients();
            if (indexedSeq3 == null) {
                if (indexedSeq4 != null) {
                    return false;
                }
            } else if (!indexedSeq3.equals(indexedSeq4)) return false;
            if (!history.canEqual(this)) return false;
            return true;
        }

        public String toString() {
            return ScalaRunTime$.MODULE$._toString((Product)this);
        }

        public boolean canEqual(Object that) {
            return that instanceof History;
        }

        public int productArity() {
            return 5;
        }

        public String productPrefix() {
            return "History";
        }

        public Object productElement(int n) {
            Object object;
            int n2 = n;
            switch (n2) {
                case 0: {
                    object = BoxesRunTime.boxToDouble((double)this._1());
                    break;
                }
                case 1: {
                    object = this._2();
                    break;
                }
                case 2: {
                    object = this._3();
                    break;
                }
                case 3: {
                    object = this._4();
                    break;
                }
                case 4: {
                    object = BoxesRunTime.boxToInteger((int)this._5());
                    break;
                }
                default: {
                    throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger((int)n).toString());
                }
            }
            return object;
        }

        public String productElementName(int n) {
            String string;
            int n2 = n;
            switch (n2) {
                case 0: {
                    string = "stepSize";
                    break;
                }
                case 1: {
                    string = "range";
                    break;
                }
                case 2: {
                    string = "currentSum";
                    break;
                }
                case 3: {
                    string = "previousGradients";
                    break;
                }
                case 4: {
                    string = "nextPos";
                    break;
                }
                default: {
                    throw new IndexOutOfBoundsException(BoxesRunTime.boxToInteger((int)n).toString());
                }
            }
            return string;
        }

        public double stepSize() {
            return this.stepSize;
        }

        public IndexedSeq<Object> range() {
            return this.range;
        }

        public T currentSum() {
            return this.currentSum;
        }

        public IndexedSeq<T> previousGradients() {
            return this.previousGradients;
        }

        public int nextPos() {
            return this.nextPos;
        }

        public History copy(double stepSize, IndexedSeq<Object> range, T currentSum, IndexedSeq<T> previousGradients, int nextPos) {
            return new History(this.$outer, stepSize, range, currentSum, previousGradients, nextPos);
        }

        public double copy$default$1() {
            return this.stepSize();
        }

        public IndexedSeq<Object> copy$default$2() {
            return this.range();
        }

        public T copy$default$3() {
            return this.currentSum();
        }

        public IndexedSeq<T> copy$default$4() {
            return this.previousGradients();
        }

        public int copy$default$5() {
            return this.nextPos();
        }

        public double _1() {
            return this.stepSize();
        }

        public IndexedSeq<Object> _2() {
            return this.range();
        }

        public T _3() {
            return this.currentSum();
        }

        public IndexedSeq<T> _4() {
            return this.previousGradients();
        }

        public int _5() {
            return this.nextPos();
        }

        public final /* synthetic */ StochasticAveragedGradient breeze$optimize$StochasticAveragedGradient$History$$$outer() {
            return this.$outer;
        }
    }
}

