/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.Map;
import java.util.stream.Stream;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.TrainProc;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
import org.neo4j.gds.embeddings.graphsage.ModelData;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrain;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainAlgorithmFactory;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.results.MemoryEstimateResult;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@GdsCallable(name="gds.beta.graphSage.train", description="The GraphSage algorithm inductively computes embeddings for nodes based on a their features and neighborhoods.", executionMode=ExecutionMode.TRAIN)
public class GraphSageTrainProc
extends TrainProc<GraphSageTrain, Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics>, GraphSageTrainConfig, TrainProc.TrainResult> {
    @Description(value="The GraphSage algorithm inductively computes embeddings for nodes based on a their features and neighborhoods.")
    @Procedure(name="gds.beta.graphSage.train", mode=Mode.READ)
    public Stream<TrainProc.TrainResult> train(@Name(value="graphName") String graphName, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        return this.trainAndStoreModelWithResult(this.compute(graphName, configuration));
    }

    @Description(value="Returns an estimation of the memory consumption for that procedure.")
    @Procedure(name="gds.beta.graphSage.train.estimate", mode=Mode.READ)
    public Stream<MemoryEstimateResult> estimate(@Name(value="graphNameOrConfiguration") Object graphNameOrConfiguration, @Name(value="algoConfiguration") Map<String, Object> algoConfiguration) {
        return this.computeEstimate(graphNameOrConfiguration, algoConfiguration);
    }

    protected GraphSageTrainConfig newConfig(String username, CypherMapWrapper config) {
        return GraphSageTrainConfig.of((String)username, (CypherMapWrapper)config);
    }

    public GraphAlgorithmFactory<GraphSageTrain, GraphSageTrainConfig> algorithmFactory() {
        return new GraphSageTrainAlgorithmFactory();
    }

    protected String modelType() {
        return "graphSage";
    }

    protected TrainProc.TrainResult constructProcResult(ComputationResult<GraphSageTrain, Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics>, GraphSageTrainConfig> computationResult) {
        return new TrainProc.TrainResult((Model)computationResult.result(), computationResult.computeMillis(), computationResult.graph().nodeCount(), computationResult.graph().relationshipCount());
    }

    protected Model<?, ?, ?> extractModel(Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics> model) {
        return model;
    }
}

