/*
 * Decompiled with CFR 0.152.
 */
package hivemall.factorization.fm;

import hivemall.factorization.fm.FMHyperParameters;
import hivemall.factorization.fm.Feature;
import hivemall.optimizer.EtaEstimator;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.math.MathUtils;
import java.util.Arrays;
import java.util.Objects;
import java.util.Random;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.metadata.HiveException;

public abstract class FactorizationMachineModel {
    protected final boolean _classification;
    protected final int _factor;
    protected final double _sigma;
    @Nonnull
    protected final EtaEstimator _eta;
    @Nonnull
    protected final VInitScheme _initScheme;
    @Nonnull
    protected final Random _rnd;
    protected final double _min_target;
    protected final double _max_target;
    protected float _lambdaW0;
    protected float _lambdaW;
    @Nonnull
    protected final float[] _lambdaV;

    public FactorizationMachineModel(@Nonnull FMHyperParameters params) {
        this._classification = params.classification;
        this._factor = params.factors;
        this._sigma = params.sigma;
        this._eta = Objects.requireNonNull(params.eta);
        this._initScheme = Objects.requireNonNull(params.vInit);
        this._rnd = new Random(params.seed);
        this._min_target = params.minTarget;
        this._max_target = params.maxTarget;
        this._lambdaW0 = params.lambdaW0;
        this._lambdaW = params.lambdaW;
        this._lambdaV = new float[params.factors];
        Arrays.fill(this._lambdaV, params.lambdaV);
    }

    public abstract int getSize();

    protected int getMinIndex() {
        throw new UnsupportedOperationException();
    }

    protected int getMaxIndex() {
        throw new UnsupportedOperationException();
    }

    public abstract float getW0();

    protected abstract void setW0(float var1);

    protected float getW(int i) {
        throw new UnsupportedOperationException();
    }

    public abstract float getW(@Nonnull Feature var1);

    protected abstract void setW(@Nonnull Feature var1, float var2);

    @Nullable
    protected float[] getV(int i, boolean init) {
        throw new UnsupportedOperationException();
    }

    public abstract float getV(@Nonnull Feature var1, int var2);

    protected abstract void setV(@Nonnull Feature var1, int var2, float var3);

    float getLambdaV(int f) {
        return this._lambdaV[f];
    }

    final double dloss(@Nonnull Feature[] x, double y) throws HiveException {
        double p = this.predict(x);
        return this.dloss(p, y);
    }

    final double dloss(double p, double y) {
        double ret;
        if (this._classification) {
            ret = (MathUtils.sigmoid(p * y) - 1.0) * y;
        } else {
            p = Math.min(p, this._max_target);
            p = Math.max(p, this._min_target);
            ret = p - y;
        }
        return ret;
    }

    protected double predict(@Nonnull Feature[] x) throws HiveException {
        double ret = this.getW0();
        for (Feature e : x) {
            double xj = e.getValue();
            float w = this.getW(e);
            double wx = (double)w * xj;
            ret += wx;
        }
        int k = this._factor;
        for (int f = 0; f < k; ++f) {
            double sumVjfXj = 0.0;
            double sumV2X2 = 0.0;
            for (Feature e : x) {
                double xj = e.getValue();
                float vjf = this.getV(e, f);
                double vx = (double)vjf * xj;
                sumVjfXj += vx;
                sumV2X2 += vx * vx;
            }
            assert (!Double.isNaN(ret += 0.5 * (sumVjfXj * sumVjfXj - sumV2X2)));
        }
        if (!NumberUtils.isFinite(ret)) {
            throw new HiveException("Detected " + ret + " in predict. We recommend to normalize training examples.\nDumping variables ...\n" + this.varDump(x));
        }
        return ret;
    }

    protected String varDump(@Nonnull Feature[] x) {
        String j;
        Feature e;
        int i;
        StringBuilder buf = new StringBuilder(1024);
        for (i = 0; i < x.length; ++i) {
            e = x[i];
            j = e.getFeature();
            double xj = e.getValue();
            if (i != 0) {
                buf.append(", ");
            }
            buf.append("x[").append(j).append("] = ").append(xj);
        }
        buf.append("\n");
        buf.append("W0 = ").append(this.getW0()).append('\n');
        for (i = 0; i < x.length; ++i) {
            e = x[i];
            j = e.getFeature();
            float wi = this.getW(e);
            if (i != 0) {
                buf.append(", ");
            }
            buf.append("W[").append(j).append("] = ").append(wi);
        }
        buf.append("\n");
        int k = this._factor;
        for (int f = 0; f < k; ++f) {
            for (int i2 = 0; i2 < x.length; ++i2) {
                Feature e2 = x[i2];
                String j2 = e2.getFeature();
                float vjf = this.getV(e2, f);
                if (i2 != 0) {
                    buf.append(", ");
                }
                buf.append('V').append(f).append('[').append(j2).append("] = ").append(vjf);
            }
            buf.append('\n');
        }
        return buf.toString();
    }

