package com.nimbusds.common.oauth2;


import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;

import com.nimbusds.jose.crypto.utils.ConstantTimeUtils;
import com.nimbusds.oauth2.sdk.ParseException;
import com.nimbusds.oauth2.sdk.token.BearerAccessToken;
import com.nimbusds.oauth2.sdk.token.BearerTokenError;
import net.jcip.annotations.ThreadSafe;
import net.minidev.json.JSONObject;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.Logger;


/**
 * Basic access token validator. Supports servlet-based and JAX-RS based web
 * applications.
 */
@ThreadSafe
public class BasicAccessTokenValidator {


	/**
	 * Error response: Missing OAuth 2.0 Bearer access token.
	 */
	public static final ErrorResponse MISSING_BEARER_TOKEN;


	/**
	 * Error response: Invalid OAuth 2.0 Bearer access token.
	 */
	public static final ErrorResponse INVALID_BEARER_TOKEN;


	/**
	 * Error response: Web API disabled.
	 */
	public static final ErrorResponse WEB_API_DISABLED;


	static {
		JSONObject o = new JSONObject();
		o.put("error", "missing_token");
		o.put("error_description", "Unauthorized: Missing Bearer access token");
		MISSING_BEARER_TOKEN = new ErrorResponse(
			BearerTokenError.MISSING_TOKEN.getHTTPStatusCode(),
			BearerTokenError.MISSING_TOKEN.toWWWAuthenticateHeader(),
			o.toJSONString());

		o = new JSONObject();
		o.put("error", BearerTokenError.INVALID_TOKEN.getCode());
		o.put("error_description", "Unauthorized: Invalid Bearer access token");
		INVALID_BEARER_TOKEN = new ErrorResponse(
			BearerTokenError.INVALID_TOKEN.getHTTPStatusCode(),
			BearerTokenError.INVALID_TOKEN.toWWWAuthenticateHeader(),
			o.toJSONString());

		o = new JSONObject();
		o.put("error", "web_api_disabled");
		o.put("error_description", "Forbidden: Web API disabled");
		WEB_API_DISABLED = new ErrorResponse(403, null, o.toJSONString());
	}


	/**
	 * Bearer token error response.
	 */
	public static class ErrorResponse {


		/**
		 * The HTTP status code.
		 */
		private final int statusCode;


		/**
		 * Optional HTTP response {@code WWW-Authenticate} header.
		 */
		private final String wwwAuthHeader;


		/**
		 * The HTTP body.
		 */
		private final String body;


		/**
		 * Creates a new bearer token error response.
		 *
		 * @param statusCode    The HTTP status code.
		 * @param wwwAuthHeader The HTTP response
		 *                      {@code WWW-Authenticate} header,
		 *                      {@code null} if none.
		 * @param body          The HTTP body (application/json).
		 */
		public ErrorResponse(final int statusCode,
				     final String wwwAuthHeader,
				     final String body) {

			this.statusCode = statusCode;
			this.wwwAuthHeader = wwwAuthHeader;
			this.body = body;
		}


		/**
		 * Returns a web application exception for this error response.
		 *
		 * @return The web application exception.
		 */
		public WebApplicationException toWebAppException() {

			Response.ResponseBuilder builder = Response.status(statusCode);

			if (wwwAuthHeader != null) {
				builder.header("WWW-Authenticate", wwwAuthHeader);
			}

			return new WebApplicationException(
				builder.entity(body).type(MediaType.APPLICATION_JSON).build());
		}


		/**
		 * Applies this error response to the specified HTTP servlet
		 * response.
		 *
		 * @param servletResponse The HTTP servlet response. Must not
		 *                        be {@code null}.
		 *
		 * @throws IOException If the error response couldn't be
		 *                     written.
		 */
		public void apply(final HttpServletResponse servletResponse)
			throws IOException {

			servletResponse.setStatus(statusCode);

			if (wwwAuthHeader != null) {
				servletResponse.setHeader("WWW-Authenticate", wwwAuthHeader);
			}

			if (body != null) {
				servletResponse.setContentType("application/json");
				servletResponse.getWriter().print(body);
			}
		}
	}
	
	
	/**
	 * The expected access token hashes, empty list if access to the web
	 * API is disabled.
	 */
	private final List<byte[]> expectedTokenHashes = new ArrayList<>();
	
	
	/**
	 * Salt for computing the SHA-256 hashes.
	 */
	private final byte[] hashSalt;
	
	
	/**
	 * Optional logger.
	 */
	private Logger log;


	/**
	 * Creates a new basic access token validator.
	 *
	 * @param accessToken The Bearer access token. If {@code null} access
	 *                    to the web API will be disabled.
	 */
	public BasicAccessTokenValidator(final BearerAccessToken accessToken) {

		this(new BearerAccessToken[]{accessToken});
	}
	
