/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.extension;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.neo4j.dbms.api.DatabaseManagementService;
import org.neo4j.gds.QueryRunner;
import org.neo4j.gds.extension.ExtensionUtil;
import org.neo4j.gds.extension.IdFunction;
import org.neo4j.gds.extension.Neo4jGraph;
import org.neo4j.gds.extension.NodeFunction;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Result;
import org.neo4j.kernel.impl.core.NodeEntity;
import org.neo4j.kernel.internal.GraphDatabaseAPI;

public class Neo4jSupportExtension
implements BeforeEachCallback {
    private static final String RETURN_STATEMENT = "RETURN *";
    private static final ExtensionContext.Namespace DBMS_NAMESPACE = ExtensionContext.Namespace.create((Object[])new Object[]{"org", "neo4j", "dbms"});
    private static final String DBMS_KEY = "service";

    public void beforeEach(ExtensionContext context) throws Exception {
        GraphDatabaseAPI db = (GraphDatabaseAPI)this.getDbms(context).map(dbms -> dbms.database("neo4j")).orElseThrow(() -> new IllegalStateException("No database was found."));
        Class requiredTestClass = context.getRequiredTestClass();
        Optional<String> createQuery = this.graphProjectQuery(requiredTestClass);
        Map<String, Node> idMap = this.neo4jGraphSetup((GraphDatabaseService)db, createQuery);
        this.injectFields(context, db, idMap);
    }

    private Optional<DatabaseManagementService> getDbms(ExtensionContext context) {
        return Optional.ofNullable((DatabaseManagementService)context.getStore(DBMS_NAMESPACE).get((Object)DBMS_KEY, DatabaseManagementService.class));
    }

    private Optional<String> graphProjectQuery(Class<?> testClass) {
        return Stream.iterate(testClass, c -> c.getSuperclass() != null, Class::getSuperclass).flatMap(clazz -> Arrays.stream(clazz.getDeclaredFields())).filter(field -> field.isAnnotationPresent(Neo4jGraph.class)).findFirst().map(ExtensionUtil::getStringValueOfField);
    }

    private Map<String, Node> neo4jGraphSetup(GraphDatabaseService db, Optional<String> createQuery) {
        return createQuery.map(query -> StringFormatting.formatWithLocale((String)"%s %s", (Object[])new Object[]{query, RETURN_STATEMENT})).map(query -> QueryRunner.runQuery(db, query, Neo4jSupportExtension::extractVariableIds)).orElseGet(Map::of);
    }

    private static Map<String, Node> extractVariableIds(Result result) {
        if (!result.hasNext()) {
            throw new IllegalArgumentException("Result of create query was empty");
        }
        List columns = result.columns();
        Map row = result.next();
        HashMap<String, Node> idMap = new HashMap<String, Node>();
        columns.forEach(column -> {
            Object value = row.get(column);
            if (value instanceof NodeEntity) {
                idMap.put((String)column, (Node)((NodeEntity)value));
            }
        });
        return idMap;
    }

    private void injectFields(ExtensionContext context, GraphDatabaseAPI db, Map<String, Node> idMap) {
        NodeFunction nodeFunction = idMap::get;
        IdFunction idFunction = variable -> nodeFunction.of(variable).getId();
        context.getRequiredTestInstances().getAllInstances().forEach(testInstance -> {
            ExtensionUtil.injectInstance((Object)testInstance, (Object)nodeFunction, NodeFunction.class);
            ExtensionUtil.injectInstance((Object)testInstance, (Object)idFunction, IdFunction.class);
            ExtensionUtil.injectInstance((Object)testInstance, (Object)db, GraphDatabaseAPI.class);
        });
    }
}

