1 // SPDX-License-Identifier: BSD-2-Clause 2 /* 3 * Copyright (c) 2018, Linaro Limited 4 */ 5 6 #include <crypto/crypto.h> 7 #include <kernel/panic.h> 8 #include <mbedtls/bignum.h> 9 #include <mempool.h> 10 #include <stdlib.h> 11 #include <string.h> 12 #include <tomcrypt_private.h> 13 #include <tomcrypt_mp.h> 14 #include <util.h> 15 16 #if defined(_CFG_CORE_LTC_PAGER) 17 #include <mm/core_mmu.h> 18 #include <mm/tee_pager.h> 19 #endif 20 21 /* Size needed for xtest to pass reliably on both ARM32 and ARM64 */ 22 #define MPI_MEMPOOL_SIZE (42 * 1024) 23 24 /* From mbedtls/library/bignum.c */ 25 #define ciL (sizeof(mbedtls_mpi_uint)) /* chars in limb */ 26 #define biL (ciL << 3) /* bits in limb */ 27 #define BITS_TO_LIMBS(i) ((i) / biL + ((i) % biL != 0)) 28 29 #if defined(_CFG_CORE_LTC_PAGER) 30 /* allocate pageable_zi vmem for mp scratch memory pool */ 31 static struct mempool *get_mp_scratch_memory_pool(void) 32 { 33 size_t size; 34 void *data; 35 36 size = ROUNDUP(MPI_MEMPOOL_SIZE, SMALL_PAGE_SIZE); 37 data = tee_pager_alloc(size); 38 if (!data) 39 panic(); 40 41 return mempool_alloc_pool(data, size, tee_pager_release_phys); 42 } 43 #else /* _CFG_CORE_LTC_PAGER */ 44 static struct mempool *get_mp_scratch_memory_pool(void) 45 { 46 static uint8_t data[MPI_MEMPOOL_SIZE] __aligned(MEMPOOL_ALIGN); 47 48 return mempool_alloc_pool(data, sizeof(data), NULL); 49 } 50 #endif 51 52 void init_mp_tomcrypt(void) 53 { 54 struct mempool *p = get_mp_scratch_memory_pool(); 55 56 if (!p) 57 panic(); 58 mbedtls_mpi_mempool = p; 59 assert(!mempool_default); 60 mempool_default = p; 61 } 62 63 static int init(void **a) 64 { 65 mbedtls_mpi *bn = mempool_alloc(mbedtls_mpi_mempool, sizeof(*bn)); 66 67 if (!bn) 68 return CRYPT_MEM; 69 70 mbedtls_mpi_init_mempool(bn); 71 *a = bn; 72 return CRYPT_OK; 73 } 74 75 static int init_size(int size_bits __unused, void **a) 76 { 77 return init(a); 78 } 79 80 static void deinit(void *a) 81 { 82 mbedtls_mpi_free((mbedtls_mpi *)a); 83 mempool_free(mbedtls_mpi_mempool, a); 84 } 85 86 static int neg(void *a, void *b) 87 { 88 if (mbedtls_mpi_copy(b, a)) 89 return CRYPT_MEM; 90 ((mbedtls_mpi *)b)->s *= -1; 91 return CRYPT_OK; 92 } 93 94 static int copy(void *a, void *b) 95 { 96 if (mbedtls_mpi_copy(b, a)) 97 return CRYPT_MEM; 98 return CRYPT_OK; 99 } 100 101 static int init_copy(void **a, void *b) 102 { 103 if (init(a) != CRYPT_OK) { 104 return CRYPT_MEM; 105 } 106 return copy(b, *a); 107 } 108 109 /* ---- trivial ---- */ 110 static int set_int(void *a, ltc_mp_digit b) 111 { 112 uint32_t b32 = b; 113 114 if (b32 != b) 115 return CRYPT_INVALID_ARG; 116 117 mbedtls_mpi_uint p = b32; 118 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 119 120 if (mbedtls_mpi_copy(a, &bn)) 121 return CRYPT_MEM; 122 return CRYPT_OK; 123 } 124 125 static unsigned long get_int(void *a) 126 { 127 mbedtls_mpi *bn = a; 128 129 if (!bn->n) 130 return 0; 131 132 return bn->p[bn->n - 1]; 133 } 134 135 static ltc_mp_digit get_digit(void *a, int n) 136 { 137 mbedtls_mpi *bn = a; 138 139 COMPILE_TIME_ASSERT(sizeof(ltc_mp_digit) >= sizeof(mbedtls_mpi_uint)); 140 141 if (n < 0 || (size_t)n >= bn->n) 142 return 0; 143 144 return bn->p[n]; 145 } 146 147 static int get_digit_count(void *a) 148 { 149 return ROUNDUP(mbedtls_mpi_size(a), sizeof(mbedtls_mpi_uint)) / 150 sizeof(mbedtls_mpi_uint); 151 } 152 153 static int compare(void *a, void *b) 154 { 155 int ret = mbedtls_mpi_cmp_mpi(a, b); 156 157 if (ret < 0) 158 return LTC_MP_LT; 159 160 if (ret > 0) 161 return LTC_MP_GT; 162 163 return LTC_MP_EQ; 164 } 165 166 static int compare_d(void *a, ltc_mp_digit b) 167 { 168 unsigned long v = b; 169 unsigned int shift = 31; 170 uint32_t mask = BIT(shift) - 1; 171 mbedtls_mpi bn; 172 173 mbedtls_mpi_init_mempool(&bn); 174 while (true) { 175 mbedtls_mpi_add_int(&bn, &bn, v & mask); 176 v >>= shift; 177 if (!v) 178 break; 179 mbedtls_mpi_shift_l(&bn, shift); 180 } 181 182 int ret = compare(a, &bn); 183 184 mbedtls_mpi_free(&bn); 185 186 return ret; 187 } 188 189 static int count_bits(void *a) 190 { 191 return mbedtls_mpi_bitlen(a); 192 } 193 194 static int count_lsb_bits(void *a) 195 { 196 return mbedtls_mpi_lsb(a); 197 } 198 199 200 static int twoexpt(void *a, int n) 201 { 202 if (mbedtls_mpi_set_bit(a, n, 1)) 203 return CRYPT_MEM; 204 205 return CRYPT_OK; 206 } 207 208 /* ---- conversions ---- */ 209 210 /* read ascii string */ 211 static int read_radix(void *a, const char *b, int radix) 212 { 213 int res = mbedtls_mpi_read_string(a, radix, b); 214 215 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 216 return CRYPT_MEM; 217 if (res) 218 return CRYPT_ERROR; 219 220 return CRYPT_OK; 221 } 222 223 /* write one */ 224 static int write_radix(void *a, char *b, int radix) 225 { 226 size_t ol = SIZE_MAX; 227 int res = mbedtls_mpi_write_string(a, radix, b, ol, &ol); 228 229 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 230 return CRYPT_MEM; 231 if (res) 232 return CRYPT_ERROR; 233 234 return CRYPT_OK; 235 } 236 237 /* get size as unsigned char string */ 238 static unsigned long unsigned_size(void *a) 239 { 240 return mbedtls_mpi_size(a); 241 } 242 243 /* store */ 244 static int unsigned_write(void *a, unsigned char *b) 245 { 246 int res = mbedtls_mpi_write_binary(a, b, unsigned_size(a)); 247 248 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 249 return CRYPT_MEM; 250 if (res) 251 return CRYPT_ERROR; 252 253 return CRYPT_OK; 254 } 255 256 /* read */ 257 static int unsigned_read(void *a, unsigned char *b, unsigned long len) 258 { 259 int res = mbedtls_mpi_read_binary(a, b, len); 260 261 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 262 return CRYPT_MEM; 263 if (res) 264 return CRYPT_ERROR; 265 266 return CRYPT_OK; 267 } 268 269 /* add */ 270 static int add(void *a, void *b, void *c) 271 { 272 if (mbedtls_mpi_add_mpi(c, a, b)) 273 return CRYPT_MEM; 274 275 return CRYPT_OK; 276 } 277 278 static int addi(void *a, ltc_mp_digit b, void *c) 279 { 280 uint32_t b32 = b; 281 282 if (b32 != b) 283 return CRYPT_INVALID_ARG; 284 285 mbedtls_mpi_uint p = b32; 286 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 287 288 return add(a, &bn, c); 289 } 290 291 /* sub */ 292 static int sub(void *a, void *b, void *c) 293 { 294 if (mbedtls_mpi_sub_mpi(c, a, b)) 295 return CRYPT_MEM; 296 297 return CRYPT_OK; 298 } 299 300 static int subi(void *a, ltc_mp_digit b, void *c) 301 { 302 uint32_t b32 = b; 303 304 if (b32 != b) 305 return CRYPT_INVALID_ARG; 306 307 mbedtls_mpi_uint p = b32; 308 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 309 310 return sub(a, &bn, c); 311 } 312 313 /* mul */ 314 static int mul(void *a, void *b, void *c) 315 { 316 if (mbedtls_mpi_mul_mpi(c, a, b)) 317 return CRYPT_MEM; 318 319 return CRYPT_OK; 320 } 321 322 static int muli(void *a, ltc_mp_digit b, void *c) 323 { 324 if (b > (unsigned long) UINT32_MAX) 325 return CRYPT_INVALID_ARG; 326 327 if (mbedtls_mpi_mul_int(c, a, b)) 328 return CRYPT_MEM; 329 330 return CRYPT_OK; 331 } 332 333 /* sqr */ 334 static int sqr(void *a, void *b) 335 { 336 return mul(a, a, b); 337 } 338 339 /* div */ 340 static int divide(void *a, void *b, void *c, void *d) 341 { 342 int res = mbedtls_mpi_div_mpi(c, d, a, b); 343 344 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 345 return CRYPT_MEM; 346 if (res) 347 return CRYPT_ERROR; 348 349 return CRYPT_OK; 350 } 351 352 static int div_2(void *a, void *b) 353 { 354 if (mbedtls_mpi_copy(b, a)) 355 return CRYPT_MEM; 356 357 if (mbedtls_mpi_shift_r(b, 1)) 358 return CRYPT_MEM; 359 360 return CRYPT_OK; 361 } 362 363 /* modi */ 364 static int modi(void *a, ltc_mp_digit b, ltc_mp_digit *c) 365 { 366 mbedtls_mpi bn_b; 367 mbedtls_mpi bn_c; 368 int res = 0; 369 370 mbedtls_mpi_init_mempool(&bn_b); 371 mbedtls_mpi_init_mempool(&bn_c); 372 373 res = set_int(&bn_b, b); 374 if (res) 375 return res; 376 377 res = mbedtls_mpi_mod_mpi(&bn_c, &bn_b, a); 378 if (!res) 379 *c = get_int(&bn_c); 380 381 mbedtls_mpi_free(&bn_b); 382 mbedtls_mpi_free(&bn_c); 383 384 if (res) 385 return CRYPT_MEM; 386 387 return CRYPT_OK; 388 } 389 390 /* gcd */ 391 static int gcd(void *a, void *b, void *c) 392 { 393 if (mbedtls_mpi_gcd(c, a, b)) 394 return CRYPT_MEM; 395 396 return CRYPT_OK; 397 } 398 399 /* lcm */ 400 static int lcm(void *a, void *b, void *c) 401 { 402 int res = CRYPT_MEM; 403 mbedtls_mpi tmp; 404 405 mbedtls_mpi_init_mempool(&tmp); 406 if (mbedtls_mpi_mul_mpi(&tmp, a, b)) 407 goto out; 408 409 if (mbedtls_mpi_gcd(c, a, b)) 410 goto out; 411 412 /* We use the following equality: gcd(a, b) * lcm(a, b) = a * b */ 413 res = divide(&tmp, c, c, NULL); 414 out: 415 mbedtls_mpi_free(&tmp); 416 return res; 417 } 418 419 static int mod(void *a, void *b, void *c) 420 { 421 int res = mbedtls_mpi_mod_mpi(c, a, b); 422 423 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 424 return CRYPT_MEM; 425 if (res) 426 return CRYPT_ERROR; 427 428 return CRYPT_OK; 429 } 430 431 static int mulmod(void *a, void *b, void *c, void *d) 432 { 433 int res; 434 mbedtls_mpi ta; 435 mbedtls_mpi tb; 436 437 mbedtls_mpi_init_mempool(&ta); 438 mbedtls_mpi_init_mempool(&tb); 439 440 res = mod(a, c, &ta); 441 if (res) 442 goto out; 443 res = mod(b, c, &tb); 444 if (res) 445 goto out; 446 res = mul(&ta, &tb, d); 447 if (res) 448 goto out; 449 res = mod(d, c, d); 450 out: 451 mbedtls_mpi_free(&ta); 452 mbedtls_mpi_free(&tb); 453 return res; 454 } 455 456 static int sqrmod(void *a, void *b, void *c) 457 { 458 return mulmod(a, a, b, c); 459 } 460 461 /* invmod */ 462 static int invmod(void *a, void *b, void *c) 463 { 464 int res = mbedtls_mpi_inv_mod(c, a, b); 465 466 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 467 return CRYPT_MEM; 468 if (res) 469 return CRYPT_ERROR; 470 471 return CRYPT_OK; 472 } 473 474 475 /* setup */ 476 static int montgomery_setup(void *a, void **b) 477 { 478 *b = malloc(sizeof(mbedtls_mpi_uint)); 479 if (!*b) 480 return CRYPT_MEM; 481 482 mbedtls_mpi_montg_init(*b, a); 483 484 return CRYPT_OK; 485 } 486 487 /* get normalization value */ 488 static int montgomery_normalization(void *a, void *b) 489 { 490 size_t c = ROUNDUP(mbedtls_mpi_size(b), sizeof(mbedtls_mpi_uint)) * 8; 491 492 if (mbedtls_mpi_lset(a, 1)) 493 return CRYPT_MEM; 494 if (mbedtls_mpi_shift_l(a, c)) 495 return CRYPT_MEM; 496 if (mbedtls_mpi_mod_mpi(a, a, b)) 497 return CRYPT_MEM; 498 499 return CRYPT_OK; 500 } 501 502 /* reduce */ 503 static int montgomery_reduce(void *a, void *b, void *c) 504 { 505 mbedtls_mpi A; 506 mbedtls_mpi *N = b; 507 mbedtls_mpi_uint *mm = c; 508 mbedtls_mpi T; 509 int ret = CRYPT_MEM; 510 511 mbedtls_mpi_init_mempool(&T); 512 mbedtls_mpi_init_mempool(&A); 513 514 if (mbedtls_mpi_grow(&T, (N->n + 1) * 2)) 515 goto out; 516 517 if (mbedtls_mpi_cmp_mpi(a, N) > 0) { 518 if (mbedtls_mpi_mod_mpi(&A, a, N)) 519 goto out; 520 } else { 521 if (mbedtls_mpi_copy(&A, a)) 522 goto out; 523 } 524 525 if (mbedtls_mpi_grow(&A, N->n + 1)) 526 goto out; 527 528 if (mbedtls_mpi_montred(&A, N, *mm, &T)) 529 goto out; 530 531 if (mbedtls_mpi_copy(a, &A)) 532 goto out; 533 534 ret = CRYPT_OK; 535 out: 536 mbedtls_mpi_free(&A); 537 mbedtls_mpi_free(&T); 538 539 return ret; 540 } 541 542 /* clean up */ 543 static void montgomery_deinit(void *a) 544 { 545 free(a); 546 } 547 548 /* 549 * This function calculates: 550 * d = a^b mod c 551 * 552 * @a: base 553 * @b: exponent 554 * @c: modulus 555 * @d: destination 556 */ 557 static int exptmod(void *a, void *b, void *c, void *d) 558 { 559 int res; 560 561 if (d == a || d == b || d == c) { 562 mbedtls_mpi dest; 563 564 mbedtls_mpi_init_mempool(&dest); 565 res = mbedtls_mpi_exp_mod(&dest, a, b, c, NULL); 566 if (!res) 567 res = mbedtls_mpi_copy(d, &dest); 568 mbedtls_mpi_free(&dest); 569 } else { 570 res = mbedtls_mpi_exp_mod(d, a, b, c, NULL); 571 } 572 573 if (res) 574 return CRYPT_MEM; 575 else 576 return CRYPT_OK; 577 } 578 579 static int rng_read(void *ignored __unused, unsigned char *buf, size_t blen) 580 { 581 if (crypto_rng_read(buf, blen)) 582 return MBEDTLS_ERR_MPI_FILE_IO_ERROR; 583 return 0; 584 } 585 586 static int isprime(void *a, int b __unused, int *c) 587 { 588 int res = mbedtls_mpi_is_prime(a, rng_read, NULL); 589 590 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 591 return CRYPT_MEM; 592 593 if (res) 594 *c = LTC_MP_NO; 595 else 596 *c = LTC_MP_YES; 597 598 return CRYPT_OK; 599 } 600 601 static int mpa_rand(void *a, int size) 602 { 603 if (mbedtls_mpi_fill_random(a, size, rng_read, NULL)) 604 return CRYPT_MEM; 605 606 return CRYPT_OK; 607 } 608 609 ltc_math_descriptor ltc_mp = { 610 .name = "MPI", 611 .bits_per_digit = sizeof(mbedtls_mpi_uint) * 8, 612 613 .init = &init, 614 .init_size = &init_size, 615 .init_copy = &init_copy, 616 .deinit = &deinit, 617 618 .neg = &neg, 619 .copy = ©, 620 621 .set_int = &set_int, 622 .get_int = &get_int, 623 .get_digit = &get_digit, 624 .get_digit_count = &get_digit_count, 625 .compare = &compare, 626 .compare_d = &compare_d, 627 .count_bits = &count_bits, 628 .count_lsb_bits = &count_lsb_bits, 629 .twoexpt = &twoexpt, 630 631 .read_radix = &read_radix, 632 .write_radix = &write_radix, 633 .unsigned_size = &unsigned_size, 634 .unsigned_write = &unsigned_write, 635 .unsigned_read = &unsigned_read, 636 637 .add = &add, 638 .addi = &addi, 639 .sub = &sub, 640 .subi = &subi, 641 .mul = &mul, 642 .muli = &muli, 643 .sqr = &sqr, 644 .mpdiv = ÷, 645 .div_2 = &div_2, 646 .modi = &modi, 647 .gcd = &gcd, 648 .lcm = &lcm, 649 650 .mulmod = &mulmod, 651 .sqrmod = &sqrmod, 652 .invmod = &invmod, 653 654 .montgomery_setup = &montgomery_setup, 655 .montgomery_normalization = &montgomery_normalization, 656 .montgomery_reduce = &montgomery_reduce, 657 .montgomery_deinit = &montgomery_deinit, 658 659 .exptmod = &exptmod, 660 .isprime = &isprime, 661 662 #ifdef LTC_MECC 663 #ifdef LTC_MECC_FP 664 .ecc_ptmul = <c_ecc_fp_mulmod, 665 #else 666 .ecc_ptmul = <c_ecc_mulmod, 667 #endif /* LTC_MECC_FP */ 668 .ecc_ptadd = <c_ecc_projective_add_point, 669 .ecc_ptdbl = <c_ecc_projective_dbl_point, 670 .ecc_map = <c_ecc_map, 671 #ifdef LTC_ECC_SHAMIR 672 #ifdef LTC_MECC_FP 673 .ecc_mul2add = <c_ecc_fp_mul2add, 674 #else 675 .ecc_mul2add = <c_ecc_mul2add, 676 #endif /* LTC_MECC_FP */ 677 #endif /* LTC_ECC_SHAMIR */ 678 #endif /* LTC_MECC */ 679 680 #ifdef LTC_MRSA 681 .rsa_keygen = &rsa_make_key, 682 .rsa_me = &rsa_exptmod, 683 #endif 684 .rand = &mpa_rand, 685 686 }; 687 688 size_t crypto_bignum_num_bytes(struct bignum *a) 689 { 690 return mbedtls_mpi_size((mbedtls_mpi *)a); 691 } 692 693 size_t crypto_bignum_num_bits(struct bignum *a) 694 { 695 return mbedtls_mpi_bitlen((mbedtls_mpi *)a); 696 } 697 698 int32_t crypto_bignum_compare(struct bignum *a, struct bignum *b) 699 { 700 return mbedtls_mpi_cmp_mpi((mbedtls_mpi *)a, (mbedtls_mpi *)b); 701 } 702 703 void crypto_bignum_bn2bin(const struct bignum *from, uint8_t *to) 704 { 705 mbedtls_mpi_write_binary((const mbedtls_mpi *)from, (void *)to, 706 mbedtls_mpi_size((const mbedtls_mpi *)from)); 707 } 708 709 TEE_Result crypto_bignum_bin2bn(const uint8_t *from, size_t fromsize, 710 struct bignum *to) 711 { 712 if (mbedtls_mpi_read_binary((mbedtls_mpi *)to, (const void *)from, 713 fromsize)) 714 return TEE_ERROR_BAD_PARAMETERS; 715 return TEE_SUCCESS; 716 } 717 718 void crypto_bignum_copy(struct bignum *to, const struct bignum *from) 719 { 720 mbedtls_mpi_copy((mbedtls_mpi *)to, (const mbedtls_mpi *)from); 721 } 722 723 struct bignum *crypto_bignum_allocate(size_t size_bits) 724 { 725 mbedtls_mpi *bn = malloc(sizeof(*bn)); 726 727 if (!bn) 728 return NULL; 729 730 mbedtls_mpi_init(bn); 731 if (mbedtls_mpi_grow(bn, BITS_TO_LIMBS(size_bits))) { 732 free(bn); 733 return NULL; 734 } 735 736 return (struct bignum *)bn; 737 } 738 739 void crypto_bignum_free(struct bignum *s) 740 { 741 mbedtls_mpi_free((mbedtls_mpi *)s); 742 free(s); 743 } 744 745 void crypto_bignum_clear(struct bignum *s) 746 { 747 mbedtls_mpi *bn = (mbedtls_mpi *)s; 748 749 bn->s = 1; 750 if (bn->p) 751 memset(bn->p, 0, sizeof(*bn->p) * bn->n); 752 } 753