/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.utils;

import java.util.Map;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.AdaMax;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nadam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.schedule.InverseSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasOptimizerUtils {
    private static final Logger log = LoggerFactory.getLogger(KerasOptimizerUtils.class);
    protected static final String LR = "lr";
    protected static final String LR2 = "learning_rate";
    protected static final String EPSILON = "epsilon";
    protected static final String MOMENTUM = "momentum";
    protected static final String BETA_1 = "beta_1";
    protected static final String BETA_2 = "beta_2";
    protected static final String DECAY = "decay";
    protected static final String RHO = "rho";
    protected static final String SCHEDULE_DECAY = "schedule_decay";

    public static IUpdater mapOptimizer(Map<String, Object> optimizerConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        AdaDelta dl4jOptimizer;
        if (!optimizerConfig.containsKey("class_name")) {
            throw new InvalidKerasConfigurationException("Optimizer config does not contain a name field.");
        }
        String optimizerName = (String)optimizerConfig.get("class_name");
        if (!optimizerConfig.containsKey("config")) {
            throw new InvalidKerasConfigurationException("Field config missing from layer config");
        }
        Map optimizerParameters = (Map)optimizerConfig.get("config");
        switch (optimizerName) {
            case "Adam": {
                double lr = (Double)(optimizerParameters.containsKey(LR) ? optimizerParameters.get(LR) : optimizerParameters.get(LR2));
                double beta1 = (Double)optimizerParameters.get(BETA_1);
                double beta2 = (Double)optimizerParameters.get(BETA_2);
                double epsilon = (Double)optimizerParameters.get(EPSILON);
                double decay = (Double)optimizerParameters.get(DECAY);
                dl4jOptimizer = new Adam.Builder().beta1(beta1).beta2(beta2).epsilon(epsilon).learningRate(lr).learningRateSchedule((ISchedule)(decay == 0.0 ? null : new InverseSchedule(ScheduleType.ITERATION, lr, decay, 1.0))).build();
                break;
            }
            case "Adadelta": {
                double rho = (Double)optimizerParameters.get(RHO);
                double epsilon = (Double)optimizerParameters.get(EPSILON);
                dl4jOptimizer = new AdaDelta.Builder().epsilon(epsilon).rho(rho).build();
                break;
            }
            case "Adgrad": {
                double lr = (Double)(optimizerParameters.containsKey(LR) ? optimizerParameters.get(LR) : optimizerParameters.get(LR2));
                double epsilon = (Double)optimizerParameters.get(EPSILON);
                double decay = (Double)optimizerParameters.get(DECAY);
                dl4jOptimizer = new AdaGrad.Builder().epsilon(epsilon).learningRate(lr).learningRateSchedule((ISchedule)(decay == 0.0 ? null : new InverseSchedule(ScheduleType.ITERATION, lr, decay, 1.0))).build();
                break;
            }
            case "Adamax": {
                double lr = (Double)(optimizerParameters.containsKey(LR) ? optimizerParameters.get(LR) : optimizerParameters.get(LR2));
                double beta1 = (Double)optimizerParameters.get(BETA_1);
                double beta2 = (Double)optimizerParameters.get(BETA_2);
                double epsilon = (Double)optimizerParameters.get(EPSILON);
                dl4jOptimizer = new AdaMax(lr, beta1, beta2, epsilon);
                break;
            }
            case "Nadam": {
                double lr = (Double)(optimizerParameters.containsKey(LR) ? optimizerParameters.get(LR) : optimizerParameters.get(LR2));
                double beta1 = (Double)optimizerParameters.get(BETA_1);
                double beta2 = (Double)optimizerParameters.get(BETA_2);
                double epsilon = (Double)optimizerParameters.get(EPSILON);
                double scheduleDecay = optimizerParameters.getOrDefault(SCHEDULE_DECAY, 0.0);
                dl4jOptimizer = new Nadam.Builder().beta1(beta1).beta2(beta2).epsilon(epsilon).learningRate(lr).learningRateSchedule((ISchedule)(scheduleDecay == 0.0 ? null : new InverseSchedule(ScheduleType.ITERATION, lr, scheduleDecay, 1.0))).build();
                break;
            }
            case "SGD": {
                double lr = (Double)(optimizerParameters.containsKey(LR) ? optimizerParameters.get(LR) : optimizerParameters.get(LR2));
                double momentum = (Double)(optimizerParameters.containsKey(EPSILON) ? optimizerParameters.get(EPSILON) : optimizerParameters.get(MOMENTUM));
                double decay = (Double)optimizerParameters.get(DECAY);
                dl4jOptimizer = new Nesterovs.Builder().momentum(momentum).learningRate(lr).learningRateSchedule((ISchedule)(decay == 0.0 ? null : new InverseSchedule(ScheduleType.ITERATION, lr, decay, 1.0))).build();
                break;
            }
            case "RMSprop": {
                double lr = (Double)(optimizerParameters.containsKey(LR) ? optimizerParameters.get(LR) : optimizerParameters.get(LR2));
                double rho = (Double)optimizerParameters.get(RHO);
                double epsilon = (Double)optimizerParameters.get(EPSILON);
                double decay = (Double)optimizerParameters.get(DECAY);
                dl4jOptimizer = new RmsProp.Builder().epsilon(epsilon).rmsDecay(rho).learningRate(lr).learningRateSchedule((ISchedule)(decay == 0.0 ? null : new InverseSchedule(ScheduleType.ITERATION, lr, decay, 1.0))).build();
                break;
            }
            default: {
                throw new UnsupportedKerasConfigurationException("Optimizer with name " + optimizerName + "can not bematched to a DL4J optimizer. Note that custom TFOptimizers are not supported by model import");
            }
        }
        return dl4jOptimizer;
    }
}

