package com.sap.cds.transaction.impl;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.function.Supplier;

import org.slf4j.Logger;

import com.sap.cds.transaction.RollbackException;
import com.sap.cds.transaction.SystemException;
import com.sap.cds.transaction.TransactionException;
import com.sap.cds.transaction.TransactionRequiredException;
import com.sap.cds.transaction.impl.SQLProxyBuilder.Call;
import com.sap.cds.transaction.spi.ContainerTransactionManager;

public class LocalTransactionManager implements ContainerTransactionManager {

	private static final ThreadLocal<Tx> txHolder = new ThreadLocal<>();
	private final Supplier<Connection> ds;
	private final Logger logger;
	private final Supplier<Connection> managedDS;

	public LocalTransactionManager(Logger logger, Supplier<Connection> ds) {
		this.logger = logger;
		this.ds = ds;
		this.managedDS = createManagedDataSource(ds);
	}

	private static Supplier<Connection> createManagedDataSource(Supplier<Connection> ds2) {
		SQLProxyBuilder<Supplier<Connection>> builder = SQLProxyBuilder.create(Supplier.class, ds2);

		Call<Connection> getConnection = () -> {
			Connection conn;
			Tx tx = txHolder.get();
			if (tx == null) {
				conn = ds2.get();
				conn.setAutoCommit(true);
			} else {
				conn = tx.getConnection();
				conn = wrapConnection(conn);
			}

			return conn;
		};

		return builder.handle("get", getConnection).build();
	}

	private static Connection wrapConnection(final Connection conn) {
		return SQLProxyBuilder.create(Connection.class, conn).handle("close", SQLProxyBuilder.NOP).build();
	}

	@Override
	public void begin() {
		if (txHolder.get() != null) {
			throw new TransactionException("there is an active transaction");
		}

		Connection conn = ds.get();
		try {
			try {
				conn.setClientInfo("LOCALE", null);
			} catch (SQLException ex) {
				logger.info("setClientInfo not supported", ex);
				// not supported by all DBs
			}
			conn.setAutoCommit(false);
		} catch (SQLException e) {
			throw new SystemException("exception during setAutoCommit", e);
		}
		txHolder.set(new Tx(conn));
	}

	@Override
	public void commit() {
		Tx tx = getActiveTransaction();
		try {
			tx.commit();
		} finally {
			close(tx);
		}
	}

	@Override
	public void rollback() {
		Tx tx = getActiveTransaction();
		try {
			tx.rollback();
		} finally {
			close(tx);
		}
	}

	private Tx getActiveTransaction() {
		Tx tx = txHolder.get();
		if (tx == null) {
			throw new TransactionRequiredException("no transaction is active");
		}
		return tx;
	}

	private void close(Tx tx) {
		txHolder.remove();
		try {
			tx.getConnection().close();
		} catch (SQLException e) {
			logger.error("Exception while closing connection", e);
		}
	}

	@Override
	public boolean isActive() {
		Tx tx = txHolder.get();

		return (tx != null);
	}

	private static class Tx {

		private final Connection conn;
		private boolean rollbackOnly = false;

		public Tx(Connection conn) {
			this.conn = conn;
		}

		public void commit() {
			try {
				if (rollbackOnly) {
					rollback();
					throw new RollbackException(
							"the transaction was marked for rollback only and has been rolled back");
				}
				conn.commit();
			} catch (SQLException e) {
				throw new RollbackException(e);
			}
		}

		public void rollback() {
			try {
				conn.rollback();
			} catch (SQLException e) {
				throw new SystemException("exception during rollback", e);
			}
		}

		public Connection getConnection() {
			return conn;
		}

		public void setRollbackOnly() {
			rollbackOnly = true;
		}

		public boolean isRollbackOnly() {
			return rollbackOnly;
		}

	}

	public Supplier<Connection> getConnectionSupplier() {
		return managedDS;
	}

	@Override
	public void setRollbackOnly() {
		Tx tx = getActiveTransaction();
		tx.setRollbackOnly();
	}

	@Override
	public boolean isRollbackOnly() {
		Tx tx = getActiveTransaction();
		return tx.isRollbackOnly();
	}

}
