/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage.algo;

import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.schema.GraphSchema;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
import org.neo4j.gds.embeddings.graphsage.LayerFactory;
import org.neo4j.gds.embeddings.graphsage.ModelData;
import org.neo4j.gds.embeddings.graphsage.MultiLabelFeatureFunction;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrain;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.embeddings.graphsage.algo.MultiLabelFeatureExtractors;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.model.ModelConfig;

public class MultiLabelGraphSageTrain
extends GraphSageTrain {
    private static final double WEIGHT_BOUND = 1.0;
    private final Graph graph;
    private final GraphSageTrainConfig config;
    private final ExecutorService executor;

    public MultiLabelGraphSageTrain(Graph graph, GraphSageTrainConfig config, ExecutorService executor, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.config = config;
        this.executor = executor;
    }

    public Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics> compute() {
        this.progressTracker.beginSubTask("GraphSageTrain");
        MultiLabelFeatureExtractors multiLabelFeatureExtractors = GraphSageHelper.multiLabelFeatureExtractors(this.graph, this.config);
        Map<NodeLabel, Weights<Matrix>> weightsByLabel = MultiLabelGraphSageTrain.makeWeightsByLabel(this.config, multiLabelFeatureExtractors);
        Integer projectedFeatureDimension = this.config.projectedFeatureDimension().orElseThrow();
        MultiLabelFeatureFunction multiLabelFeatureFunction = new MultiLabelFeatureFunction(weightsByLabel, projectedFeatureDimension);
        GraphSageModelTrainer trainer = new GraphSageModelTrainer(this.config, this.executor, this.progressTracker, multiLabelFeatureFunction, multiLabelFeatureFunction.weightsByLabel().values());
        GraphSageModelTrainer.ModelTrainResult trainResult = trainer.train(this.graph, GraphSageHelper.initializeMultiLabelFeatures(this.graph, multiLabelFeatureExtractors));
        this.progressTracker.endSubTask("GraphSageTrain");
        return Model.of((String)"graphSage", (GraphSchema)this.graph.schema(), (Object)ModelData.of(trainResult.layers(), multiLabelFeatureFunction), (ModelConfig)this.config, (ToMapConvertible)trainResult.metrics());
    }

    public void release() {
    }

    private static Map<NodeLabel, Weights<Matrix>> makeWeightsByLabel(GraphSageTrainConfig config, MultiLabelFeatureExtractors multiLabelFeatureExtractors) {
        return multiLabelFeatureExtractors.featureCountPerLabel().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> LayerFactory.generateWeights(config.projectedFeatureDimension().orElseThrow(), (Integer)e.getValue(), 1.0, config.randomSeed().orElseGet(() -> ThreadLocalRandom.current().nextLong()))));
    }
}

