1 // SPDX-License-Identifier: BSD-2-Clause 2 /* 3 * Copyright (c) 2018, Linaro Limited 4 */ 5 6 #include <crypto/crypto.h> 7 #include <kernel/panic.h> 8 #include <mbedtls/bignum.h> 9 #include <mempool.h> 10 #include <stdlib.h> 11 #include <string.h> 12 #include <tomcrypt_private.h> 13 #include <tomcrypt_mp.h> 14 #include <util.h> 15 16 #if defined(_CFG_CORE_LTC_PAGER) 17 #include <mm/core_mmu.h> 18 #include <mm/tee_pager.h> 19 #endif 20 21 /* Size needed for xtest to pass reliably on both ARM32 and ARM64 */ 22 #define MPI_MEMPOOL_SIZE (42 * 1024) 23 24 /* From mbedtls/library/bignum.c */ 25 #define ciL (sizeof(mbedtls_mpi_uint)) /* chars in limb */ 26 #define biL (ciL << 3) /* bits in limb */ 27 #define BITS_TO_LIMBS(i) ((i) / biL + ((i) % biL != 0)) 28 29 #if defined(_CFG_CORE_LTC_PAGER) 30 /* allocate pageable_zi vmem for mp scratch memory pool */ 31 static struct mempool *get_mp_scratch_memory_pool(void) 32 { 33 size_t size; 34 void *data; 35 36 size = ROUNDUP(MPI_MEMPOOL_SIZE, SMALL_PAGE_SIZE); 37 data = tee_pager_alloc(size); 38 if (!data) 39 panic(); 40 41 return mempool_alloc_pool(data, size, tee_pager_release_phys); 42 } 43 #else /* _CFG_CORE_LTC_PAGER */ 44 static struct mempool *get_mp_scratch_memory_pool(void) 45 { 46 static uint8_t data[MPI_MEMPOOL_SIZE] __aligned(MEMPOOL_ALIGN); 47 48 return mempool_alloc_pool(data, sizeof(data), NULL); 49 } 50 #endif 51 52 void init_mp_tomcrypt(void) 53 { 54 struct mempool *p = get_mp_scratch_memory_pool(); 55 56 if (!p) 57 panic(); 58 mbedtls_mpi_mempool = p; 59 assert(!mempool_default); 60 mempool_default = p; 61 } 62 63 static int init(void **a) 64 { 65 mbedtls_mpi *bn = mempool_alloc(mbedtls_mpi_mempool, sizeof(*bn)); 66 67 if (!bn) 68 return CRYPT_MEM; 69 70 mbedtls_mpi_init_mempool(bn); 71 *a = bn; 72 return CRYPT_OK; 73 } 74 75 static int init_size(int size_bits __unused, void **a) 76 { 77 return init(a); 78 } 79 80 static void deinit(void *a) 81 { 82 mbedtls_mpi_free((mbedtls_mpi *)a); 83 mempool_free(mbedtls_mpi_mempool, a); 84 } 85 86 static int neg(void *a, void *b) 87 { 88 if (mbedtls_mpi_copy(b, a)) 89 return CRYPT_MEM; 90 ((mbedtls_mpi *)b)->s *= -1; 91 return CRYPT_OK; 92 } 93 94 static int copy(void *a, void *b) 95 { 96 if (mbedtls_mpi_copy(b, a)) 97 return CRYPT_MEM; 98 return CRYPT_OK; 99 } 100 101 static int init_copy(void **a, void *b) 102 { 103 if (init(a) != CRYPT_OK) { 104 return CRYPT_MEM; 105 } 106 return copy(b, *a); 107 } 108 109 /* ---- trivial ---- */ 110 static int set_int(void *a, ltc_mp_digit b) 111 { 112 uint32_t b32 = b; 113 114 if (b32 != b) 115 return CRYPT_INVALID_ARG; 116 117 mbedtls_mpi_uint p = b32; 118 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 119 120 if (mbedtls_mpi_copy(a, &bn)) 121 return CRYPT_MEM; 122 return CRYPT_OK; 123 } 124 125 static unsigned long get_int(void *a) 126 { 127 mbedtls_mpi *bn = a; 128 129 if (!bn->n) 130 return 0; 131 132 return bn->p[bn->n - 1]; 133 } 134 135 static ltc_mp_digit get_digit(void *a, int n) 136 { 137 mbedtls_mpi *bn = a; 138 139 COMPILE_TIME_ASSERT(sizeof(ltc_mp_digit) >= sizeof(mbedtls_mpi_uint)); 140 141 if (n < 0 || (size_t)n >= bn->n) 142 return 0; 143 144 return bn->p[n]; 145 } 146 147 static int get_digit_count(void *a) 148 { 149 return ROUNDUP(mbedtls_mpi_size(a), sizeof(mbedtls_mpi_uint)) / 150 sizeof(mbedtls_mpi_uint); 151 } 152 153 static int compare(void *a, void *b) 154 { 155 int ret = mbedtls_mpi_cmp_mpi(a, b); 156 157 if (ret < 0) 158 return LTC_MP_LT; 159 160 if (ret > 0) 161 return LTC_MP_GT; 162 163 return LTC_MP_EQ; 164 } 165 166 static int compare_d(void *a, ltc_mp_digit b) 167 { 168 unsigned long v = b; 169 unsigned int shift = 31; 170 uint32_t mask = BIT(shift) - 1; 171 mbedtls_mpi bn; 172 173 mbedtls_mpi_init_mempool(&bn); 174 while (true) { 175 mbedtls_mpi_add_int(&bn, &bn, v & mask); 176 v >>= shift; 177 if (!v) 178 break; 179 mbedtls_mpi_shift_l(&bn, shift); 180 } 181 182 int ret = compare(a, &bn); 183 184 mbedtls_mpi_free(&bn); 185 186 return ret; 187 } 188 189 static int count_bits(void *a) 190 { 191 return mbedtls_mpi_bitlen(a); 192 } 193 194 static int count_lsb_bits(void *a) 195 { 196 return mbedtls_mpi_lsb(a); 197 } 198 199 200 static int twoexpt(void *a, int n) 201 { 202 if (mbedtls_mpi_set_bit(a, n, 1)) 203 return CRYPT_MEM; 204 205 return CRYPT_OK; 206 } 207 208 /* ---- conversions ---- */ 209 210 /* read ascii string */ 211 static int read_radix(void *a, const char *b, int radix) 212 { 213 int res = mbedtls_mpi_read_string(a, radix, b); 214 215 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 216 return CRYPT_MEM; 217 if (res) 218 return CRYPT_ERROR; 219 220 return CRYPT_OK; 221 } 222 223 /* write one */ 224 static int write_radix(void *a, char *b, int radix) 225 { 226 size_t ol = SIZE_MAX; 227 int res = mbedtls_mpi_write_string(a, radix, b, ol, &ol); 228 229 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 230 return CRYPT_MEM; 231 if (res) 232 return CRYPT_ERROR; 233 234 return CRYPT_OK; 235 } 236 237 /* get size as unsigned char string */ 238 static unsigned long unsigned_size(void *a) 239 { 240 return mbedtls_mpi_size(a); 241 } 242 243 /* store */ 244 static int unsigned_write(void *a, unsigned char *b) 245 { 246 int res = mbedtls_mpi_write_binary(a, b, unsigned_size(a)); 247 248 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 249 return CRYPT_MEM; 250 if (res) 251 return CRYPT_ERROR; 252 253 return CRYPT_OK; 254 } 255 256 /* read */ 257 static int unsigned_read(void *a, unsigned char *b, unsigned long len) 258 { 259 int res = mbedtls_mpi_read_binary(a, b, len); 260 261 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 262 return CRYPT_MEM; 263 if (res) 264 return CRYPT_ERROR; 265 266 return CRYPT_OK; 267 } 268 269 /* add */ 270 static int add(void *a, void *b, void *c) 271 { 272 if (mbedtls_mpi_add_mpi(c, a, b)) 273 return CRYPT_MEM; 274 275 return CRYPT_OK; 276 } 277 278 static int addi(void *a, ltc_mp_digit b, void *c) 279 { 280 uint32_t b32 = b; 281 282 if (b32 != b) 283 return CRYPT_INVALID_ARG; 284 285 mbedtls_mpi_uint p = b32; 286 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 287 288 return add(a, &bn, c); 289 } 290 291 /* sub */ 292 static int sub(void *a, void *b, void *c) 293 { 294 if (mbedtls_mpi_sub_mpi(c, a, b)) 295 return CRYPT_MEM; 296 297 return CRYPT_OK; 298 } 299 300 static int subi(void *a, ltc_mp_digit b, void *c) 301 { 302 uint32_t b32 = b; 303 304 if (b32 != b) 305 return CRYPT_INVALID_ARG; 306 307 mbedtls_mpi_uint p = b32; 308 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 309 310 return sub(a, &bn, c); 311 } 312 313 /* mul */ 314 static int mul(void *a, void *b, void *c) 315 { 316 if (mbedtls_mpi_mul_mpi(c, a, b)) 317 return CRYPT_MEM; 318 319 return CRYPT_OK; 320 } 321 322 static int muli(void *a, ltc_mp_digit b, void *c) 323 { 324 if (b > (unsigned long) UINT32_MAX) 325 return CRYPT_INVALID_ARG; 326 327 if (mbedtls_mpi_mul_int(c, a, b)) 328 return CRYPT_MEM; 329 330 return CRYPT_OK; 331 } 332 333 /* sqr */ 334 static int sqr(void *a, void *b) 335 { 336 return mul(a, a, b); 337 } 338 339 /* div */ 340 static int divide(void *a, void *b, void *c, void *d) 341 { 342 int res = mbedtls_mpi_div_mpi(c, d, a, b); 343 344 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 345 return CRYPT_MEM; 346 if (res) 347 return CRYPT_ERROR; 348 349 return CRYPT_OK; 350 } 351 352 static int div_2(void *a, void *b) 353 { 354 if (mbedtls_mpi_copy(b, a)) 355 return CRYPT_MEM; 356 357 if (mbedtls_mpi_shift_r(b, 1)) 358 return CRYPT_MEM; 359 360 return CRYPT_OK; 361 } 362 363 /* modi */ 364 static int modi(void *a, ltc_mp_digit b, ltc_mp_digit *c) 365 { 366 mbedtls_mpi bn_b; 367 mbedtls_mpi bn_c; 368 int res = 0; 369 370 mbedtls_mpi_init_mempool(&bn_b); 371 mbedtls_mpi_init_mempool(&bn_c); 372 373 res = set_int(&bn_b, b); 374 if (res) 375 return res; 376 377 res = mbedtls_mpi_mod_mpi(&bn_c, &bn_b, a); 378 if (!res) 379 *c = get_int(&bn_c); 380 381 mbedtls_mpi_free(&bn_b); 382 mbedtls_mpi_free(&bn_c); 383 384 if (res) 385 return CRYPT_MEM; 386 387 return CRYPT_OK; 388 } 389 390 /* gcd */ 391 static int gcd(void *a, void *b, void *c) 392 { 393 if (mbedtls_mpi_gcd(c, a, b)) 394 return CRYPT_MEM; 395 396 return CRYPT_OK; 397 } 398 399 /* lcm */ 400 static int lcm(void *a, void *b, void *c) 401 { 402 int res = CRYPT_MEM; 403 mbedtls_mpi tmp; 404 405 mbedtls_mpi_init_mempool(&tmp); 406 if (mbedtls_mpi_mul_mpi(&tmp, a, b)) 407 goto out; 408 409 if (mbedtls_mpi_gcd(c, a, b)) 410 goto out; 411 412 /* We use the following equality: gcd(a, b) * lcm(a, b) = a * b */ 413 res = divide(&tmp, c, c, NULL); 414 out: 415 mbedtls_mpi_free(&tmp); 416 return res; 417 } 418 419 static int mod(void *a, void *b, void *c) 420 { 421 int res = mbedtls_mpi_mod_mpi(c, a, b); 422 423 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 424 return CRYPT_MEM; 425 if (res) 426 return CRYPT_ERROR; 427 428 return CRYPT_OK; 429 } 430 431 static int addmod(void *a, void *b, void *c, void *d) 432 { 433 int res = add(a, b, d); 434 435 if (res) 436 return res; 437 438 return mod(d, c, d); 439 } 440 441 static int submod(void *a, void *b, void *c, void *d) 442 { 443 int res = sub(a, b, d); 444 445 if (res) 446 return res; 447 448 return mod(d, c, d); 449 } 450 451 static int mulmod(void *a, void *b, void *c, void *d) 452 { 453 int res; 454 mbedtls_mpi ta; 455 mbedtls_mpi tb; 456 457 mbedtls_mpi_init_mempool(&ta); 458 mbedtls_mpi_init_mempool(&tb); 459 460 res = mod(a, c, &ta); 461 if (res) 462 goto out; 463 res = mod(b, c, &tb); 464 if (res) 465 goto out; 466 res = mul(&ta, &tb, d); 467 if (res) 468 goto out; 469 res = mod(d, c, d); 470 out: 471 mbedtls_mpi_free(&ta); 472 mbedtls_mpi_free(&tb); 473 return res; 474 } 475 476 static int sqrmod(void *a, void *b, void *c) 477 { 478 return mulmod(a, a, b, c); 479 } 480 481 /* invmod */ 482 static int invmod(void *a, void *b, void *c) 483 { 484 int res = mbedtls_mpi_inv_mod(c, a, b); 485 486 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 487 return CRYPT_MEM; 488 if (res) 489 return CRYPT_ERROR; 490 491 return CRYPT_OK; 492 } 493 494 495 /* setup */ 496 static int montgomery_setup(void *a, void **b) 497 { 498 *b = malloc(sizeof(mbedtls_mpi_uint)); 499 if (!*b) 500 return CRYPT_MEM; 501 502 mbedtls_mpi_montg_init(*b, a); 503 504 return CRYPT_OK; 505 } 506 507 /* get normalization value */ 508 static int montgomery_normalization(void *a, void *b) 509 { 510 size_t c = ROUNDUP(mbedtls_mpi_size(b), sizeof(mbedtls_mpi_uint)) * 8; 511 512 if (mbedtls_mpi_lset(a, 1)) 513 return CRYPT_MEM; 514 if (mbedtls_mpi_shift_l(a, c)) 515 return CRYPT_MEM; 516 if (mbedtls_mpi_mod_mpi(a, a, b)) 517 return CRYPT_MEM; 518 519 return CRYPT_OK; 520 } 521 522 /* reduce */ 523 static int montgomery_reduce(void *a, void *b, void *c) 524 { 525 mbedtls_mpi A; 526 mbedtls_mpi *N = b; 527 mbedtls_mpi_uint *mm = c; 528 mbedtls_mpi T; 529 int ret = CRYPT_MEM; 530 531 mbedtls_mpi_init_mempool(&T); 532 mbedtls_mpi_init_mempool(&A); 533 534 if (mbedtls_mpi_grow(&T, (N->n + 1) * 2)) 535 goto out; 536 537 if (mbedtls_mpi_cmp_mpi(a, N) > 0) { 538 if (mbedtls_mpi_mod_mpi(&A, a, N)) 539 goto out; 540 } else { 541 if (mbedtls_mpi_copy(&A, a)) 542 goto out; 543 } 544 545 if (mbedtls_mpi_grow(&A, N->n + 1)) 546 goto out; 547 548 if (mbedtls_mpi_montred(&A, N, *mm, &T)) 549 goto out; 550 551 if (mbedtls_mpi_copy(a, &A)) 552 goto out; 553 554 ret = CRYPT_OK; 555 out: 556 mbedtls_mpi_free(&A); 557 mbedtls_mpi_free(&T); 558 559 return ret; 560 } 561 562 /* clean up */ 563 static void montgomery_deinit(void *a) 564 { 565 free(a); 566 } 567 568 /* 569 * This function calculates: 570 * d = a^b mod c 571 * 572 * @a: base 573 * @b: exponent 574 * @c: modulus 575 * @d: destination 576 */ 577 static int exptmod(void *a, void *b, void *c, void *d) 578 { 579 int res; 580 581 if (d == a || d == b || d == c) { 582 mbedtls_mpi dest; 583 584 mbedtls_mpi_init_mempool(&dest); 585 res = mbedtls_mpi_exp_mod(&dest, a, b, c, NULL); 586 if (!res) 587 res = mbedtls_mpi_copy(d, &dest); 588 mbedtls_mpi_free(&dest); 589 } else { 590 res = mbedtls_mpi_exp_mod(d, a, b, c, NULL); 591 } 592 593 if (res) 594 return CRYPT_MEM; 595 else 596 return CRYPT_OK; 597 } 598 599 static int rng_read(void *ignored __unused, unsigned char *buf, size_t blen) 600 { 601 if (crypto_rng_read(buf, blen)) 602 return MBEDTLS_ERR_MPI_FILE_IO_ERROR; 603 return 0; 604 } 605 606 static int isprime(void *a, int b __unused, int *c) 607 { 608 int res = mbedtls_mpi_is_prime(a, rng_read, NULL); 609 610 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 611 return CRYPT_MEM; 612 613 if (res) 614 *c = LTC_MP_NO; 615 else 616 *c = LTC_MP_YES; 617 618 return CRYPT_OK; 619 } 620 621 static int mpi_rand(void *a, int size) 622 { 623 if (mbedtls_mpi_fill_random(a, size, rng_read, NULL)) 624 return CRYPT_MEM; 625 626 return CRYPT_OK; 627 } 628 629 ltc_math_descriptor ltc_mp = { 630 .name = "MPI", 631 .bits_per_digit = sizeof(mbedtls_mpi_uint) * 8, 632 633 .init = init, 634 .init_size = init_size, 635 .init_copy = init_copy, 636 .deinit = deinit, 637 638 .neg = neg, 639 .copy = copy, 640 641 .set_int = set_int, 642 .get_int = get_int, 643 .get_digit = get_digit, 644 .get_digit_count = get_digit_count, 645 .compare = compare, 646 .compare_d = compare_d, 647 .count_bits = count_bits, 648 .count_lsb_bits = count_lsb_bits, 649 .twoexpt = twoexpt, 650 651 .read_radix = read_radix, 652 .write_radix = write_radix, 653 .unsigned_size = unsigned_size, 654 .unsigned_write = unsigned_write, 655 .unsigned_read = unsigned_read, 656 657 .add = add, 658 .addi = addi, 659 .sub = sub, 660 .subi = subi, 661 .mul = mul, 662 .muli = muli, 663 .sqr = sqr, 664 .mpdiv = divide, 665 .div_2 = div_2, 666 .modi = modi, 667 .gcd = gcd, 668 .lcm = lcm, 669 670 .mulmod = mulmod, 671 .sqrmod = sqrmod, 672 .invmod = invmod, 673 674 .montgomery_setup = montgomery_setup, 675 .montgomery_normalization = montgomery_normalization, 676 .montgomery_reduce = montgomery_reduce, 677 .montgomery_deinit = montgomery_deinit, 678 679 .exptmod = exptmod, 680 .isprime = isprime, 681 682 #ifdef LTC_MECC 683 #ifdef LTC_MECC_FP 684 .ecc_ptmul = ltc_ecc_fp_mulmod, 685 #else 686 .ecc_ptmul = ltc_ecc_mulmod, 687 #endif /* LTC_MECC_FP */ 688 .ecc_ptadd = ltc_ecc_projective_add_point, 689 .ecc_ptdbl = ltc_ecc_projective_dbl_point, 690 .ecc_map = ltc_ecc_map, 691 #ifdef LTC_ECC_SHAMIR 692 #ifdef LTC_MECC_FP 693 .ecc_mul2add = ltc_ecc_fp_mul2add, 694 #else 695 .ecc_mul2add = ltc_ecc_mul2add, 696 #endif /* LTC_MECC_FP */ 697 #endif /* LTC_ECC_SHAMIR */ 698 #endif /* LTC_MECC */ 699 700 #ifdef LTC_MRSA 701 .rsa_keygen = rsa_make_key, 702 .rsa_me = rsa_exptmod, 703 #endif 704 .addmod = addmod, 705 .submod = submod, 706 .rand = mpi_rand, 707 708 }; 709 710 size_t crypto_bignum_num_bytes(struct bignum *a) 711 { 712 return mbedtls_mpi_size((mbedtls_mpi *)a); 713 } 714 715 size_t crypto_bignum_num_bits(struct bignum *a) 716 { 717 return mbedtls_mpi_bitlen((mbedtls_mpi *)a); 718 } 719 720 int32_t crypto_bignum_compare(struct bignum *a, struct bignum *b) 721 { 722 return mbedtls_mpi_cmp_mpi((mbedtls_mpi *)a, (mbedtls_mpi *)b); 723 } 724 725 void crypto_bignum_bn2bin(const struct bignum *from, uint8_t *to) 726 { 727 const mbedtls_mpi *f = (const mbedtls_mpi *)from; 728 int rc __maybe_unused = 0; 729 730 rc = mbedtls_mpi_write_binary(f, (void *)to, mbedtls_mpi_size(f)); 731 assert(!rc); 732 } 733 734 TEE_Result crypto_bignum_bin2bn(const uint8_t *from, size_t fromsize, 735 struct bignum *to) 736 { 737 if (mbedtls_mpi_read_binary((mbedtls_mpi *)to, (const void *)from, 738 fromsize)) 739 return TEE_ERROR_BAD_PARAMETERS; 740 return TEE_SUCCESS; 741 } 742 743 void crypto_bignum_copy(struct bignum *to, const struct bignum *from) 744 { 745 mbedtls_mpi_copy((mbedtls_mpi *)to, (const mbedtls_mpi *)from); 746 } 747 748 struct bignum *crypto_bignum_allocate(size_t size_bits) 749 { 750 mbedtls_mpi *bn = malloc(sizeof(*bn)); 751 752 if (!bn) 753 return NULL; 754 755 mbedtls_mpi_init(bn); 756 if (mbedtls_mpi_grow(bn, BITS_TO_LIMBS(size_bits))) { 757 free(bn); 758 return NULL; 759 } 760 761 return (struct bignum *)bn; 762 } 763 764 void crypto_bignum_free(struct bignum *s) 765 { 766 mbedtls_mpi_free((mbedtls_mpi *)s); 767 free(s); 768 } 769 770 void crypto_bignum_clear(struct bignum *s) 771 { 772 mbedtls_mpi *bn = (mbedtls_mpi *)s; 773 774 bn->s = 1; 775 if (bn->p) 776 memset(bn->p, 0, sizeof(*bn->p) * bn->n); 777 } 778