/*
 * Copyright 2021 OPPO ESA Stack Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.esastack.cabin.container.service.share;

import io.esastack.cabin.api.domain.Module;
import io.esastack.cabin.api.service.share.SharedClassService;
import io.esastack.cabin.common.exception.CabinRuntimeException;
import io.esastack.cabin.common.log.CabinLoggerFactory;
import io.esastack.cabin.common.util.CabinStringUtil;
import io.esastack.cabin.container.domain.LibModule;
import io.esastack.cabin.container.service.loader.LibModuleClassLoader;
import io.esastack.cabin.loader.jar.Handler;
import org.slf4j.Logger;

import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * Different modules could export same packages, but should not export same classes!
 */
public class SharedClassServiceImpl implements SharedClassService {

    private static final Logger LOGGER = CabinLoggerFactory.getLogger(SharedClassServiceImpl.class);

    private static final Object sentinel = new Object();

    private final AtomicBoolean preLoaded = new AtomicBoolean(false);

    private final ConcurrentMap<String, Class<?>> cachedClasses;

    private final ConcurrentMap<String, LibModule> classToModuleMap;

    //Map<LibModule, Object> is used as concurrent hash set; one package may be exported by multiple module.
    private final Map<String, Map<LibModule, Object>> packageToModuleMap;

    private final Map<String, Map<String, Object>> moduleExportedClasses;

    public SharedClassServiceImpl() {
        this.cachedClasses = new ConcurrentHashMap<>();
        this.classToModuleMap = new ConcurrentHashMap<>();
        this.packageToModuleMap = new ConcurrentHashMap<>();
        this.moduleExportedClasses = new ConcurrentHashMap<>();
    }

    @Override
    public void addSharedPackage(final String packageName, final Module module) {
        if (CabinStringUtil.isBlank(packageName)) {
            return;
        }

        final Map<LibModule, Object> modules =
                packageToModuleMap.computeIfAbsent(packageName, name -> new ConcurrentHashMap<>());
        modules.put((LibModule) module, sentinel);
    }

    /**
     * No Exception should be thrown, because the loading class may extends superclass or implement interface from
     * other lib modules or biz modules; In these situations, we just ignore the failure.
     */
    @Override
    public void preLoadAllSharedClasses() {
        if (preLoaded.compareAndSet(false, true)) {
            classToModuleMap.forEach((className, module) -> {
                final Class<?> clazz = getClassFromModule(className, module);
                if (clazz == null && LOGGER.isDebugEnabled()) {
                    LOGGER.debug(String.format("Could not load class %s which is exported by module %s from it!",
                            className, module.getName()));
                }
            });
        }
    }

    /**
     * This method would be called only after all module had been exported.
     * At this time, all the class/package, module mappings have been established.
     * If load a class while exporting a module, it may implement an interface or inherit a super class which is
     * contained and exported by an other module which has not been exported; In this situation, the interface or super
     * class loading may fail.
     */
    @Override
    public Class<?> getSharedClass(final String className) {
        Class<?> clazz = getCachedClass(className);
        if (clazz != null) {
            return clazz;
        }

        clazz = getAndCacheClassFromModule(className);
        if (clazz != null) {
            return clazz;
        }

        //Get the classes not scanned while re-package the lib modules, such as classes generated by cglib by Spring.
        int index = className.lastIndexOf(".");
        while (index > 0) {
            final String packageName = className.substring(0, index);
            final Map<LibModule, Object> modules = packageToModuleMap.get(packageName);
            if (modules != null && !modules.isEmpty()) {
                Class<?> prevLoadedClass = null;
                LibModule prevLoadedModule = null;
                for (LibModule libModule : modules.keySet()) {
                    clazz = getClassFromModule(className, libModule);
                    if (clazz != null) {
                        if (prevLoadedClass != null) {
                            throw new CabinRuntimeException(
                                    String.format("Class export conflicted, %s is exported by module %s and %s",
                                            className, prevLoadedModule.getName(), libModule.getName()));
                        } else {
                            prevLoadedClass = clazz;
                            prevLoadedModule = libModule;
                        }
                    }
                }
                if (prevLoadedClass != null) {
                    LOGGER.info("Trying to add class {} exported by Module {} to sharedClassService!",
                            className, prevLoadedModule.getName());
                    final Class<?> prevClazz = cachedClasses.putIfAbsent(className, prevLoadedClass);
                    if (prevClazz != null && prevClazz != prevLoadedClass) {
                        throw new CabinRuntimeException(
                                String.format("Class export conflicted, %s is exported by ClassLoader %s and %s",
                                        className, prevClazz.getClassLoader(), prevLoadedClass.getClassLoader()));
                    }
                    storeModuleExportedClass(prevLoadedModule.getName(), className);
                    return prevLoadedClass;
                }
            }
            index = packageName.lastIndexOf(".");
        }
        return null;
    }

