/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.netty.impl.client.auth.ntlm.av;

import static org.mule.service.http.netty.impl.client.auth.ntlm.smb.SMBUtil.readInt2;
import static org.mule.service.http.netty.impl.client.auth.ntlm.smb.SMBUtil.writeInt2;
import static org.mule.service.http.netty.impl.client.auth.ntlm.av.AvPair.MsvAvChannelBindings;
import static org.mule.service.http.netty.impl.client.auth.ntlm.av.AvPair.MsvAvEOL;
import static org.mule.service.http.netty.impl.client.auth.ntlm.av.AvPair.MsvAvFlags;
import static org.mule.service.http.netty.impl.client.auth.ntlm.av.AvPair.MsvAvSingleHost;
import static org.mule.service.http.netty.impl.client.auth.ntlm.av.AvPair.MsvAvTargetName;
import static org.mule.service.http.netty.impl.client.auth.ntlm.av.AvPair.MsvAvTimestamp;
import static java.lang.System.arraycopy;

import java.util.LinkedList;
import java.util.List;

/**
 * Utility class for encoding, decoding, and manipulating lists of {@link AvPair} objects in NTLM authentication. Provides methods
 * to decode byte arrays into lists of {@code AvPair}, encode lists back into byte arrays, and search or modify individual pairs
 * within these lists.
 *
 * <p>
 * This implementation is based on the jcifs library, available at:
 * <a href="https://github.com/codelibs/jcifs">https://github.com/codelibs/jcifs</a>
 * </p>
 *
 * @see <a href="https://github.com/codelibs/jcifs">jcifs on GitHub</a>
 */
public class AvPairs {

  /**
   * Decodes a byte array of AV_PAIR data into a list of {@code AvPair} objects.
   *
   * @param data the byte array containing the encoded AV_PAIR structures.
   * @return a list of decoded {@code AvPair} objects.
   * @throws RuntimeException if the AV_EOL pair is missing or has an invalid length.
   */
  public static List<AvPair> decode(byte[] data) {
    List<AvPair> pairs = new LinkedList<>();
    int pos = 0;
    boolean foundEnd = false;
    while (pos + 4 <= data.length) {
      int avId = readInt2(data, pos);
      int avLen = readInt2(data, pos + 2);
      pos += 4;

      if (avId == MsvAvEOL) {
        if (avLen != 0) {
          throw new RuntimeException("Invalid avLen for AvEOL");
        }
        foundEnd = true;
        break;
      }

      byte[] raw = new byte[avLen];
      arraycopy(data, pos, raw, 0, avLen);
      pairs.add(parseAvPair(avId, raw));

      pos += avLen;
    }
    if (!foundEnd) {
      throw new RuntimeException("Missing AvEOL");
    }
    return pairs;
  }

  /**
   * Parses an AV_PAIR from its ID and raw byte data.
   *
   * @param avId the type ID of the AV_PAIR.
   * @param raw  the raw byte data for the AV_PAIR.
   * @return an {@code AvPair} subclass instance based on the type ID.
   */
  private static AvPair parseAvPair(int avId, byte[] raw) {
    return switch (avId) {
      case MsvAvFlags -> new AvFlags(raw);
      case MsvAvTimestamp -> new AvTimestamp(raw);
      case MsvAvTargetName -> new AvTargetName(raw);
      case MsvAvSingleHost -> new AvSingleHost(raw);
      case MsvAvChannelBindings -> new AvChannelBindings(raw);
      default -> new AvPair(avId, raw);
    };
  }

  /**
   * Checks if a list of {@code AvPair} objects contains a pair with the specified type.
   *
   * @param pairs the list of {@code AvPair} objects.
   * @param type  the type to search for.
   * @return {@code true} if a pair with the specified type exists in the list; {@code false} otherwise.
   */
  public static boolean contains(List<AvPair> pairs, int type) {
    if (pairs == null) {
      return false;
    }
    for (AvPair p : pairs) {
      if (p.getType() == type) {
        return true;
      }
    }
    return false;
  }

  /**
   * Retrieves an {@code AvPair} with the specified type from a list.
   *
   * @param pairs the list of {@code AvPair} objects.
   * @param type  the type of the {@code AvPair} to retrieve.
   * @return the {@code AvPair} with the specified type, or {@code null} if not found.
   */
  public static AvPair get(List<AvPair> pairs, int type) {
    for (AvPair p : pairs) {
      if (p.getType() == type) {
        return p;
      }
    }
    return null;
  }

  /**
   * Replaces an {@code AvPair} with the same type in the list, or adds it if it does not exist.
   *
   * @param pairs the list of {@code AvPair} objects.
   * @param rep   the {@code AvPair} to replace or add.
   */
  public static void replace(List<AvPair> pairs, AvPair rep) {
    remove(pairs, rep.getType());
    pairs.add(rep);
  }

  /**
   * Removes an {@code AvPair} with the specified type from the list.
   *
   * @param pairs the list of {@code AvPair} objects.
   * @param type  the type of the {@code AvPair} to remove.
   */
  public static void remove(List<AvPair> pairs, int type) {
    pairs.removeIf(p -> p.getType() == type);
  }

  /**
   * Encodes a list of {@code AvPair} objects into a byte array.
   *
   * @param pairs the list of {@code AvPair} objects to encode.
   * @return the encoded byte array representing the list of AV_PAIR structures.
   */
  public static byte[] encode(List<AvPair> pairs) {
    int size = 0;
    for (AvPair p : pairs) {
      size += 4 + p.getRaw().length;
    }
    size += 4;

    byte[] enc = new byte[size];
    int pos = 0;
    for (AvPair p : pairs) {
      byte[] raw = p.getRaw();
      writeInt2(p.getType(), enc, pos);
      writeInt2(raw.length, enc, pos + 2);
      arraycopy(raw, 0, enc, pos + 4, raw.length);
      pos += 4 + raw.length;
    }

    // Add MsvAvEOL at the end
    writeInt2(MsvAvEOL, enc, pos);
    writeInt2(0, enc, pos + 2);
    return enc;
  }
}
