/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.Arrays;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.Relu;
import org.neo4j.gds.ml.core.functions.Sigmoid;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.gds.utils.StringJoining;

public enum ActivationFunction {
    SIGMOID{

        @Override
        public Function<Variable<Matrix>, Variable<Matrix>> activationFunction() {
            return Sigmoid::new;
        }

        @Override
        public double weightInitBound(int rows, int cols) {
            return Math.sqrt(2.0 / (double)(rows + cols));
        }
    }
    ,
    RELU{

        @Override
        public Function<Variable<Matrix>, Variable<Matrix>> activationFunction() {
            return Relu::new;
        }

        @Override
        public double weightInitBound(int rows, int cols) {
            return Math.sqrt(2.0 / (double)cols);
        }
    };

    private static final List<String> VALUES;

    public abstract Function<Variable<Matrix>, Variable<Matrix>> activationFunction();

    public abstract double weightInitBound(int var1, int var2);

    public static ActivationFunction of(String activationFunction) {
        return ActivationFunction.valueOf(StringFormatting.toUpperCaseWithLocale((String)activationFunction));
    }

    public static ActivationFunction parse(Object input) {
        if (input instanceof String) {
            String inputString = StringFormatting.toUpperCaseWithLocale((String)((String)input));
            if (!VALUES.contains(inputString)) {
                throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"ActivationFunction `%s` is not supported. Must be one of: %s.", (Object[])new Object[]{input, StringJoining.join(VALUES)}));
            }
            return ActivationFunction.of(inputString);
        }
        if (input instanceof ActivationFunction) {
            return (ActivationFunction)((Object)input);
        }
        throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Expected ActivationFunction or String. Got %s.", (Object[])new Object[]{input.getClass().getSimpleName()}));
    }

    public static String toString(ActivationFunction af) {
        return af.toString();
    }

    static {
        VALUES = Arrays.stream(ActivationFunction.values()).map(Enum::name).collect(Collectors.toList());
    }
}

