/**************************************************************************
 * (C) 2019-2021 SAP SE or an SAP affiliate company. All rights reserved. *
 **************************************************************************/
package com.sap.cds.framework.spring.config.datasource;

import java.util.Arrays;
import java.util.List;

import javax.sql.DataSource;

import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.boot.context.properties.bind.BindResult;
import org.springframework.boot.context.properties.bind.Bindable;
import org.springframework.boot.context.properties.bind.Binder;
import org.springframework.boot.jdbc.DataSourceBuilder;
import org.springframework.context.EnvironmentAware;
import org.springframework.core.env.Environment;

import com.sap.cds.framework.spring.config.runtime.BootstrapCache;
import com.sap.cds.services.datasource.DataSourceDescriptor;
import com.sap.cds.services.utils.datasource.DataSourceUtils;

/**
 * Auto-configures a DataSource in Spring, based on a given {@link DataSourceDescriptor}
 */
public class DataSourceDescriptorBeanFactory implements FactoryBean<DataSource>, EnvironmentAware, BeanFactoryAware {

	private final String descriptorName;
	private Environment environment;
	private BeanFactory beanFactory;
	private DataSource dataSource;

	public DataSourceDescriptorBeanFactory(String descriptorName) {
		this.descriptorName = descriptorName;
	}

	@Override
	public DataSource getObject() throws Exception {
		if(dataSource == null) {
			BootstrapCache bootstrapCache = beanFactory.getBean(BootstrapCache.class);
			DataSourceDescriptor descriptor = bootstrapCache.getDataSourceDescriptors()
					.stream().filter(d -> d.getName().equals(descriptorName))
					.findFirst().orElseThrow(() -> new IllegalStateException());
			DataSourceBuilder<?> builder = DataSourceBuilder.create();
			builder.driverClassName(descriptor.getDriverClassName());
			builder.url(descriptor.getUrl());
			builder.username(descriptor.getUsername());
			builder.password(descriptor.getPassword());
			DataSource ds = builder.build();

			Binder binder = Binder.get(environment);
			BindResult<DataSource> bindResult;
			Bindable<DataSource> bindableDs = Bindable.ofInstance(ds);
			for (String section: getDataSourceSections(descriptor.getName())) {
				bindResult = binder.bind(section, bindableDs);
				if (bindResult.isBound()) {
					ds = bindResult.get();
					break;
				}
			}
			dataSource = ds;
		}
		return dataSource;
	}

	private static List<String> getDataSourceSections(String name) {
		return Arrays.asList(DataSourceUtils.getDataSourceSection(name, DataSourceUtils.PoolType.HIKARI)
				, DataSourceUtils.getDataSourceSection(name, DataSourceUtils.PoolType.TOMCAT)
				, DataSourceUtils.getDataSourceSection(name, DataSourceUtils.PoolType.DBCP2));
	}

	@Override
	public Class<?> getObjectType() {
		return DataSource.class;
	}

	@Override
	public void setBeanFactory(BeanFactory beanFactory) {
		this.beanFactory = beanFactory;
	}

	@Override
	public void setEnvironment(Environment environment) {
		this.environment = environment;
	}

}
