/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client.auth.ntlm;

import static org.slf4j.LoggerFactory.getLogger;

import org.mule.service.http.netty.impl.client.auth.NettyAuthHeaderFactory;

import java.io.IOException;
import java.util.Base64;

import jcifs.ntlmssp.Type1Message;
import jcifs.ntlmssp.Type2Message;
import jcifs.ntlmssp.Type3Message;
import org.slf4j.Logger;

/**
 * Implementation of the client-side NTLM message generation.
 * <p>
 * The NTLM authentication is done once per connection, and the mechanism can be summarized as: <il>
 * <li>1. The client sends a request without any kind of authentication</li>
 * <li>2. The server responds with a 401 error code, including a WWW-AUTHENTICATE header with the value "NTLM".</li>
 * <li>3. The client sends another request over the same connection, now containing a Type1 message in the AUTHORIZATION
 * header.</li>
 * <li>4. The server responds with a 401 error code, including a Type2 message in the WWW-AUTHENTICATE header.</li>
 * <li>5. The client sends a last request with the Type3 message in the AUTHORIZATION header.</li> </il> This class has methods to
 * generate the client-side messages (Type1 and Type3) and to create the corresponding AUTHORIZATION header values given a
 * WWW-AUTHENTICATE header value.
 */
public class NtlmMessageFactory implements NettyAuthHeaderFactory {

  // Status of the authentication, to avoid falling into bad states (as infinite recursions) if the server is not well
  // implemented.
  enum Status {

    // Initial state, no headers exchanged.
    NOT_STARTED,

    // In some auth methods there are multiple messages exchanged, for example 401 received,
    // we sent a message, and we're waiting for the challenge from the server
    WAITING_FOR_CHALLENGE,

    // Last message was responded and even if this one failed there is nothing to be done
    FINISHED,
  }

  private static final Logger LOGGER = getLogger(NtlmMessageFactory.class);

  // Same flags as we used in the legacy AHC implementation.
  private static final int TYPE_1_MESSAGE_FLAGS = 0xA2088201;
  // The "WWW-Authenticate" and "Authorization" headers use this prefix and encode the NTLM messages using Base64.
  private static final String NTLM_MESSAGES_PREFIX = "NTLM ";
  // Value sent by server in the WWW-AUTHENTICATE to indicate that NTLM Type1 message is required.
  private static final String STARTING_NTLM_WWW_AUTHENTICATE_HEADER = "NTLM";

  private final String domain;
  private final String workstation;
  private final String username;
  private final String password;

  private Status status;

  public NtlmMessageFactory(String domain, String workstation, String username, String password) {
    this.domain = null == domain ? "" : domain;
    this.workstation = workstation;
    this.username = username;
    this.password = password;
    this.status = Status.NOT_STARTED;
  }

  /**
   * Generates a raw NTLM Type 1 message. The flags used for this are the same as the ones present in our old implementation using
   * Grizzly AHC. The only thing that changes is the NTLM version.
   *
   * @return A raw representation of the Type 1 message (without Base64 encoding nor the "NTLM " prefix needed in the header).
   */
  public byte[] createType1Message() {
    LOGGER.debug("NTLM MessageFactory creating Type1Message...");
    return RawType1MessageHolder.RAW_TYPE_1_MESSAGE;
  }

  /**
   * Generates a raw NTLM Type 3 message. The flags used for this are the same as the ones present in the Type 2 message passed as
   * parameter.
   *
   * @param type2Material a raw representation of the Type 2 message received from the server.
   * @return A raw representation of the Type 3 message (without Base64 encoding nor the "NTLM " prefix needed in the header).
   * @throws IOException If an error occurs while parsing the passed type 2 message material, or creating the type 3 message.
   */
  public byte[] createType3Message(byte[] type2Material) throws IOException {
    LOGGER.debug("NTLM MessageFactory creating Type3Message...");
    Type2Message type2Message = new Type2Message(type2Material);
    if (null == type2Message.getChallenge()) {
      type2Message.setChallenge(new byte[0]);
    }
    Type3Message type3Message = new Type3Message(type2Message, password, domain, username,
                                                 workstation, type2Message.getFlags());
    return type3Message.toByteArray();
  }

  protected String secondChallenge(String wwwAuthenticateHeader) throws IOException {
    if (null == wwwAuthenticateHeader) {
      return null;
    }

    byte[] type2Material = Base64.getDecoder().decode(wwwAuthenticateHeader.substring(5));
    return createHeaderValue(createType3Message(type2Material));
  }

  private boolean mustSendType1(String wwwAuthenticateHeader) {
    // Server sent a "WWW-Authenticate: NTLM" header.
    return STARTING_NTLM_WWW_AUTHENTICATE_HEADER.equals(wwwAuthenticateHeader.trim());
  }

  private boolean mustSendType3(String wwwAuthenticateHeader) {
    // Server sent a "WWW-Authenticate: NTLM <something, we assume a type 2 challenge>" header.
    // We don't check if the "something" is a type 2 message here because in that case the createType3Message will throw
    // an exception.
    return wwwAuthenticateHeader.startsWith(NTLM_MESSAGES_PREFIX);
  }

  private String createHeaderValue(byte[] rawNtlmMessageContent) {
    // Transform a raw NTLM message to a valid "Authorization" header value.
    return NTLM_MESSAGES_PREFIX + Base64.getEncoder().encodeToString(rawNtlmMessageContent);
  }

  @Override
  public boolean hasFinished() {
    return Status.FINISHED == status;
  }

  @Override
  public String getNextHeader(String wwwAuthenticateHeader) throws Exception {
    if (null == wwwAuthenticateHeader) {
      return null;
    }

    String authHeader = null;
    if (Status.NOT_STARTED == status) {
      if (mustSendType1(wwwAuthenticateHeader)) {
        // The server told us that we have to send a first challenge message to start authentication.
        authHeader = createHeaderValue(createType1Message());
      }
      status = Status.WAITING_FOR_CHALLENGE;
    } else if (Status.WAITING_FOR_CHALLENGE == status) {
      if (mustSendType3(wwwAuthenticateHeader)) {
        // The server requests a second challenge message
        authHeader = secondChallenge(wwwAuthenticateHeader);
      }
      status = Status.FINISHED;
    }

    return authHeader;
  }

  // This holder is the alternative to the classic double-checked locking, but it's compliant with the linter rules.
  private static final class RawType1MessageHolder {

    private static final byte[] RAW_TYPE_1_MESSAGE = doCalculateType1Message();

    private static byte[] doCalculateType1Message() {
      Type1Message type1Message = new Type1Message(TYPE_1_MESSAGE_FLAGS, null, null);
      return type1Message.toByteArray();
    }
  }
}

