/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.multilabel.sgd.fm;

import com.oracle.labs.mlrg.olcut.config.ArgumentException;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import java.util.logging.Logger;
import org.tribuo.math.optimisers.GradientOptimiserOptions;
import org.tribuo.multilabel.sgd.MultiLabelObjective;
import org.tribuo.multilabel.sgd.fm.FMMultiLabelTrainer;
import org.tribuo.multilabel.sgd.objectives.BinaryCrossEntropy;
import org.tribuo.multilabel.sgd.objectives.Hinge;

public class FMMultiLabelOptions
implements Options {
    private static final Logger logger = Logger.getLogger(FMMultiLabelOptions.class.getName());
    public GradientOptimiserOptions sgoOptions;
    @Option(longName="fm-epochs", usage="Number of SGD epochs.")
    public int fmEpochs = 5;
    @Option(longName="fm-objective", usage="Loss function.")
    public LossEnum fmObjective = LossEnum.SIGMOID;
    @Option(longName="fm-logging-interval", usage="Log the objective after <int> examples.")
    public int fmLoggingInterval = 100;
    @Option(longName="fm-minibatch-size", usage="Minibatch size.")
    public int fmMinibatchSize = 1;
    @Option(longName="fm-seed", usage="Sets the random seed for the FMMultiLabelTrainer.")
    private long fmSeed = 12345L;
    @Option(longName="fm-factor-size", usage="Factor size.")
    public int fmFactorSize = 5;
    @Option(longName="fm-variance", usage="Variance of the initialization gaussian.")
    public double fmVariance = 0.5;

    public MultiLabelObjective getLoss() {
        switch (this.fmObjective) {
            case HINGE: {
                return new Hinge();
            }
            case SIGMOID: {
                return new BinaryCrossEntropy();
            }
        }
        throw new ArgumentException("sgd-objective", "Unknown loss function " + (Object)((Object)this.fmObjective));
    }

    public FMMultiLabelTrainer getTrainer() {
        logger.info(String.format("Set logging interval to %d", this.fmLoggingInterval));
        return new FMMultiLabelTrainer(this.getLoss(), this.sgoOptions.getOptimiser(), this.fmEpochs, this.fmLoggingInterval, this.fmMinibatchSize, this.fmSeed, this.fmFactorSize, this.fmVariance);
    }

    public static enum LossEnum {
        HINGE,
        SIGMOID;

    }
}

