/*
 * Decompiled with CFR 0.152.
 */
package hivemall.regression;

import hivemall.GeneralLearnerBaseUDTF;
import hivemall.model.FeatureValue;
import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;

@Description(name="train_regressor", value="_FUNC_(list<string|int|bigint> features, double label [, const string options]) - Returns a relation consists of <string|int|bigint feature, float weight>", extended="Build a prediction model by a generic regressor")
public final class GeneralRegressorUDTF
extends GeneralLearnerBaseUDTF {
    @Override
    protected String getLossOptionDescription() {
        return "Loss function [SquaredLoss (default), QuantileLoss, EpsilonInsensitiveLoss, SquaredEpsilonInsensitiveLoss, HuberLoss]";
    }

    @Override
    protected LossFunctions.LossType getDefaultLossType() {
        return LossFunctions.LossType.SquaredLoss;
    }

    @Override
    protected void checkLossFunction(@Nonnull LossFunctions.LossFunction lossFunction) throws UDFArgumentException {
        if (!lossFunction.forRegression()) {
            throw new UDFArgumentException("The loss function `" + (Object)((Object)lossFunction.getType()) + "` is not designed for regression");
        }
    }

    @Override
    protected void checkTargetValue(float label) throws UDFArgumentException {
    }

    @Override
    protected void train(@Nonnull FeatureValue[] features, float target) {
        float p = this.predict(features);
        this.update(features, target, p);
    }
}

