/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.optimizing.graddesc;

import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.core.optimizing.IGradientBasedOptimizer;
import ai.libs.jaicore.ml.core.optimizing.IGradientDescendableFunction;
import ai.libs.jaicore.ml.core.optimizing.IGradientFunction;
import ai.libs.jaicore.ml.core.optimizing.graddesc.GradientDescentOptimizerConfig;
import java.util.Map;
import org.aeonbits.owner.ConfigFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GradientDescentOptimizer
implements IGradientBasedOptimizer {
    private double learningRate;
    private final double gradientThreshold;
    private final int maxIterations;
    private static final Logger log = LoggerFactory.getLogger(GradientDescentOptimizer.class);

    public GradientDescentOptimizer(GradientDescentOptimizerConfig config) {
        this.learningRate = config.learningRate();
        this.gradientThreshold = config.gradientThreshold();
        this.maxIterations = config.maxIterations();
    }

    public GradientDescentOptimizer() {
        this((GradientDescentOptimizerConfig)ConfigFactory.create(GradientDescentOptimizerConfig.class, (Map[])new Map[0]));
    }

    @Override
    public Vector optimize(IGradientDescendableFunction descendableFunction, IGradientFunction gradient, Vector initialGuess) {
        Vector gradients;
        int iterations = 0;
        do {
            gradients = gradient.apply(initialGuess);
            this.updatePredictions(initialGuess, gradients);
            log.warn("iteration {}:\n weights \t{} \n gradients \t{}", new Object[]{++iterations, initialGuess, gradients});
        } while (!this.allGradientsAreBelowThreshold(gradients) && iterations < this.maxIterations);
        log.warn("Gradient descent based optimization took {} iterations.", (Object)iterations);
        return initialGuess;
    }

    private boolean allGradientsAreBelowThreshold(Vector gradients) {
        return gradients.stream().allMatch(grad -> Math.abs(grad) < this.gradientThreshold || !Double.isFinite(grad));
    }

    private void updatePredictions(Vector initialGuess, Vector gradients) {
        for (int i = 0; i < initialGuess.length(); ++i) {
            double weight = initialGuess.getValue(i);
            double gradient = gradients.getValue(i);
            if (Math.abs(gradient) < this.gradientThreshold) continue;
            initialGuess.setValue(i, weight += (gradient *= -1.0) * this.learningRate);
        }
    }
}