    @Override
    public void addSharedClass(final String className, final Class<?> clazz) {

    }

    @Override
    public void addSharedClass(final String className, final Module module) {
        if (className == null || module == null) {
            return;
        }
        final Module prevModule = classToModuleMap.putIfAbsent(className, (LibModule) module);
        if (prevModule != null && prevModule != module) {
            throw new CabinRuntimeException(String.format("Class export conflicted, %s is exported by module" +
                    " %s and %s", className, prevModule.getName(), module.getName()));
        }
    }

    @Override
    public Map<String, Class<?>> getSharedClassMap() {
        preLoadAllSharedClasses();
        return Collections.unmodifiableMap(cachedClasses);
    }

    @Override
    public int getSharedClassCount() {
        return classToModuleMap.size();
    }

    @Override
    public boolean containsClass(final String className) {
        return classToModuleMap.containsKey(className);
    }

    /**
     * Remove from class/package map first, avoiding added to cachedClasses after destroy module.
     * @param moduleName module to destroy
     */
    @Override
    public void destroyModuleClasses(final String moduleName) {
        classToModuleMap.entrySet().removeIf(entry -> entry.getValue().getName().equals(moduleName));
        packageToModuleMap.entrySet().removeIf(entry -> {
            entry.getValue().entrySet().removeIf(en -> en.getKey().getName().equals(moduleName));
            return entry.getValue().isEmpty();
        });
        final Map<String, Object> classesMap = moduleExportedClasses.remove(moduleName);
        if (classesMap == null || classesMap.isEmpty()) {
            return;
        }
        for (String clazzName: classesMap.keySet()) {
            cachedClasses.remove(clazzName);
        }
    }

    private Class<?> getCachedClass(final String className) {
        return cachedClasses.get(className);
    }

    private Class<?> getAndCacheClassFromModule(final String className) {
        final LibModule module = classToModuleMap.get(className);
        if (module != null) {
            final Class<?> result = getClassFromModule(className, module);
            if (result != null) {
                cachedClasses.put(className, result);
                storeModuleExportedClass(module.getName(), className);
                return result;
            }
        }
        return null;
    }

    private void storeModuleExportedClass(final String moduleName, final String clazzName) {
        moduleExportedClasses.computeIfAbsent(moduleName, name -> new ConcurrentHashMap<>());
        moduleExportedClasses.get(moduleName).put(clazzName, sentinel);
    }

    private Class<?> getClassFromModule(final String className, final LibModule module) {
        try {
            Handler.setUseFastConnectionExceptions(true);
            final LibModuleClassLoader libModuleClassLoader = (LibModuleClassLoader) module.getClassLoader();
            if (libModuleClassLoader != null) {
                return libModuleClassLoader.loadClassFromClasspath(className);
            }
        } catch (Throwable e) {
            //NOP
        } finally {
            Handler.setUseFastConnectionExceptions(false);
        }
        return null;
    }
}
