/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.embeddings.graphsage.ActivationFunction;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.gds.utils.StringJoining;

public interface Aggregator {
    public Variable<Matrix> aggregate(Variable<Matrix> var1, SubGraph var2);

    public List<Weights<? extends Tensor<?>>> weights();

    public AggregatorType type();

    public ActivationFunction activationFunction();

    public static enum AggregatorType {
        MEAN{

            @Override
            public MemoryRange memoryEstimation(long minNodeCount, long maxNodeCount, long minPreviousNodeCount, long maxPreviousNodeCount, int inputDimension, int embeddingDimension) {
                long minBound = MemoryUsage.sizeOfDoubleArray((long)(minNodeCount * (long)inputDimension)) + 2L * MemoryUsage.sizeOfDoubleArray((long)(minNodeCount * (long)embeddingDimension));
                long maxBound = MemoryUsage.sizeOfDoubleArray((long)(maxNodeCount * (long)inputDimension)) + 2L * MemoryUsage.sizeOfDoubleArray((long)(maxNodeCount * (long)embeddingDimension));
                return MemoryRange.of((long)minBound, (long)maxBound);
            }
        }
        ,
        POOL{

            @Override
            public MemoryRange memoryEstimation(long minNodeCount, long maxNodeCount, long minPreviousNodeCount, long maxPreviousNodeCount, int inputDimension, int embeddingDimension) {
                long minBound = 3L * MemoryUsage.sizeOfDoubleArray((long)(minPreviousNodeCount * (long)embeddingDimension)) + 6L * MemoryUsage.sizeOfDoubleArray((long)(minNodeCount * (long)embeddingDimension));
                long maxBound = 3L * MemoryUsage.sizeOfDoubleArray((long)(maxPreviousNodeCount * (long)embeddingDimension)) + 6L * MemoryUsage.sizeOfDoubleArray((long)(maxNodeCount * (long)embeddingDimension));
                return MemoryRange.of((long)minBound, (long)maxBound);
            }
        };

        private static final List<String> VALUES;

        public static AggregatorType of(String aggregatorType) {
            return AggregatorType.valueOf(StringFormatting.toUpperCaseWithLocale((String)aggregatorType));
        }

        public static AggregatorType parse(Object input) {
            if (input instanceof String) {
                String inputString = StringFormatting.toUpperCaseWithLocale((String)((String)input));
                if (!VALUES.contains(inputString)) {
                    throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Aggregator `%s` is not supported. Must be one of: %s.", (Object[])new Object[]{input, StringJoining.join(VALUES)}));
                }
                return AggregatorType.of(inputString);
            }
            if (input instanceof AggregatorType) {
                return (AggregatorType)((Object)input);
            }
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Expected Aggregator or String. Got %s.", (Object[])new Object[]{input.getClass().getSimpleName()}));
        }

        public static String toString(AggregatorType af) {
            return af.toString();
        }

        public abstract MemoryRange memoryEstimation(long var1, long var3, long var5, long var7, int var9, int var10);

        static {
            VALUES = Arrays.stream(AggregatorType.values()).map(Enum::name).collect(Collectors.toList());
        }
    }
}