    final void updateW0(double dloss, float eta) {
        float gradW0 = (float)dloss;
        float prevW0 = this.getW0();
        float nextW0 = prevW0 - eta * (gradW0 + 2.0f * this._lambdaW0 * prevW0);
        if (!NumberUtils.isFinite(nextW0)) {
            throw new IllegalStateException("Got " + nextW0 + " for next W0\ngradW0=" + gradW0 + ", prevW0=" + prevW0 + ", dloss=" + dloss + ", eta=" + eta);
        }
        this.setW0(nextW0);
    }

    void updateWi(double dloss, @Nonnull Feature x, float eta) {
        double Xi = x.getValue();
        float gradWi = (float)(dloss * Xi);
        float wi = this.getW(x);
        float nextWi = wi - eta * (gradWi + 2.0f * this._lambdaW * wi);
        if (!NumberUtils.isFinite(nextWi)) {
            throw new IllegalStateException("Got " + nextWi + " for next W[" + x.getFeature() + "]\nXi=" + Xi + ", gradWi=" + gradWi + ", wi=" + wi + ", dloss=" + dloss + ", eta=" + eta);
        }
        this.setW(x, nextWi);
    }

    final void updateV(double dloss, @Nonnull Feature x, int f, double sumViX, float eta) {
        float LambdaVf;
        double h;
        float gradV;
        double Xi = x.getValue();
        float Vif = this.getV(x, f);
        float nextVif = Vif - eta * ((gradV = (float)(dloss * (h = this.gradV(Xi, Vif, sumViX)))) + 2.0f * (LambdaVf = this.getLambdaV(f)) * Vif);
        if (!NumberUtils.isFinite(nextVif)) {
            throw new IllegalStateException("Got " + nextVif + " for next V" + f + '[' + x.getFeature() + "]\nXi=" + Xi + ", Vif=" + Vif + ", h=" + h + ", gradV=" + gradV + ", lambdaVf=" + LambdaVf + ", dloss=" + dloss + ", sumViX=" + sumViX + ", eta=" + eta);
        }
        this.setV(x, f, nextVif);
    }

    final void updateLambdaW0(double dloss, float eta) {
        float lambda_w_grad = -2.0f * eta * this.getW0();
        float lambdaW0 = this._lambdaW0 - (float)((double)eta * dloss * (double)lambda_w_grad);
        this._lambdaW0 = Math.max(0.0f, lambdaW0);
    }

    final void updateLambdaW(@Nonnull Feature[] x, double dloss, float eta) {
        double sumWX = 0.0;
        for (Feature e : x) {
            assert (e != null) : Arrays.toString(x);
            double xi = e.getValue();
            sumWX += (double)this.getW(e) * xi;
        }
        double lambda_w_grad = (double)(-2.0f * eta) * sumWX;
        float lambdaW = this._lambdaW - (float)((double)eta * dloss * lambda_w_grad);
        this._lambdaW = Math.max(0.0f, lambdaW);
    }

    final void updateLambdaV(@Nonnull Feature[] x, double dloss, float eta) {
        int k = this._factor;
        for (int f = 0; f < k; ++f) {
            double sum_f_dash = 0.0;
            double sum_f = 0.0;
            double sum_f_dash_f = 0.0;
            float lambdaVf = this.getLambdaV(f);
            double sumVfX = this.sumVfX(x, f);
            for (Feature e : x) {
                assert (e != null) : Arrays.toString(x);
                double x_j = e.getValue();
                float v_jf = this.getV(e, f);
                double gradV = this.gradV(x_j, v_jf, sumVfX);
                double v_dash = (double)v_jf - (double)eta * (gradV + 2.0 * (double)lambdaVf * (double)v_jf);
                sum_f_dash += x_j * v_dash;
                sum_f += x_j * (double)v_jf;
                sum_f_dash_f += x_j * v_dash * x_j * (double)v_jf;
            }
            double lambda_v_grad = (double)(-2.0f * eta) * (sum_f_dash * sum_f - sum_f_dash_f);
            lambdaVf = (float)((double)lambdaVf - (double)eta * dloss * lambda_v_grad);
            this._lambdaV[f] = Math.max(0.0f, lambdaVf);
        }
    }

