// SPDX-License-Identifier: Apache-2.0
#include <setjmp.h>
#include <stdarg.h>
#include <stddef.h>
#include <cmocka.h>

#include <aes.h>
#include <block-cipher.h>

static void test_aes_ecb(void **state) {
  (void)state;

  uint8_t data[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                    0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
                    0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                    0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F};
  uint8_t key[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                   0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F};
  uint8_t expected[] = {0x0a, 0x94, 0x0b, 0xb5, 0x41, 0x6e, 0xf0, 0x45,
                        0xf1, 0xc3, 0x94, 0x58, 0xc6, 0x53, 0xea, 0x5a,
                        0x0a, 0x94, 0x0b, 0xb5, 0x41, 0x6e, 0xf0, 0x45,
                        0xf1, 0xc3, 0x94, 0x58, 0xc6, 0x53, 0xea, 0x5a};

  block_cipher_config cfg = {.mode = ECB,
                             .in = data,
                             .in_size = sizeof(data),
                             .out = data,
                             .key = key,
                             .iv = NULL,
                             .block_size = 16,
                             .encrypt = aes128_enc,
                             .decrypt = aes128_dec};
  block_cipher_enc(&cfg);
  for (int i = 0; i != 32; ++i) {
    assert_int_equal(data[i], expected[i]);
  }

  block_cipher_dec(&cfg);
  for (int i = 0; i != 16; ++i) {
    assert_int_equal(data[i], i);
    assert_int_equal(data[i + 16], i);
  }
}

static void test_aes_cbc(void **state) {
  (void)state;

  uint8_t data[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09,
                    0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03,
                    0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D,
                    0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                    0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F};
  uint8_t key[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                   0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F};
  uint8_t iv[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                  0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07};
  uint8_t expected[] = {
      0x60, 0x3d, 0x8c, 0xcd, 0x75, 0x21, 0xe2, 0x96, 0x15, 0x67, 0xc0, 0x24,
      0xdf, 0x33, 0x67, 0x85, 0x17, 0x2c, 0xb7, 0x7e, 0xc4, 0x97, 0x7d, 0xa0,
      0x38, 0x80, 0xb8, 0x32, 0xd2, 0x0a, 0x79, 0x76, 0xf4, 0xfb, 0xca, 0x63,
      0x59, 0xe4, 0x72, 0x42, 0x45, 0xfe, 0xc3, 0x70, 0xbb, 0x63, 0x7d, 0x38};

  block_cipher_config cfg = {.mode = CBC,
                             .in = data,
                             .in_size = sizeof(data),
                             .out = data,
                             .key = key,
                             .iv = iv,
                             .block_size = 16,
                             .encrypt = aes128_enc,
                             .decrypt = aes128_dec};
  block_cipher_enc(&cfg);
  for (int i = 0; i != 48; ++i) {
    assert_int_equal(data[i], expected[i]);
  }
  block_cipher_dec(&cfg);
  for (int i = 0; i != 16; ++i) {
    assert_int_equal(data[i], i);
    assert_int_equal(data[i + 16], i);
    assert_int_equal(data[i + 32], i);
  }
}

static void test_aes_cfb(void **state) {
  (void)state;

  uint8_t data[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09,
                    0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03,
                    0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D,
                    0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                    0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F};
  uint8_t key[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                   0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F};
  uint8_t iv[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                  0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07};
  uint8_t expected[] = {
      0x3d, 0xfa, 0xde, 0xc7, 0x14, 0xdf, 0x6d, 0x0c, 0xb6, 0xa9, 0x1b, 0xd1,
      0x1e, 0xa1, 0xec, 0xf6, 0x29, 0x64, 0x0e, 0x6e, 0xfb, 0x38, 0x3d, 0x58,
      0x10, 0x7e, 0x07, 0x41, 0xc5, 0x7c, 0x32, 0xc1, 0xc0, 0xd6, 0xba, 0x3c,
      0x45, 0x10, 0xc1, 0x13, 0xe3, 0x6a, 0xc9, 0x7b, 0x6f, 0x9a, 0x32, 0xa6};

  block_cipher_config cfg = {.mode = CFB,
                             .in = data,
                             .in_size = sizeof(data),
                             .out = data,
                             .key = key,
                             .iv = iv,
                             .block_size = 16,
                             .encrypt = aes128_enc,
                             .decrypt = aes128_dec};
  block_cipher_enc(&cfg);
  for (int i = 0; i != 48; ++i) {
    assert_int_equal(data[i], expected[i]);
  }
  block_cipher_dec(&cfg);
  for (int i = 0; i != 16; ++i) {
    assert_int_equal(data[i], i);
    assert_int_equal(data[i + 16], i);
    assert_int_equal(data[i + 32], i);
  }
}

static void test_aes_ofb(void **state) {
  (void)state;

  uint8_t data[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09,
                    0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03,
                    0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D,
                    0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                    0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F};
  uint8_t key[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                   0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F};
  uint8_t iv[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                  0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07};
  uint8_t expected[] = {
      0x3d, 0xfa, 0xde, 0xc7, 0x14, 0xdf, 0x6d, 0x0c, 0xb6, 0xa9, 0x1b, 0xd1,
      0x1e, 0xa1, 0xec, 0xf6, 0x2e, 0x0f, 0x5b, 0x85, 0x51, 0xf0, 0x25, 0xcf,
      0x02, 0xca, 0x9f, 0x55, 0xb0, 0x6b, 0x50, 0x75, 0xf7, 0x6f, 0xf6, 0x85,
      0x7b, 0xa3, 0xc8, 0xb5, 0x20, 0x07, 0xad, 0xaf, 0x2e, 0xd9, 0x86, 0xab};

  block_cipher_config cfg = {.mode = OFB,
                             .in = data,
                             .in_size = sizeof(data),
                             .out = data,
                             .key = key,
                             .iv = iv,
                             .block_size = 16,
                             .encrypt = aes128_enc,
                             .decrypt = aes128_dec};
  block_cipher_enc(&cfg);
  for (int i = 0; i != 48; ++i) {
    assert_int_equal(data[i], expected[i]);
  }
  block_cipher_dec(&cfg);
  for (int i = 0; i != 16; ++i) {
    assert_int_equal(data[i], i);
    assert_int_equal(data[i + 16], i);
    assert_int_equal(data[i + 32], i);
  }
}

static void test_aes_ctr(void **state) {
  (void)state;

  uint8_t data[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09,
                    0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03,
                    0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D,
                    0x0E, 0x0F, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                    0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F};
  uint8_t key[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                   0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F};
  uint8_t iv[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
                  0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07};

  block_cipher_config cfg = {.mode = CTR,
                             .in = data,
                             .in_size = sizeof(data),
                             .out = data,
                             .key = key,
                             .iv = iv,
                             .block_size = 16,
                             .encrypt = aes128_enc,
                             .decrypt = aes128_dec};
  block_cipher_enc(&cfg);
  block_cipher_dec(&cfg);
  for (int i = 0; i != 16; ++i) {
    assert_int_equal(data[i], i);
    assert_int_equal(data[i + 16], i);
    assert_int_equal(data[i + 32], i);
  }
}

int main() {
  const struct CMUnitTest tests[] = {
      cmocka_unit_test(test_aes_ecb), cmocka_unit_test(test_aes_cbc),
      cmocka_unit_test(test_aes_cfb), cmocka_unit_test(test_aes_ofb),
      cmocka_unit_test(test_aes_ctr),
  };

  return cmocka_run_group_tests(tests, NULL, NULL);
}