/*
 * Decompiled with CFR 0.152.
 */
package com.barrybecker4.optimization.strategy.gradient;

import com.barrybecker4.math.MathUtil$;
import com.barrybecker4.math.linear.Vector;
import com.barrybecker4.optimization.optimizee.Optimizee;
import com.barrybecker4.optimization.parameter.Direction$;
import com.barrybecker4.optimization.parameter.NumericParameterArray;
import com.barrybecker4.optimization.parameter.ParameterArray;
import com.barrybecker4.optimization.parameter.ParameterArrayWithFitness;
import com.barrybecker4.optimization.parameter.types.Parameter;
import com.barrybecker4.optimization.strategy.gradient.ImprovementIteration$;
import java.io.Serializable;
import scala.Function1;
import scala.Predef$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.Scala3RunTime$;
import scala.runtime.java8.JFunction1;

public class ImprovementIteration {
    private final ParameterArrayWithFitness params;
    private Vector oldGradient;
    private Vector delta;
    private Vector fitnessDelta;
    private Vector gradient;

    public static Vector $lessinit$greater$default$2() {
        return ImprovementIteration$.MODULE$.$lessinit$greater$default$2();
    }

    public ImprovementIteration(ParameterArrayWithFitness params, Vector oldGradient) {
        this.params = params;
        this.oldGradient = oldGradient;
        this.delta = ((NumericParameterArray)params.pa()).asVector();
        this.fitnessDelta = ((NumericParameterArray)params.pa()).asVector();
        this.gradient = ((NumericParameterArray)params.pa()).asVector();
        if (this.oldGradient() == null) {
            this.oldGradient_$eq(((NumericParameterArray)params.pa()).asVector());
            RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), params.pa().size()).foreach((Function1 & Serializable)i -> this.$init$$$anonfun$1(BoxesRunTime.unboxToInt((Object)i)));
            this.oldGradient_$eq(this.oldGradient().normalize());
        }
    }

    public Vector oldGradient() {
        return this.oldGradient;
    }

    public void oldGradient_$eq(Vector x$1) {
        this.oldGradient = x$1;
    }

    public Vector gradient() {
        return this.gradient;
    }

    public void gradient_$eq(Vector x$1) {
        this.gradient = x$1;
    }

    public double incSumOfSqs(int i, Optimizee optimizee) {
        Parameter p = this.params.pa().get(i);
        this.delta = this.delta.set(i, p.getIncrementForDirection(Direction$.MODULE$.FORWARD()));
        NumericParameterArray nparams = (NumericParameterArray)this.params.pa();
        NumericParameterArray forwardParams = nparams.incrementByEps(i, Direction$.MODULE$.FORWARD());
        double fwdFitnessDelta = this.findFitnessDelta(optimizee, this.params, forwardParams);
        this.fitnessDelta = this.fitnessDelta.set(i, fwdFitnessDelta);
        Predef$.MODULE$.println((Object)new StringBuilder(16).append("fitDelta for ").append(i).append(" = ").append(this.fitnessDelta).toString());
        double d = this.delta.apply(i);
        if (d == 0.0) {
            throw Scala3RunTime$.MODULE$.assertFailed();
        }
        return fwdFitnessDelta * fwdFitnessDelta / (d * d);
    }

    private double findFitnessDelta(Optimizee optimizee, ParameterArrayWithFitness params, ParameterArray testParams) {
        return optimizee.evaluateByComparison() ? optimizee.compareFitness(testParams, params.pa()) : optimizee.evaluateFitness(testParams) - params.fitness();
    }

    public void updateGradient(double jumpSize, double gradLength) {
        double gradLen = gradLength == 0.0 ? MathUtil$.MODULE$.EPS_MEDIUM() : gradLength;
        RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), this.delta.size()).foreach((Function1)(JFunction1.mcVI.sp & Serializable)i -> {
            double denominator = this.delta.apply(i) * gradLen;
            this.gradient_$eq(this.gradient().set(i, -jumpSize * this.fitnessDelta.apply(i) / denominator));
        });
    }

    private final /* synthetic */ Vector $init$$$anonfun$1(int i) {
        return this.oldGradient().set(i, 1.0);
    }
}

