/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.executor;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.UncheckedIOException;
import java.lang.reflect.InvocationTargetException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.immutables.value.Value;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.AlgorithmFactory;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.executor.ImmutableGdsCallableDefinition;
import org.reflections.Configuration;
import org.reflections.Reflections;
import org.reflections.scanners.Scanner;
import org.reflections.scanners.Scanners;
import org.reflections.util.ConfigurationBuilder;
import org.reflections.util.FilterBuilder;

public final class GdsCallableFinder {
    private static final List<String> DEFAULT_PACKAGE_DENY_LIST = List.of();

    public static Stream<GdsCallableDefinition> findAll() {
        return GdsCallableFinder.findAll(DEFAULT_PACKAGE_DENY_LIST);
    }

    public static Stream<GdsCallableDefinition> findAll(Collection<String> denyList) {
        return GdsCallableFinder.allGdsCallables(denyList).sorted(Comparator.comparing(GdsCallableDefinition::name, String.CASE_INSENSITIVE_ORDER));
    }

    public static Optional<GdsCallableDefinition> findByName(String name) {
        return GdsCallableFinder.findByName(name, DEFAULT_PACKAGE_DENY_LIST);
    }

    public static Optional<GdsCallableDefinition> findByName(String name, Collection<String> denyList) {
        return Optional.ofNullable(ClassesHolder.CALLABLE_CLASSES.get(name.toLowerCase(Locale.ENGLISH))).filter(GdsCallableFinder.block(denyList));
    }

    @NotNull
    private static Stream<GdsCallableDefinition> allGdsCallables(Collection<String> denyList) {
        return ClassesHolder.CALLABLE_CLASSES.values().stream().filter(GdsCallableFinder.block(denyList));
    }

    private static Predicate<GdsCallableDefinition> block(Collection<String> denyList) {
        return def -> denyList.stream().noneMatch(item -> def.algorithmSpecClass().getPackageName().startsWith((String)item));
    }

    private GdsCallableFinder() {
    }

    @ValueClass
    public static interface GdsCallableDefinition {
        public Class<AlgorithmSpec<Algorithm<Object>, Object, AlgoBaseConfig, Object, AlgorithmFactory<?, Algorithm<Object>, AlgoBaseConfig>>> algorithmSpecClass();

        @Value.Lazy
        default public AlgorithmSpec<Algorithm<Object>, Object, AlgoBaseConfig, Object, AlgorithmFactory<?, Algorithm<Object>, AlgoBaseConfig>> algorithmSpec() {
            try {
                return this.algorithmSpecClass().getConstructor(new Class[0]).newInstance(new Object[0]);
            }
            catch (IllegalAccessException | InstantiationException | NoSuchMethodException | InvocationTargetException e) {
                throw new RuntimeException(e);
            }
        }

        public String name();

        public String description();

        public ExecutionMode executionMode();
    }

    private static final class ClassesHolder {
        private static final Map<String, GdsCallableDefinition> CALLABLE_CLASSES = ClassesHolder.loadPossibleClasses();

        private ClassesHolder() {
        }

        @NotNull
        private static Map<String, GdsCallableDefinition> loadPossibleClasses() {
            boolean didCollectCallablesViaScanning;
            ClassLoader classLoader = Objects.requireNonNullElse(Thread.currentThread().getContextClassLoader(), ClassesHolder.class.getClassLoader());
            ArrayList classes = new ArrayList();
            boolean collectViaScanning = StackWalker.getInstance().walk(s -> s.reduce((l, r) -> r)).map(StackWalker.StackFrame::getClassName).filter(c -> c.equals("com.neo4j.gds.estimation.cli.EstimationCli")).isEmpty();
            boolean bl = didCollectCallablesViaScanning = collectViaScanning && classes.addAll(ClassesHolder.loadPossibleClassesViaClasspathScanning(classLoader));
            if (!didCollectCallablesViaScanning) {
                classes.addAll(ClassesHolder.loadPossibleClassesFromJar(classLoader));
                classes.addAll(ClassesHolder.loadPossibleClassesFromResourcesFolder(classLoader));
            }
            assert (ClassesHolder.assertAllAreAlgoSpec(classes));
            return classes.stream().map(clazz -> {
                GdsCallable gdsCallable = clazz.getAnnotation(GdsCallable.class);
                return ImmutableGdsCallableDefinition.builder().name(gdsCallable.name()).description(gdsCallable.description()).executionMode(gdsCallable.executionMode()).algorithmSpecClass((Class<AlgorithmSpec<Algorithm<Object>, Object, AlgoBaseConfig, Object, AlgorithmFactory<?, Algorithm<Object>, AlgoBaseConfig>>>)clazz).build();
            }).collect(Collectors.toMap(def -> def.name().toLowerCase(Locale.ENGLISH), Function.identity(), (l, r) -> l));
        }

        private static List<Class<?>> loadPossibleClassesFromJar(ClassLoader classLoader) {
            return ClassesHolder.loadPossibleClassesFrom(classLoader, "META-INF/services/" + GdsCallable.class.getCanonicalName());
        }

        private static List<Class<?>> loadPossibleClassesFromResourcesFolder(ClassLoader classLoader) {
            return ClassesHolder.loadPossibleClassesFrom(classLoader, "/META-INF/services/" + GdsCallable.class.getCanonicalName());
        }

        /*
         * Enabled aggressive exception aggregation
         */
        private static List<Class<?>> loadPossibleClassesFrom(ClassLoader classLoader, String path) {
            InputStream callablesStream = classLoader.getResourceAsStream(path);
            if (callablesStream == null) {
                return List.of();
            }
            try (InputStream inputStream = callablesStream;){
                List<Class<?>> list;
                try (BufferedReader callables = new BufferedReader(new InputStreamReader((InputStream)new BufferedInputStream(callablesStream), StandardCharsets.UTF_8));){
                    list = callables.lines().map(clazz -> {
                        try {
                            return classLoader.loadClass((String)clazz);
                        }
                        catch (ClassNotFoundException e) {
                            throw new RuntimeException(e);
                        }
                    }).collect(Collectors.toList());
                }
                return list;
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        }

        private static List<Class<?>> loadPossibleClassesViaClasspathScanning(ClassLoader classLoader) {
            return Stream.of("org.neo4j.gds").map(pkg -> new Reflections((Configuration)new ConfigurationBuilder().addClassLoaders(new ClassLoader[]{classLoader}).forPackage(pkg, new ClassLoader[]{classLoader}).addScanners(new Scanner[]{Scanners.TypesAnnotated}).filterInputsBy((Predicate)new FilterBuilder().includePackage(pkg)))).flatMap(reflections -> reflections.getTypesAnnotatedWith(GdsCallable.class).stream()).collect(Collectors.toList());
        }

        private static boolean assertAllAreAlgoSpec(Iterable<Class<?>> classes) {
            for (Class<?> clazz : classes) {
                assert (AlgorithmSpec.class.isAssignableFrom(clazz));
            }
            return true;
        }
    }
}

