/*
 * Copyright (c) MuleSoft, Inc.  All rights reserved.  http://www.mulesoft.com
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.encryption;

import static junit.framework.TestCase.assertEquals;
import static org.apache.commons.lang.RandomStringUtils.randomAlphabetic;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assume.assumeFalse;
import static org.junit.rules.ExpectedException.none;
import static org.mule.encryption.jce.JCE.isJCEInstalled;

import org.mule.encryption.exception.MuleEncryptionException;
import org.mule.encryption.jce.JCEEncrypter;
import org.mule.encryption.key.EncryptionKeyFactory;
import org.mule.encryption.key.SymmetricKeyFactory;

import java.util.Base64;

import javax.crypto.Cipher;
import javax.crypto.spec.SecretKeySpec;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

public class JCEEncrypterTestCase {

  private static final String PAYLOAD = "Payload to Encrypt";
  private static final String KEY = "posofj00posofj00";
  private static final String ANOTHER_KEY = "differentKey1234";
  private static final String BASE64_ENCRYPTED_PAYLOAD = "nHWo5JhNAYM+TzxqeHdRDXx15Q5R56YVGiQgXCoBJew=";

  @Rule
  public ExpectedException expectedException = none();

  private JCEEncrypter encrypter;
  private JCEEncrypter encrypterRandomIv;

  @Before
  public void setup() {
    EncryptionKeyFactory keyFactory = keyFactory(KEY);

    encrypter = new JCEEncrypter("AES/CBC/PKCS5Padding", keyFactory);
    encrypterRandomIv = new JCEEncrypter("AES/CBC/PKCS5Padding", keyFactory, true);
  }

  @Test
  public void encryptByteArray() throws MuleEncryptionException {
    byte[] content = PAYLOAD.getBytes();

    byte[] encrypted = encrypter.encrypt(content);

    String base64Encrypted = new String(Base64.getEncoder().encode(encrypted));
    assertThat(base64Encrypted, is(BASE64_ENCRYPTED_PAYLOAD));
  }

  @Test
  public void decryptByteArray() throws Exception {
    byte[] content = Base64.getDecoder().decode(BASE64_ENCRYPTED_PAYLOAD.getBytes());

    byte[] decrypted = encrypter.decrypt(content);

    assertThat(new String(decrypted), is(PAYLOAD));
  }


  @Test
  public void encryptAndDecryptUsingRandomIvByteArray() throws MuleEncryptionException {
    byte[] content = PAYLOAD.getBytes();

    byte[] encrypted = encrypterRandomIv.encrypt(content);

    byte[] decrypted = encrypterRandomIv.decrypt(encrypted);

    assertEquals(new String(Base64.getEncoder().encode(content)), new String(Base64.getEncoder().encode(decrypted)));
  }


  @Test
  public void encryptUsingECBMode() throws MuleEncryptionException {
    encrypter = new JCEEncrypter("AES/ECB/PKCS5Padding", keyFactory(KEY));
    byte[] content = PAYLOAD.getBytes();

    byte[] encrypted = encrypter.encrypt(content);
    byte[] decrypted = encrypter.decrypt(encrypted);

    assertThat(new String(decrypted), is(PAYLOAD));
  }

  @Test
  public void errorEncrypting() throws Exception {
    JCEEncrypter anotherEncrypter = new JCEEncrypter("AES/CBC/PKCS5Padding", keyFactory(ANOTHER_KEY));
    byte[] encrypted = encrypter.encrypt(PAYLOAD.getBytes());

    expectedException.expect(MuleEncryptionException.class);
    expectedException.expectMessage("Could not encrypt or decrypt the data.");

    anotherEncrypter.decrypt(encrypted);
  }

  @Test
  public void invalidAlgorithm() throws Exception {
    JCEEncrypter anotherEncrypter = new JCEEncrypter("invalid/CBC/PKCS5Padding", keyFactory(ANOTHER_KEY));
    byte[] encrypted = encrypter.encrypt(PAYLOAD.getBytes());

    expectedException.expect(MuleEncryptionException.class);
    expectedException.expectMessage("Cipher 'invalid/CBC/PKCS5Padding' not found");

    anotherEncrypter.decrypt(encrypted);
  }

  @Test
  public void invalidMode() throws Exception {
    JCEEncrypter anotherEncrypter = new JCEEncrypter("AES/invalid/PKCS5Padding", keyFactory(ANOTHER_KEY));
    byte[] encrypted = encrypter.encrypt(PAYLOAD.getBytes());

    expectedException.expect(MuleEncryptionException.class);
    expectedException.expectMessage("Cipher 'AES/invalid/PKCS5Padding' not found");

    anotherEncrypter.decrypt(encrypted);
  }

  @Test
  public void invalidPadding() throws Exception {
    JCEEncrypter anotherEncrypter = new JCEEncrypter("AES/CBC/invalid", keyFactory(ANOTHER_KEY));
    byte[] encrypted = encrypter.encrypt(PAYLOAD.getBytes());

    expectedException.expect(MuleEncryptionException.class);
    expectedException.expectMessage("Cipher 'AES/CBC/invalid' not found");

    anotherEncrypter.decrypt(encrypted);
  }

  @Test
  public void invalidProvider() throws Exception {
    JCEEncrypter anotherEncrypter = new JCEEncrypter("AES/CBC/PKCS5Padding", "invalid", keyFactory(ANOTHER_KEY));
    byte[] encrypted = encrypter.encrypt(PAYLOAD.getBytes());

    expectedException.expect(MuleEncryptionException.class);
    expectedException.expectMessage("Provider 'invalid' not found");

    anotherEncrypter.decrypt(encrypted);
  }

  @Test
  public void shortKey() throws Exception {
    JCEEncrypter anotherEncrypter = new JCEEncrypter("AES/CBC/PKCS5Padding", keyFactory("shortKey"));

    expectedException.expect(MuleEncryptionException.class);
    expectedException.expectMessage("The key is invalid, please make sure it's of a supported size (actual is 8)");

    anotherEncrypter.encrypt(PAYLOAD.getBytes());
  }

  @Test
  public void longKey() throws Exception {
    assumeFalse(isJCEInstalled());
    int maxKeySize = Cipher.getMaxAllowedKeyLength("AES") / 8;
    JCEEncrypter anotherEncrypter = new JCEEncrypter("AES/CBC/PKCS5Padding", keyFactory(randomAlphabetic(maxKeySize + 1)));

    expectedException.expect(MuleEncryptionException.class);
    expectedException
        .expectMessage("The key is invalid, please make sure it's of a supported size (actual is " + (maxKeySize + 1) + ")");

    anotherEncrypter.encrypt(PAYLOAD.getBytes());
  }

  private EncryptionKeyFactory keyFactory(String key) {
    return (SymmetricKeyFactory) () -> new SecretKeySpec(key.getBytes(), "AES");
  }
}
