/*
 * Decompiled with CFR 0.152.
 */
package hivemall.optimizer;

import hivemall.utils.lang.Primitives;
import java.util.Map;
import javax.annotation.Nonnull;

public abstract class Regularization {
    private static final float DEFAULT_LAMBDA = 1.0E-4f;
    protected final float lambda;

    public Regularization(@Nonnull Map<String, String> options) {
        this.lambda = Primitives.parseFloat(options.get("lambda"), 1.0E-4f);
    }

    public void getHyperParameters(@Nonnull Map<String, Object> hyperParams) {
        hyperParams.put("lambda", Float.valueOf(this.lambda));
    }

    public float regularize(float weight, float gradient) {
        return gradient + this.lambda * this.getRegularizer(weight);
    }

    abstract float getRegularizer(float var1);

    @Nonnull
    public static Regularization get(@Nonnull Map<String, String> options) throws IllegalArgumentException {
        String regName = options.get("regularization");
        if (regName == null) {
            return new PassThrough(options);
        }
        if ("no".equalsIgnoreCase(regName)) {
            return new PassThrough(options);
        }
        if ("l1".equalsIgnoreCase(regName)) {
            return new L1(options);
        }
        if ("l2".equalsIgnoreCase(regName)) {
            return new L2(options);
        }
        if ("elasticnet".equalsIgnoreCase(regName)) {
            return new ElasticNet(options);
        }
        if ("rda".equalsIgnoreCase(regName)) {
            return new PassThrough(options);
        }
        throw new IllegalArgumentException("Unsupported regularization name: " + regName);
    }

    public static final class ElasticNet
    extends Regularization {
        private static final float DEFAULT_L1_RATIO = 0.5f;
        @Nonnull
        private final L1 l1;
        @Nonnull
        private final L2 l2;
        private final float l1Ratio;

        public ElasticNet(@Nonnull Map<String, String> options) {
            super(options);
            this.l1 = new L1(options);
            this.l2 = new L2(options);
            this.l1Ratio = Primitives.parseFloat(options.get("l1_ratio"), 0.5f);
            if (this.l1Ratio < 0.0f || this.l1Ratio > 1.0f) {
                throw new IllegalArgumentException("L1 ratio should be in [0.0, 1.0], but got " + this.l1Ratio);
            }
        }

        @Override
        public float getRegularizer(float weight) {
            return this.l1Ratio * this.l1.getRegularizer(weight) + (1.0f - this.l1Ratio) * this.l2.getRegularizer(weight);
        }

        @Override
        public void getHyperParameters(@Nonnull Map<String, Object> hyperParams) {
            super.getHyperParameters(hyperParams);
            hyperParams.put("regularization", "ElasticNet");
            hyperParams.put("l1_ratio", Float.valueOf(this.l1Ratio));
        }
    }

    public static final class L2
    extends Regularization {
        public L2(Map<String, String> options) {
            super(options);
        }

        @Override
        public float getRegularizer(float weight) {
            return weight;
        }

        @Override
        public void getHyperParameters(@Nonnull Map<String, Object> hyperParams) {
            super.getHyperParameters(hyperParams);
            hyperParams.put("regularization", "L2");
        }
    }

    public static final class L1
    extends Regularization {
        public L1(Map<String, String> options) {
            super(options);
        }

        @Override
        public float getRegularizer(float weight) {
            return weight > 0.0f ? 1.0f : -1.0f;
        }

        @Override
        public void getHyperParameters(@Nonnull Map<String, Object> hyperParams) {
            super.getHyperParameters(hyperParams);
            hyperParams.put("regularization", "L1");
        }
    }

    public static final class PassThrough
    extends Regularization {
        public PassThrough(Map<String, String> options) {
            super(options);
        }

        @Override
        public float getRegularizer(float weight) {
            return 0.0f;
        }

        @Override
        public float regularize(float weight, float gradient) {
            return gradient;
        }

        @Override
        public void getHyperParameters(@Nonnull Map<String, Object> hyperParams) {
            super.getHyperParameters(hyperParams);
            hyperParams.put("regularization", "no");
        }
    }
}

