package com.sap.cds.services.impl.mt;

import java.util.Objects;

import com.sap.cds.services.mt.TenantProviderService;
import com.sap.cds.services.request.UserInfo;
import com.sap.cds.services.runtime.CdsRuntimeConfiguration;
import com.sap.cds.services.runtime.CdsRuntimeConfigurer;
import com.sap.cds.services.runtime.UserInfoProvider;

public class ProviderTenantNormalizationConfiguration implements CdsRuntimeConfiguration {

	@Override
	public int order() {
		return 100; // after default user providers
	}

	@Override
	public void providers(CdsRuntimeConfigurer configurer) {
		if (configurer.getCdsRuntime().getEnvironment().getCdsProperties().getSecurity().getAuthentication().isNormalizeProviderTenant()) {
			// ensure that the event runs with the default model
			String providerTenant = configurer.getCdsRuntime().requestContext().systemUserProvider().run(requestContext -> {
				return configurer.getCdsRuntime().getServiceCatalog()
					.getService(TenantProviderService.class, TenantProviderService.DEFAULT_NAME)
					.readProviderTenant();
			});
			if (providerTenant != null) {
				configurer.provider(new ProviderTenantNormalizer(providerTenant));
			}
		}
	}

	private static class ProviderTenantNormalizer implements UserInfoProvider {

		private final String providerTenant;
		private UserInfoProvider previous;

		public ProviderTenantNormalizer(String providerTenant) {
			this.providerTenant = providerTenant;
		}

		@Override
		public UserInfo get() {
			UserInfo userInfo = previous != null ? previous.get() : null;
			if (userInfo != null && Objects.equals(userInfo.getTenant(), providerTenant)) {
				userInfo = userInfo.copy().setTenant(null);
			}
			return userInfo;
		}

		@Override
		public void setPrevious(UserInfoProvider previous) {
			this.previous = previous;
		}

	}

}
