/*
 * Decompiled with CFR 0.152.
 */
package hivemall.factorization.fm;

import com.google.common.base.Preconditions;
import hivemall.UDTFWithOptions;
import hivemall.annotations.VisibleForTesting;
import hivemall.common.ConversionState;
import hivemall.factorization.fm.FMArrayModel;
import hivemall.factorization.fm.FMHyperParameters;
import hivemall.factorization.fm.FMIntFeatureMapModel;
import hivemall.factorization.fm.FMStringFeatureMapModel;
import hivemall.factorization.fm.FactorizationMachineModel;
import hivemall.factorization.fm.Feature;
import hivemall.factorization.fm.IntFeature;
import hivemall.factorization.fm.StringFeature;
import hivemall.optimizer.EtaEstimator;
import hivemall.optimizer.LossFunctions;
import hivemall.utils.collections.Fastutil;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.FileUtils;
import hivemall.utils.io.NioStatefulSegment;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.math.MathUtils;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import java.util.Random;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;

@Description(name="train_fm", value="_FUNC_(array<string> x, double y [, const string options]) - Returns a prediction model")
public class FactorizationMachineUDTF
extends UDTFWithOptions {
    private static final Log LOG = LogFactory.getLog(FactorizationMachineUDTF.class);
    protected ListObjectInspector _xOI;
    protected PrimitiveObjectInspector _yOI;
    @Nullable
    protected Feature[] _probes;
    protected FMHyperParameters _params;
    protected boolean _classification;
    protected int _iterations;
    protected int _factors;
    protected boolean _parseFeatureAsInt;
    protected boolean _earlyStopping;
    protected ConversionState _validationState;
    protected boolean _adaptiveRegularization;
    @Nullable
    protected Random _va_rand;
    protected float _validationRatio;
    protected int _validationThreshold;
    protected LossFunctions.LossFunction _lossFunction;
    protected EtaEstimator _etaEstimator;
    protected ConversionState _cvState;
    protected transient FactorizationMachineModel _model;
    protected long _t;
    protected long _numValidations;
    private transient ByteBuffer _inputBuf;
    private transient NioStatefulSegment _fileIO;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("c", "classification", false, "Act as classification");
        opts.addOption("seed", true, "Seed value [default: -1 (random)]");
        opts.addOption("iters", "iterations", true, "The number of iterations [default: 10]");
        opts.addOption("iter", true, "The number of iterations [default: 10]. Note this is alias of `iters` for backward compatibility");
        opts.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]");
        opts.addOption("f", "factors", true, "The number of the latent variables [default: 5]");
        opts.addOption("k", "factor", true, "The number of the latent variables [default: 5] Alias of `-factors` option");
        opts.addOption("sigma", true, "The standard deviation for initializing V [default: 0.1]");
        opts.addOption("lambda0", "lambda", true, "The initial lambda value for regularization [default: 1.0E-4]");
        opts.addOption("lambdaW0", "lambda_w0", true, "The initial lambda value for W0 regularization [default: 1.0E-4]");
        opts.addOption("lambdaWi", "lambda_wi", true, "The initial lambda value for Wi regularization [default: 1.0E-4]");
        opts.addOption("lambdaV", "lambda_v", true, "The initial lambda value for V regularization [default: 1.0E-4]");
        opts.addOption("min", "min_target", true, "The minimum value of target variable");
        opts.addOption("max", "max_target", true, "The maximum value of target variable");
        opts.addOption("eta", true, "The initial learning rate [default: 0.3]");
        opts.addOption("eta0", true, "The initial learning rate [default: 0.1]");
        opts.addOption("t", "total_steps", true, "The total number of training examples");
        opts.addOption("power_t", true, "The exponent for inverse scaling learning rate [default: 0.1]");
        opts.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: OFF]");
        opts.addOption("cv_rate", "convergence_rate", true, "Threshold to determine convergence [default: 0.005]");
        opts.addOption("early_stopping", false, "Stop at the iteration that achieves the best validation on partial samples [default: OFF]");
        opts.addOption("va_ratio", "validation_ratio", true, "Ratio of training data used for validation [default: 0.05f]");
        opts.addOption("va_threshold", "validation_threshold", true, "Threshold to start validation. At least N training examples are used before validation [default: 1000]");
        if (this.isAdaptiveRegularizationSupported()) {
            opts.addOption("adareg", "adaptive_regularization", false, "Whether to enable adaptive regularization [default: OFF]");
        }
        opts.addOption("init_v", true, "Initialization strategy of matrix V [adjusted_random, libffm, random, gaussian](FM default: 'adjusted_random' for regression, 'gaussian' for classification, FFM default: random)");
        opts.addOption("maxval", "max_init_value", true, "The maximum initial value in the matrix V [default: 0.5]");
        opts.addOption("min_init_stddev", true, "The minimum standard deviation of initial matrix V [default: 0.1]");
        opts.addOption("int_feature", "feature_as_integer", false, "Parse a feature as integer [default: OFF]");
        opts.addOption("enable_norm", "l2norm", false, "Enable instance-wise L2 normalization");
        return opts;
    }

    protected boolean isAdaptiveRegularizationSupported() {
        return true;
    }

    @Override
    protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs) throws UDFArgumentException {
        FMHyperParameters params = this._params;
        CommandLine cl = null;
        if (argOIs.length >= 3) {
            String rawArgs = HiveUtils.getConstString(argOIs, 2);
            cl = this.parseOptions(rawArgs);
            params.processOptions(cl);
        }
        this._classification = params.classification;
        this._iterations = params.iters;
        this._factors = params.factors;
        this._parseFeatureAsInt = params.parseFeatureAsInt;
        this._earlyStopping = params.earlyStopping;
        this._adaptiveRegularization = params.adaptiveRegularization;
        if (this._earlyStopping || this._adaptiveRegularization) {
            this._va_rand = new Random(params.seed + 31L);
        }
        this._validationState = new ConversionState();
        this._validationRatio = params.validationRatio;
        this._validationThreshold = params.validationThreshold;
        this._lossFunction = params.classification ? LossFunctions.getLossFunction(LossFunctions.LossType.LogLoss) : LossFunctions.getLossFunction(LossFunctions.LossType.SquaredLoss);
        this._etaEstimator = params.eta;
        this._cvState = new ConversionState(params.conversionCheck, params.convergenceRate);
        return cl;
    }

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 2 && argOIs.length != 3) {
            this.showHelp(String.format("%s takes 2 or 3 arguments: array<string> x, double y [, CONSTANT string options]: %s", ((Object)((Object)this)).getClass().getSimpleName(), Arrays.toString(argOIs)));
        }
        this._xOI = HiveUtils.asListOI(argOIs, 0);
        HiveUtils.validateFeatureOI(this._xOI.getListElementObjectInspector());
        this._yOI = HiveUtils.asDoubleCompatibleOI(argOIs, 1);
        this._params = this.newHyperParameters();
        this.processOptions(argOIs);
        this._model = null;
        this._t = 0L;
        this._numValidations = 0L;
        if (LOG.isInfoEnabled()) {
            LOG.info((Object)this._params);
        }
        return this.getOutputOI(this._params);
    }

    @Nonnull
    protected FMHyperParameters newHyperParameters() {
        return new FMHyperParameters();
    }

    @Nonnull
    protected StructObjectInspector getOutputOI(@Nonnull FMHyperParameters params) {
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<Object> fieldOIs = new ArrayList<Object>();
        fieldNames.add("feature");
        if (params.parseFeatureAsInt) {
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        } else {
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        }
        fieldNames.add("W_i");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        fieldNames.add("V_if");
        fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector((ObjectInspector)PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Nonnull
    protected FactorizationMachineModel initModel(@Nonnull FMHyperParameters params) throws UDFArgumentException {
        FactorizationMachineModel model = params.parseFeatureAsInt ? (params.numFeatures == -1 ? new FMIntFeatureMapModel(params) : new FMArrayModel(params)) : new FMStringFeatureMapModel(params);
        this._model = model;
        return model;
    }

    public void process(Object[] args) throws HiveException {
        Feature[] x;
        if (this._model == null) {
            this._model = this.initModel(this._params);
        }
        if ((x = this.parseFeatures(args[0])) == null) {
            return;
        }
        this._probes = x;
        this._model.check(x);
        double y = PrimitiveObjectInspectorUtils.getDouble((Object)args[1], (PrimitiveObjectInspector)this._yOI);
        if (this._classification) {
            y = y > 0.0 ? 1.0 : -1.0;
        }
        boolean validation = this.isValidationExample();
        this.recordTrain(x, y, validation);
        this.train(x, y, validation);
    }

    private boolean isValidationExample() {
        if (this._va_rand != null && this._t >= (long)this._validationThreshold) {
            return this._va_rand.nextFloat() < this._validationRatio;
        }
        return false;
    }

    @Nullable
    protected Feature[] parseFeatures(@Nonnull Object arg) throws HiveException {
        Feature[] features = Feature.parseFeatures(arg, this._xOI, this._probes, this._parseFeatureAsInt);
        if (this._params.l2norm) {
            Feature.l2normalize(features);
        }
        return features;
    }

    private void recordTrain(@Nonnull Feature[] x, double y, boolean validation) throws HiveException {
        if (this._iterations <= 1) {
            return;
        }
        ByteBuffer inputBuf = this._inputBuf;
        NioStatefulSegment dst = this._fileIO;
        if (inputBuf == null) {
            File file;
            try {
                file = File.createTempFile("hivemall_fm", ".sgmt");
                file.deleteOnExit();
                if (!file.canWrite()) {
                    throw new UDFArgumentException("Cannot write a temporary file: " + file.getAbsolutePath());
                }
                LOG.info((Object)("Record training examples to a file: " + file.getAbsolutePath()));
            }
            catch (IOException ioe) {
                throw new UDFArgumentException((Throwable)ioe);
            }
            catch (Throwable e) {
                throw new UDFArgumentException(e);
            }
            this._inputBuf = inputBuf = ByteBuffer.allocateDirect(0x100000);
            this._fileIO = dst = new NioStatefulSegment(file, false);
        }
        int xBytes = Feature.requiredBytes(x);
        int recordBytes = 12 + xBytes + 1;
        int requiredBytes = 4 + recordBytes;
        int remain = inputBuf.remaining();
        if (remain < requiredBytes) {
            FactorizationMachineUDTF.writeBuffer(inputBuf, dst);
        }
        inputBuf.putInt(recordBytes);
        inputBuf.putInt(x.length);
        for (Feature f : x) {
            f.writeTo(inputBuf);
        }
        inputBuf.putDouble(y);
        if (validation) {
            ++this._numValidations;
            inputBuf.put(Primitives.TRUE_BYTE);
        } else {
            inputBuf.put(Primitives.FALSE_BYTE);
        }
    }

    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefulSegment dst) throws HiveException {
        srcBuf.flip();
        try {
            dst.write(srcBuf);
        }
        catch (IOException e) {
            throw new HiveException("Exception causes while writing a buffer to file", (Throwable)e);
        }
        srcBuf.clear();
    }

    private void train(@Nonnull Feature[] x, double y, boolean validation) throws HiveException {
        try {
            if (validation) {
                this.processValidationSample(x, y);
            } else {
                ++this._t;
                this.trainTheta(x, y);
            }
        }
        catch (Exception ex) {
            throw new HiveException("Exception caused in the " + this._t + "-th call of train()", (Throwable)ex);
        }
    }

    protected void processValidationSample(@Nonnull Feature[] x, double y) throws HiveException {
        if (this._earlyStopping) {
            double p = this._model.predict(x);
            double loss = this._lossFunction.loss(p, y);
            this._validationState.incrLoss(loss);
        }
        if (this._adaptiveRegularization) {
            this.trainLambda(x, y);
        }
    }

    protected void trainTheta(Feature[] x, double y) throws HiveException {
        float eta = this._etaEstimator.eta(this._t);
        double p = this._model.predict(x);
        double lossGrad = this._model.dloss(p, y);
        double loss = this._lossFunction.loss(p, y);
        this._cvState.incrLoss(loss);
        if (MathUtils.closeToZero(lossGrad, 1.0E-9)) {
            return;
        }
        this._model.updateW0(lossGrad, eta);
        double[] sumVfx = this._model.sumVfX(x);
        for (Feature xi : x) {
            this._model.updateWi(lossGrad, xi, eta);
            int k = this._factors;
            for (int f = 0; f < k; ++f) {
                this._model.updateV(lossGrad, xi, f, sumVfx[f], eta);
            }
        }
    }

    private void trainLambda(Feature[] x, double y) throws HiveException {
        float eta = this._etaEstimator.eta(this._t);
        double p = this._model.predict(x);
        double lossGrad = this._model.dloss(p, y);
        this._model.updateLambdaW0(lossGrad, eta);
        this._model.updateLambdaW(x, lossGrad, eta);
        this._model.updateLambdaV(x, lossGrad, eta);
    }

    public void close() throws HiveException {
        int P;
        this._probes = null;
        if (this._t == 0L) {
            this._model = null;
            return;
        }
        if (this._iterations > 1) {
            this.runTrainingIteration(this._iterations);
        }
        if ((P = this._model.getSize()) <= 0) {
            LOG.warn((Object)("Model size P was less than zero: " + P));
            this._model = null;
            return;
        }
        this.forwardModel();
        this._model = null;
    }

    @VisibleForTesting
    void finalizeTraining() throws HiveException {
        if (this._iterations > 1) {
            this.runTrainingIteration(this._iterations);
        }
    }

    protected void forwardModel() throws HiveException {
        if (this._parseFeatureAsInt) {
            this.forwardAsIntFeature(this._model, this._factors);
        } else {
            FMStringFeatureMapModel strModel = (FMStringFeatureMapModel)this._model;
            this.forwardAsStringFeature(strModel, this._factors);
        }
    }

    private void forwardAsIntFeature(@Nonnull FactorizationMachineModel model, int factors) throws HiveException {
        IntWritable f_idx = new IntWritable(0);
        FloatWritable f_Wi = new FloatWritable(0.0f);
        FloatWritable[] f_Vi = HiveUtils.newFloatArray(factors, 0.0f);
        Object[] forwardObjs = new Object[]{f_idx, f_Wi, null};
        f_idx.set(0);
        f_Wi.set(model.getW0());
        this.forward(forwardObjs);
        forwardObjs[2] = Arrays.asList(f_Vi);
        int maxIdx = model.getMaxIndex();
        for (int i = model.getMinIndex(); i <= maxIdx; ++i) {
            float[] vi = model.getV(i, false);
            if (vi == null) continue;
            f_idx.set(i);
            float w = model.getW(i);
            f_Wi.set(w);
            for (int f = 0; f < factors; ++f) {
                float v = vi[f];
                f_Vi[f].set(v);
            }
            this.forward(forwardObjs);
        }
    }

    private void forwardAsStringFeature(@Nonnull FMStringFeatureMapModel model, int factors) throws HiveException {
        Text feature = new Text();
        FloatWritable f_Wi = new FloatWritable(0.0f);
        FloatWritable[] f_Vi = HiveUtils.newFloatArray(factors, 0.0f);
        Object[] forwardObjs = new Object[]{feature, f_Wi, null};
        feature.set("0");
        f_Wi.set(model.getW0());
        this.forward(forwardObjs);
        forwardObjs[2] = Arrays.asList(f_Vi);
        for (Map.Entry entry : Fastutil.fastIterable(model.getMap())) {
            String i = (String)entry.getKey();
            assert (i != null);
            feature.set(i);
            FMStringFeatureMapModel.Entry entry2 = (FMStringFeatureMapModel.Entry)entry.getValue();
            f_Wi.set(entry2.W);
            float[] Vi = entry2.Vf;
            for (int f = 0; f < factors; ++f) {
                float v = Vi[f];
                f_Vi[f].set(v);
            }
            this.forward(forwardObjs);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void runTrainingIteration(int iterations) throws HiveException {
        block29: {
            ByteBuffer inputBuf = (ByteBuffer)Preconditions.checkNotNull((Object)this._inputBuf);
            NioStatefulSegment fileIO = (NioStatefulSegment)Preconditions.checkNotNull((Object)this._fileIO);
            long numTrainingExamples = this._t;
            boolean lossIncreasedLastIter = false;
            Reporter reporter = this.getReporter();
            Counters.Counter iterCounter = reporter == null ? null : reporter.getCounter("hivemall.factorization.fm.FactorizationMachines$Counter", "iteration");
            try {
                int iter;
                if (fileIO.getPosition() == 0L) {
                    if (inputBuf.position() == 0) {
                        return;
                    }
                    inputBuf.flip();
                    for (iter = 2; iter <= iterations; ++iter) {
                        this._validationState.next();
                        this._cvState.next();
                        FactorizationMachineUDTF.reportProgress(reporter);
                        FactorizationMachineUDTF.setCounterValue(iterCounter, iter);
                        while (inputBuf.remaining() > 0) {
                            int bytes = inputBuf.getInt();
                            assert (bytes > 0) : bytes;
                            int xLength = inputBuf.getInt();
                            Feature[] x = new Feature[xLength];
                            for (int j = 0; j < xLength; ++j) {
                                x[j] = this.instantiateFeature(inputBuf);
                            }
                            double y = inputBuf.getDouble();
                            boolean validation = inputBuf.get() == Primitives.TRUE_BYTE.byteValue();
                            this.train(x, y, validation);
                        }
                        boolean lossIncreased = this._validationState.isLossIncreased();
                        if (lossIncreasedLastIter && lossIncreased || this._cvState.isConverged(numTrainingExamples)) break;
                        lossIncreasedLastIter = lossIncreased;
                        inputBuf.rewind();
                    }
                    LOG.info((Object)("Performed " + this._cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(this._t) + " training updates in total), used " + this._numValidations + " validation examples"));
                    break block29;
                }
                if (inputBuf.remaining() > 0) {
                    FactorizationMachineUDTF.writeBuffer(inputBuf, fileIO);
                }
                try {
                    fileIO.flush();
                }
                catch (IOException e) {
                    throw new HiveException("Failed to flush a file: " + fileIO.getFile().getAbsolutePath(), (Throwable)e);
                }
                if (LOG.isInfoEnabled()) {
                    File tmpFile = fileIO.getFile();
                    LOG.info((Object)("Wrote " + numTrainingExamples + " records to a temporary file for iterative training: " + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + ")"));
                }
                for (iter = 2; iter <= iterations; ++iter) {
                    this._validationState.next();
                    this._cvState.next();
                    FactorizationMachineUDTF.setCounterValue(iterCounter, iter);
                    inputBuf.clear();
                    fileIO.resetPosition();
                    while (true) {
                        int recordBytes;
                        int remain;
                        int bytesRead;
                        FactorizationMachineUDTF.reportProgress(reporter);
                        try {
                            bytesRead = fileIO.read(inputBuf);
                        }
                        catch (IOException e) {
                            throw new HiveException("Failed to read a file: " + fileIO.getFile().getAbsolutePath(), (Throwable)e);
                        }
                        if (bytesRead == 0) break;
                        assert (bytesRead > 0) : bytesRead;
                        inputBuf.flip();
                        if (remain < 4) {
                            throw new HiveException("Illegal file format was detected");
                        }
                        for (remain = inputBuf.remaining(); remain >= 4; remain -= recordBytes) {
                            int pos = inputBuf.position();
                            recordBytes = inputBuf.getInt();
                            if ((remain -= 4) < recordBytes) {
                                inputBuf.position(pos);
                                break;
                            }
                            int xLength = inputBuf.getInt();
                            Feature[] x = new Feature[xLength];
                            for (int j = 0; j < xLength; ++j) {
                                x[j] = this.instantiateFeature(inputBuf);
                            }
                            double y = inputBuf.getDouble();
                            boolean validation = inputBuf.get() == Primitives.TRUE_BYTE.byteValue();
                            this.train(x, y, validation);
                        }
                        inputBuf.compact();
                    }
                    boolean lossIncreased = this._validationState.isLossIncreased();
                    if (lossIncreasedLastIter && lossIncreased || this._cvState.isConverged(numTrainingExamples)) break;
                    lossIncreasedLastIter = lossIncreased;
                }
                LOG.info((Object)("Performed " + this._cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on a secondary storage (thus " + NumberUtils.formatNumber(this._t) + " training updates in total), used " + this._numValidations + " validation examples"));
            }
            finally {
                try {
                    fileIO.close(true);
                }
                catch (IOException e) {
                    throw new HiveException("Failed to close a file: " + fileIO.getFile().getAbsolutePath(), (Throwable)e);
                }
                this._inputBuf = null;
                this._fileIO = null;
            }
        }
    }

    @Nonnull
    protected Feature instantiateFeature(@Nonnull ByteBuffer input) {
        if (this._parseFeatureAsInt) {
            return new IntFeature(input);
        }
        return new StringFeature(input);
    }
}

