/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.List;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.api.properties.nodes.DoubleArrayNodePropertyValues;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSage;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageBaseConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageModelResolver;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.validation.AfterLoadValidation;
import org.neo4j.gds.executor.validation.ValidationConfiguration;

public final class GraphSageCompanion {
    public static final String GRAPHSAGE_DESCRIPTION = "The GraphSage algorithm inductively computes embeddings for nodes based on a their features and neighborhoods.";

    private GraphSageCompanion() {
    }

    @NotNull
    public static <T extends GraphSageBaseConfig> DoubleArrayNodePropertyValues getNodeProperties(ComputationResult<GraphSage, GraphSage.GraphSageResult, T> computationResult) {
        final long size = computationResult.graph().nodeCount();
        final HugeObjectArray embeddings = ((GraphSage.GraphSageResult)computationResult.result()).embeddings();
        return new DoubleArrayNodePropertyValues(){

            public long nodeCount() {
                return size;
            }

            public double[] doubleArrayValue(long nodeId) {
                return (double[])embeddings.get(nodeId);
            }
        };
    }

    static <CONFIG extends GraphSageBaseConfig> ValidationConfiguration<CONFIG> getValidationConfig(final ModelCatalog catalog) {
        return new ValidationConfiguration<CONFIG>(){

            public List<AfterLoadValidation<CONFIG>> afterLoadValidations() {
                return List.of((graphStore, graphProjectConfig, graphSageConfig) -> {
                    Model model = GraphSageModelResolver.resolveModel((ModelCatalog)catalog, (String)graphSageConfig.username(), (String)graphSageConfig.modelName());
                    GraphSageTrainConfig trainConfig = (GraphSageTrainConfig)model.trainConfig();
                    trainConfig.graphStoreValidation(graphStore, graphSageConfig.nodeLabelIdentifiers(graphStore), graphSageConfig.internalRelationshipTypes(graphStore));
                });
            }
        };
    }
}

