/*
 * Decompiled with CFR 0.152.
 */
package hivemall.factorization.fm;

import hivemall.factorization.fm.FactorizationMachineModel;
import hivemall.optimizer.EtaEstimator;
import hivemall.utils.lang.Primitives;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;

class FMHyperParameters {
    protected static final float DEFAULT_ETA0 = 0.1f;
    protected static final float DEFAULT_LAMBDA = 1.0E-4f;
    boolean classification = false;
    int factors = 5;
    float lambda = 1.0E-4f;
    float lambdaW0;
    float lambdaW;
    float lambdaV;
    double sigma = 0.1;
    long seed = -1L;
    @Nonnull
    FactorizationMachineModel.VInitScheme vInit;
    double minTarget = Double.MIN_VALUE;
    double maxTarget = Double.MAX_VALUE;
    @Nonnull
    EtaEstimator eta;
    int numFeatures = -1;
    boolean l2norm;
    int iters = 10;
    boolean conversionCheck = true;
    double convergenceRate = 0.005;
    boolean earlyStopping = false;
    boolean adaptiveRegularization = false;
    float validationRatio = 0.05f;
    int validationThreshold = 1000;
    boolean parseFeatureAsInt = false;

    FMHyperParameters() {
        this.vInit = this.instantiateVInit();
        this.eta = new EtaEstimator.InvscalingEtaEstimator(0.1f, 0.1);
    }

    public String toString() {
        return "FMHyperParameters [classification=" + this.classification + ", factors=" + this.factors + ", lambda=" + this.lambda + ", lambdaW0=" + this.lambdaW0 + ", lambdaW=" + this.lambdaW + ", lambdaV=" + this.lambdaV + ", sigma=" + this.sigma + ", seed=" + this.seed + ", vInit=" + (Object)((Object)this.vInit) + ", minTarget=" + this.minTarget + ", maxTarget=" + this.maxTarget + ", eta=" + this.eta + ", numFeatures=" + this.numFeatures + ", l2norm=" + this.l2norm + ", iters=" + this.iters + ", conversionCheck=" + this.conversionCheck + ", convergenceRate=" + this.convergenceRate + ", adaptiveRegularization=" + this.adaptiveRegularization + ", validationRatio=" + this.validationRatio + ", validationThreshold=" + this.validationThreshold + ", parseFeatureAsInt=" + this.parseFeatureAsInt + "]";
    }