    double[] sumVfX(@Nonnull Feature[] x) {
        int k = this._factor;
        double[] ret = new double[k];
        for (int f = 0; f < k; ++f) {
            ret[f] = this.sumVfX(x, f);
        }
        return ret;
    }

    private double sumVfX(@Nonnull Feature[] x, int f) {
        double ret = 0.0;
        for (Feature e : x) {
            double xj = e.getValue();
            float Vjf = this.getV(e, f);
            ret += (double)Vjf * xj;
        }
        if (!NumberUtils.isFinite(ret)) {
            throw new IllegalStateException("Got " + ret + " for sumV[ " + f + "]X.\nx = " + Arrays.toString(x));
        }
        return ret;
    }

    private double gradV(@Nonnull double Xj, float Vjf, double sumVfX) {
        return Xj * (sumVfX - (double)Vjf * Xj);
    }

    public void check(@Nonnull Feature[] x) throws HiveException {
    }

    @Nonnull
    protected final float[] initV() {
        float[] ret = new float[this._factor];
        switch (this._initScheme) {
            case adjustedRandom: {
                FactorizationMachineModel.adjustedRandomFill(ret, this._initScheme.rand[0], this._initScheme.maxInitValue);
                break;
            }
            case libffmRandom: {
                FactorizationMachineModel.libffmRandomFill(ret, this._initScheme.rand[0], this._initScheme.maxInitValue);
                break;
            }
            case random: {
                FactorizationMachineModel.randomFill(ret, this._initScheme.rand[0], this._initScheme.maxInitValue);
                break;
            }
            case gaussian: {
                FactorizationMachineModel.gaussianFill(ret, this._initScheme.rand, this._initScheme.initStdDev);
                break;
            }
            default: {
                throw new IllegalStateException("Unsupported V initialization scheme: " + (Object)((Object)this._initScheme));
            }
        }
        return ret;
    }

    protected static final void adjustedRandomFill(@Nonnull float[] a, @Nonnull Random rand, float maxInitValue) {
        int k = a.length;
        float basev = maxInitValue / (float)k;
        for (int i = 0; i < k; ++i) {
            float v;
            a[i] = v = rand.nextFloat() * basev;
        }
    }

    protected static final void libffmRandomFill(@Nonnull float[] a, @Nonnull Random rand, float maxInitValue) {
        int k = a.length;
        float basev = maxInitValue / (float)Math.sqrt(k);
        for (int i = 0; i < k; ++i) {
            float v;
            a[i] = v = rand.nextFloat() * basev;
        }
    }

    protected static final void randomFill(@Nonnull float[] a, @Nonnull Random rand, float maxInitValue) {
        int k = a.length;
        for (int i = 0; i < k; ++i) {
            float v;
            a[i] = v = rand.nextFloat() * maxInitValue;
        }
    }

    protected static final void gaussianFill(@Nonnull float[] a, @Nonnull Random[] rand, double stddev) {
        int k = a.length;
        for (int i = 0; i < k; ++i) {
            float v;
            a[i] = v = (float)MathUtils.gaussian(0.0, stddev, rand[i]);
        }
    }

    public static enum VInitScheme {
        adjustedRandom,
        libffmRandom,
        random,
        gaussian;

        @Nonnegative
        float maxInitValue;
        @Nonnegative
        double initStdDev;
        Random[] rand;

        @Nonnull
        public static VInitScheme resolve(@Nullable String opt) {
            return VInitScheme.resolve(opt, adjustedRandom);
        }

        @Nonnull
        public static VInitScheme resolve(@Nullable String opt, @Nonnull VInitScheme defaultScheme) {
            if (opt == null) {
                return defaultScheme;
            }
            if ("adjusted_random".equalsIgnoreCase(opt) || "adjustedRandom".equalsIgnoreCase(opt)) {
                return adjustedRandom;
            }
            if ("libffm_random".equalsIgnoreCase(opt) || "libffmRandom".equalsIgnoreCase(opt) || "libffm".equalsIgnoreCase(opt)) {
                return libffmRandom;
            }
            if ("random".equalsIgnoreCase(opt)) {
                return random;
            }
            if ("gaussian".equalsIgnoreCase(opt)) {
                return gaussian;
            }
            return defaultScheme;
        }

        public void setMaxInitValue(float maxInitValue) {
            this.maxInitValue = maxInitValue;
        }

        public void setInitStdDev(double initStdDev) {
            this.initStdDev = initStdDev;
        }

        public void initRandom(int factor, long seed) {
            int size = this != gaussian ? 1 : factor;
            this.rand = new Random[size];
            for (int i = 0; i < size; ++i) {
                this.rand[i] = new Random(seed + (long)i);
            }
        }
    }
}

