/*
 * Decompiled with CFR 0.152.
 */
package eva2.optimization.strategies;

import eva2.optimization.individuals.AbstractEAIndividual;
import eva2.optimization.individuals.InterfaceDataTypeDouble;
import eva2.optimization.population.InterfaceSolutionSet;
import eva2.optimization.population.Population;
import eva2.optimization.population.SolutionSet;
import eva2.optimization.strategies.AbstractOptimizer;
import eva2.problems.F1Problem;
import eva2.problems.InterfaceFirstOrderDerivableProblem;
import eva2.tools.EVAERROR;
import eva2.tools.ReflectPackage;
import eva2.util.annotation.Description;
import java.io.Serializable;

@Description(value="Gradient Descent can be applied to derivable functions (InterfaceFirstOrderDerivableProblem).")
public class GradientDescentAlgorithm
extends AbstractOptimizer
implements Serializable {
    InterfaceDataTypeDouble bestDataTypeDouble;
    InterfaceDataTypeDouble testDataTypeDouble;
    private int iterations = 1;
    private double wDecreaseStepSize = 0.5;
    private double wIncreaseStepSize = 1.1;
    boolean recovery = false;
    private int recoverylocksteps = 5;
    private double recoverythreshold = 100000.0;
    boolean localStepSizeAdaption = true;
    boolean globalStepSizeAdaption = false;
    private double globalinitstepsize = 1.0;
    double globalmaxstepsize = 3.0;
    double globalminstepsize = 1.0E-10;
    boolean manhattan = false;
    double localmaxstepsize = 10.0;
    double localminstepsize = 1.0E-10;
    private boolean momentumterm = false;
    public double maximumabsolutechange = 0.2;
    private static final String lockKey = "gdaLockDataKey";
    private static final String lastFitnessKey = "gdaLastFitDataKey";
    private static final String stepSizeKey = "gdaStepSizeDataKey";
    private static final String wStepSizeKey = "gdaWStepSizeDataKey";
    private static final String gradientKey = "gdaGradientDataKey";
    private static final String changesKey = "gdaChangesDataKey";
    private static final String oldParamsKey = "gdaOldParamsDataKey";
    private double momentumweigth = 0.1;

    @Override
    public void initializeByPopulation(Population pop, boolean reset) {
        this.setPopulation((Population)pop.clone());
        if (reset) {
            this.getPopulation().initialize();
            this.optimizationProblem.evaluate(this.getPopulation());
            this.firePropertyChangedEvent("NextGenerationPerformed");
        }
    }

    public GradientDescentAlgorithm() {
        this.population = new Population();
        this.population.setTargetSize(1);
    }

    public GradientDescentAlgorithm(double minStepSize, double maxStepSize, double maxAbsoluteChange) {
        this.globalStepSizeAdaption = false;
        this.localStepSizeAdaption = true;
        this.localminstepsize = minStepSize;
        this.globalminstepsize = minStepSize;
        this.localmaxstepsize = maxStepSize;
        this.globalmaxstepsize = maxStepSize;
        this.maximumabsolutechange = maxAbsoluteChange;
    }

    @Override
    public Object clone() {
        throw new UnsupportedOperationException("Method clone() not yet implemented.");
    }

    @Override
    public String getName() {
        return "GradientDescentAlgorithm";
    }

    @Override
    public void initialize() {
        this.optimizationProblem.initializePopulation(this.population);
        this.optimizationProblem.evaluate(this.population);
    }

    public double signum(double val) {
        return val < 0.0 ? -1.0 : 1.0;
    }

    @Override
    public void optimize() {
        AbstractEAIndividual indy;
        int i;
        for (i = 0; i < this.population.size(); ++i) {
            int li;
            indy = (AbstractEAIndividual)this.population.get(i);
            if (indy.hasData(gradientKey)) continue;
            int[] lock = new int[((InterfaceDataTypeDouble)((Object)indy)).getDoubleData().length];
            double[] wstepsize = new double[((InterfaceDataTypeDouble)((Object)indy)).getDoubleData().length];
            for (li = 0; li < lock.length; ++li) {
                lock[li] = 0;
            }
            for (li = 0; li < lock.length; ++li) {
                wstepsize[li] = 1.0;
            }
            double fitness = 0.0;
            indy.putData(lockKey, lock);
            indy.putData(lastFitnessKey, fitness);
            indy.putData(stepSizeKey, this.globalinitstepsize);
            indy.putData(wStepSizeKey, wstepsize);
        }
        for (i = 0; i < this.population.size(); ++i) {
            indy = (AbstractEAIndividual)this.population.get(i);
            double[][] range = ((InterfaceDataTypeDouble)((Object)indy)).getDoubleRange();
            double[] params = ((InterfaceDataTypeDouble)((Object)indy)).getDoubleData();
            indy.putData(oldParamsKey, params);
            int[] lock = (int[])indy.getData(lockKey);
            double indystepsize = (Double)indy.getData(stepSizeKey);
            if (this.optimizationProblem instanceof InterfaceFirstOrderDerivableProblem && indy instanceof InterfaceDataTypeDouble) {
                for (int iterations = 0; iterations < this.iterations; ++iterations) {
                    double[] oldgradient = indy.hasData(gradientKey) ? (double[])indy.getData(gradientKey) : null;
                    double[] wstepsize = (double[])indy.getData(wStepSizeKey);
                    double[] oldchange = null;
                    double[] gradient = ((InterfaceFirstOrderDerivableProblem)((Object)this.optimizationProblem)).getFirstOrderGradients(params);
                    if (oldgradient != null && wstepsize != null) {
                        for (int li = 0; li < wstepsize.length; ++li) {
                            double prod = gradient[li] * oldgradient[li];
                            if (prod < 0.0) {
                                wstepsize[li] = this.wDecreaseStepSize * wstepsize[li];
                            } else if (prod > 0.0) {
                                wstepsize[li] = this.wIncreaseStepSize * wstepsize[li];
                            }
                            wstepsize[li] = wstepsize[li] < this.localminstepsize ? this.localminstepsize : wstepsize[li];
                            wstepsize[li] = wstepsize[li] > this.localmaxstepsize ? this.localmaxstepsize : wstepsize[li];
                        }
                    }
                    double[] newparams = new double[params.length];
                    indy.putData(gradientKey, gradient);
                    double[] change = new double[params.length];
                    if (indy.hasData(changesKey)) {
                        oldchange = (double[])indy.getData(changesKey);
                    }
                    boolean dograddesc = this.momentumterm && oldchange != null;
                    for (int j = 0; j < newparams.length; ++j) {
                        if (lock[j] == 0) {
                            double tempstepsize = 1.0;
                            if (this.localStepSizeAdaption) {
                                tempstepsize *= wstepsize[j];
                            }
                            if (this.globalStepSizeAdaption) {
                                tempstepsize *= indystepsize;
                            }
                            double wchange = this.signum(tempstepsize * gradient[j]) * Math.min(this.maximumabsolutechange, Math.abs(tempstepsize * gradient[j]));
                            if (this.manhattan) {
                                wchange = this.signum(wchange) * tempstepsize;
                            }
                            if (dograddesc) {
                                wchange += this.momentumweigth * oldchange[j];
                            }
                            newparams[j] = params[j] - wchange;
                            if (newparams[j] < range[j][0]) {
                                newparams[j] = range[j][0];
                            }
                            if (newparams[j] > range[j][1]) {
                                newparams[j] = range[j][1];
                            }
                            int n = j;
                            change[n] = change[n] + wchange;
                            continue;
                        }
                        int n = j;
                        lock[n] = lock[n] - 1;
                    }
                    params = newparams;
                    indy.putData(changesKey, change);
                }
            } else {
                String msg = "Warning, problem of type InterfaceFirstOrderDerivableProblem and template of type InterfaceDataTypeDouble is required for " + this.getClass();
                EVAERROR.errorMsgOnce(msg);
                Class<?>[] clsArr = ReflectPackage.getAssignableClasses(InterfaceFirstOrderDerivableProblem.class.getName(), true, true);
                msg = msg + " (available: ";
                for (Class<?> cls : clsArr) {
                    msg = msg + " " + cls.getSimpleName();
                }
                msg = msg + ")";
                throw new RuntimeException(msg);
            }
            ((InterfaceDataTypeDouble)((Object)indy)).setDoubleGenotype(params);
        }
        this.optimizationProblem.evaluate(this.population);
        this.population.incrGeneration();
        if (this.recovery) {
            for (i = 0; i < this.population.size(); ++i) {
                indy = (AbstractEAIndividual)this.population.get(i);
                if (!(indy.getFitness()[0] > this.recoverythreshold)) continue;
                ((InterfaceDataTypeDouble)((Object)indy)).setDoublePhenotype((double[])indy.getData(oldParamsKey));
                double[] changes = (double[])indy.getData(changesKey);
                int[] lock = (int[])indy.getData(lockKey);
                int indexmaxchange = 0;
                double maxchangeval = Double.NEGATIVE_INFINITY;
                for (int j = 0; j < changes.length; ++j) {
                    if (!(changes[j] > maxchangeval) || lock[j] != 0) continue;
                    indexmaxchange = j;
                    maxchangeval = changes[j];
                }
                lock[indexmaxchange] = this.recoverylocksteps;
                indy.putData(lockKey, lock);
            }
            this.optimizationProblem.evaluate(this.population);
            this.population.incrGeneration();
        }
        if (this.globalStepSizeAdaption) {
            for (i = 0; i < this.population.size(); ++i) {
                indy = (AbstractEAIndividual)this.population.get(i);
                if (indy.getData(lastFitnessKey) != null) {
                    double lastfit = (Double)indy.getData(lastFitnessKey);
                    double indystepsize = (Double)indy.getData(stepSizeKey);
                    indystepsize = lastfit < indy.getFitness()[0] ? (indystepsize *= this.wDecreaseStepSize) : (indystepsize *= this.wIncreaseStepSize);
                    indystepsize = indystepsize > this.globalmaxstepsize ? this.globalmaxstepsize : indystepsize;
                    indystepsize = indystepsize < this.globalminstepsize ? this.globalminstepsize : indystepsize;
                    indy.putData(stepSizeKey, indystepsize);
                }
                indy.putData(lastFitnessKey, indy.getFitness()[0]);
            }
        }
        this.firePropertyChangedEvent("NextGenerationPerformed");
    }

    @Override
    public InterfaceSolutionSet getAllSolutions() {
        return new SolutionSet(this.getPopulation());
    }

    @Override
    public String getStringRepresentation() {
        return "GradientDescentAlgorithm";
    }

    public static void main(String[] args) {
        GradientDescentAlgorithm program = new GradientDescentAlgorithm();
        F1Problem problem = new F1Problem();
        program.setProblem(problem);
        program.initialize();
        for (int i = 0; i < 100; ++i) {
            program.optimize();
            System.out.println(program.getPopulation().getBestFitness()[0]);
        }
        double[] res = ((InterfaceDataTypeDouble)((Object)program.getPopulation().getBestIndividual())).getDoubleData();
        for (int i = 0; i < res.length; ++i) {
            System.out.print(res[i] + " ");
        }
    }

    public boolean isAdaptStepSizeGlobally() {
        return this.globalStepSizeAdaption;
    }

    public void setAdaptStepSizeGlobally(boolean globalstepsizeadaption) {
        this.globalStepSizeAdaption = globalstepsizeadaption;
        if (globalstepsizeadaption && this.localStepSizeAdaption) {
            this.setAdaptStepSizeLocally(false);
        }
    }

    public String adaptStepSizeGloballyTipText() {
        return "Use a single step size per individual - (priority over local step size).";
    }

    public double getGlobalMaxStepSize() {
        return this.globalmaxstepsize;
    }

    public void setGlobalMaxStepSize(double p) {
        this.globalmaxstepsize = p;
    }

    public String globalMaxStepSizeTipText() {
        return "Maximum step size for global adaption.";
    }

    public double getGlobalMinStepSize() {
        return this.globalminstepsize;
    }

    public void setGlobalMinStepSize(double p) {
        this.globalminstepsize = p;
    }

    public String globalMindStepSizeTipText() {
        return "Minimum step size for global adaption.";
    }

    public double getGlobalInitStepSize() {
        return this.globalinitstepsize;
    }

    public void setGlobalInitStepSize(double initstepsize) {
        this.globalinitstepsize = initstepsize;
    }

    public String globalInitStepSizeTipText() {
        return "Initial step size for global adaption.";
    }

    public boolean isAdaptStepSizeLocally() {
        return this.localStepSizeAdaption;
    }

    public void setAdaptStepSizeLocally(boolean stepsizeadaption) {
        this.localStepSizeAdaption = stepsizeadaption;
        if (this.globalStepSizeAdaption && this.localStepSizeAdaption) {
            this.setAdaptStepSizeGlobally(false);
        }
    }

    public String adaptStepSizeLocallyTipText() {
        return "Use a step size parameter in any dimension.";
    }

    public double getLocalMinStepSize() {
        return this.localminstepsize;
    }

    public void setLocalMinStepSize(double localminstepsize) {
        this.localminstepsize = localminstepsize;
    }

    public double getLocalMaxStepSize() {
        return this.localmaxstepsize;
    }

    public void setLocalMaxStepSize(double localmaxstepsize) {
        this.localmaxstepsize = localmaxstepsize;
    }

    public void setStepSizeIncreaseFact(double nplus) {
        this.wIncreaseStepSize = nplus;
    }

    public double getStepSizeIncreaseFact() {
        return this.wIncreaseStepSize;
    }

    public String stepSizeIncreaseFactTipText() {
        return "Factor for increasing the step size in adaption.";
    }

    public void setStepSizeDecreaseFact(double nminus) {
        this.wDecreaseStepSize = nminus;
    }

    public double getStepSizeDecreaseFact() {
        return this.wDecreaseStepSize;
    }

    public String stepSizeDecreaseFactTipText() {
        return "Factor for decreasing the step size in adaption.";
    }

    public boolean isRecovery() {
        return this.recovery;
    }

    public void setRecovery(boolean recovery) {
        this.recovery = recovery;
    }

    public int getRecoveryLocksteps() {
        return this.recoverylocksteps;
    }

    public void setRecoveryLocksteps(int locksteps) {
        this.recoverylocksteps = locksteps;
    }

    public double getRecoveryThreshold() {
        return this.recoverythreshold;
    }

    public void setRecoveryThreshold(double recoverythreshold) {
        this.recoverythreshold = recoverythreshold;
    }

    public String recoveryThresholdTipText() {
        return "If the fitness exceeds this threshold, an unstable area is assumed and one step recovered.";
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setIterations(int iterations) {
        this.iterations = iterations;
    }

    public String iterationsTipText() {
        return "The number of GD-iterations per generation.";
    }

    public boolean isManhattan() {
        return this.manhattan;
    }

    public void setManhattan(boolean manhattan) {
        this.manhattan = manhattan;
    }

    public boolean isMomentumTerm() {
        return this.momentumterm;
    }

    public void setMomentumTerm(boolean momentum) {
        this.momentumterm = momentum;
    }

    public double getMomentumWeigth() {
        return this.momentumweigth;
    }

    public void setMomentumWeigth(double momentumweigth) {
        this.momentumweigth = momentumweigth;
    }

    public double getMaximumAbsoluteChange() {
        return this.maximumabsolutechange;
    }

    public void setMaximumAbsoluteChange(double maximumabsolutechange) {
        this.maximumabsolutechange = maximumabsolutechange;
    }

    public String maximumAbsoluteChangeTipText() {
        return "The maximum change along a coordinate in one step.";
    }
}

