/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.glrm;

import hex.genmodel.utils.ArrayUtils;

public enum GlrmLoss {
    Quadratic{

        @Override
        public boolean isForNumeric() {
            return true;
        }

        @Override
        public boolean isForCategorical() {
            return false;
        }

        @Override
        public boolean isForBinary() {
            return false;
        }

        @Override
        public double loss(double u, double a) {
            return (u - a) * (u - a);
        }

        @Override
        public double lgrad(double u, double a) {
            return 2.0 * (u - a);
        }

        @Override
        public double impute(double u) {
            return u;
        }
    }
    ,
    Absolute{

        @Override
        public boolean isForNumeric() {
            return true;
        }

        @Override
        public boolean isForCategorical() {
            return false;
        }

        @Override
        public boolean isForBinary() {
            return false;
        }

        @Override
        public double loss(double u, double a) {
            return Math.abs(u - a);
        }

        @Override
        public double lgrad(double u, double a) {
            return Math.signum(u - a);
        }

        @Override
        public double impute(double u) {
            return u;
        }
    }
    ,
    Huber{

        @Override
        public boolean isForNumeric() {
            return true;
        }

        @Override
        public boolean isForCategorical() {
            return false;
        }

        @Override
        public boolean isForBinary() {
            return false;
        }

        @Override
        public double loss(double u, double a) {
            double x = u - a;
            return x > 1.0 ? x - 0.5 : (x < -1.0 ? -x - 0.5 : 0.5 * x * x);
        }

        @Override
        public double lgrad(double u, double a) {
            double x = u - a;
            return x > 1.0 ? 1.0 : (x < -1.0 ? -1.0 : x);
        }

        @Override
        public double impute(double u) {
            return u;
        }
    }
    ,
    Poisson{

        @Override
        public boolean isForNumeric() {
            return true;
        }

        @Override
        public boolean isForCategorical() {
            return false;
        }

        @Override
        public boolean isForBinary() {
            return false;
        }

        @Override
        public double loss(double u, double a) {
            assert (a >= 0.0) : "Poisson loss L(u,a) requires variable a >= 0";
            return Math.exp(u) + (a == 0.0 ? 0.0 : -a * u + a * Math.log(a) - a);
        }

        @Override
        public double lgrad(double u, double a) {
            assert (a >= 0.0) : "Poisson loss L(u,a) requires variable a >= 0";
            return Math.exp(u) - a;
        }

        @Override
        public double impute(double u) {
            return Math.exp(u);
        }
    }
    ,
    Periodic{
        private double f;
        private int period;

        @Override
        public boolean isForNumeric() {
            return true;
        }

        @Override
        public boolean isForCategorical() {
            return false;
        }

        @Override
        public boolean isForBinary() {
            return false;
        }

        @Override
        public double loss(double u, double a) {
            return 1.0 - Math.cos((u - a) * this.f);
        }

        @Override
        public double lgrad(double u, double a) {
            return this.f * Math.sin((u - a) * this.f);
        }

        @Override
        public double impute(double u) {
            return u;
        }

        @Override
        public void setParameters(int period) {
            this.period = period;
            this.f = Math.PI * 2 / (double)period;
        }

        public String toString() {
            return "Periodic(" + this.period + ")";
        }
    }
    ,
    Logistic{

        @Override
        public boolean isForNumeric() {
            return false;
        }

        @Override
        public boolean isForCategorical() {
            return false;
        }

        @Override
        public boolean isForBinary() {
            return true;
        }

        @Override
        public double loss(double u, double a) {
            assert (a == 0.0 || a == 1.0) : "Logistic loss should be applied to binary features only";
            return Math.log(1.0 + Math.exp((1.0 - 2.0 * a) * u));
        }

        @Override
        public double lgrad(double u, double a) {
            double s = 1.0 - 2.0 * a;
            return s / (1.0 + Math.exp(s * u));
        }

        @Override
        public double impute(double u) {
            return u > 0.0 ? 1.0 : 0.0;
        }
    }
    ,
    Hinge{

        @Override
        public boolean isForNumeric() {
            return false;
        }

        @Override
        public boolean isForCategorical() {
            return false;
        }

        @Override
        public boolean isForBinary() {
            return true;
        }

        @Override
        public double loss(double u, double a) {
            assert (a == 0.0 || a == 1.0) : "Hinge loss should be applied to binary variables only";
            return Math.max(1.0 + (1.0 - 2.0 * a) * u, 0.0);
        }

        @Override
        public double lgrad(double u, double a) {
            double s = 1.0 - 2.0 * a;
            return 1.0 + s * u > 0.0 ? s : 0.0;
        }

        @Override
        public double impute(double u) {
            return u > 0.0 ? 1.0 : 0.0;
        }
    }
    ,
    Categorical{

        @Override
        public boolean isForNumeric() {
            return false;
        }

        @Override
        public boolean isForCategorical() {
            return true;
        }

        @Override
        public boolean isForBinary() {
            return false;
        }

        @Override
        public double mloss(double[] u, int a) {
            return this.mloss(u, a, u.length);
        }

        @Override
        public double mloss(double[] u, int a, int u_len) {
            if (a < 0 || a >= u_len) {
                throw new IndexOutOfBoundsException("a must be between 0 and " + (u_len - 1));
            }
            double sum = 0.0;
            for (int ind = 0; ind < u_len; ++ind) {
                sum += Math.max(1.0 + u[ind], 0.0);
            }
            return sum += Math.max(1.0 - u[a], 0.0) - Math.max(1.0 + u[a], 0.0);
        }

        @Override
        public double[] mlgrad(double[] u, int a) {
            double[] grad = new double[u.length];
            return this.mlgrad(u, a, grad, u.length);
        }

        @Override
        public double[] mlgrad(double[] u, int a, double[] grad, int u_len) {
            if (a < 0 || a >= u_len) {
                throw new IndexOutOfBoundsException("a must be between 0 and " + (u_len - 1));
            }
            for (int i = 0; i < u_len; ++i) {
                grad[i] = 1.0 + u[i] > 0.0 ? 1.0 : 0.0;
            }
            grad[a] = 1.0 - u[a] > 0.0 ? -1.0 : 0.0;
            return grad;
        }

        @Override
        public int mimpute(double[] u) {
            return ArrayUtils.maxIndex(u);
        }
    }
    ,
    Ordinal{

        @Override
        public boolean isForNumeric() {
            return false;
        }

        @Override
        public boolean isForCategorical() {
            return true;
        }

        @Override
        public boolean isForBinary() {
            return false;
        }

        @Override
        public double mloss(double[] u, int a) {
            if (a < 0 || a >= u.length) {
                throw new IndexOutOfBoundsException("a must be between 0 and " + (u.length - 1));
            }
            double sum = 0.0;
            for (int i = 0; i < u.length - 1; ++i) {
                sum += a > i ? Math.max(1.0 - u[i], 0.0) : 1.0;
            }
            return sum;
        }

        @Override
        public double mloss(double[] u, int a, int u_len) {
            if (a < 0 || a >= u_len) {
                throw new IndexOutOfBoundsException("a must be between 0 and " + (u_len - 1));
            }
            double sum = 0.0;
            for (int i = 0; i < u_len - 1; ++i) {
                sum += a > i ? Math.max(1.0 - u[i], 0.0) : 1.0;
            }
            return sum;
        }

        @Override
        public double[] mlgrad(double[] u, int a) {
            if (a < 0 || a >= u.length) {
                throw new IndexOutOfBoundsException("a must be between 0 and " + (u.length - 1));
            }
            double[] grad = new double[u.length];
            for (int i = 0; i < u.length - 1; ++i) {
                grad[i] = a > i && 1.0 - u[i] > 0.0 ? -1.0 : 0.0;
            }
            return grad;
        }

        @Override
        public double[] mlgrad(double[] u, int a, double[] grad, int u_len) {
            if (a < 0 || a >= u_len) {
                throw new IndexOutOfBoundsException("a must be between 0 and " + (u_len - 1));
            }
            for (int i = 0; i < u_len - 1; ++i) {
                grad[i] = a > i && 1.0 - u[i] > 0.0 ? -1.0 : 0.0;
            }
            return grad;
        }

        @Override
        public int mimpute(double[] u) {
            double sum;
            double best_loss = sum = (double)(u.length - 1);
            int best_a = 0;
            for (int a = 1; a < u.length; ++a) {
                if (!((sum -= Math.min(1.0, u[a - 1])) < best_loss)) continue;
                best_loss = sum;
                best_a = a;
            }
            return best_a;
        }
    };


    public abstract boolean isForNumeric();

    public abstract boolean isForCategorical();

    public abstract boolean isForBinary();

    public double loss(double u, double a) {
        throw new UnsupportedOperationException();
    }

    public double lgrad(double u, double a) {
        throw new UnsupportedOperationException();
    }

    public double impute(double u) {
        throw new UnsupportedOperationException();
    }

    public double mloss(double[] u, int a) {
        throw new UnsupportedOperationException();
    }

    public double mloss(double[] u, int a, int u_len) {
        throw new UnsupportedOperationException();
    }

    public double[] mlgrad(double[] u, int a) {
        throw new UnsupportedOperationException();
    }

    public double[] mlgrad(double[] u, int a, double[] prod, int u_len) {
        throw new UnsupportedOperationException();
    }

    public int mimpute(double[] u) {
        throw new UnsupportedOperationException();
    }

    public void setParameters(int p) {
        throw new UnsupportedOperationException();
    }
}