    void processOptions(@Nonnull CommandLine cl) throws UDFArgumentException {
        this.classification = cl.hasOption("classification");
        this.factors = cl.hasOption("factor") ? Primitives.parseInt(cl.getOptionValue("factor"), this.factors) : Primitives.parseInt(cl.getOptionValue("factors"), this.factors);
        this.lambda = Primitives.parseFloat(cl.getOptionValue("lambda"), this.lambda);
        this.lambdaW0 = Primitives.parseFloat(cl.getOptionValue("lambda_w0"), this.lambda);
        this.lambdaW = Primitives.parseFloat(cl.getOptionValue("lambda_wi"), this.lambda);
        this.lambdaV = Primitives.parseFloat(cl.getOptionValue("lambda_v"), this.lambda);
        this.sigma = Primitives.parseDouble(cl.getOptionValue("sigma"), this.sigma);
        this.seed = Primitives.parseLong(cl.getOptionValue("seed"), this.seed);
        if (this.seed == -1L) {
            this.seed = System.nanoTime();
        }
        this.vInit = this.instantiateVInit(cl, this.factors, this.seed, this.classification);
        this.minTarget = Primitives.parseDouble(cl.getOptionValue("min_target"), this.minTarget);
        this.maxTarget = Primitives.parseDouble(cl.getOptionValue("max_target"), this.maxTarget);
        this.eta = EtaEstimator.get(cl, 0.1f);
        this.numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"), this.numFeatures);
        this.l2norm = cl.hasOption("enable_norm");
        this.iters = cl.hasOption("iter") ? Primitives.parseInt(cl.getOptionValue("iter"), this.iters) : Primitives.parseInt(cl.getOptionValue("iterations"), this.iters);
        this.conversionCheck = !cl.hasOption("disable_cvtest");
        this.convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), this.convergenceRate);
        this.earlyStopping = cl.hasOption("early_stopping");
        this.adaptiveRegularization = cl.hasOption("adaptive_regularization");
        this.validationRatio = Primitives.parseFloat(cl.getOptionValue("validation_ratio"), this.validationRatio);
        if (this.validationRatio < 0.0f || this.validationRatio >= 1.0f) {
            throw new UDFArgumentException("validation_ratio should be in range [0, 1): " + this.validationRatio);
        }
        this.validationThreshold = Primitives.parseInt(cl.getOptionValue("validation_threshold"), this.validationThreshold);
        this.parseFeatureAsInt = cl.hasOption("int_feature");
    }

    @Nonnull
    private FactorizationMachineModel.VInitScheme instantiateVInit() {
        FactorizationMachineModel.VInitScheme vInit = this.getDefaultVinitScheme(this.classification);
        vInit.setMaxInitValue(0.5f);
        vInit.setInitStdDev(0.2);
        vInit.initRandom(this.factors, System.nanoTime());
        return vInit;
    }

    @Nonnull
    private FactorizationMachineModel.VInitScheme instantiateVInit(@Nonnull CommandLine cl, int factor, long seed, boolean classification) {
        String vInitOpt = cl.getOptionValue("init_v");
        float maxInitValue = Primitives.parseFloat(cl.getOptionValue("max_init_value"), 0.5f);
        double initStdDev = Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.1);
        FactorizationMachineModel.VInitScheme vInit = FactorizationMachineModel.VInitScheme.resolve(vInitOpt, this.getDefaultVinitScheme(classification));
        vInit.setMaxInitValue(maxInitValue);
        initStdDev = Math.max(initStdDev, 1.0 / (double)factor);
        vInit.setInitStdDev(initStdDev);
        vInit.initRandom(factor, seed);
        return vInit;
    }

    @Nonnull
    protected FactorizationMachineModel.VInitScheme getDefaultVinitScheme(boolean classification) {
        return classification ? FactorizationMachineModel.VInitScheme.gaussian : FactorizationMachineModel.VInitScheme.adjustedRandom;
    }

    public static final class FFMHyperParameters
    extends FMHyperParameters {
        boolean globalBias = false;
        boolean linearCoeff = false;
        int numFields = 256;
        boolean useAdaGrad = false;
        float eps = 1.0f;
        boolean useFTRL = false;
        float alphaFTRL = 0.5f;
        float betaFTRL = 1.0f;
        float lambda1 = 2.0E-4f;
        float lambda2 = 1.0E-4f;

        FFMHyperParameters() {
        }

        @Override
        @Nonnull
        protected FactorizationMachineModel.VInitScheme getDefaultVinitScheme(boolean classification) {
            return FactorizationMachineModel.VInitScheme.random;
        }

        @Override
        void processOptions(@Nonnull CommandLine cl) throws UDFArgumentException {
            String optimizer;
            int hashbits;
            super.processOptions(cl);
            if (cl.hasOption("int_feature")) {
                throw new UDFArgumentException("int_feature option is not supported yet for FFM");
            }
            this.globalBias = cl.hasOption("global_bias");
            this.linearCoeff = cl.hasOption("linear_term");
            if (cl.hasOption("enable_norm") && cl.hasOption("disable_norm")) {
                throw new UDFArgumentException("-enable_norm and -disable_norm MUST NOT be used simultaneously");
            }
            boolean bl = this.l2norm = !cl.hasOption("disable_norm");
            if (this.numFeatures == -1 && (hashbits = Primitives.parseInt(cl.getOptionValue("feature_hashing"), -1)) != -1) {
                if (hashbits < 18 || hashbits > 31) {
                    throw new UDFArgumentException("-feature_hashing MUST be in range [18,31]: " + hashbits);
                }
                this.numFeatures = 1 << hashbits;
            }
            this.numFields = Primitives.parseInt(cl.getOptionValue("num_fields"), this.numFields);
            if (this.numFields <= 1) {
                throw new UDFArgumentException("-num_fields MUST be greater than 1: " + this.numFields);
            }
            switch (optimizer = cl.getOptionValue("optimizer", "ftrl").toLowerCase()) {
                case "ftrl": {
                    this.useFTRL = true;
                    this.useAdaGrad = false;
                    this.alphaFTRL = Primitives.parseFloat(cl.getOptionValue("alphaFTRL"), this.alphaFTRL);
                    if (this.alphaFTRL == 0.0f) {
                        throw new UDFArgumentException("-alphaFTRL SHOULD NOT be 0");
                    }
                    this.betaFTRL = Primitives.parseFloat(cl.getOptionValue("betaFTRL"), this.betaFTRL);
                    this.lambda1 = Primitives.parseFloat(cl.getOptionValue("lambda1"), this.lambda1);
                    this.lambda2 = Primitives.parseFloat(cl.getOptionValue("lambda2"), this.lambda2);
                    break;
                }
                case "adagrad": {
                    this.useAdaGrad = true;
                    this.useFTRL = false;
                    this.eps = Primitives.parseFloat(cl.getOptionValue("eps"), this.eps);
                    break;
                }
                default: {
                    this.useFTRL = false;
                    this.useAdaGrad = false;
                }
            }
        }

        @Override
        public String toString() {
            return "FFMHyperParameters [globalBias=" + this.globalBias + ", linearCoeff=" + this.linearCoeff + ", numFields=" + this.numFields + ", useAdaGrad=" + this.useAdaGrad + ", eps=" + this.eps + ", useFTRL=" + this.useFTRL + ", alphaFTRL=" + this.alphaFTRL + ", betaFTRL=" + this.betaFTRL + ", lambda1=" + this.lambda1 + ", lambda2=" + this.lambda2 + "], " + super.toString();
        }
    }
}

