/*
 * Decompiled with CFR 0.152.
 */
package hivemall.classifier;

import hivemall.annotations.Experimental;
import hivemall.annotations.VisibleForTesting;
import hivemall.classifier.BinaryOnlineClassifierUDTF;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.optimizer.LossFunctions;
import hivemall.utils.collections.Fastutil;
import hivemall.utils.hashing.HashFunction;
import hivemall.utils.lang.Preconditions;
import it.unimi.dsi.fastutil.ints.Int2FloatMap;
import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;

@Description(name="train_kpa", value="_FUNC_(array<string|int|bigint> features, int label [, const string options]) - returns a relation <h int, hk int, float w0, float w1, float w2, float w3>")
@Experimental
public final class KernelExpansionPassiveAggressiveUDTF
extends BinaryOnlineClassifierUDTF {
    private float _pkc;
    private Algorithm _algo;
    private float _w0;
    private Int2FloatMap _w1;
    private Int2FloatMap _w2;
    private Int2FloatMap _w3;
    private float _loss;

    @VisibleForTesting
    float getLoss() {
        return this._loss;
    }

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("pkc", true, "Constant c inside polynomial kernel K = (dot(xi,xj) + c)^2 [default 1.0]");
        opts.addOption("algo", "algorithm", true, "Algorithm for calculating loss [pa, pa1 (default), pa2]");
        opts.addOption("c", "aggressiveness", true, "Aggressiveness parameter C for PA-1 and PA-2 [default 1.0]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        float pkc = 1.0f;
        float c = 1.0f;
        String algo = "pa1";
        CommandLine cl = super.processOptions(argOIs);
        if (cl != null) {
            String c_str;
            String pkc_str = cl.getOptionValue("pkc");
            if (pkc_str != null) {
                pkc = Float.parseFloat(pkc_str);
            }
            if ((c_str = cl.getOptionValue("c")) != null && (c = Float.parseFloat(c_str)) <= 0.0f) {
                throw new UDFArgumentException("Aggressiveness parameter C must be C > 0: " + c);
            }
            algo = cl.getOptionValue("algo", algo);
        }
        if ("pa1".equalsIgnoreCase(algo)) {
            this._algo = new PA1(c);
        } else if ("pa2".equalsIgnoreCase(algo)) {
            this._algo = new PA2(c);
        } else if ("pa".equalsIgnoreCase(algo)) {
            this._algo = new PA();
        } else {
            throw new UDFArgumentException("Unsupported algorithm: " + algo);
        }
        this._pkc = pkc;
        return cl;
    }

    @Override
    protected PredictionModel createModel() {
        this._w0 = 0.0f;
        this._w1 = new Int2FloatOpenHashMap(16384);
        this._w1.defaultReturnValue(0.0f);
        this._w2 = new Int2FloatOpenHashMap(16384);
        this._w2.defaultReturnValue(0.0f);
        this._w3 = new Int2FloatOpenHashMap(16384);
        this._w3.defaultReturnValue(0.0f);
        return null;
    }

    @Override
    protected StructObjectInspector getReturnOI(ObjectInspector featureRawOI) {
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<Object> fieldOIs = new ArrayList<Object>();
        fieldNames.add("h");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("w0");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        fieldNames.add("w1");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        fieldNames.add("w2");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        fieldNames.add("hk");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("w3");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Override
    @Nullable
    FeatureValue[] parseFeatures(@Nonnull List<?> features) {
        int size = features.size();
        if (size == 0) {
            return null;
        }
        FeatureValue[] featureVector = new FeatureValue[size];
        for (int i = 0; i < size; ++i) {
            FeatureValue fv;
            Object f = features.get(i);
            if (f == null) continue;
            featureVector[i] = fv = FeatureValue.parse(f, true);
        }
        return featureVector;
    }

    @Override
    protected void train(@Nonnull FeatureValue[] features, int label) {
        float loss;
        float y = label > 0 ? 1.0f : -1.0f;
        PredictionResult margin = this.calcScoreWithKernelAndNorm(features);
        float p = margin.getScore();
        this._loss = loss = LossFunctions.hingeLoss(p, y);
        if (loss > 0.0f) {
            this.updateKernel(y, loss, margin, features);
        }
    }

    @Override
    float predict(@Nonnull FeatureValue[] features) {
        float score = 0.0f;
        for (int i = 0; i < features.length; ++i) {
            if (features[i] == null) continue;
            int h = features[i].getFeatureAsInt();
            float w1 = this._w1.get(h);
            float w2 = this._w2.get(h);
            double xi = features[i].getValue();
            double xx = xi * xi;
            score = (float)((double)score + (double)w1 * xi);
            score = (float)((double)score + (double)w2 * xx);
            for (int j = i + 1; j < features.length; ++j) {
                int k = features[j].getFeatureAsInt();
                int hk = HashFunction.hash(h, k, true);
                float w3 = this._w3.get(hk);
                double xj = features[j].getValue();
                score = (float)((double)score + xi * xj * (double)w3);
            }
        }
        return score;
    }

    @Nonnull
    final PredictionResult calcScoreWithKernelAndNorm(@Nonnull FeatureValue[] features) {
        float score = this._w0;
        float norm = 0.0f;
        for (int i = 0; i < features.length; ++i) {
            if (features[i] == null) continue;
            int h = features[i].getFeatureAsInt();
            float w1 = this._w1.get(h);
            float w2 = this._w2.get(h);
            double xi = features[i].getValue();
            double xx = xi * xi;
            score = (float)((double)score + (double)w1 * xi);
            score = (float)((double)score + (double)w2 * xx);
            norm = (float)((double)norm + xx);
            for (int j = i + 1; j < features.length; ++j) {
                int k = features[j].getFeatureAsInt();
                int hk = HashFunction.hash(h, k, true);
                float w3 = this._w3.get(hk);
                double xj = features[j].getValue();
                score = (float)((double)score + xi * xj * (double)w3);
            }
        }
        return new PredictionResult(score).squaredNorm(norm);
    }

    protected void updateKernel(float label, float loss, @Nonnull PredictionResult margin, @Nonnull FeatureValue[] features) {
        float eta = this._algo.eta(loss, margin);
        float coeff = eta * label;
        this.expandKernel(features, coeff);
    }

    private void expandKernel(@Nonnull FeatureValue[] supportVector, float alpha) {
        float pkc = this._pkc;
        this._w0 += alpha * pkc * pkc;
        for (int i = 0; i < supportVector.length; ++i) {
            FeatureValue si = supportVector[i];
            int h = si.getFeatureAsInt();
            float Zih = si.getValueAsFloat();
            float alphaZih = alpha * Zih;
            float alphaZih2 = alphaZih * 2.0f;
            this._w1.put(h, this._w1.get(h) + pkc * alphaZih2);
            this._w2.put(h, this._w2.get(h) + alphaZih * Zih);
            for (int j = i + 1; j < supportVector.length; ++j) {
                FeatureValue sj = supportVector[j];
                int k = sj.getFeatureAsInt();
                int hk = HashFunction.hash(h, k, true);
                float Zjk = sj.getValueAsFloat();
                this._w3.put(hk, this._w3.get(hk) + alphaZih2 * Zjk);
            }
        }
    }

    @Override
    public void close() throws HiveException {
        int k;
        IntWritable h = new IntWritable(0);
        FloatWritable w0 = new FloatWritable(this._w0);
        FloatWritable w1 = new FloatWritable();
        FloatWritable w2 = new FloatWritable();
        IntWritable hk = new IntWritable(0);
        FloatWritable w3 = new FloatWritable();
        Object[] row = new Object[]{h, w0, null, null, null, null};
        this.forward(row);
        row[1] = null;
        row[2] = w1;
        row[3] = w2;
        Int2FloatMap w2map = this._w2;
        for (Int2FloatMap.Entry e : Fastutil.fastIterable(this._w1)) {
            k = e.getIntKey();
            Preconditions.checkArgument(k > 0, HiveException.class);
            h.set(k);
            w1.set(e.getFloatValue());
            w2.set(w2map.get(k));
            this.forward(row);
        }
        this._w1 = null;
        this._w2 = null;
        row[0] = null;
        row[2] = null;
        row[3] = null;
        row[4] = hk;
        row[5] = w3;
        this._w3.int2FloatEntrySet();
        for (Int2FloatMap.Entry e : Fastutil.fastIterable(this._w3)) {
            k = e.getIntKey();
            Preconditions.checkArgument(k > 0, HiveException.class);
            hk.set(k);
            w3.set(e.getFloatValue());
            this.forward(row);
        }
        this._w3 = null;
    }

    static class PA2
    implements Algorithm {
        private final float c;

        PA2(float c) {
            this.c = c;
        }

        @Override
        public float eta(float loss, PredictionResult margin) {
            float squared_norm = margin.getSquaredNorm();
            float eta = loss / (squared_norm + 0.5f / this.c);
            return eta;
        }
    }

    static class PA1
    implements Algorithm {
        private final float c;

        PA1(float c) {
            this.c = c;
        }

        @Override
        public float eta(float loss, PredictionResult margin) {
            float squared_norm = margin.getSquaredNorm();
            float eta = loss / squared_norm;
            return Math.min(this.c, eta);
        }
    }

    static class PA
    implements Algorithm {
        PA() {
        }

        @Override
        public float eta(float loss, PredictionResult margin) {
            return loss / margin.getSquaredNorm();
        }
    }

    static interface Algorithm {
        public float eta(float var1, @Nonnull PredictionResult var2);
    }
}

