/*
 * Decompiled with CFR 0.152.
 */
package hivemall.regression;

import hivemall.optimizer.EtaEstimator;
import hivemall.optimizer.LossFunctions;
import hivemall.regression.RegressionBaseUDTF;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;

@Deprecated
@Description(name="logress", value="_FUNC_(array<int|bigint|string> features, float target [, constant string options]) - Returns a relation consists of <{int|bigint|string} feature, float weight>")
public final class LogressUDTF
extends RegressionBaseUDTF {
    private EtaEstimator etaEstimator;

    @Override
    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        int numArgs = argOIs.length;
        if (numArgs != 2 && numArgs != 3) {
            throw new UDFArgumentException("LogressUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, float target [, constant string options]");
        }
        return super.initialize(argOIs);
    }

    @Override
    protected Options getOptions() {
        Options opts = super.getOptions();
        opts.addOption("t", "total_steps", true, "a total of n_samples * epochs time steps");
        opts.addOption("power_t", true, "The exponent for inverse scaling learning rate [default: 0.1]");
        opts.addOption("eta0", true, "The initial learning rate [default: 0.1]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = super.processOptions(argOIs);
        this.etaEstimator = EtaEstimator.get(cl);
        return cl;
    }

    @Override
    protected void checkTargetValue(float target) throws UDFArgumentException {
        if (target < 0.0f || target > 1.0f) {
            throw new UDFArgumentException("target must be in range 0 to 1: " + target);
        }
    }

    @Override
    protected float computeGradient(float target, float predicted) {
        float eta = this.etaEstimator.eta(this.count);
        float gradient = LossFunctions.logisticLoss(target, predicted);
        return eta * gradient;
    }
}

