/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.optimization.Optimizer;
import jsat.utils.ProbailityMatch;

public class NelderMead
implements Optimizer {
    private static final long serialVersionUID = -2930235371787386607L;
    private double reflection = 1.0;
    private double expansion = 2.0;
    private double contraction = -0.5;
    private double shrink = 0.5;

    public void setReflection(double reflection) {
        if (reflection <= 0.0 || Double.isNaN(reflection) || Double.isInfinite(reflection)) {
            throw new ArithmeticException("Reflection constant must be > 0, not " + reflection);
        }
        this.reflection = reflection;
    }

    public void setExpansion(double expansion) {
        if (expansion <= 1.0 || Double.isNaN(expansion) || Double.isInfinite(expansion)) {
            throw new ArithmeticException("Expansion constant must be > 1, not " + expansion);
        }
        if (expansion <= this.reflection) {
            throw new ArithmeticException("Expansion constant must be less than the reflection constant");
        }
        this.expansion = expansion;
    }

    public void setContraction(double contraction) {
        if (contraction >= 1.0 || contraction <= 0.0 || Double.isNaN(contraction) || Double.isInfinite(contraction)) {
            throw new ArithmeticException("Contraction constant must be > 0 and < 1, not " + contraction);
        }
        this.contraction = contraction;
    }

    public void setShrink(double shrink) {
        if (shrink >= 1.0 || shrink <= 0.0 || Double.isNaN(shrink) || Double.isInfinite(shrink)) {
            throw new ArithmeticException("Shrinkage constant must be > 0 and < 1, not " + shrink);
        }
        this.shrink = shrink;
    }

    @Override
    public Vec optimize(double eps, int iterationLimit, Function f, Function fd, Vec vars, List<Vec> inputs, Vec outputs, ExecutorService threadpool) {
        return this.optimize(eps, iterationLimit, f, fd, vars, inputs, outputs);
    }

    @Override
    public Vec optimize(double eps, int iterationLimit, Function f, Function fd, Vec vars, List<Vec> inputs, Vec outputs) {
        ArrayList<Vec> initialPoints = new ArrayList<Vec>();
        initialPoints.add(vars);
        return this.optimize(eps, iterationLimit, f, initialPoints);
    }

    public Vec optimize(double eps, int iterationLimit, Function f, List<Vec> initalPoints) {
        if (initalPoints.isEmpty()) {
            throw new ArithmeticException("Empty Initial list. Can not determin dimension of problem");
        }
        Vec init = initalPoints.get(0);
        int N = initalPoints.get(0).length();
        ArrayList<ProbailityMatch<Vec>> simplex = new ArrayList<ProbailityMatch<Vec>>(N);
        for (Vec vars : initalPoints) {
            simplex.add(new ProbailityMatch<Vec>(f.f(vars), vars.clone()));
        }
        Random rand = new Random(initalPoints.hashCode());
        while (simplex.size() < N + 1) {
            DenseVector newSimplex = new DenseVector(N);
            for (int i = 0; i < newSimplex.length(); ++i) {
                if (init.get(i) != 0.0) {
                    newSimplex.set(i, init.get(i) * rand.nextGaussian());
                    continue;
                }
                newSimplex.set(i, rand.nextGaussian());
            }
            simplex.add(new ProbailityMatch<DenseVector>(f.f(newSimplex), newSimplex));
        }
        Collections.sort(simplex);
        while (simplex.size() > N + 1) {
            simplex.remove(simplex.size() - 1);
        }
        DenseVector x0 = new DenseVector(N);
        DenseVector xr = new DenseVector(N);
        DenseVector xec = new DenseVector(N);
        DenseVector tmp = new DenseVector(N);
        int lastIndex = simplex.size() - 1;
        for (int iterationCount = 0; iterationCount < iterationLimit && !(Math.abs(((ProbailityMatch)simplex.get(lastIndex)).getProbability() - ((ProbailityMatch)simplex.get(0)).getProbability()) < eps); ++iterationCount) {
            double fxec;
            x0.zeroOut();
            for (ProbailityMatch probailityMatch : simplex) {
                x0.mutableAdd((Vec)probailityMatch.getMatch());
            }
            ((Vec)x0).mutableDivide(simplex.size());
            x0.copyTo(xr);
            x0.copyTo(tmp);
            tmp.mutableSubtract((Vec)((ProbailityMatch)simplex.get(lastIndex)).getMatch());
            ((Vec)xr).mutableAdd(this.reflection, tmp);
            double fxr = f.f(xr);
            if (((ProbailityMatch)simplex.get(0)).getProbability() <= fxr && fxr < ((ProbailityMatch)simplex.get(lastIndex - 1)).getProbability()) {
                NelderMead.insertIntoSimplex(simplex, xr, fxr);
                continue;
            }
            if (fxr < ((ProbailityMatch)simplex.get(0)).getProbability()) {
                x0.copyTo(xec);
                ((Vec)xec).mutableAdd(this.expansion, tmp);
                fxec = f.f(xec);
                if (fxec < fxr) {
                    NelderMead.insertIntoSimplex(simplex, xec, fxec);
                    continue;
                }
                NelderMead.insertIntoSimplex(simplex, xr, fxr);
                continue;
            }
            x0.copyTo(xec);
            ((Vec)xec).mutableAdd(this.contraction, tmp);
            fxec = f.f(xec);
            if (fxec < ((ProbailityMatch)simplex.get(lastIndex)).getProbability()) {
                NelderMead.insertIntoSimplex(simplex, xec, fxec);
                continue;
            }
            Vec xBest = (Vec)((ProbailityMatch)simplex.get(0)).getMatch();
            for (int i = 1; i < simplex.size(); ++i) {
                ProbailityMatch pm = (ProbailityMatch)simplex.get(i);
                Vec xi = (Vec)pm.getMatch();
                xi.mutableSubtract(xBest);
                xi.mutableMultiply(this.shrink);
                xi.mutableAdd(xBest);
                pm.setProbability(f.f(xi));
            }
            Collections.sort(simplex);
        }
        return (Vec)((ProbailityMatch)simplex.get(0)).getMatch();
    }

    private static void insertIntoSimplex(List<ProbailityMatch<Vec>> simplex, Vec x, double fx) {
        ProbailityMatch<Vec> pm = simplex.remove(simplex.size() - 1);
        pm.setProbability(fx);
        x.copyTo(pm.getMatch());
        int sortInto = Collections.binarySearch(simplex, pm);
        if (sortInto >= 0) {
            simplex.add(sortInto, pm);
        } else if ((sortInto = -sortInto - 1) == simplex.size()) {
            simplex.add(pm);
        } else {
            simplex.add(sortInto, pm);
        }
    }
}

