/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.List;
import java.util.function.Function;
import org.neo4j.gds.embeddings.graphsage.ActivationFunction;
import org.neo4j.gds.embeddings.graphsage.Aggregator;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.MultiMean;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.subgraph.BatchNeighbors;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Tensor;

public class MeanAggregator
implements Aggregator {
    private final Weights<Matrix> weights;
    private final Function<Variable<Matrix>, Variable<Matrix>> activationFunction;
    private final ActivationFunction activation;

    public MeanAggregator(Weights<Matrix> weights, ActivationFunction activationFunction) {
        this.weights = weights;
        this.activation = activationFunction;
        this.activationFunction = activationFunction.activationFunction();
    }

    @Override
    public Variable<Matrix> aggregate(Variable<Matrix> previousLayerRepresentations, SubGraph subGraph) {
        MultiMean means = new MultiMean(previousLayerRepresentations, (BatchNeighbors)subGraph);
        MatrixMultiplyWithTransposedSecondOperand product = MatrixMultiplyWithTransposedSecondOperand.of((Variable)means, this.weights);
        return this.activationFunction.apply((Variable<Matrix>)product);
    }

    @Override
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(this.weights);
    }

    @Override
    public List<Weights<? extends Tensor<?>>> weightsWithoutBias() {
        return List.of(this.weights);
    }

    @Override
    public Aggregator.AggregatorType type() {
        return Aggregator.AggregatorType.MEAN;
    }

    @Override
    public ActivationFunction activationFunction() {
        return this.activation;
    }

    public Matrix weightsData() {
        return (Matrix)this.weights.data();
    }
}

