/*
 * Decompiled with CFR 0.152.
 */
package hivemall;

import hivemall.LearnerBaseUDTF;
import hivemall.annotations.VisibleForTesting;
import hivemall.common.ConversionState;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionModel;
import hivemall.model.WeightValue;
import hivemall.optimizer.LossFunctions;
import hivemall.optimizer.Optimizer;
import hivemall.optimizer.OptimizerOptions;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.FileUtils;
import hivemall.utils.io.NIOUtils;
import hivemall.utils.io.NioStatefulSegment;
import hivemall.utils.lang.FloatAccumulator;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaIntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;

public abstract class GeneralLearnerBaseUDTF
extends LearnerBaseUDTF {
    private static final Log logger = LogFactory.getLog(GeneralLearnerBaseUDTF.class);
    private static final float MAX_DLOSS = 1.0E12f;
    private static final float MIN_DLOSS = -1.0E12f;
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector targetOI;
    private FeatureType featureType;
    @Nonnull
    private final Map<String, String> optimizerOptions = OptimizerOptions.create();
    private Optimizer optimizer;
    private LossFunctions.LossFunction lossFunction;
    private PredictionModel model;
    private long count;
    @Nullable
    private transient Map<Object, FloatAccumulator> accumulated;
    private int sampled;
    @Nullable
    protected transient NioStatefulSegment fileIO;
    @Nullable
    protected transient ByteBuffer inputBuf;
    private int iterations;
    protected ConversionState cvState;

    public GeneralLearnerBaseUDTF() {
        this(true);
    }

    public GeneralLearnerBaseUDTF(boolean enableNewModel) {
        super(enableNewModel);
    }

    @Nonnull
    protected abstract String getLossOptionDescription();

    @Nonnull
    protected abstract LossFunctions.LossType getDefaultLossType();

    protected abstract void checkLossFunction(@Nonnull LossFunctions.LossFunction var1) throws UDFArgumentException;

    protected abstract void checkTargetValue(float var1) throws UDFArgumentException;

    protected abstract void train(@Nonnull FeatureValue[] var1, float var2);

    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length < 2) {
            this.showHelp("_FUNC_ takes two or three arguments: List<Int|BigInt|Text> features, float target [, constant string options]");
        }
        this.featureListOI = HiveUtils.asListOI(argOIs, 0);
        this.featureType = GeneralLearnerBaseUDTF.getFeatureType(this.featureListOI);
        this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs, 1);
        this.processOptions(argOIs);
        this.model = this.createModel();
        try {
            this.optimizer = this.createOptimizer(this.optimizerOptions);
        }
        catch (Throwable e) {
            throw new UDFArgumentException(e);
        }
        this.count = 0L;
        this.sampled = 0;
        return this.getReturnOI(this.getFeatureOutputOI(this.featureType));
    }

    @Override
    protected Options getOptions() {
        Options opts = super.getOptions();
        opts.addOption("inspect_opts", false, "Inspect Optimizer options");
        opts.addOption("loss", "loss_function", true, this.getLossOptionDescription());
        opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]");
        opts.addOption("iters", "iterations", true, "The maximum number of iterations [default: 10]");
        opts.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: OFF]");
        opts.addOption("cv_rate", "convergence_rate", true, "Threshold to determine convergence [default: 0.005]");
        OptimizerOptions.setup(opts);
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = super.processOptions(argOIs);
        LossFunctions.LossFunction lossFunction = LossFunctions.getLossFunction(this.getDefaultLossType());
        int iterations = 10;
        boolean conversionCheck = true;
        double convergenceRate = 0.005;
        if (cl != null) {
            if (cl.hasOption("loss_function")) {
                try {
                    lossFunction = LossFunctions.getLossFunction(cl.getOptionValue("loss_function"));
                }
                catch (Throwable e) {
                    throw new UDFArgumentException(e.getMessage());
                }
            }
            this.checkLossFunction(lossFunction);
            iterations = Primitives.parseInt(cl.getOptionValue("iterations"), iterations);
            if (iterations < 1) {
                throw new UDFArgumentException("'-iterations' must be greater than or equals to 1: " + iterations);
            }
            conversionCheck = !cl.hasOption("disable_cvtest");
            convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate);
        }
        this.lossFunction = lossFunction;
        this.iterations = iterations;
        this.cvState = new ConversionState(conversionCheck, convergenceRate);
        OptimizerOptions.processOptions(cl, this.optimizerOptions);
        if (cl != null && cl.hasOption("inspect_opts")) {
            Optimizer optimizer = this.createOptimizer(this.optimizerOptions);
            Map<String, Object> params = optimizer.getHyperParameters();
            params.put("loss_function", lossFunction.getType().toString());
            params.put("iterations", iterations);
            params.put("disable_cvtest", !conversionCheck);
            params.put("cv_rate", convergenceRate);
            throw new UDFArgumentException(String.format("Inspected Optimizer options ...\n%s", params.toString()));
        }
        return cl;
    }

    @Nonnull
    private static FeatureType getFeatureType(@Nonnull ListObjectInspector featureListOI) throws UDFArgumentException {
        ObjectInspector featureOI = featureListOI.getListElementObjectInspector();
        if (featureOI instanceof StringObjectInspector) {
            return FeatureType.STRING;
        }
        if (featureOI instanceof IntObjectInspector) {
            return FeatureType.INT;
        }
        if (featureOI instanceof LongObjectInspector) {
            return FeatureType.LONG;
        }
        throw new UDFArgumentException("Feature object inspector must be one of [StringObjectInspector, IntObjectInspector, LongObjectInspector]: " + featureOI.toString());
    }

    @Nonnull
    protected final ObjectInspector getFeatureOutputOI(@Nonnull FeatureType featureType) throws UDFArgumentException {
        JavaIntObjectInspector outputOI;
        if (this.dense_model) {
            outputOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        } else {
            switch (featureType) {
                case STRING: {
                    outputOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
                    break;
                }
                case INT: {
                    outputOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
                    break;
                }
                case LONG: {
                    outputOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector;
                    break;
                }
                default: {
                    throw new IllegalStateException("Unexpected feature type: " + (Object)((Object)featureType));
                }
            }
        }
        return outputOI;
    }

    @Nonnull
    protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureOutputOI) {
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<Object> fieldOIs = new ArrayList<Object>();
        fieldNames.add("feature");
        fieldOIs.add(featureOutputOI);
        fieldNames.add("weight");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        if (this.useCovariance()) {
            fieldNames.add("covar");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    public void process(Object[] args) throws HiveException {
        List features;
        FeatureValue[] featureVector;
        if (this.is_mini_batch && this.accumulated == null) {
            this.accumulated = new HashMap<Object, FloatAccumulator>(1024);
        }
        if ((featureVector = this.parseFeatures(features = this.featureListOI.getList(args[0]))) == null) {
            return;
        }
        float target = PrimitiveObjectInspectorUtils.getFloat((Object)args[1], (PrimitiveObjectInspector)this.targetOI);
        this.checkTargetValue(target);
        ++this.count;
        this.train(featureVector, target);
        this.recordTrainSampleToTempFile(featureVector, target);
    }

    protected void recordTrainSampleToTempFile(@Nonnull FeatureValue[] featureVector, float target) throws HiveException {
        if (this.iterations == 1) {
            return;
        }
        ByteBuffer buf = this.inputBuf;
        NioStatefulSegment dst = this.fileIO;
        if (buf == null) {
            File file;
            try {
                file = File.createTempFile("hivemall_general_learner", ".sgmt");
                file.deleteOnExit();
                if (!file.canWrite()) {
                    throw new UDFArgumentException("Cannot write a temporary file: " + file.getAbsolutePath());
                }
                logger.info((Object)("Record training samples to a file: " + file.getAbsolutePath()));
            }
            catch (IOException ioe) {
                throw new UDFArgumentException((Throwable)ioe);
            }
            catch (Throwable e) {
                throw new UDFArgumentException(e);
            }
            this.inputBuf = buf = ByteBuffer.allocateDirect(0x200000);
            this.fileIO = dst = new NioStatefulSegment(file, false);
        }
        int featureVectorBytes = 0;
        for (FeatureValue f : featureVector) {
            if (f == null) continue;
            int featureLength = f.getFeatureAsString().length();
            featureVectorBytes += 2 * featureLength;
            featureVectorBytes += 4;
            featureVectorBytes += 8;
        }
        int recordBytes = 4 + featureVectorBytes + 4;
        int requiredBytes = 4 + recordBytes;
        int remain = buf.remaining();
        if (remain < requiredBytes) {
            GeneralLearnerBaseUDTF.writeBuffer(buf, dst);
        }
        if (requiredBytes > buf.remaining()) {
            throw new HiveException("Buffer size (2MB) for writing training example is not enough: " + NumberUtils.prettySize(requiredBytes));
        }
        buf.putInt(recordBytes);
        buf.putInt(featureVector.length);
        for (FeatureValue f : featureVector) {
            GeneralLearnerBaseUDTF.writeFeatureValue(buf, f);
        }
        buf.putFloat(target);
    }

    private static void writeFeatureValue(@Nonnull ByteBuffer buf, @Nonnull FeatureValue f) {
        NIOUtils.putString(f.getFeatureAsString(), buf);
        buf.putDouble(f.getValue());
    }

    @Nonnull
    private static FeatureValue readFeatureValue(@Nonnull ByteBuffer buf, @Nonnull FeatureType featureType) {
        Object feature;
        String featureStr = NIOUtils.getString(buf);
        switch (featureType) {
            case STRING: {
                feature = featureStr;
                break;
            }
            case INT: {
                feature = Integer.valueOf(featureStr);
                break;
            }
            case LONG: {
                feature = Long.valueOf(featureStr);
                break;
            }
            default: {
                throw new IllegalStateException("Unexpected feature type " + (Object)((Object)featureType) + " for feature: " + featureStr);
            }
        }
        double value = buf.getDouble();
        return new FeatureValue(feature, value);
    }

    @Nullable
    public final FeatureValue[] parseFeatures(@Nonnull List<?> features) {
        int size = features.size();
        if (size == 0) {
            return null;
        }
        ObjectInspector featureInspector = this.featureListOI.getListElementObjectInspector();
        FeatureValue[] featureVector = new FeatureValue[size];
        for (int i = 0; i < size; ++i) {
            FeatureValue fv;
            Object f = features.get(i);
            if (f == null) continue;
            if (this.featureType == FeatureType.STRING) {
                String s = f.toString();
                fv = FeatureValue.parseFeatureAsString(s);
            } else {
                Object k = ObjectInspectorUtils.copyToStandardObject(f, (ObjectInspector)featureInspector, (ObjectInspectorUtils.ObjectInspectorCopyOption)ObjectInspectorUtils.ObjectInspectorCopyOption.JAVA);
                fv = new FeatureValue(k, 1.0f);
            }
            featureVector[i] = fv;
        }
        return featureVector;
    }

    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefulSegment dst) throws HiveException {
        srcBuf.flip();
        try {
            dst.write(srcBuf);
        }
        catch (IOException e) {
            throw new HiveException("Exception causes while writing a buffer to file", (Throwable)e);
        }
        srcBuf.clear();
    }

    public float predict(@Nonnull FeatureValue[] features) {
        float score = 0.0f;
        for (FeatureValue f : features) {
            if (f == null) continue;
            Object k = f.getFeature();
            float v = f.getValueAsFloat();
            float old_w = this.model.getWeight(k);
            if (old_w == 0.0f) continue;
            score += old_w * v;
        }
        return score;
    }

    protected void update(@Nonnull FeatureValue[] features, float target, float predicted) {
        this.optimizer.proceedStep();
        float loss = this.lossFunction.loss(predicted, target);
        this.cvState.incrLoss(loss);
        float dloss = this.lossFunction.dloss(predicted, target);
        if (dloss == 0.0f) {
            return;
        }
        if (dloss < -1.0E12f) {
            dloss = -1.0E12f;
        } else if (dloss > 1.0E12f) {
            dloss = 1.0E12f;
        }
        if (this.is_mini_batch) {
            this.accumulateUpdate(features, loss, dloss);
            if (this.sampled >= this.mini_batch_size) {
                this.batchUpdate();
            }
        } else {
            this.onlineUpdate(features, loss, dloss);
        }
    }

    protected void accumulateUpdate(@Nonnull FeatureValue[] features, float loss, float dloss) {
        for (FeatureValue f : features) {
            Object feature = f.getFeature();
            float xi = f.getValueAsFloat();
            float weight = this.model.getWeight(feature);
            float gradient = dloss * xi;
            float new_weight = this.optimizer.update(feature, weight, loss, gradient);
            FloatAccumulator acc = this.accumulated.get(feature);
            if (acc == null) {
                acc = new FloatAccumulator(new_weight);
                this.accumulated.put(feature, acc);
                continue;
            }
            acc.add(new_weight);
        }
        ++this.sampled;
    }

    protected void batchUpdate() {
        if (this.accumulated.isEmpty()) {
            this.sampled = 0;
            return;
        }
        for (Map.Entry<Object, FloatAccumulator> e : this.accumulated.entrySet()) {
            Object feature = e.getKey();
            FloatAccumulator v = e.getValue();
            float new_weight = v.get();
            if (new_weight == 0.0f) {
                this.model.delete(feature);
                continue;
            }
            this.model.setWeight(feature, new_weight);
        }
        this.accumulated.clear();
        this.sampled = 0;
    }

    protected void onlineUpdate(@Nonnull FeatureValue[] features, float loss, float dloss) {
        for (FeatureValue f : features) {
            float gradient;
            Object feature = f.getFeature();
            float xi = f.getValueAsFloat();
            float weight = this.model.getWeight(feature);
            float new_weight = this.optimizer.update(feature, weight, loss, gradient = dloss * xi);
            if (new_weight == 0.0f) {
                this.model.delete(feature);
                continue;
            }
            this.model.setWeight(feature, new_weight);
        }
    }

    @Override
    public final void close() throws HiveException {
        super.close();
        this.finalizeTraining();
        this.forwardModel();
        this.accumulated = null;
        this.model = null;
    }

    @VisibleForTesting
    public void finalizeTraining() throws HiveException {
        if (this.count == 0L) {
            this.model = null;
            return;
        }
        if (this.is_mini_batch) {
            this.batchUpdate();
        }
        if (this.iterations > 1) {
            this.runIterativeTraining(this.iterations);
        }
    }

    protected final void runIterativeTraining(@Nonnegative int iterations) throws HiveException {
        block35: {
            ByteBuffer buf = this.inputBuf;
            NioStatefulSegment dst = this.fileIO;
            assert (buf != null);
            assert (dst != null);
            long numTrainingExamples = this.count;
            Reporter reporter = this.getReporter();
            Counters.Counter iterCounter = reporter == null ? null : reporter.getCounter("hivemall.GeneralLearnerBase$Counter", "iteration");
            try {
                int iter;
                if (dst.getPosition() == 0L) {
                    if (buf.position() == 0) {
                        return;
                    }
                    buf.flip();
                    for (iter = 2; iter <= iterations; ++iter) {
                        this.cvState.next();
                        GeneralLearnerBaseUDTF.reportProgress(reporter);
                        GeneralLearnerBaseUDTF.setCounterValue(iterCounter, iter);
                        while (buf.remaining() > 0) {
                            int recordBytes = buf.getInt();
                            assert (recordBytes > 0) : recordBytes;
                            int featureVectorLength = buf.getInt();
                            FeatureValue[] featureVector = new FeatureValue[featureVectorLength];
                            for (int j = 0; j < featureVectorLength; ++j) {
                                featureVector[j] = GeneralLearnerBaseUDTF.readFeatureValue(buf, this.featureType);
                            }
                            float target = buf.getFloat();
                            this.train(featureVector, target);
                        }
                        buf.rewind();
                        if (this.is_mini_batch) {
                            this.batchUpdate();
                        }
                        if (this.cvState.isConverged(numTrainingExamples)) break;
                    }
                    logger.info((Object)("Performed " + this.cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on memory (thus " + NumberUtils.formatNumber(numTrainingExamples * (long)this.cvState.getCurrentIteration()) + " training updates in total) "));
                    break block35;
                }
                if (buf.remaining() > 0) {
                    GeneralLearnerBaseUDTF.writeBuffer(buf, dst);
                }
                try {
                    dst.flush();
                }
                catch (IOException e) {
                    throw new HiveException("Failed to flush a file: " + dst.getFile().getAbsolutePath(), (Throwable)e);
                }
                if (logger.isInfoEnabled()) {
                    File tmpFile = dst.getFile();
                    logger.info((Object)("Wrote " + numTrainingExamples + " records to a temporary file for iterative training: " + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + ")"));
                }
                for (iter = 2; iter <= iterations; ++iter) {
                    this.cvState.next();
                    GeneralLearnerBaseUDTF.setCounterValue(iterCounter, iter);
                    buf.clear();
                    dst.resetPosition();
                    while (true) {
                        int recordBytes;
                        int remain;
                        int bytesRead;
                        GeneralLearnerBaseUDTF.reportProgress(reporter);
                        try {
                            bytesRead = dst.read(buf);
                        }
                        catch (IOException e) {
                            throw new HiveException("Failed to read a file: " + dst.getFile().getAbsolutePath(), (Throwable)e);
                        }
                        if (bytesRead == 0) break;
                        assert (bytesRead > 0) : bytesRead;
                        buf.flip();
                        if (remain < 4) {
                            throw new HiveException("Illegal file format was detected");
                        }
                        for (remain = buf.remaining(); remain >= 4; remain -= recordBytes) {
                            int pos = buf.position();
                            recordBytes = buf.getInt();
                            if ((remain -= 4) < recordBytes) {
                                buf.position(pos);
                                break;
                            }
                            int featureVectorLength = buf.getInt();
                            FeatureValue[] featureVector = new FeatureValue[featureVectorLength];
                            for (int j = 0; j < featureVectorLength; ++j) {
                                featureVector[j] = GeneralLearnerBaseUDTF.readFeatureValue(buf, this.featureType);
                            }
                            float target = buf.getFloat();
                            this.train(featureVector, target);
                        }
                        buf.compact();
                    }
                    if (this.is_mini_batch) {
                        this.batchUpdate();
                    }
                    if (this.cvState.isConverged(numTrainingExamples)) break;
                }
                logger.info((Object)("Performed " + this.cvState.getCurrentIteration() + " iterations of " + NumberUtils.formatNumber(numTrainingExamples) + " training examples on a secondary storage (thus " + NumberUtils.formatNumber(numTrainingExamples * (long)this.cvState.getCurrentIteration()) + " training updates in total)"));
            }
            catch (Throwable e) {
                throw new HiveException("Exception caused in the iterative training", e);
            }
            finally {
                try {
                    dst.close(true);
                }
                catch (IOException e) {
                    throw new HiveException("Failed to close a file: " + dst.getFile().getAbsolutePath(), (Throwable)e);
                }
                this.inputBuf = null;
                this.fileIO = null;
            }
        }
    }

    protected void forwardModel() throws HiveException {
        WeightValue probe;
        int numForwarded = 0;
        if (this.useCovariance()) {
            probe = new WeightValue.WeightValueWithCovar();
            Object[] forwardMapObj = new Object[3];
            FloatWritable fv = new FloatWritable();
            FloatWritable cov = new FloatWritable();
            IMapIterator itor = this.model.entries();
            while (itor.next() != -1) {
                itor.getValue(probe);
                if (!probe.isTouched()) continue;
                float v = probe.get();
                float cv = ((WeightValue.WeightValueWithCovar)probe).getCovariance();
                if (v == 0.0f && cv == 0.0f) continue;
                fv.set(v);
                cov.set(cv);
                Object k = itor.getKey();
                forwardMapObj[0] = k;
                forwardMapObj[1] = fv;
                forwardMapObj[2] = cov;
                this.forward(forwardMapObj);
                ++numForwarded;
            }
        } else {
            probe = new WeightValue();
            Object[] forwardMapObj = new Object[2];
            FloatWritable fv = new FloatWritable();
            IMapIterator itor = this.model.entries();
            while (itor.next() != -1) {
                float v;
                itor.getValue(probe);
                if (!probe.isTouched() || (v = probe.get()) == 0.0f) continue;
                fv.set(v);
                Object k = itor.getKey();
                forwardMapObj[0] = k;
                forwardMapObj[1] = fv;
                this.forward(forwardMapObj);
                ++numForwarded;
            }
        }
        long numMixed = this.model.getNumMixed();
        logger.info((Object)("Trained a prediction model using " + this.count + " training examples" + (numMixed > 0L ? "( numMixed: " + numMixed + " )" : "")));
        logger.info((Object)("Forwarded the prediction model of " + numForwarded + " rows"));
    }

    @VisibleForTesting
    public double getCumulativeLoss() {
        return this.cvState == null ? Double.NaN : this.cvState.getCumulativeLoss();
    }

    public static enum FeatureType {
        STRING,
        INT,
        LONG;

    }
}

