/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.interop.tensorflow;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.framework.optimizers.AdaDelta;
import org.tensorflow.framework.optimizers.AdaGrad;
import org.tensorflow.framework.optimizers.AdaGradDA;
import org.tensorflow.framework.optimizers.Adam;
import org.tensorflow.framework.optimizers.Adamax;
import org.tensorflow.framework.optimizers.Ftrl;
import org.tensorflow.framework.optimizers.GradientDescent;
import org.tensorflow.framework.optimizers.Momentum;
import org.tensorflow.framework.optimizers.Nadam;
import org.tensorflow.framework.optimizers.RMSProp;
import org.tensorflow.op.Op;
import org.tensorflow.types.family.TNumber;

public enum GradientOptimiser {
    ADADELTA("learningRate", "rho", "epsilon"),
    ADAGRAD("learningRate", "initialAccumulatorValue"),
    ADAGRADDA("learningRate", "initialAccumulatorValue", "l1Strength", "l2Strength"),
    ADAM("learningRate", "betaOne", "betaTwo", "epsilon"),
    ADAMAX("learningRate", "betaOne", "betaTwo", "epsilon"),
    FTRL("learningRate", "learningRatePower", "initialAccumulatorValue", "l1Strength", "l2Strength", "l2ShrinkageRegularizationStrength"),
    GRADIENT_DESCENT("learningRate"),
    MOMENTUM("learningRate", "momentum"),
    NESTEROV("learningRate", "momentum"),
    NADAM("learningRate", "betaOne", "betaTwo", "epsilon"),
    RMSPROP("learningRate", "decay", "momentum", "epsilon");

    private final Set<String> args;

    private GradientOptimiser(String ... args) {
        this.args = Collections.unmodifiableSet(new HashSet<String>(Arrays.asList(args)));
    }

    public Set<String> getParameterNames() {
        return this.args;
    }

    public boolean validateParamNames(Set<String> paramNames) {
        return this.args.size() == paramNames.size() && this.args.containsAll(paramNames);
    }

    public <T extends TNumber> Op applyOptimiser(Graph graph, Operand<T> loss, Map<String, Float> optimiserParams) {
        GradientDescent optimiser;
        if (!this.validateParamNames(optimiserParams.keySet())) {
            throw new IllegalArgumentException("Invalid optimiser parameters, expected " + this.args.toString() + ", found " + optimiserParams.keySet().toString());
        }
        switch (this) {
            case ADADELTA: {
                optimiser = new AdaDelta(graph, "tribuo-adadelta", optimiserParams.get("learningRate").floatValue(), optimiserParams.get("rho").floatValue(), optimiserParams.get("epsilon").floatValue());
                break;
            }
            case ADAGRAD: {
                optimiser = new AdaGrad(graph, "tribuo-adagrad", optimiserParams.get("learningRate").floatValue(), optimiserParams.get("initialAccumulatorValue").floatValue());
                break;
            }
            case ADAGRADDA: {
                optimiser = new AdaGradDA(graph, "tribuo-adagradda", optimiserParams.get("learningRate").floatValue(), optimiserParams.get("initialAccumulatorValue").floatValue(), optimiserParams.get("l1Strength").floatValue(), optimiserParams.get("l2Strength").floatValue());
                break;
            }
            case ADAM: {
                optimiser = new Adam(graph, "tribuo-adam", optimiserParams.get("learningRate").floatValue(), optimiserParams.get("betaOne").floatValue(), optimiserParams.get("betaTwo").floatValue(), optimiserParams.get("epsilon").floatValue());
                break;
            }
            case ADAMAX: {
                optimiser = new Adamax(graph, "tribuo-adamax", optimiserParams.get("learningRate").floatValue(), optimiserParams.get("betaOne").floatValue(), optimiserParams.get("betaTwo").floatValue(), optimiserParams.get("epsilon").floatValue());
                break;
            }
            case FTRL: {
                optimiser = new Ftrl(graph, "tribuo-ftrl", optimiserParams.get("learningRate").floatValue(), optimiserParams.get("learningRatePower").floatValue(), optimiserParams.get("initialAccumulatorValue").floatValue(), optimiserParams.get("l1Strength").floatValue(), optimiserParams.get("l2Strength").floatValue(), optimiserParams.get("l2ShrinkageRegularizationStrength").floatValue());
                break;
            }
            case GRADIENT_DESCENT: {
                optimiser = new GradientDescent(graph, "tribuo-sgd", optimiserParams.get("learningRate").floatValue());
                break;
            }
            case MOMENTUM: {
                optimiser = new Momentum(graph, "tribuo-momentum", optimiserParams.get("learningRate").floatValue(), optimiserParams.get("momentum").floatValue(), false);
                break;
            }
            case NESTEROV: {
                optimiser = new Momentum(graph, "tribuo-nesterov", optimiserParams.get("learningRate").floatValue(), optimiserParams.get("momentum").floatValue(), true);
                break;
            }
            case NADAM: {
                optimiser = new Nadam(graph, "tribuo-nadam", optimiserParams.get("learningRate").floatValue(), optimiserParams.get("betaOne").floatValue(), optimiserParams.get("betaTwo").floatValue(), optimiserParams.get("epsilon").floatValue());
                break;
            }
            case RMSPROP: {
                optimiser = new RMSProp(graph, "tribuo-rmsprop", optimiserParams.get("learningRate").floatValue(), optimiserParams.get("decay").floatValue(), optimiserParams.get("momentum").floatValue(), optimiserParams.get("epsilon").floatValue(), false);
                break;
            }
            default: {
                throw new IllegalStateException("Unimplemented switch branch " + this.toString());
            }
        }
        return optimiser.minimize(loss, "tribuo-" + this.toString() + "-minimize");
    }
}

