/*
 * Decompiled with CFR 0.152.
 */
package jsat.distributions.kernels;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import jsat.distributions.kernels.KernelTrick;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.SubMatrix;
import jsat.linear.Vec;
import jsat.math.FastMath;
import jsat.math.FunctionBase;
import jsat.math.optimization.GoldenSearch;
import jsat.utils.DoubleList;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

public class KernelPoint {
    protected KernelTrick k;
    private double errorTolerance;
    protected List<Vec> vecs;
    protected List<Double> kernelAccel;
    protected Matrix K;
    protected Matrix InvK;
    protected Matrix KExpanded;
    protected Matrix InvKExpanded;
    protected DoubleList alpha;
    protected BudgetStrategy budgetStrategy = BudgetStrategy.PROJECTION;
    protected int maxBudget = Integer.MAX_VALUE;
    private double sqrdNorm = 0.0;
    private boolean normGood = true;

    public KernelPoint(KernelTrick k, double errorTolerance) {
        this.k = k;
        this.setErrorTolerance(errorTolerance);
        this.setBudgetStrategy(BudgetStrategy.PROJECTION);
        this.setMaxBudget(Integer.MAX_VALUE);
        if (k.supportsAcceleration()) {
            this.kernelAccel = new DoubleList(16);
        }
        this.alpha = new DoubleList(16);
        this.vecs = new ArrayList<Vec>(16);
    }

