/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.Random;
import org.neo4j.gds.embeddings.graphsage.ActivationFunction;
import org.neo4j.gds.embeddings.graphsage.Layer;
import org.neo4j.gds.embeddings.graphsage.LayerConfig;
import org.neo4j.gds.embeddings.graphsage.MaxPoolAggregatingLayer;
import org.neo4j.gds.embeddings.graphsage.MeanAggregatingLayer;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.utils.StringFormatting;

public final class LayerFactory {
    private LayerFactory() {
    }

    public static Layer createLayer(LayerConfig layerConfig) {
        int rows = layerConfig.rows();
        int cols = layerConfig.cols();
        ActivationFunction activationFunction = layerConfig.activationFunction();
        long randomSeed = layerConfig.randomSeed();
        Weights<Matrix> weights = LayerFactory.generateWeights(rows, cols, activationFunction.weightInitBound(rows, cols), randomSeed);
        switch (layerConfig.aggregatorType()) {
            case MEAN: {
                return new MeanAggregatingLayer(weights, layerConfig.sampleSize(), activationFunction);
            }
            case POOL: {
                Weights<Matrix> poolWeights = weights;
                Weights<Matrix> selfWeights = LayerFactory.generateWeights(rows, cols, activationFunction.weightInitBound(rows, cols), randomSeed + 1L);
                Weights<Matrix> neighborsWeights = LayerFactory.generateWeights(rows, rows, activationFunction.weightInitBound(rows, rows), randomSeed + 2L);
                Weights bias = new Weights((Tensor)Vector.create((double)0.0, (int)rows));
                return new MaxPoolAggregatingLayer(layerConfig.sampleSize(), poolWeights, selfWeights, neighborsWeights, (Weights<Vector>)bias, activationFunction);
            }
        }
        throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Aggregator: %s is unknown", (Object[])new Object[]{layerConfig.aggregatorType()}));
    }

    public static Weights<Matrix> generateWeights(int rows, int cols, double weightBound, long randomSeed) {
        double[] data = new Random(randomSeed).doubles(Math.multiplyExact(rows, cols), -weightBound, weightBound).toArray();
        return new Weights((Tensor)new Matrix(data, rows, cols));
    }
}

