package com.nimbusds.common.oauth2;


import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.WebApplicationException;

import com.nimbusds.jose.crypto.utils.ConstantTimeUtils;
import com.nimbusds.oauth2.sdk.ParseException;
import com.nimbusds.oauth2.sdk.token.BearerAccessToken;
import net.jcip.annotations.ThreadSafe;
import org.apache.commons.codec.DecoderException;
import org.apache.commons.codec.binary.Hex;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.Logger;


/**
 * SHA-256 based access token validator. The expected access tokens are
 * configured as their SHA-256 hashes, to prevent accidental leaks into logs,
 * etc. Supports servlet-based and JAX-RS based web applications.
 */
@ThreadSafe
public class SHA256BasedAccessTokenValidator implements MasterAccessTokenValidator {
	
	
	/**
	 * The minimum acceptable access token length.
	 */
	public static final int MIN_TOKEN_LENGTH = 32;
	
	
	/**
	 * The expected access token hashes, empty list if access to the web
	 * API is disabled.
	 */
	private final List<byte[]> expectedTokenHashes = new ArrayList<>();
	
	
	/**
	 * Optional logger.
	 */
	private Logger log;
	
	
	/**
	 * Creates a new basic access token validator.
	 *
	 * @param tokenHash The Bearer access token SHA-256 hash (in hex). If
	 *                  {@code null} access to the web API will be
	 *                  disabled.
	 */
	public SHA256BasedAccessTokenValidator(final String tokenHash) {
		
		this(new String[]{tokenHash});
	}
	
	/**
	 * Creates a new basic access token validator.
	 *
	 * @param tokenHashes The Bearer access token SHA-256 hashes (in hex).
	 *                    If {@code null} access to the web API will be
	 *                    disabled.
	 */
	public SHA256BasedAccessTokenValidator(final String ... tokenHashes) {
		
		for (String hash: tokenHashes) {
			if (hash == null) continue;
			try {
				expectedTokenHashes.add(Hex.decodeHex(hash));
			} catch (DecoderException e) {
				throw new IllegalArgumentException("Invalid hex: " + hash);
			}
		}
	}
	
	
	@Override
	public boolean accessIsDisabled() {
		
		return expectedTokenHashes.isEmpty();
	}
	
	
	@Override
	public Logger getLogger() {
		return log;
	}
	
	
	@Override
	public void setLogger(final Logger log) {
		this.log = log;
	}
	
	
	@Override
	public void validateBearerAccessToken(final String authzHeader)
		throws WebApplicationException {
		
		// Web API disabled?
		if (accessIsDisabled()) {
			throw WEB_API_DISABLED.toWebAppException();
		}
		
		if (StringUtils.isBlank(authzHeader)) {
			throw MISSING_BEARER_TOKEN.toWebAppException();
		}
		
		BearerAccessToken receivedToken;
		
		try {
			receivedToken = BearerAccessToken.parse(authzHeader);
			
		} catch (ParseException e) {
			throw MISSING_BEARER_TOKEN.toWebAppException();
		}
		
		if (null != log) {
			log.trace("[CM3000] Validating bearer access token: {}", TokenAbbreviator.abbreviate(receivedToken));
		}
		
		// Check min length
		if (receivedToken.getValue().length() < MIN_TOKEN_LENGTH) {
			throw INVALID_BEARER_TOKEN.toWebAppException();
		}
		
		// Compare hashes
		final byte[] receivedTokenHash = MasterAccessTokenValidator.computeSHA256(receivedToken, null);
		
		for (byte[] expectedHash: expectedTokenHashes) {
			if (ConstantTimeUtils.areEqual(receivedTokenHash, expectedHash)) {
				return;
			}
		}
		
		throw INVALID_BEARER_TOKEN.toWebAppException();
	}
	
	
	@Override
	public boolean validateBearerAccessToken(final HttpServletRequest servletRequest,
						 final HttpServletResponse servletResponse)
		throws IOException {
		
		// Web API disabled?
		if (accessIsDisabled()) {
			WEB_API_DISABLED.apply(servletResponse);
			return false;
		}
		
		BearerAccessToken receivedToken;
		
		if (servletRequest.getHeader("Authorization") != null) {
			
			String authzHeaderValue = servletRequest.getHeader("Authorization");
			
			if (StringUtils.isBlank(authzHeaderValue)) {
				MISSING_BEARER_TOKEN.apply(servletResponse);
				return false;
			}
			
			try {
				receivedToken = BearerAccessToken.parse(authzHeaderValue);
				
			} catch (ParseException e) {
				MISSING_BEARER_TOKEN.apply(servletResponse);
				return false;
			}
			
		} else if (servletRequest.getParameter("access_token") != null) {
			
			String accessTokenValue = servletRequest.getParameter("access_token");
			
			if (StringUtils.isBlank(accessTokenValue)) {
				MISSING_BEARER_TOKEN.apply(servletResponse);
				return false;
			}
			
			receivedToken = new BearerAccessToken(accessTokenValue);
		} else {
			MISSING_BEARER_TOKEN.apply(servletResponse);
			return false;
		}
		
		if (null != log) {
			log.trace("[CM3000] Validating bearer access token: {}", TokenAbbreviator.abbreviate(receivedToken));
		}
		
		// Check min length
		if (receivedToken.getValue().length() < MIN_TOKEN_LENGTH) {
			INVALID_BEARER_TOKEN.apply(servletResponse);
			return false;
		}
		
		// Compare hashes
		final byte[] receivedTokenHash = MasterAccessTokenValidator.computeSHA256(receivedToken, null);
		
		for (byte[] expectedHash: expectedTokenHashes) {
			if (ConstantTimeUtils.areEqual(receivedTokenHash, expectedHash)) {
				return true; // pass
			}
		}
		
		INVALID_BEARER_TOKEN.apply(servletResponse);
		return false;
	}
}
