/*
 * Decompiled with CFR 0.152.
 */
package hivemall.optimizer;

import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.StringUtils;
import java.util.Map;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;

public abstract class EtaEstimator {
    public static final float DEFAULT_ETA0 = 0.1f;
    public static final float DEFAULT_ETA = 0.3f;
    public static final double DEFAULT_POWER_T = 0.1;
    protected final float eta0;

    public EtaEstimator(float eta0) {
        this.eta0 = eta0;
    }

    @Nonnull
    public abstract String typeName();

    public float eta0() {
        return this.eta0;
    }

    public abstract float eta(long var1);

    public void update(@Nonnegative float multiplier) {
    }

    public void getHyperParameters(@Nonnull Map<String, Object> hyperParams) {
        hyperParams.put("eta", this.typeName());
        hyperParams.put("eta0", Float.valueOf(this.eta0()));
    }

    @Nonnull
    public static EtaEstimator get(@Nullable CommandLine cl) throws UDFArgumentException {
        return EtaEstimator.get(cl, 0.1f);
    }

    @Nonnull
    public static EtaEstimator get(@Nullable CommandLine cl, float defaultEta0) throws UDFArgumentException {
        if (cl == null) {
            return new InvscalingEtaEstimator(defaultEta0, 0.1);
        }
        if (cl.hasOption("boldDriver")) {
            float eta = Primitives.parseFloat(cl.getOptionValue("eta"), 0.3f);
            return new AdjustingEtaEstimator(eta);
        }
        String etaValue = cl.getOptionValue("eta");
        if (etaValue != null) {
            float eta = Float.parseFloat(etaValue);
            return new FixedEtaEstimator(eta);
        }
        float eta0 = Primitives.parseFloat(cl.getOptionValue("eta0"), defaultEta0);
        if (cl.hasOption("t")) {
            long t = Long.parseLong(cl.getOptionValue("t"));
            return new SimpleEtaEstimator(eta0, t);
        }
        double power_t = Primitives.parseDouble(cl.getOptionValue("power_t"), 0.1);
        return new InvscalingEtaEstimator(eta0, power_t);
    }

    @Nonnull
    public static EtaEstimator get(@Nonnull Map<String, String> options) throws IllegalArgumentException {
        float eta0 = Primitives.parseFloat(options.get("eta0"), 0.1f);
        double power_t = Primitives.parseDouble(options.get("power_t"), 0.1);
        String etaScheme = options.get("eta");
        if (etaScheme == null) {
            return new InvscalingEtaEstimator(eta0, power_t);
        }
        if ("fixed".equalsIgnoreCase(etaScheme)) {
            return new FixedEtaEstimator(eta0);
        }
        if ("simple".equalsIgnoreCase(etaScheme)) {
            if (!options.containsKey("total_steps")) {
                throw new IllegalArgumentException("-total_steps MUST be provided when `-eta simple` is specified");
            }
            long t = Long.parseLong(options.get("total_steps"));
            return new SimpleEtaEstimator(eta0, t);
        }
        if ("inv".equalsIgnoreCase(etaScheme) || "inverse".equalsIgnoreCase(etaScheme) || "invscaling".equalsIgnoreCase(etaScheme)) {
            return new InvscalingEtaEstimator(eta0, power_t);
        }
        if (StringUtils.isNumber(etaScheme)) {
            float eta = Float.parseFloat(etaScheme);
            return new FixedEtaEstimator(eta);
        }
        throw new IllegalArgumentException("Unsupported ETA name: " + etaScheme);
    }

    public static final class AdjustingEtaEstimator
    extends EtaEstimator {
        private float eta;

        public AdjustingEtaEstimator(float eta) {
            super(eta);
            this.eta = eta;
        }

        @Override
        @Nonnull
        public String typeName() {
            return "boldDriver";
        }

        @Override
        public float eta(long t) {
            return this.eta;
        }

        @Override
        public void update(@Nonnegative float multiplier) {
            float newEta = this.eta * multiplier;
            if (!NumberUtils.isFinite(newEta)) {
                return;
            }
            this.eta = Math.min(this.eta0, newEta);
        }

        public String toString() {
            return "AdjustingEtaEstimator [ eta0 = " + this.eta0 + ", eta = " + this.eta + " ]";
        }
    }

    public static final class InvscalingEtaEstimator
    extends EtaEstimator {
        private final double power_t;

        public InvscalingEtaEstimator(float eta0, double power_t) {
            super(eta0);
            this.power_t = power_t;
        }

        @Override
        @Nonnull
        public String typeName() {
            return "Invscaling";
        }

        @Override
        public float eta(long t) {
            return (float)((double)this.eta0 / Math.pow(t, this.power_t));
        }

        public String toString() {
            return "InvscalingEtaEstimator [ eta0 = " + this.eta0 + ", power_t = " + this.power_t + " ]";
        }

        @Override
        public void getHyperParameters(@Nonnull Map<String, Object> hyperParams) {
            super.getHyperParameters(hyperParams);
            hyperParams.put("power_t", this.power_t);
        }
    }

    public static final class SimpleEtaEstimator
    extends EtaEstimator {
        private final float finalEta;
        private final double total_steps;

        public SimpleEtaEstimator(float eta0, long total_steps) {
            super(eta0);
            this.finalEta = (float)((double)eta0 / 2.0);
            this.total_steps = total_steps;
        }

        @Override
        @Nonnull
        public String typeName() {
            return "Simple";
        }

        @Override
        public float eta(long t) {
            if ((double)t > this.total_steps) {
                return this.finalEta;
            }
            return (float)((double)this.eta0 / (1.0 + (double)t / this.total_steps));
        }

        public String toString() {
            return "SimpleEtaEstimator [ eta0 = " + this.eta0 + ", totalSteps = " + this.total_steps + ", finalEta = " + this.finalEta + " ]";
        }

        @Override
        public void getHyperParameters(@Nonnull Map<String, Object> hyperParams) {
            super.getHyperParameters(hyperParams);
            hyperParams.put("total_steps", this.total_steps);
        }
    }

    public static final class FixedEtaEstimator
    extends EtaEstimator {
        public FixedEtaEstimator(float eta) {
            super(eta);
        }

        @Override
        @Nonnull
        public String typeName() {
            return "Fixed";
        }

        @Override
        public float eta(long t) {
            return this.eta0;
        }

        public String toString() {
            return "FixedEtaEstimator [ eta0 = " + this.eta0 + " ]";
        }
    }
}

