/**************************************************************************
 * (C) 2019-2021 SAP SE or an SAP affiliate company. All rights reserved. *
 **************************************************************************/
package com.sap.cds.framework.spring.transaction;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.stereotype.Component;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.TransactionSynchronization;
import org.springframework.transaction.support.TransactionSynchronizationManager;

import com.sap.cds.services.changeset.ChangeSetContext;
import com.sap.cds.services.changeset.ChangeSetContextSPI;
import com.sap.cds.services.impl.changeset.ChangeSetContextImpl;
import com.sap.cds.services.transaction.ChangeSetMemberDelegate;

@Component
@ConditionalOnClass(PlatformTransactionManager.class) // only in case spring-tx is loaded
public class PlatformTransactionManagerPostProcessor implements BeanPostProcessor {

	@Override
	public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
		Object proxiedBean = bean;
		if (bean instanceof PlatformTransactionManager) {
			// proxy all interfaces of bean
			Class<?> clazz = bean.getClass();
			Set<Class<?>> interfaces = new HashSet<>();
			do {
				interfaces.addAll(Arrays.asList(clazz.getInterfaces()));
				clazz = clazz.getSuperclass();
			} while (!clazz.equals(Object.class));

			interfaces.add(SpringTransactionManagerGetter.class);

			PTMInvocationHandler handler = new PTMInvocationHandler((PlatformTransactionManager) bean);
			// TODO this removes public non-interface methods from the proxy
			// -> find a better solution to fully support access to
			// AbstractPlatformTransactionManager
			proxiedBean = Proxy.newProxyInstance(bean.getClass().getClassLoader(), interfaces.toArray(new Class<?>[0]), handler);
		}
		return proxiedBean;
	}

	public static interface SpringTransactionManagerGetter {

		SpringTransactionManager getSpringTransactionManager();

	}

	private static class SpringTransactionSynchronization implements TransactionSynchronization {

		private final ChangeSetContextImpl changeSetContext;
		private TransactionStatus txStatus;

		public SpringTransactionSynchronization(ChangeSetContextImpl changeSetContext) {
			this.changeSetContext = changeSetContext;
		}

		public void setTransactionStatus(TransactionStatus txStatus) {
			this.txStatus = txStatus;
		}

		@Override
		public void beforeCompletion() {
			try {
				changeSetContext.triggerBeforeClose();
			} finally {
				if(changeSetContext.isMarkedForCancel() && txStatus != null) {
					txStatus.setRollbackOnly();
				}
			}
		}

		@Override
		public void afterCompletion(int status) {
			if(status != TransactionSynchronization.STATUS_COMMITTED) {
				changeSetContext.markForCancel();
			}
			changeSetContext.close();
		}

	}

	private static class PTMInvocationHandler implements InvocationHandler {

		private final PlatformTransactionManager platformTxMgr;
		private final SpringTransactionManager txMgr;
		private boolean proxyEnabled = false;

		public PTMInvocationHandler(PlatformTransactionManager platformTxMgr) {
			this.platformTxMgr = platformTxMgr;
			this.txMgr = new SpringTransactionManager(platformTxMgr);
		}

		@Override
		public Object invoke(Object obj, Method method, Object[] args) throws Throwable {
			if (proxyEnabled && method.getName().equals("getTransaction") && args.length == 1 && args[0] instanceof TransactionDefinition) {
				TransactionDefinition txDefinition = (TransactionDefinition) args[0];
				ChangeSetContextSPI currentChangeSet = (ChangeSetContextSPI) ChangeSetContext.getCurrent();

				if (txDefinition.getPropagationBehavior() == TransactionDefinition.PROPAGATION_REQUIRED &&
						currentChangeSet != null && !currentChangeSet.hasChangeSetMember(txMgr.getName())) {

					// the transaction is not opened already
					// as full control of the transaction is not requested through "REQUIRED"
					// we can manage the transaction as part of the existing change set
					currentChangeSet.register(new ChangeSetMemberDelegate(txMgr));
					txMgr.begin();

				} else if((txDefinition.getPropagationBehavior() == TransactionDefinition.PROPAGATION_REQUIRED && currentChangeSet == null)
						|| txDefinition.getPropagationBehavior() == TransactionDefinition.PROPAGATION_REQUIRES_NEW) {

					// a new transaction will be created, therefore a new change set needs to be opened (due to commit boundaries)
					// we need to create a change set, but still have the transaction be managed by spring (shadow change set)
					ChangeSetContextImpl newChangeSetContext = ChangeSetContextImpl.attach();

					// no new transaction should be opened as part of this change set -> register respective ChangeSetMember
					newChangeSetContext.register(new ChangeSetMemberDelegate(txMgr));

					// change set listeners need to work properly -> register them in Spring
					SpringTransactionSynchronization sync = new SpringTransactionSynchronization(newChangeSetContext);

					try {
						TransactionStatus txStatus = (TransactionStatus) method.invoke(platformTxMgr, args);

						// keep track of TransactionStatus to allow for setRollbackOnly()
						sync.setTransactionStatus(txStatus);
						TransactionSynchronizationManager.registerSynchronization(sync);

						return txStatus;
					} catch (Exception e) { // NOSONAR
						// close change set, if something went wrong to avoid inconsistent states
						newChangeSetContext.close();
						throw e;
					}
				}
			} else if (method.getName().equals("getSpringTransactionManager") && args == null) {
				// enable the proxy, as it is used by CAP
				this.proxyEnabled = true;
				// return the spring transaction manager
				return txMgr;
			}
			// invoke the original method
			return method.invoke(platformTxMgr, args);
		}

	}

}
