/*
 * Copyright 2019 https://www.ifengxue.com
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.ifengxue.http.integration.spring;

import static java.util.stream.Collectors.toSet;

import com.ifengxue.http.annotation.Rest;
import com.ifengxue.http.proxy.HttpClientConfig;
import com.ifengxue.http.proxy.Interceptor;
import com.ifengxue.http.proxy.ProxyBuilder;
import java.beans.Introspector;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ClassUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanClassLoaderAware;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.beans.factory.config.EmbeddedValueResolver;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.GenericBeanDefinition;
import org.springframework.context.EnvironmentAware;
import org.springframework.context.ResourceLoaderAware;
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
import org.springframework.context.annotation.ImportBeanDefinitionRegistrar;
import org.springframework.core.env.Environment;
import org.springframework.core.io.ResourceLoader;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import org.springframework.util.StringValueResolver;

/**
 * 注册被{@link com.ifengxue.http.annotation.Rest}标记的接口为 Spring bean
 */
@Slf4j
public class RestBeanRegister implements ImportBeanDefinitionRegistrar, BeanClassLoaderAware,
    BeanFactoryAware, EnvironmentAware, ResourceLoaderAware {

  public static final String REST_ENABLED_NAME = "rest.enabled";
  private ClassLoader classLoader;
  private BeanFactory beanFactory;
  private Environment environment;
  private ResourceLoader resourceLoader;
  private StringValueResolver valueResolver;

  @Override
  @SuppressWarnings("unchecked")
  public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
    Map<String, Object> attributes = Optional.ofNullable(importingClassMetadata
        .getAnnotationAttributes(EnableRest.class.getName()))
        .orElseThrow(() -> new IllegalStateException("未发现@EnableRest标记的类"));
    if (!environment.getProperty(REST_ENABLED_NAME, boolean.class, true)) {
      log.debug("未启用@Rest自动注册");
      return;
    }
    Set<String> basePackages = Arrays.stream((String[]) attributes.get("basePackages")).collect(toSet());
    for (Class<?> basePackageClass : (Class<?>[]) attributes.get("basePackageClasses")) {
      basePackages.add(basePackageClass.getPackage().getName());
    }
    if (basePackages.isEmpty()) {
      basePackages.add(lookupBasePackage());
    }

    Set<String> excludePackages = Arrays.stream((String[]) attributes.get("excludePackages")).collect(toSet());

    List<Interceptor> interceptors = new ArrayList<>();
    for (Class<? extends Interceptor> interceptorClass : (Class<? extends Interceptor>[]) attributes
        .get("interceptorClasses")) {
      interceptors.add(BeanUtils.instantiateClass(interceptorClass));
    }

    Set<String> excludeClassNames = Arrays.stream((Class<?>[]) attributes.get("excludeClasses")).map(Class::getName)
        .collect(toSet());
    ClassPathScanningCandidateComponentProvider classScanner = createClassScanner();
    classScanner.setResourceLoader(resourceLoader);
    classScanner.addIncludeFilter(new AnnotationTypeFilter(SpringRestIntegration.class));
    classScanner.addExcludeFilter((metadataReader, metadataReaderFactory) -> {
      for (String excludePackage : excludePackages) {
        if (metadataReader.getClassMetadata().getClassName().startsWith(excludePackage)) {
          return true;
        }
      }
      return false;
    });
    classScanner.addExcludeFilter((metadataReader, metadataReaderFactory) ->
        excludeClassNames.contains(metadataReader.getClassMetadata().getClassName()));

    Set<BeanDefinition> beanDefinitions = new HashSet<>();
    for (String basePackage : basePackages) {
      beanDefinitions.addAll(classScanner.findCandidateComponents(basePackage));
    }

    // parse global client config
    ClientConfigHolder globalClientConfigHolder = ClientConfigHolder.from(attributes, valueResolver);
    for (BeanDefinition beanDefinition : beanDefinitions) {
      Class<?> clazz;
      try {
        clazz = ClassUtils
            .getClass(classLoader, ((AnnotatedBeanDefinition) beanDefinition).getMetadata().getClassName());
      } catch (ClassNotFoundException e) {
        log.debug("load class error.", e);
        continue;
      }
      GenericBeanDefinition genericBeanDefinition;
      if (beanDefinition instanceof GenericBeanDefinition) {
        genericBeanDefinition = (GenericBeanDefinition) beanDefinition;
      } else {
        genericBeanDefinition = new GenericBeanDefinition(beanDefinition);
      }
      genericBeanDefinition.setBeanClass(clazz);

      String[] interceptorsRef = (String[]) attributes.get("interceptorsRef");
      if (interceptorsRef.length > 0) {
        genericBeanDefinition.setDependsOn(interceptorsRef);
      }
      genericBeanDefinition.setInstanceSupplier(
          () -> createProxyBean(globalClientConfigHolder, clazz, interceptors, interceptorsRef));

      SpringRestIntegration annotation = clazz.getAnnotation(SpringRestIntegration.class);
      String beanName =
          StringUtils.isNotBlank(annotation.name()) ? annotation.name()
              : Introspector.decapitalize(clazz.getSimpleName());
      registry.registerBeanDefinition(beanName, genericBeanDefinition);
    }
  }

  private String lookupBasePackage() {
    Throwable e = new RuntimeException();
    while (e != null) {
      for (StackTraceElement stackTraceElement : e.getStackTrace()) {
        if (stackTraceElement.getMethodName().equals("main")) {
          return ClassUtils.getPackageName(stackTraceElement.getClassName());
        }
      }
      e = e.getCause();
    }

    throw new IllegalStateException("can't find matched base package");
  }

  private Object createProxyBean(ClientConfigHolder globalClientConfigHolder, Class<?> clazz,
      List<Interceptor> interceptors, String[] interceptorsRef) {
    SpringRestIntegration springRestIntegration = clazz.getAnnotation(SpringRestIntegration.class);
    // parse interceptors
    List<Interceptor> newInterceptors = new ArrayList<>();
    if (springRestIntegration.interceptorsRef().length > 0 || springRestIntegration.interceptorClasses().length > 0) {
      for (String interceptor : springRestIntegration.interceptorsRef()) {
        newInterceptors.add(beanFactory.getBean(interceptor, Interceptor.class));
      }
      for (Class<? extends Interceptor> interceptorClass : springRestIntegration.interceptorClasses()) {
        Interceptor interceptor = null;
        try {
          interceptor = beanFactory.getBean(interceptorClass);
          log.debug("find matched bean [{}]", interceptorClass.getName());
        } catch (NoSuchBeanDefinitionException e) {
          if (log.isTraceEnabled()) {
            log.debug("can't find matched bean " + interceptorClass.getName(), e);
          }
        }
        if (interceptor == null) {
          interceptor = BeanUtils.instantiateClass(interceptorClass);
          log.debug("can't find matched bean [{}], instantiate class success.", interceptorClass.getName());
        }
        newInterceptors.add(interceptor);
      }
    }

    for (String interceptorRef : interceptorsRef) {
      newInterceptors.add(beanFactory.getBean(interceptorRef, Interceptor.class));
    }
    newInterceptors.addAll(interceptors);

    String host = Optional.ofNullable(clazz.getAnnotation(Rest.class))
        .map(Rest::value)
        .orElse("");
    Object proxyBean = ProxyBuilder.newBuilder()
        .classLoader(classLoader)
        .proxyInterface(clazz)
        .host(valueResolver.resolveStringValue(host))
        .addInterceptors(newInterceptors)
        .build();

    ClientConfigHolder clientConfigHolder = globalClientConfigHolder.merge(springRestIntegration, valueResolver);
    HttpClientConfig httpClientConfig = (HttpClientConfig) proxyBean;
    if (StringUtils.isNotEmpty(clientConfigHolder.getCharset())) {
      httpClientConfig.setCharset(Charset.forName(clientConfigHolder.getCharset()));
    }
    if (StringUtils.isNotEmpty(clientConfigHolder.getUserAgent())) {
      httpClientConfig.setUserAgent(clientConfigHolder.getUserAgent());
    }
    if (clientConfigHolder.getSocketTimeout() != null && clientConfigHolder.getConnectTimeout() != null) {
      httpClientConfig.setTimeout(clientConfigHolder.getSocketTimeout(), clientConfigHolder.getConnectTimeout());
    }
    return proxyBean;
  }

  private ClassPathScanningCandidateComponentProvider createClassScanner() {
    return new ClassPathScanningCandidateComponentProvider(false, environment) {
      @Override
      protected boolean isCandidateComponent(AnnotatedBeanDefinition beanDefinition) {
        return beanDefinition.getMetadata().isInterface();
      }
    };
  }

  @Override
  public void setBeanClassLoader(ClassLoader classLoader) {
    this.classLoader = classLoader;
  }

  @Override
  public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
    this.beanFactory = beanFactory;
    this.valueResolver = new EmbeddedValueResolver((ConfigurableBeanFactory) beanFactory);
  }

  @Override
  public void setEnvironment(Environment environment) {
    this.environment = environment;
  }

  @Override
  public void setResourceLoader(ResourceLoader resourceLoader) {
    this.resourceLoader = resourceLoader;
  }

  @Data
  private static class ClientConfigHolder {

    private String userAgent;
    private String charset;
    private Integer socketTimeout;
    private Integer connectTimeout;

    private static ClientConfigHolder from(Map<String, Object> attributes, StringValueResolver valueResolver) {
      ClientConfigHolder holder = new ClientConfigHolder();
      holder.setUserAgent(SpringIntegrationUtil
          .parseString(valueResolver, (String) attributes.get("userAgent"), (String) attributes.get("userAgentString"),
              "userAgent and userAgentString cannot be set at the same time"));
      holder.setCharset(SpringIntegrationUtil
          .parseString(valueResolver, (String) attributes.get("charset"), (String) attributes.get("charsetString"),
              "charset and charsetString cannot be set at the same time"));
      Integer socketTimeout = SpringIntegrationUtil.parseInt(valueResolver, (Integer) attributes.get("socketTimeout"),
          (String) attributes.get("socketTimeoutString"),
          "socketTimeout and socketTimeoutString cannot be set at the same time");
      if (socketTimeout != null) {
        holder.setSocketTimeout(socketTimeout);
      }
      Integer connectTimeout = SpringIntegrationUtil.parseInt(valueResolver, (Integer) attributes.get("connectTimeout"),
          (String) attributes.get("connectTimeoutString"),
          "connectTimeout and connectTimeoutString cannot be set at the same time");
      if (connectTimeout != null) {
        holder.setConnectTimeout(connectTimeout);
      }
      return holder;
    }

    private ClientConfigHolder merge(SpringRestIntegration annotation, StringValueResolver valueResolver) {
      ClientConfigHolder clone = new ClientConfigHolder();
      clone.setUserAgent(this.userAgent);
      clone.setCharset(this.charset);
      clone.setSocketTimeout(this.socketTimeout);
      clone.setConnectTimeout(this.connectTimeout);

      String newUserAgent = SpringIntegrationUtil
          .parseString(valueResolver, annotation.userAgent(), annotation.userAgentString(),
              "userAgent and userAgentString cannot be set at the same time");
      if (StringUtils.isNotBlank(newUserAgent)) {
        clone.setUserAgent(newUserAgent);
      }

      String newCharset = SpringIntegrationUtil
          .parseString(valueResolver, annotation.charset(), annotation.charsetString(),
              "charset and charsetString cannot be set at the same time");
      if (StringUtils.isNotBlank(newCharset)) {
        clone.setCharset(newCharset);
      }

      Integer newSocketTimeout = SpringIntegrationUtil
          .parseInt(valueResolver, annotation.socketTimeout(), annotation.socketTimeoutString(),
              "socketTimeout and socketTimeoutString cannot be set at the same time");
      if (newSocketTimeout != null) {
        clone.setSocketTimeout(newSocketTimeout);
      }

      Integer newConnectTimeout = SpringIntegrationUtil
          .parseInt(valueResolver, annotation.connectTimeout(), annotation.connectTimeoutString(),
              "connectTimeout and connectTimeoutString cannot be set at the same time");
      if (newConnectTimeout != null) {
        clone.setConnectTimeout(newConnectTimeout);
      }
      return clone;
    }
  }
}
