/*
 * Decompiled with CFR 0.152.
 */
package hivemall.classifier;

import hivemall.GeneralLearnerBaseUDTF;
import hivemall.model.FeatureValue;
import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;

@Description(name="train_classifier", value="_FUNC_(list<string|int|bigint> features, int label [, const string options]) - Returns a relation consists of <string|int|bigint feature, float weight>", extended="Build a prediction model by a generic classifier")
public final class GeneralClassifierUDTF
extends GeneralLearnerBaseUDTF {
    @Override
    protected String getLossOptionDescription() {
        return "Loss function [HingeLoss (default), LogLoss, SquaredHingeLoss, ModifiedHuberLoss, or\na regression loss: SquaredLoss, QuantileLoss, EpsilonInsensitiveLoss, SquaredEpsilonInsensitiveLoss, HuberLoss]";
    }

    @Override
    protected LossFunctions.LossType getDefaultLossType() {
        return LossFunctions.LossType.HingeLoss;
    }

    @Override
    protected void checkLossFunction(@Nonnull LossFunctions.LossFunction lossFunction) throws UDFArgumentException {
    }

    @Override
    protected void checkTargetValue(float label) throws UDFArgumentException {
        if (label != -1.0f && label != 0.0f && label != 1.0f) {
            throw new UDFArgumentException("Invalid label value for classification: " + label);
        }
    }

    @Override
    protected void train(@Nonnull FeatureValue[] features, float label) {
        float predicted = this.predict(features);
        float y = label > 0.0f ? 1.0f : -1.0f;
        this.update(features, y, predicted);
    }
}

