1 // SPDX-License-Identifier: BSD-2-Clause 2 /* 3 * Copyright (c) 2018, Linaro Limited 4 */ 5 6 #include <crypto/crypto.h> 7 #include <kernel/panic.h> 8 #include <mbedtls/bignum.h> 9 #include <mempool.h> 10 #include <stdlib.h> 11 #include <string.h> 12 #include <tomcrypt_private.h> 13 #include <tomcrypt_mp.h> 14 #include <util.h> 15 16 #if defined(_CFG_CORE_LTC_PAGER) 17 #include <mm/core_mmu.h> 18 #include <mm/tee_pager.h> 19 #endif 20 21 /* Size needed for xtest to pass reliably on both ARM32 and ARM64 */ 22 #define MPI_MEMPOOL_SIZE (46 * 1024) 23 24 /* From mbedtls/library/bignum.c */ 25 #define ciL (sizeof(mbedtls_mpi_uint)) /* chars in limb */ 26 #define biL (ciL << 3) /* bits in limb */ 27 #define BITS_TO_LIMBS(i) ((i) / biL + ((i) % biL != 0)) 28 29 #if defined(_CFG_CORE_LTC_PAGER) 30 /* allocate pageable_zi vmem for mp scratch memory pool */ 31 static struct mempool *get_mp_scratch_memory_pool(void) 32 { 33 size_t size; 34 void *data; 35 36 size = ROUNDUP(MPI_MEMPOOL_SIZE, SMALL_PAGE_SIZE); 37 data = tee_pager_alloc(size); 38 if (!data) 39 panic(); 40 41 return mempool_alloc_pool(data, size, tee_pager_release_phys); 42 } 43 #else /* _CFG_CORE_LTC_PAGER */ 44 static struct mempool *get_mp_scratch_memory_pool(void) 45 { 46 static uint8_t data[MPI_MEMPOOL_SIZE] __aligned(MEMPOOL_ALIGN); 47 48 return mempool_alloc_pool(data, sizeof(data), NULL); 49 } 50 #endif 51 52 void init_mp_tomcrypt(void) 53 { 54 struct mempool *p = get_mp_scratch_memory_pool(); 55 56 if (!p) 57 panic(); 58 mbedtls_mpi_mempool = p; 59 assert(!mempool_default); 60 mempool_default = p; 61 } 62 63 static int init(void **a) 64 { 65 mbedtls_mpi *bn = mempool_alloc(mbedtls_mpi_mempool, sizeof(*bn)); 66 67 if (!bn) 68 return CRYPT_MEM; 69 70 mbedtls_mpi_init_mempool(bn); 71 *a = bn; 72 return CRYPT_OK; 73 } 74 75 static int init_size(int size_bits __unused, void **a) 76 { 77 return init(a); 78 } 79 80 static void deinit(void *a) 81 { 82 mbedtls_mpi_free((mbedtls_mpi *)a); 83 mempool_free(mbedtls_mpi_mempool, a); 84 } 85 86 static int neg(void *a, void *b) 87 { 88 if (mbedtls_mpi_copy(b, a)) 89 return CRYPT_MEM; 90 ((mbedtls_mpi *)b)->s *= -1; 91 return CRYPT_OK; 92 } 93 94 static int copy(void *a, void *b) 95 { 96 if (mbedtls_mpi_copy(b, a)) 97 return CRYPT_MEM; 98 return CRYPT_OK; 99 } 100 101 static int init_copy(void **a, void *b) 102 { 103 if (init(a) != CRYPT_OK) { 104 return CRYPT_MEM; 105 } 106 return copy(b, *a); 107 } 108 109 /* ---- trivial ---- */ 110 static int set_int(void *a, ltc_mp_digit b) 111 { 112 uint32_t b32 = b; 113 114 if (b32 != b) 115 return CRYPT_INVALID_ARG; 116 117 mbedtls_mpi_uint p = b32; 118 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 119 120 if (mbedtls_mpi_copy(a, &bn)) 121 return CRYPT_MEM; 122 return CRYPT_OK; 123 } 124 125 static unsigned long get_int(void *a) 126 { 127 mbedtls_mpi *bn = a; 128 129 if (!bn->n) 130 return 0; 131 132 return bn->p[bn->n - 1]; 133 } 134 135 static ltc_mp_digit get_digit(void *a, int n) 136 { 137 mbedtls_mpi *bn = a; 138 139 COMPILE_TIME_ASSERT(sizeof(ltc_mp_digit) >= sizeof(mbedtls_mpi_uint)); 140 141 if (n < 0 || (size_t)n >= bn->n) 142 return 0; 143 144 return bn->p[n]; 145 } 146 147 static int get_digit_count(void *a) 148 { 149 return ROUNDUP_DIV(mbedtls_mpi_size(a), sizeof(mbedtls_mpi_uint)); 150 } 151 152 static int compare(void *a, void *b) 153 { 154 int ret = mbedtls_mpi_cmp_mpi(a, b); 155 156 if (ret < 0) 157 return LTC_MP_LT; 158 159 if (ret > 0) 160 return LTC_MP_GT; 161 162 return LTC_MP_EQ; 163 } 164 165 static int compare_d(void *a, ltc_mp_digit b) 166 { 167 unsigned long v = b; 168 unsigned int shift = 31; 169 uint32_t mask = BIT(shift) - 1; 170 mbedtls_mpi bn; 171 172 mbedtls_mpi_init_mempool(&bn); 173 while (true) { 174 mbedtls_mpi_add_int(&bn, &bn, v & mask); 175 v >>= shift; 176 if (!v) 177 break; 178 mbedtls_mpi_shift_l(&bn, shift); 179 } 180 181 int ret = compare(a, &bn); 182 183 mbedtls_mpi_free(&bn); 184 185 return ret; 186 } 187 188 static int count_bits(void *a) 189 { 190 return mbedtls_mpi_bitlen(a); 191 } 192 193 static int count_lsb_bits(void *a) 194 { 195 return mbedtls_mpi_lsb(a); 196 } 197 198 199 static int twoexpt(void *a, int n) 200 { 201 if (mbedtls_mpi_set_bit(a, n, 1)) 202 return CRYPT_MEM; 203 204 return CRYPT_OK; 205 } 206 207 /* ---- conversions ---- */ 208 209 /* read ascii string */ 210 static int read_radix(void *a, const char *b, int radix) 211 { 212 int res = mbedtls_mpi_read_string(a, radix, b); 213 214 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 215 return CRYPT_MEM; 216 if (res) 217 return CRYPT_ERROR; 218 219 return CRYPT_OK; 220 } 221 222 /* write one */ 223 static int write_radix(void *a, char *b, int radix) 224 { 225 size_t ol = SIZE_MAX; 226 int res = mbedtls_mpi_write_string(a, radix, b, ol, &ol); 227 228 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 229 return CRYPT_MEM; 230 if (res) 231 return CRYPT_ERROR; 232 233 return CRYPT_OK; 234 } 235 236 /* get size as unsigned char string */ 237 static unsigned long unsigned_size(void *a) 238 { 239 return mbedtls_mpi_size(a); 240 } 241 242 /* store */ 243 static int unsigned_write(void *a, unsigned char *b) 244 { 245 int res = mbedtls_mpi_write_binary(a, b, unsigned_size(a)); 246 247 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 248 return CRYPT_MEM; 249 if (res) 250 return CRYPT_ERROR; 251 252 return CRYPT_OK; 253 } 254 255 /* read */ 256 static int unsigned_read(void *a, unsigned char *b, unsigned long len) 257 { 258 int res = mbedtls_mpi_read_binary(a, b, len); 259 260 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 261 return CRYPT_MEM; 262 if (res) 263 return CRYPT_ERROR; 264 265 return CRYPT_OK; 266 } 267 268 /* add */ 269 static int add(void *a, void *b, void *c) 270 { 271 if (mbedtls_mpi_add_mpi(c, a, b)) 272 return CRYPT_MEM; 273 274 return CRYPT_OK; 275 } 276 277 static int addi(void *a, ltc_mp_digit b, void *c) 278 { 279 uint32_t b32 = b; 280 281 if (b32 != b) 282 return CRYPT_INVALID_ARG; 283 284 mbedtls_mpi_uint p = b32; 285 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 286 287 return add(a, &bn, c); 288 } 289 290 /* sub */ 291 static int sub(void *a, void *b, void *c) 292 { 293 if (mbedtls_mpi_sub_mpi(c, a, b)) 294 return CRYPT_MEM; 295 296 return CRYPT_OK; 297 } 298 299 static int subi(void *a, ltc_mp_digit b, void *c) 300 { 301 uint32_t b32 = b; 302 303 if (b32 != b) 304 return CRYPT_INVALID_ARG; 305 306 mbedtls_mpi_uint p = b32; 307 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 308 309 return sub(a, &bn, c); 310 } 311 312 /* mul */ 313 static int mul(void *a, void *b, void *c) 314 { 315 if (mbedtls_mpi_mul_mpi(c, a, b)) 316 return CRYPT_MEM; 317 318 return CRYPT_OK; 319 } 320 321 static int muli(void *a, ltc_mp_digit b, void *c) 322 { 323 if (b > (unsigned long) UINT32_MAX) 324 return CRYPT_INVALID_ARG; 325 326 if (mbedtls_mpi_mul_int(c, a, b)) 327 return CRYPT_MEM; 328 329 return CRYPT_OK; 330 } 331 332 /* sqr */ 333 static int sqr(void *a, void *b) 334 { 335 return mul(a, a, b); 336 } 337 338 /* div */ 339 static int divide(void *a, void *b, void *c, void *d) 340 { 341 int res = mbedtls_mpi_div_mpi(c, d, a, b); 342 343 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 344 return CRYPT_MEM; 345 if (res) 346 return CRYPT_ERROR; 347 348 return CRYPT_OK; 349 } 350 351 static int div_2(void *a, void *b) 352 { 353 if (mbedtls_mpi_copy(b, a)) 354 return CRYPT_MEM; 355 356 if (mbedtls_mpi_shift_r(b, 1)) 357 return CRYPT_MEM; 358 359 return CRYPT_OK; 360 } 361 362 /* modi */ 363 static int modi(void *a, ltc_mp_digit b, ltc_mp_digit *c) 364 { 365 mbedtls_mpi bn_b; 366 mbedtls_mpi bn_c; 367 int res = 0; 368 369 mbedtls_mpi_init_mempool(&bn_b); 370 mbedtls_mpi_init_mempool(&bn_c); 371 372 res = set_int(&bn_b, b); 373 if (res) 374 return res; 375 376 res = mbedtls_mpi_mod_mpi(&bn_c, &bn_b, a); 377 if (!res) 378 *c = get_int(&bn_c); 379 380 mbedtls_mpi_free(&bn_b); 381 mbedtls_mpi_free(&bn_c); 382 383 if (res) 384 return CRYPT_MEM; 385 386 return CRYPT_OK; 387 } 388 389 /* gcd */ 390 static int gcd(void *a, void *b, void *c) 391 { 392 if (mbedtls_mpi_gcd(c, a, b)) 393 return CRYPT_MEM; 394 395 return CRYPT_OK; 396 } 397 398 /* lcm */ 399 static int lcm(void *a, void *b, void *c) 400 { 401 int res = CRYPT_MEM; 402 mbedtls_mpi tmp; 403 404 mbedtls_mpi_init_mempool(&tmp); 405 if (mbedtls_mpi_mul_mpi(&tmp, a, b)) 406 goto out; 407 408 if (mbedtls_mpi_gcd(c, a, b)) 409 goto out; 410 411 /* We use the following equality: gcd(a, b) * lcm(a, b) = a * b */ 412 res = divide(&tmp, c, c, NULL); 413 out: 414 mbedtls_mpi_free(&tmp); 415 return res; 416 } 417 418 static int mod(void *a, void *b, void *c) 419 { 420 int res = mbedtls_mpi_mod_mpi(c, a, b); 421 422 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 423 return CRYPT_MEM; 424 if (res) 425 return CRYPT_ERROR; 426 427 return CRYPT_OK; 428 } 429 430 static int addmod(void *a, void *b, void *c, void *d) 431 { 432 int res = add(a, b, d); 433 434 if (res) 435 return res; 436 437 return mod(d, c, d); 438 } 439 440 static int submod(void *a, void *b, void *c, void *d) 441 { 442 int res = sub(a, b, d); 443 444 if (res) 445 return res; 446 447 return mod(d, c, d); 448 } 449 450 static int mulmod(void *a, void *b, void *c, void *d) 451 { 452 int res; 453 mbedtls_mpi ta; 454 mbedtls_mpi tb; 455 456 mbedtls_mpi_init_mempool(&ta); 457 mbedtls_mpi_init_mempool(&tb); 458 459 res = mod(a, c, &ta); 460 if (res) 461 goto out; 462 res = mod(b, c, &tb); 463 if (res) 464 goto out; 465 res = mul(&ta, &tb, d); 466 if (res) 467 goto out; 468 res = mod(d, c, d); 469 out: 470 mbedtls_mpi_free(&ta); 471 mbedtls_mpi_free(&tb); 472 return res; 473 } 474 475 static int sqrmod(void *a, void *b, void *c) 476 { 477 return mulmod(a, a, b, c); 478 } 479 480 /* invmod */ 481 static int invmod(void *a, void *b, void *c) 482 { 483 int res = mbedtls_mpi_inv_mod(c, a, b); 484 485 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 486 return CRYPT_MEM; 487 if (res) 488 return CRYPT_ERROR; 489 490 return CRYPT_OK; 491 } 492 493 494 /* setup */ 495 static int montgomery_setup(void *a, void **b) 496 { 497 *b = mempool_alloc(mbedtls_mpi_mempool, sizeof(mbedtls_mpi_uint)); 498 if (!*b) 499 return CRYPT_MEM; 500 501 mbedtls_mpi_montg_init(*b, a); 502 503 return CRYPT_OK; 504 } 505 506 /* get normalization value */ 507 static int montgomery_normalization(void *a, void *b) 508 { 509 size_t c = ROUNDUP(mbedtls_mpi_size(b), sizeof(mbedtls_mpi_uint)) * 8; 510 511 if (mbedtls_mpi_lset(a, 1)) 512 return CRYPT_MEM; 513 if (mbedtls_mpi_shift_l(a, c)) 514 return CRYPT_MEM; 515 if (mbedtls_mpi_mod_mpi(a, a, b)) 516 return CRYPT_MEM; 517 518 return CRYPT_OK; 519 } 520 521 /* reduce */ 522 static int montgomery_reduce(void *a, void *b, void *c) 523 { 524 mbedtls_mpi A; 525 mbedtls_mpi *N = b; 526 mbedtls_mpi_uint *mm = c; 527 mbedtls_mpi T; 528 int ret = CRYPT_MEM; 529 530 mbedtls_mpi_init_mempool(&T); 531 mbedtls_mpi_init_mempool(&A); 532 533 if (mbedtls_mpi_grow(&T, (N->n + 1) * 2)) 534 goto out; 535 536 if (mbedtls_mpi_cmp_mpi(a, N) > 0) { 537 if (mbedtls_mpi_mod_mpi(&A, a, N)) 538 goto out; 539 } else { 540 if (mbedtls_mpi_copy(&A, a)) 541 goto out; 542 } 543 544 if (mbedtls_mpi_grow(&A, N->n + 1)) 545 goto out; 546 547 mbedtls_mpi_montred(&A, N, *mm, &T); 548 549 if (mbedtls_mpi_copy(a, &A)) 550 goto out; 551 552 ret = CRYPT_OK; 553 out: 554 mbedtls_mpi_free(&A); 555 mbedtls_mpi_free(&T); 556 557 return ret; 558 } 559 560 /* clean up */ 561 static void montgomery_deinit(void *a) 562 { 563 mempool_free(mbedtls_mpi_mempool, a); 564 } 565 566 /* 567 * This function calculates: 568 * d = a^b mod c 569 * 570 * @a: base 571 * @b: exponent 572 * @c: modulus 573 * @d: destination 574 */ 575 static int exptmod(void *a, void *b, void *c, void *d) 576 { 577 int res; 578 579 if (d == a || d == b || d == c) { 580 mbedtls_mpi dest; 581 582 mbedtls_mpi_init_mempool(&dest); 583 res = mbedtls_mpi_exp_mod(&dest, a, b, c, NULL); 584 if (!res) 585 res = mbedtls_mpi_copy(d, &dest); 586 mbedtls_mpi_free(&dest); 587 } else { 588 res = mbedtls_mpi_exp_mod(d, a, b, c, NULL); 589 } 590 591 if (res) 592 return CRYPT_MEM; 593 else 594 return CRYPT_OK; 595 } 596 597 static int rng_read(void *ignored __unused, unsigned char *buf, size_t blen) 598 { 599 if (crypto_rng_read(buf, blen)) 600 return MBEDTLS_ERR_MPI_FILE_IO_ERROR; 601 return 0; 602 } 603 604 static int isprime(void *a, int b, int *c) 605 { 606 int res = mbedtls_mpi_is_prime_ext(a, b, rng_read, NULL); 607 608 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 609 return CRYPT_MEM; 610 611 if (res) 612 *c = LTC_MP_NO; 613 else 614 *c = LTC_MP_YES; 615 616 return CRYPT_OK; 617 } 618 619 static int mpi_rand(void *a, int size) 620 { 621 if (mbedtls_mpi_fill_random(a, size, rng_read, NULL)) 622 return CRYPT_MEM; 623 624 return CRYPT_OK; 625 } 626 627 ltc_math_descriptor ltc_mp = { 628 .name = "MPI", 629 .bits_per_digit = sizeof(mbedtls_mpi_uint) * 8, 630 631 .init = init, 632 .init_size = init_size, 633 .init_copy = init_copy, 634 .deinit = deinit, 635 636 .neg = neg, 637 .copy = copy, 638 639 .set_int = set_int, 640 .get_int = get_int, 641 .get_digit = get_digit, 642 .get_digit_count = get_digit_count, 643 .compare = compare, 644 .compare_d = compare_d, 645 .count_bits = count_bits, 646 .count_lsb_bits = count_lsb_bits, 647 .twoexpt = twoexpt, 648 649 .read_radix = read_radix, 650 .write_radix = write_radix, 651 .unsigned_size = unsigned_size, 652 .unsigned_write = unsigned_write, 653 .unsigned_read = unsigned_read, 654 655 .add = add, 656 .addi = addi, 657 .sub = sub, 658 .subi = subi, 659 .mul = mul, 660 .muli = muli, 661 .sqr = sqr, 662 .mpdiv = divide, 663 .div_2 = div_2, 664 .modi = modi, 665 .gcd = gcd, 666 .lcm = lcm, 667 668 .mulmod = mulmod, 669 .sqrmod = sqrmod, 670 .invmod = invmod, 671 672 .montgomery_setup = montgomery_setup, 673 .montgomery_normalization = montgomery_normalization, 674 .montgomery_reduce = montgomery_reduce, 675 .montgomery_deinit = montgomery_deinit, 676 677 .exptmod = exptmod, 678 .isprime = isprime, 679 680 #ifdef LTC_MECC 681 #ifdef LTC_MECC_FP 682 .ecc_ptmul = ltc_ecc_fp_mulmod, 683 #else 684 .ecc_ptmul = ltc_ecc_mulmod, 685 #endif /* LTC_MECC_FP */ 686 .ecc_ptadd = ltc_ecc_projective_add_point, 687 .ecc_ptdbl = ltc_ecc_projective_dbl_point, 688 .ecc_map = ltc_ecc_map, 689 #ifdef LTC_ECC_SHAMIR 690 #ifdef LTC_MECC_FP 691 .ecc_mul2add = ltc_ecc_fp_mul2add, 692 #else 693 .ecc_mul2add = ltc_ecc_mul2add, 694 #endif /* LTC_MECC_FP */ 695 #endif /* LTC_ECC_SHAMIR */ 696 #endif /* LTC_MECC */ 697 698 #ifdef LTC_MRSA 699 .rsa_keygen = rsa_make_key, 700 .rsa_me = rsa_exptmod, 701 #endif 702 .addmod = addmod, 703 .submod = submod, 704 .rand = mpi_rand, 705 706 }; 707 708 size_t crypto_bignum_num_bytes(struct bignum *a) 709 { 710 return mbedtls_mpi_size((mbedtls_mpi *)a); 711 } 712 713 size_t crypto_bignum_num_bits(struct bignum *a) 714 { 715 return mbedtls_mpi_bitlen((mbedtls_mpi *)a); 716 } 717 718 int32_t crypto_bignum_compare(struct bignum *a, struct bignum *b) 719 { 720 return mbedtls_mpi_cmp_mpi((mbedtls_mpi *)a, (mbedtls_mpi *)b); 721 } 722 723 void crypto_bignum_bn2bin(const struct bignum *from, uint8_t *to) 724 { 725 const mbedtls_mpi *f = (const mbedtls_mpi *)from; 726 int rc __maybe_unused = 0; 727 728 rc = mbedtls_mpi_write_binary(f, (void *)to, mbedtls_mpi_size(f)); 729 assert(!rc); 730 } 731 732 TEE_Result crypto_bignum_bin2bn(const uint8_t *from, size_t fromsize, 733 struct bignum *to) 734 { 735 if (mbedtls_mpi_read_binary((mbedtls_mpi *)to, (const void *)from, 736 fromsize)) 737 return TEE_ERROR_BAD_PARAMETERS; 738 return TEE_SUCCESS; 739 } 740 741 void crypto_bignum_copy(struct bignum *to, const struct bignum *from) 742 { 743 int rc __maybe_unused = 0; 744 745 rc = mbedtls_mpi_copy((mbedtls_mpi *)to, (const mbedtls_mpi *)from); 746 assert(!rc); 747 } 748 749 struct bignum *crypto_bignum_allocate(size_t size_bits) 750 { 751 mbedtls_mpi *bn = malloc(sizeof(*bn)); 752 753 if (!bn) 754 return NULL; 755 756 mbedtls_mpi_init(bn); 757 if (mbedtls_mpi_grow(bn, BITS_TO_LIMBS(size_bits))) { 758 free(bn); 759 return NULL; 760 } 761 762 return (struct bignum *)bn; 763 } 764 765 void crypto_bignum_free(struct bignum **s) 766 { 767 assert(s); 768 769 mbedtls_mpi_free((mbedtls_mpi *)*s); 770 free(*s); 771 *s = NULL; 772 } 773 774 void crypto_bignum_clear(struct bignum *s) 775 { 776 mbedtls_mpi *bn = (mbedtls_mpi *)s; 777 778 bn->s = 1; 779 if (bn->p) 780 memset(bn->p, 0, sizeof(*bn->p) * bn->n); 781 } 782