/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import org.neo4j.gds.embeddings.graphsage.ActivationFunction;
import org.neo4j.gds.embeddings.graphsage.Aggregator;
import org.neo4j.gds.embeddings.graphsage.Layer;
import org.neo4j.gds.embeddings.graphsage.MaxPoolingAggregator;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.subgraph.NeighborhoodSampler;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Vector;

public class MaxPoolAggregatingLayer
implements Layer {
    private final NeighborhoodSampler sampler;
    private final int sampleSize;
    private final Weights<Matrix> poolWeights;
    private final Weights<Matrix> selfWeights;
    private final Weights<Matrix> neighborsWeights;
    private final Weights<Vector> bias;
    private final ActivationFunction activationFunction;

    public MaxPoolAggregatingLayer(int sampleSize, Weights<Matrix> poolWeights, Weights<Matrix> selfWeights, Weights<Matrix> neighborsWeights, Weights<Vector> bias, ActivationFunction activationFunction, long randomState) {
        this.poolWeights = poolWeights;
        this.selfWeights = selfWeights;
        this.neighborsWeights = neighborsWeights;
        this.bias = bias;
        this.sampleSize = sampleSize;
        this.sampler = new NeighborhoodSampler(randomState);
        this.activationFunction = activationFunction;
    }

    @Override
    public int sampleSize() {
        return this.sampleSize;
    }

    @Override
    public Aggregator aggregator() {
        return new MaxPoolingAggregator(this.poolWeights, this.selfWeights, this.neighborsWeights, this.bias, this.activationFunction);
    }

    @Override
    public NeighborhoodSampler sampler() {
        return this.sampler;
    }
}

