/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.Map;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.embeddings.graphsage.FeatureFunction;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.LabelwiseFeatureProjection;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;

public class MultiLabelFeatureFunction
implements FeatureFunction {
    private final Map<NodeLabel, Weights<Matrix>> weightsByLabel;
    private final int projectedFeatureDimension;

    public MultiLabelFeatureFunction(Map<NodeLabel, Weights<Matrix>> weightsByLabel, int projectedFeatureDimension) {
        this.weightsByLabel = weightsByLabel;
        this.projectedFeatureDimension = projectedFeatureDimension;
    }

    public Map<NodeLabel, Weights<Matrix>> weightsByLabel() {
        return this.weightsByLabel;
    }

    @Override
    public Variable<Matrix> apply(Graph graph, long[] nodeIds, HugeObjectArray<double[]> features) {
        NodeLabel[] labels = new NodeLabel[nodeIds.length];
        SingleNodeLabelConsumer consumer = new SingleNodeLabelConsumer();
        for (int i = 0; i < nodeIds.length; ++i) {
            graph.forEachNodeLabel(nodeIds[i], (IdMap.NodeLabelConsumer)consumer);
            labels[i] = consumer.nodeLabel;
        }
        return new LabelwiseFeatureProjection(nodeIds, features, this.weightsByLabel, this.projectedFeatureDimension, labels);
    }

    public int projectedFeatureDimension() {
        return this.projectedFeatureDimension;
    }

    private static class SingleNodeLabelConsumer
    implements IdMap.NodeLabelConsumer {
        NodeLabel nodeLabel;

        private SingleNodeLabelConsumer() {
        }

        public boolean accept(NodeLabel nodeLabel) {
            this.nodeLabel = nodeLabel;
            return false;
        }
    }
}

