/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.List;
import java.util.Map;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.nodeproperties.DoubleArrayNodeProperties;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSage;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageBaseConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageModelResolver;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.GraphStoreValidation;
import org.neo4j.gds.executor.validation.AfterLoadValidation;
import org.neo4j.gds.executor.validation.ValidationConfiguration;
import org.neo4j.gds.utils.StringFormatting;

public final class GraphSageCompanion {
    public static final String GRAPHSAGE_DESCRIPTION = "The GraphSage algorithm inductively computes embeddings for nodes based on a their features and neighborhoods.";

    private GraphSageCompanion() {
    }

    @NotNull
    public static <T extends GraphSageBaseConfig> DoubleArrayNodeProperties getNodeProperties(ComputationResult<GraphSage, GraphSage.GraphSageResult, T> computationResult) {
        final long size = computationResult.graph().nodeCount();
        final HugeObjectArray embeddings = ((GraphSage.GraphSageResult)computationResult.result()).embeddings();
        return new DoubleArrayNodeProperties(){

            public long size() {
                return size;
            }

            public double[] doubleArrayValue(long nodeId) {
                return (double[])embeddings.get(nodeId);
            }
        };
    }

    static <CONFIG extends GraphSageBaseConfig> ValidationConfiguration<CONFIG> getValidationConfig(final ModelCatalog catalog, final String username) {
        return new ValidationConfiguration<CONFIG>(){

            public List<AfterLoadValidation<CONFIG>> afterLoadValidations() {
                return List.of((graphStore, graphProjectConfig, graphSageConfig) -> {
                    Model model = GraphSageModelResolver.resolveModel((ModelCatalog)catalog, (String)username, (String)graphSageConfig.modelName());
                    GraphStoreValidation.validate((GraphStore)graphStore, (AlgoBaseConfig)((AlgoBaseConfig)model.trainConfig()));
                });
            }
        };
    }

    static Map<String, Object> getActualConfig(Object graphNameOrConfig, Map<String, Object> maybeConfig) {
        return graphNameOrConfig instanceof Map ? (Map)graphNameOrConfig : maybeConfig;
    }

    public static void injectRelationshipWeightPropertyFromModel(Map<String, Object> configuration, ModelCatalog modelCatalog, String username) {
        if (configuration.containsKey("relationshipWeightProperty")) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"The parameter `%s` cannot be overwritten during embedding computation. Instead, specify this parameter in the configuration of the model training.", (Object[])new Object[]{"relationshipWeightProperty"}));
        }
        String modelName = CypherMapWrapper.create(configuration).requireString("modelName");
        String trainProperty = ((GraphSageTrainConfig)GraphSageModelResolver.resolveModel((ModelCatalog)modelCatalog, (String)username, (String)modelName).trainConfig()).relationshipWeightProperty();
        configuration.put("relationshipWeightProperty", trainProperty);
    }
}

