/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.List;
import java.util.function.Function;
import org.neo4j.gds.embeddings.graphsage.ActivationFunction;
import org.neo4j.gds.embeddings.graphsage.Aggregator;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.ElementWiseMax;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.MatrixSum;
import org.neo4j.gds.ml.core.functions.MatrixVectorSum;
import org.neo4j.gds.ml.core.functions.Slice;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.subgraph.BatchNeighbors;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;

public class MaxPoolingAggregator
implements Aggregator {
    private final Weights<Matrix> poolWeights;
    private final Weights<Matrix> selfWeights;
    private final Weights<Matrix> neighborsWeights;
    private final Weights<Vector> bias;
    private final Function<Variable<Matrix>, Variable<Matrix>> activationFunction;
    private final ActivationFunction activation;

    public MaxPoolingAggregator(Weights<Matrix> poolWeights, Weights<Matrix> selfWeights, Weights<Matrix> neighborsWeights, Weights<Vector> bias, ActivationFunction activationFunction) {
        this.poolWeights = poolWeights;
        this.selfWeights = selfWeights;
        this.neighborsWeights = neighborsWeights;
        this.bias = bias;
        this.activationFunction = activationFunction.activationFunction();
        this.activation = activationFunction;
    }

    @Override
    public Variable<Matrix> aggregate(Variable<Matrix> previousLayerRepresentations, SubGraph subGraph) {
        MatrixMultiplyWithTransposedSecondOperand weightedPreviousLayer = MatrixMultiplyWithTransposedSecondOperand.of(previousLayerRepresentations, this.poolWeights);
        MatrixVectorSum biasedWeightedPreviousLayer = new MatrixVectorSum((Variable)weightedPreviousLayer, this.bias);
        Variable<Matrix> neighborhoodActivations = this.activationFunction.apply((Variable<Matrix>)biasedWeightedPreviousLayer);
        ElementWiseMax elementwiseMax = new ElementWiseMax(neighborhoodActivations, (BatchNeighbors)subGraph);
        Slice selfPreviousLayer = new Slice(previousLayerRepresentations, subGraph.batchIds());
        MatrixMultiplyWithTransposedSecondOperand self = MatrixMultiplyWithTransposedSecondOperand.of((Variable)selfPreviousLayer, this.selfWeights);
        MatrixMultiplyWithTransposedSecondOperand neighbors = MatrixMultiplyWithTransposedSecondOperand.of((Variable)elementwiseMax, this.neighborsWeights);
        MatrixSum sum = new MatrixSum(List.of(self, neighbors));
        return this.activationFunction.apply((Variable<Matrix>)sum);
    }

    @Override
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(this.poolWeights, this.selfWeights, this.neighborsWeights, this.bias);
    }

    @Override
    public Aggregator.AggregatorType type() {
        return Aggregator.AggregatorType.POOL;
    }

    @Override
    public ActivationFunction activationFunction() {
        return this.activation;
    }

    public Matrix poolWeights() {
        return (Matrix)this.poolWeights.data();
    }

    public Matrix selfWeights() {
        return (Matrix)this.selfWeights.data();
    }

    public Matrix neighborsWeights() {
        return (Matrix)this.neighborsWeights.data();
    }

    public Vector bias() {
        return (Vector)this.bias.data();
    }
}