	/**
	 * Creates a new basic access token validator.
	 *
	 * @param accessTokens The Bearer access tokens. If {@code null} access
	 *                     to the web API will be disabled.
	 */
	public BasicAccessTokenValidator(final BearerAccessToken ... accessTokens) {
		
		hashSalt = generate32ByteSalt();
		
		for (BearerAccessToken t: accessTokens) {
			if (t == null) continue;
			expectedTokenHashes.add(computeSHA256(hashSalt, t));
		}
	}
	
	
	/**
	 * Returns {@code true} if access is disabled (no access token
	 * configured).
	 *
	 * @return {@code true} if access is disabled, else {@code false}.
	 */
	public boolean accessIsDisabled() {
		
		return expectedTokenHashes.isEmpty();
	}
	
	
	/**
	 * Gets the optional logger.
	 *
	 * @return The logger, {@code null} if not specified.
	 */
	public Logger getLogger() {
		return log;
	}
	
	
	/**
	 * Sets the optional logger.
	 *
	 * @param log The logger, {@code null} if not specified.
	 */
	public void setLogger(final Logger log) {
		this.log = log;
	}
	
	
	/**
	 * Generates a 32 byte salt.
	 *
	 * @return The 32 byte salt.
	 */
	private static byte[] generate32ByteSalt() {
		
		byte[] hashSalt = new byte[32];
		new SecureRandom().nextBytes(hashSalt);
		return hashSalt;
	}
	
	
	/**
	 * Computes the SHA-256 hash of the specified Bearer access token.
	 *
	 * @param salt  The salt to use. Must not be {@code null}.
	 * @param token The Bearer access token. Must not be {@code null}.
	 *
	 * @return The computed SHA-256 hash.
	 */
	private static byte[] computeSHA256(final byte[] salt, final BearerAccessToken token) {
		
		try {
			MessageDigest sha256 = MessageDigest.getInstance("SHA-256");
			sha256.update(salt);
			return sha256.digest(token.getValue().getBytes(StandardCharsets.UTF_8));
		} catch (NoSuchAlgorithmException e) {
			throw new RuntimeException(e);
		}
	}


	/**
	 * Validates a bearer access token passed in the specified HTTP
	 * Authorization header value.
	 *
	 * @param authzHeader The HTTP Authorization header value, {@code null}
	 *                    if not specified.
	 *
	 * @throws WebApplicationException If the header value is {@code null},
	 *                                 the web API is disabled, or the
	 *                                 Bearer access token is missing or
	 *                                 invalid.
	 */
	public void validateBearerAccessToken(final String authzHeader)
		throws WebApplicationException {
		
		// Web API disabled?
		if (accessIsDisabled()) {
			throw WEB_API_DISABLED.toWebAppException();
		}

		if (StringUtils.isBlank(authzHeader)) {
			throw MISSING_BEARER_TOKEN.toWebAppException();
		}

		BearerAccessToken receivedToken;

		try {
			receivedToken = BearerAccessToken.parse(authzHeader);

		} catch (ParseException e) {
			throw MISSING_BEARER_TOKEN.toWebAppException();
		}
		
		if (null != log) {
			log.trace("[CM3000] Validating bearer access token: {}", TokenAbbreviator.abbreviate(receivedToken));
		}

		// Check receivedToken
		final byte[] receivedTokenHash = computeSHA256(hashSalt, receivedToken);
		
		for (byte[] expectedHash: expectedTokenHashes) {
			if (ConstantTimeUtils.areEqual(receivedTokenHash, expectedHash)) {
				return;
			}
		}
		
		throw INVALID_BEARER_TOKEN.toWebAppException();
	}


	/**
	 * Validates a bearer access token passed in the specified HTTP servlet
	 * request.
	 *
	 * @param servletRequest  The HTTP servlet request. Must not be
	 *                        {@code null}.
	 * @param servletResponse The HTTP servlet response. Must not be
	 *                        {@code null}.
	 *
	 * @return {@code true} if the bearer access token was successfully
	 *         validated, {@code false}.
	 *
	 * @throws IOException If the response couldn't be written.
	 */
	public boolean validateBearerAccessToken(final HttpServletRequest servletRequest,
						 final HttpServletResponse servletResponse)
		throws IOException {
		
		// Web API disabled?
		if (accessIsDisabled()) {
			WEB_API_DISABLED.apply(servletResponse);
			return false;
		}
		
		BearerAccessToken receivedToken;
		
		if (servletRequest.getHeader("Authorization") != null) {
			
			String authzHeaderValue = servletRequest.getHeader("Authorization");
			
			if (StringUtils.isBlank(authzHeaderValue)) {
				MISSING_BEARER_TOKEN.apply(servletResponse);
				return false;
			}
			
			try {
				receivedToken = BearerAccessToken.parse(authzHeaderValue);
				
			} catch (ParseException e) {
				MISSING_BEARER_TOKEN.apply(servletResponse);
				return false;
			}
			
		} else if (servletRequest.getParameter("access_token") != null) {
			
			String accessTokenValue = servletRequest.getParameter("access_token");
			
			if (StringUtils.isBlank(accessTokenValue)) {
				MISSING_BEARER_TOKEN.apply(servletResponse);
				return false;
			}
			
			receivedToken = new BearerAccessToken(accessTokenValue);
		} else {
			MISSING_BEARER_TOKEN.apply(servletResponse);
			return false;
		}
		
		// Check receivedToken
		if (null != log) {
			log.trace("[CM3000] Validating bearer access token: {}", TokenAbbreviator.abbreviate(receivedToken));
		}
		
		final byte[] receivedTokenHash = computeSHA256(hashSalt, receivedToken);
		
		for (byte[] expectedHash: expectedTokenHashes) {
			if (ConstantTimeUtils.areEqual(receivedTokenHash, expectedHash)) {
				return true; // pass
			}
		}
		
		INVALID_BEARER_TOKEN.apply(servletResponse);
		return false;
	}
}