    public KernelPoint(KernelPoint toCopy) {
        this.k = toCopy.k.clone();
        this.errorTolerance = toCopy.errorTolerance;
        if (toCopy.vecs != null) {
            this.vecs = new ArrayList<Vec>(toCopy.vecs.size());
            for (Vec v : toCopy.vecs) {
                this.vecs.add(v.clone());
            }
            if (toCopy.kernelAccel != null) {
                this.kernelAccel = new DoubleList(toCopy.kernelAccel);
            }
            this.alpha = new DoubleList(toCopy.alpha);
        }
        if (toCopy.KExpanded != null) {
            this.KExpanded = toCopy.KExpanded.clone();
            this.InvKExpanded = toCopy.InvKExpanded.clone();
            this.K = new SubMatrix(this.KExpanded, 0, 0, toCopy.K.rows(), toCopy.K.cols());
            this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, toCopy.InvK.rows(), toCopy.InvK.rows());
        }
        this.maxBudget = toCopy.maxBudget;
        this.sqrdNorm = toCopy.sqrdNorm;
        this.normGood = toCopy.normGood;
    }

    public void setMaxBudget(int maxBudget) {
        if (maxBudget < 1) {
            throw new IllegalArgumentException("Budget must be positive, not " + maxBudget);
        }
        this.maxBudget = maxBudget;
    }

    public int getMaxBudget() {
        return this.maxBudget;
    }

    public void setBudgetStrategy(BudgetStrategy budgetStrategy) {
        if (this.getBasisSize() > 0) {
            throw new RuntimeException("KerenlPoint already started, budget may not be changed");
        }
        this.budgetStrategy = budgetStrategy;
    }

    public BudgetStrategy getBudgetStrategy() {
        return this.budgetStrategy;
    }

    public void setErrorTolerance(double errorTolerance) {
        if (Double.isNaN(errorTolerance) || errorTolerance < 0.0 || errorTolerance > 1.0) {
            throw new IllegalArgumentException("Error tolerance must be in [0, 1], not " + errorTolerance);
        }
        this.errorTolerance = errorTolerance;
    }

    public double getErrorTolerance() {
        return this.errorTolerance;
    }

    public double getSqrdNorm() {
        if (!this.normGood) {
            this.sqrdNorm = 0.0;
            for (int i = 0; i < this.alpha.size(); ++i) {
                int j;
                if (this.K != null) {
                    this.sqrdNorm += this.alpha.get(i) * this.alpha.get(i) * this.K.get(i, i);
                    for (j = i + 1; j < this.alpha.size(); ++j) {
                        this.sqrdNorm += 2.0 * this.alpha.get(i) * this.alpha.get(j) * this.K.get(i, j);
                    }
                    continue;
                }
                this.sqrdNorm += this.alpha.get(i) * this.alpha.get(i) * this.k.eval(i, i, this.vecs, this.kernelAccel);
                for (j = i + 1; j < this.alpha.size(); ++j) {
                    this.sqrdNorm += 2.0 * this.alpha.get(i) * this.alpha.get(j) * this.k.eval(i, j, this.vecs, this.kernelAccel);
                }
            }
            this.normGood = true;
        }
        return this.sqrdNorm;
    }

    public double dot(Vec x) {
        return this.dot(x, this.k.getQueryInfo(x));
    }

    public double dot(Vec x, List<Double> qi) {
        if (this.getBasisSize() == 0) {
            return 0.0;
        }
        return this.k.evalSum(this.vecs, this.kernelAccel, this.alpha.getBackingArray(), x, qi, 0, this.alpha.size());
    }

    public double dot(KernelPoint x) {
        if (this.getBasisSize() == 0 || x.getBasisSize() == 0) {
            return 0.0;
        }
        int shift = this.alpha.size();
        List<Vec> mergedVecs = ListUtils.mergedView(this.vecs, x.vecs);
        List<Double> mergedCache = this.kernelAccel == null || x.kernelAccel == null ? null : ListUtils.mergedView(this.kernelAccel, x.kernelAccel);
        double dot = 0.0;
        for (int i = 0; i < this.alpha.size(); ++i) {
            for (int j = 0; j < x.alpha.size(); ++j) {
                dot += this.alpha.get(i) * x.alpha.get(j) * this.k.eval(i, j + shift, mergedVecs, mergedCache);
            }
        }
        return dot;
    }

    public double dist(Vec x) {
        return this.dist(x, this.k.getQueryInfo(x));
    }

    public double dist(Vec x, List<Double> qi) {
        double k_xx = this.k.eval(0, 0, Arrays.asList(x), qi);
        return Math.sqrt(k_xx + this.getSqrdNorm() - 2.0 * this.dot(x, qi));
    }

    public double dist(KernelPoint x) {
        if (this == x) {
            return 0.0;
        }
        double d = this.getSqrdNorm() + x.getSqrdNorm() - 2.0 * this.dot(x);
        return Math.sqrt(Math.max(0.0, d));
    }

    public void mutableMultiply(double c) {
        if (Double.isNaN(c) || Double.isInfinite(c)) {
            throw new IllegalArgumentException("multiplier must be a real value, not " + c);
        }
        if (this.getBasisSize() == 0) {
            return;
        }
        this.sqrdNorm *= c * c;
        this.alpha.getVecView().mutableMultiply(c);
    }

    public void mutableAdd(Vec x_t) {
        this.mutableAdd(1.0, x_t);
    }

    public void mutableAdd(double c, Vec x_t) {
        this.mutableAdd(c, x_t, this.k.getQueryInfo(x_t));
    }

    public void mutableAdd(double c, Vec x_t, List<Double> qi) {
        if (c == 0.0) {
            return;
        }
        this.normGood = false;
        double y_t = c;
        double k_tt = this.k.eval(0, 0, Arrays.asList(x_t), qi);
        if (this.budgetStrategy == BudgetStrategy.PROJECTION) {
            if (this.K == null) {
                this.KExpanded = new DenseMatrix(16, 16);
                this.K = new SubMatrix(this.KExpanded, 0, 0, 1, 1);
                this.K.set(0, 0, k_tt);
                this.InvKExpanded = new DenseMatrix(16, 16);
                this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, 1, 1);
                this.InvK.set(0, 0, 1.0 / k_tt);
                this.alpha.add(y_t);
                this.vecs.add(x_t);
                if (this.kernelAccel != null) {
                    this.kernelAccel.addAll(qi);
                }
                return;
            }
            DenseVector kxt = new DenseVector(this.K.rows());
            for (int i = 0; i < kxt.length(); ++i) {
                kxt.set(i, this.k.eval(i, x_t, qi, this.vecs, this.kernelAccel));
            }
            Vec alphas_t = this.InvK.multiply(kxt);
            double delta_t = k_tt - alphas_t.dot(kxt);
            int size = this.K.rows();
            if (delta_t > this.errorTolerance && size < this.maxBudget) {
                this.vecs.add(x_t);
                if (this.kernelAccel != null) {
                    this.kernelAccel.addAll(qi);
                }
                if (size == this.KExpanded.rows()) {
                    this.KExpanded.changeSize(size * 2, size * 2);
                    this.InvKExpanded.changeSize(size * 2, size * 2);
                }
                Matrix.OuterProductUpdate(this.InvK, alphas_t, alphas_t, 1.0 / delta_t);
                this.K = new SubMatrix(this.KExpanded, 0, 0, size + 1, size + 1);
                this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, size + 1, size + 1);
                for (int i = 0; i < size; ++i) {
                    this.K.set(size, i, kxt.get(i));
                    this.K.set(i, size, kxt.get(i));
                    this.InvK.set(size, i, -alphas_t.get(i) / delta_t);
                    this.InvK.set(i, size, -alphas_t.get(i) / delta_t);
                }
                this.K.set(size, size, k_tt);
                this.InvK.set(size, size, 1.0 / delta_t);
                this.alpha.add(y_t);
            } else {
                Vec alphaVec = this.alpha.getVecView();
                alphaVec.mutableAdd(y_t, alphas_t);
                this.normGood = false;
            }
        } else if (this.budgetStrategy == BudgetStrategy.MERGE_RBF) {
            this.normGood = false;
            this.addPoint(x_t, qi, y_t);
            if (this.vecs.size() > this.maxBudget) {
                int m = 0;
                double alpha_m = Math.abs(this.alpha.get(m));
                for (int i = 1; i < this.alpha.size(); ++i) {
                    if (!(Math.abs(this.alpha.getD(i)) < Math.abs(alpha_m))) continue;
                    alpha_m = this.alpha.getD(i);
                    m = i;
                }
                double minLoss = Double.POSITIVE_INFINITY;
                int n = -1;
                double n_h = 0.0;
                double n_alpha_z = 0.0;
                double tol = 0.001;
                while (n == -1) {
                    for (int i = 0; i < this.alpha.size(); ++i) {
                        double k_nz;
                        double h;
                        double k_mz;
                        double alpha_z;
                        double k_mn;
                        double loss;
                        double a_n;
                        double a_m;
                        double normalize;
                        if (i == m || Math.abs(normalize = (a_m = alpha_m) + (a_n = this.alpha.getD(i))) < tol || !((loss = a_m * a_m + a_n * a_n + 2.0 * (k_mn = this.k.eval(i, m, this.vecs, this.kernelAccel)) * a_m * a_n - (alpha_z = a_m * (k_mz = Math.pow(k_mn, (1.0 - (h = KernelPoint.getH(k_mn, a_m / normalize, a_n / normalize))) * (1.0 - h))) + a_n * (k_nz = Math.pow(k_mn, h * h))) * alpha_z) < minLoss)) continue;
                        minLoss = loss;
                        n = i;
                        n_h = h;
                        n_alpha_z = alpha_z;
                    }
                    tol /= 10.0;
                }
                Vec n_z = this.vecs.get(m).multiply(n_h);
                n_z.mutableAdd(1.0 - n_h, this.vecs.get(n));
                List<Double> nz_qi = this.k.getQueryInfo(n_z);
                this.finalMergeStep(m, n, n_z, nz_qi, n_alpha_z, true);
            }
        } else if (this.budgetStrategy == BudgetStrategy.STOP) {
            this.normGood = false;
            if (this.getBasisSize() < this.maxBudget) {
                this.addPoint(x_t, qi, y_t);
            }
        } else if (this.budgetStrategy == BudgetStrategy.RANDOM) {
            this.normGood = false;
            if (this.getBasisSize() >= this.maxBudget) {
                Random rand = RandomUtil.getRandom();
                int toRemove = rand.nextInt(this.vecs.size());
                this.removeIndex(toRemove);
            }
            this.addPoint(x_t, qi, y_t);
        } else {
            throw new RuntimeException("BUG: report me!");
        }
    }

    private void addPoint(Vec x_t, List<Double> qi, double y_t) {
        this.vecs.add(x_t);
        if (this.kernelAccel != null) {
            this.kernelAccel.addAll(qi);
        }
        this.alpha.add(y_t);
    }

    protected void finalMergeStep(int m, int n, Vec n_z, List<Double> nz_qi, double n_alpha_z, boolean alterVecs) {
        int smallIndx = Math.min(m, n);
        int largeIndx = Math.max(m, n);
        this.alpha.remove(largeIndx);
        this.alpha.remove(smallIndx);
        if (alterVecs) {
            this.vecs.remove(largeIndx);
            this.vecs.remove(smallIndx);
            this.kernelAccel.remove(largeIndx);
            this.kernelAccel.remove(smallIndx);
            this.vecs.add(n_z);
            this.kernelAccel.addAll(nz_qi);
        }
        this.alpha.add(n_alpha_z);
    }

    protected static double getH(final double k_mn, final double a_m, final double a_n) {
        if (a_m == a_n) {
            return 0.5;
        }
        FunctionBase f = new FunctionBase(){
            private static final long serialVersionUID = -6891301465754898634L;

            @Override
            public double f(Vec x) {
                double h = x.get(0);
                return -(a_m * FastMath.pow(k_mn, (1.0 - h) * (1.0 - h)) + a_n * FastMath.pow(k_mn, h * h));
            }
        };
        if (Math.signum(a_m) != Math.signum(a_n)) {
            if (a_m < 0.0) {
                return GoldenSearch.minimize(0.001, 100, 0.0, 0.2, 0, f, 0.0);
            }
            if (a_n < 0.0) {
                return GoldenSearch.minimize(0.001, 100, 0.8, 1.0, 0, f, 0.0);
            }
        }
        if (a_m > a_n) {
            return GoldenSearch.minimize(0.001, 100, 0.5, 1.0, 0, f, 0.0);
        }
        return GoldenSearch.minimize(0.001, 100, 0.0, 0.5, 0, f, 0.0);
    }

    protected void removeIndex(int toRemove) {
        if (this.kernelAccel != null) {
            int num = this.kernelAccel.size() / this.vecs.size();
            for (int i = 0; i < num; ++i) {
                this.kernelAccel.remove(toRemove);
            }
        }
        this.alpha.remove(toRemove);
        this.vecs.remove(toRemove);
    }

    public int getBasisSize() {
        if (this.vecs == null) {
            return 0;
        }
        return this.vecs.size();
    }

    public List<Vec> getRawBasisVecs() {
        return Collections.unmodifiableList(this.vecs);
    }

    public KernelPoint clone() {
        return new KernelPoint(this);
    }

    public static enum BudgetStrategy {
        PROJECTION,
        MERGE_RBF,
        STOP,
        RANDOM;

    }
}

