1 // SPDX-License-Identifier: BSD-2-Clause 2 /* 3 * Copyright (c) 2018, Linaro Limited 4 */ 5 6 #include <crypto/crypto.h> 7 #include <kernel/panic.h> 8 #include <mbedtls/bignum.h> 9 #include <mempool.h> 10 #include <stdlib.h> 11 #include <string.h> 12 #include <tomcrypt.h> 13 #include <tomcrypt_mp.h> 14 #include <util.h> 15 16 #if defined(_CFG_CORE_LTC_PAGER) 17 #include <mm/core_mmu.h> 18 #include <mm/tee_pager.h> 19 #endif 20 21 /* Size needed for xtest to pass reliably on both ARM32 and ARM64 */ 22 #define MPI_MEMPOOL_SIZE (42 * 1024) 23 24 /* From mbedtls/library/bignum.c */ 25 #define ciL (sizeof(mbedtls_mpi_uint)) /* chars in limb */ 26 #define biL (ciL << 3) /* bits in limb */ 27 #define BITS_TO_LIMBS(i) ((i) / biL + ((i) % biL != 0)) 28 29 #if defined(_CFG_CORE_LTC_PAGER) 30 /* allocate pageable_zi vmem for mp scratch memory pool */ 31 static struct mempool *get_mp_scratch_memory_pool(void) 32 { 33 size_t size; 34 void *data; 35 36 size = ROUNDUP(MPI_MEMPOOL_SIZE, SMALL_PAGE_SIZE); 37 data = tee_pager_alloc(size, 0); 38 if (!data) 39 panic(); 40 41 return mempool_alloc_pool(data, size, tee_pager_release_phys); 42 } 43 #else /* _CFG_CORE_LTC_PAGER */ 44 static struct mempool *get_mp_scratch_memory_pool(void) 45 { 46 static uint8_t data[MPI_MEMPOOL_SIZE] __aligned(MEMPOOL_ALIGN); 47 48 return mempool_alloc_pool(data, sizeof(data), NULL); 49 } 50 #endif 51 52 void init_mp_tomcrypt(void) 53 { 54 struct mempool *p = get_mp_scratch_memory_pool(); 55 56 if (!p) 57 panic(); 58 mbedtls_mpi_mempool = p; 59 } 60 61 static int init(void **a) 62 { 63 mbedtls_mpi *bn = mempool_alloc(mbedtls_mpi_mempool, sizeof(*bn)); 64 65 if (!bn) 66 return CRYPT_MEM; 67 68 mbedtls_mpi_init_mempool(bn); 69 *a = bn; 70 return CRYPT_OK; 71 } 72 73 static int init_size(int size_bits __unused, void **a) 74 { 75 return init(a); 76 } 77 78 static void deinit(void *a) 79 { 80 mbedtls_mpi_free((mbedtls_mpi *)a); 81 mempool_free(mbedtls_mpi_mempool, a); 82 } 83 84 static int neg(void *a, void *b) 85 { 86 if (mbedtls_mpi_copy(b, a)) 87 return CRYPT_MEM; 88 ((mbedtls_mpi *)b)->s *= -1; 89 return CRYPT_OK; 90 } 91 92 static int copy(void *a, void *b) 93 { 94 if (mbedtls_mpi_copy(b, a)) 95 return CRYPT_MEM; 96 return CRYPT_OK; 97 } 98 99 static int init_copy(void **a, void *b) 100 { 101 if (init(a) != CRYPT_OK) { 102 return CRYPT_MEM; 103 } 104 return copy(b, *a); 105 } 106 107 /* ---- trivial ---- */ 108 static int set_int(void *a, unsigned long b) 109 { 110 uint32_t b32 = b; 111 112 if (b32 != b) 113 return CRYPT_INVALID_ARG; 114 115 mbedtls_mpi_uint p = b32; 116 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 117 118 if (mbedtls_mpi_copy(a, &bn)) 119 return CRYPT_MEM; 120 return CRYPT_OK; 121 } 122 123 static unsigned long get_int(void *a) 124 { 125 mbedtls_mpi *bn = a; 126 127 if (!bn->n) 128 return 0; 129 130 return bn->p[bn->n - 1]; 131 } 132 133 static ltc_mp_digit get_digit(void *a, int n) 134 { 135 mbedtls_mpi *bn = a; 136 137 COMPILE_TIME_ASSERT(sizeof(ltc_mp_digit) >= sizeof(mbedtls_mpi_uint)); 138 139 if (n < 0 || (size_t)n >= bn->n) 140 return 0; 141 142 return bn->p[n]; 143 } 144 145 static int get_digit_count(void *a) 146 { 147 return ROUNDUP(mbedtls_mpi_size(a), sizeof(mbedtls_mpi_uint)) / 148 sizeof(mbedtls_mpi_uint); 149 } 150 151 static int compare(void *a, void *b) 152 { 153 int ret = mbedtls_mpi_cmp_mpi(a, b); 154 155 if (ret < 0) 156 return LTC_MP_LT; 157 158 if (ret > 0) 159 return LTC_MP_GT; 160 161 return LTC_MP_EQ; 162 } 163 164 static int compare_d(void *a, unsigned long b) 165 { 166 unsigned long v = b; 167 unsigned int shift = 31; 168 uint32_t mask = BIT(shift) - 1; 169 mbedtls_mpi bn; 170 171 mbedtls_mpi_init_mempool(&bn); 172 while (true) { 173 mbedtls_mpi_add_int(&bn, &bn, v & mask); 174 v >>= shift; 175 if (!v) 176 break; 177 mbedtls_mpi_shift_l(&bn, shift); 178 } 179 180 int ret = compare(a, &bn); 181 182 mbedtls_mpi_free(&bn); 183 184 return ret; 185 } 186 187 static int count_bits(void *a) 188 { 189 return mbedtls_mpi_bitlen(a); 190 } 191 192 static int count_lsb_bits(void *a) 193 { 194 return mbedtls_mpi_lsb(a); 195 } 196 197 198 static int twoexpt(void *a, int n) 199 { 200 if (mbedtls_mpi_set_bit(a, n, 1)) 201 return CRYPT_MEM; 202 203 return CRYPT_OK; 204 } 205 206 /* ---- conversions ---- */ 207 208 /* read ascii string */ 209 static int read_radix(void *a, const char *b, int radix) 210 { 211 int res = mbedtls_mpi_read_string(a, radix, b); 212 213 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 214 return CRYPT_MEM; 215 if (res) 216 return CRYPT_ERROR; 217 218 return CRYPT_OK; 219 } 220 221 /* write one */ 222 static int write_radix(void *a, char *b, int radix) 223 { 224 size_t ol = SIZE_MAX; 225 int res = mbedtls_mpi_write_string(a, radix, b, ol, &ol); 226 227 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 228 return CRYPT_MEM; 229 if (res) 230 return CRYPT_ERROR; 231 232 return CRYPT_OK; 233 } 234 235 /* get size as unsigned char string */ 236 static unsigned long unsigned_size(void *a) 237 { 238 return mbedtls_mpi_size(a); 239 } 240 241 /* store */ 242 static int unsigned_write(void *a, unsigned char *b) 243 { 244 int res = mbedtls_mpi_write_binary(a, b, unsigned_size(a)); 245 246 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 247 return CRYPT_MEM; 248 if (res) 249 return CRYPT_ERROR; 250 251 return CRYPT_OK; 252 } 253 254 /* read */ 255 static int unsigned_read(void *a, unsigned char *b, unsigned long len) 256 { 257 int res = mbedtls_mpi_read_binary(a, b, len); 258 259 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 260 return CRYPT_MEM; 261 if (res) 262 return CRYPT_ERROR; 263 264 return CRYPT_OK; 265 } 266 267 /* add */ 268 static int add(void *a, void *b, void *c) 269 { 270 if (mbedtls_mpi_add_mpi(c, a, b)) 271 return CRYPT_MEM; 272 273 return CRYPT_OK; 274 } 275 276 static int addi(void *a, unsigned long b, void *c) 277 { 278 uint32_t b32 = b; 279 280 if (b32 != b) 281 return CRYPT_INVALID_ARG; 282 283 mbedtls_mpi_uint p = b32; 284 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 285 286 return add(a, &bn, c); 287 } 288 289 /* sub */ 290 static int sub(void *a, void *b, void *c) 291 { 292 if (mbedtls_mpi_sub_mpi(c, a, b)) 293 return CRYPT_MEM; 294 295 return CRYPT_OK; 296 } 297 298 static int subi(void *a, unsigned long b, void *c) 299 { 300 uint32_t b32 = b; 301 302 if (b32 != b) 303 return CRYPT_INVALID_ARG; 304 305 mbedtls_mpi_uint p = b32; 306 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 307 308 return sub(a, &bn, c); 309 } 310 311 /* mul */ 312 static int mul(void *a, void *b, void *c) 313 { 314 if (mbedtls_mpi_mul_mpi(c, a, b)) 315 return CRYPT_MEM; 316 317 return CRYPT_OK; 318 } 319 320 static int muli(void *a, unsigned long b, void *c) 321 { 322 if (b > (unsigned long) UINT32_MAX) 323 return CRYPT_INVALID_ARG; 324 325 if (mbedtls_mpi_mul_int(c, a, b)) 326 return CRYPT_MEM; 327 328 return CRYPT_OK; 329 } 330 331 /* sqr */ 332 static int sqr(void *a, void *b) 333 { 334 return mul(a, a, b); 335 } 336 337 /* div */ 338 static int divide(void *a, void *b, void *c, void *d) 339 { 340 int res = mbedtls_mpi_div_mpi(c, d, a, b); 341 342 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 343 return CRYPT_MEM; 344 if (res) 345 return CRYPT_ERROR; 346 347 return CRYPT_OK; 348 } 349 350 static int div_2(void *a, void *b) 351 { 352 if (mbedtls_mpi_copy(b, a)) 353 return CRYPT_MEM; 354 355 if (mbedtls_mpi_shift_r(b, 1)) 356 return CRYPT_MEM; 357 358 return CRYPT_OK; 359 } 360 361 /* modi */ 362 static int modi(void *a, unsigned long b, unsigned long *c) 363 { 364 mbedtls_mpi bn_b; 365 mbedtls_mpi bn_c; 366 int res = 0; 367 368 mbedtls_mpi_init_mempool(&bn_b); 369 mbedtls_mpi_init_mempool(&bn_c); 370 371 res = set_int(&bn_b, b); 372 if (res) 373 return res; 374 375 res = mbedtls_mpi_mod_mpi(&bn_c, &bn_b, a); 376 if (!res) 377 *c = get_int(&bn_c); 378 379 mbedtls_mpi_free(&bn_b); 380 mbedtls_mpi_free(&bn_c); 381 382 if (res) 383 return CRYPT_MEM; 384 385 return CRYPT_OK; 386 } 387 388 /* gcd */ 389 static int gcd(void *a, void *b, void *c) 390 { 391 if (mbedtls_mpi_gcd(c, a, b)) 392 return CRYPT_MEM; 393 394 return CRYPT_OK; 395 } 396 397 /* lcm */ 398 static int lcm(void *a, void *b, void *c) 399 { 400 int res = CRYPT_MEM; 401 mbedtls_mpi tmp; 402 403 mbedtls_mpi_init_mempool(&tmp); 404 if (mbedtls_mpi_mul_mpi(&tmp, a, b)) 405 goto out; 406 407 if (mbedtls_mpi_gcd(c, a, b)) 408 goto out; 409 410 /* We use the following equality: gcd(a, b) * lcm(a, b) = a * b */ 411 res = divide(&tmp, c, c, NULL); 412 out: 413 mbedtls_mpi_free(&tmp); 414 return res; 415 } 416 417 static int mod(void *a, void *b, void *c) 418 { 419 int res = mbedtls_mpi_mod_mpi(c, a, b); 420 421 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 422 return CRYPT_MEM; 423 if (res) 424 return CRYPT_ERROR; 425 426 return CRYPT_OK; 427 } 428 429 static int mulmod(void *a, void *b, void *c, void *d) 430 { 431 int res; 432 mbedtls_mpi ta; 433 mbedtls_mpi tb; 434 435 mbedtls_mpi_init_mempool(&ta); 436 mbedtls_mpi_init_mempool(&tb); 437 438 res = mod(a, c, &ta); 439 if (res) 440 goto out; 441 res = mod(b, c, &tb); 442 if (res) 443 goto out; 444 res = mul(&ta, &tb, d); 445 if (res) 446 goto out; 447 res = mod(d, c, d); 448 out: 449 mbedtls_mpi_free(&ta); 450 mbedtls_mpi_free(&tb); 451 return res; 452 } 453 454 static int sqrmod(void *a, void *b, void *c) 455 { 456 return mulmod(a, a, b, c); 457 } 458 459 /* invmod */ 460 static int invmod(void *a, void *b, void *c) 461 { 462 int res = mbedtls_mpi_inv_mod(c, a, b); 463 464 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 465 return CRYPT_MEM; 466 if (res) 467 return CRYPT_ERROR; 468 469 return CRYPT_OK; 470 } 471 472 473 /* setup */ 474 static int montgomery_setup(void *a, void **b) 475 { 476 *b = malloc(sizeof(mbedtls_mpi_uint)); 477 if (!*b) 478 return CRYPT_MEM; 479 480 mbedtls_mpi_montg_init(*b, a); 481 482 return CRYPT_OK; 483 } 484 485 /* get normalization value */ 486 static int montgomery_normalization(void *a, void *b) 487 { 488 size_t c = ROUNDUP(mbedtls_mpi_size(b), sizeof(mbedtls_mpi_uint)) * 8; 489 490 if (mbedtls_mpi_lset(a, 1)) 491 return CRYPT_MEM; 492 if (mbedtls_mpi_shift_l(a, c)) 493 return CRYPT_MEM; 494 if (mbedtls_mpi_mod_mpi(a, a, b)) 495 return CRYPT_MEM; 496 497 return CRYPT_OK; 498 } 499 500 /* reduce */ 501 static int montgomery_reduce(void *a, void *b, void *c) 502 { 503 mbedtls_mpi A; 504 mbedtls_mpi *N = b; 505 mbedtls_mpi_uint *mm = c; 506 mbedtls_mpi T; 507 int ret = CRYPT_MEM; 508 509 mbedtls_mpi_init_mempool(&T); 510 mbedtls_mpi_init_mempool(&A); 511 512 if (mbedtls_mpi_grow(&T, (N->n + 1) * 2)) 513 goto out; 514 515 if (mbedtls_mpi_cmp_mpi(a, N) > 0) { 516 if (mbedtls_mpi_mod_mpi(&A, a, N)) 517 goto out; 518 } else { 519 if (mbedtls_mpi_copy(&A, a)) 520 goto out; 521 } 522 523 if (mbedtls_mpi_grow(&A, N->n + 1)) 524 goto out; 525 526 if (mbedtls_mpi_montred(&A, N, *mm, &T)) 527 goto out; 528 529 if (mbedtls_mpi_copy(a, &A)) 530 goto out; 531 532 ret = CRYPT_OK; 533 out: 534 mbedtls_mpi_free(&A); 535 mbedtls_mpi_free(&T); 536 537 return ret; 538 } 539 540 /* clean up */ 541 static void montgomery_deinit(void *a) 542 { 543 free(a); 544 } 545 546 /* 547 * This function calculates: 548 * d = a^b mod c 549 * 550 * @a: base 551 * @b: exponent 552 * @c: modulus 553 * @d: destination 554 */ 555 static int exptmod(void *a, void *b, void *c, void *d) 556 { 557 int res; 558 559 if (d == a || d == b || d == c) { 560 mbedtls_mpi dest; 561 562 mbedtls_mpi_init_mempool(&dest); 563 res = mbedtls_mpi_exp_mod(&dest, a, b, c, NULL); 564 if (!res) 565 res = mbedtls_mpi_copy(d, &dest); 566 mbedtls_mpi_free(&dest); 567 } else { 568 res = mbedtls_mpi_exp_mod(d, a, b, c, NULL); 569 } 570 571 if (res) 572 return CRYPT_MEM; 573 else 574 return CRYPT_OK; 575 } 576 577 static int rng_read(void *ignored __unused, unsigned char *buf, size_t blen) 578 { 579 if (crypto_rng_read(buf, blen)) 580 return MBEDTLS_ERR_MPI_FILE_IO_ERROR; 581 return 0; 582 } 583 584 static int isprime(void *a, int b __unused, int *c) 585 { 586 int res = mbedtls_mpi_is_prime(a, rng_read, NULL); 587 588 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 589 return CRYPT_MEM; 590 591 if (res) 592 *c = LTC_MP_NO; 593 else 594 *c = LTC_MP_YES; 595 596 return CRYPT_OK; 597 } 598 599 static int mpa_rand(void *a, int size) 600 { 601 if (mbedtls_mpi_fill_random(a, size, rng_read, NULL)) 602 return CRYPT_MEM; 603 604 return CRYPT_OK; 605 } 606 607 ltc_math_descriptor ltc_mp = { 608 .name = "MPI", 609 .bits_per_digit = sizeof(mbedtls_mpi_uint) * 8, 610 611 .init = &init, 612 .init_size = &init_size, 613 .init_copy = &init_copy, 614 .deinit = &deinit, 615 616 .neg = &neg, 617 .copy = ©, 618 619 .set_int = &set_int, 620 .get_int = &get_int, 621 .get_digit = &get_digit, 622 .get_digit_count = &get_digit_count, 623 .compare = &compare, 624 .compare_d = &compare_d, 625 .count_bits = &count_bits, 626 .count_lsb_bits = &count_lsb_bits, 627 .twoexpt = &twoexpt, 628 629 .read_radix = &read_radix, 630 .write_radix = &write_radix, 631 .unsigned_size = &unsigned_size, 632 .unsigned_write = &unsigned_write, 633 .unsigned_read = &unsigned_read, 634 635 .add = &add, 636 .addi = &addi, 637 .sub = &sub, 638 .subi = &subi, 639 .mul = &mul, 640 .muli = &muli, 641 .sqr = &sqr, 642 .mpdiv = ÷, 643 .div_2 = &div_2, 644 .modi = &modi, 645 .gcd = &gcd, 646 .lcm = &lcm, 647 648 .mod = &mod, 649 .mulmod = &mulmod, 650 .sqrmod = &sqrmod, 651 .invmod = &invmod, 652 653 .montgomery_setup = &montgomery_setup, 654 .montgomery_normalization = &montgomery_normalization, 655 .montgomery_reduce = &montgomery_reduce, 656 .montgomery_deinit = &montgomery_deinit, 657 658 .exptmod = &exptmod, 659 .isprime = &isprime, 660 661 #ifdef LTC_MECC 662 #ifdef LTC_MECC_FP 663 .ecc_ptmul = <c_ecc_fp_mulmod, 664 #else 665 .ecc_ptmul = <c_ecc_mulmod, 666 #endif /* LTC_MECC_FP */ 667 .ecc_ptadd = <c_ecc_projective_add_point, 668 .ecc_ptdbl = <c_ecc_projective_dbl_point, 669 .ecc_map = <c_ecc_map, 670 #ifdef LTC_ECC_SHAMIR 671 #ifdef LTC_MECC_FP 672 .ecc_mul2add = <c_ecc_fp_mul2add, 673 #else 674 .ecc_mul2add = <c_ecc_mul2add, 675 #endif /* LTC_MECC_FP */ 676 #endif /* LTC_ECC_SHAMIR */ 677 #endif /* LTC_MECC */ 678 679 #ifdef LTC_MRSA 680 .rsa_keygen = &rsa_make_key, 681 .rsa_me = &rsa_exptmod, 682 #endif 683 .rand = &mpa_rand, 684 685 }; 686 687 size_t crypto_bignum_num_bytes(struct bignum *a) 688 { 689 return mbedtls_mpi_size((mbedtls_mpi *)a); 690 } 691 692 size_t crypto_bignum_num_bits(struct bignum *a) 693 { 694 return mbedtls_mpi_bitlen((mbedtls_mpi *)a); 695 } 696 697 int32_t crypto_bignum_compare(struct bignum *a, struct bignum *b) 698 { 699 return mbedtls_mpi_cmp_mpi((mbedtls_mpi *)a, (mbedtls_mpi *)b); 700 } 701 702 void crypto_bignum_bn2bin(const struct bignum *from, uint8_t *to) 703 { 704 mbedtls_mpi_write_binary((const mbedtls_mpi *)from, (void *)to, 705 mbedtls_mpi_size((const mbedtls_mpi *)from)); 706 } 707 708 TEE_Result crypto_bignum_bin2bn(const uint8_t *from, size_t fromsize, 709 struct bignum *to) 710 { 711 if (mbedtls_mpi_read_binary((mbedtls_mpi *)to, (const void *)from, 712 fromsize)) 713 return TEE_ERROR_BAD_PARAMETERS; 714 return TEE_SUCCESS; 715 } 716 717 void crypto_bignum_copy(struct bignum *to, const struct bignum *from) 718 { 719 mbedtls_mpi_copy((mbedtls_mpi *)to, (const mbedtls_mpi *)from); 720 } 721 722 struct bignum *crypto_bignum_allocate(size_t size_bits) 723 { 724 mbedtls_mpi *bn = malloc(sizeof(*bn)); 725 726 if (!bn) 727 return NULL; 728 729 mbedtls_mpi_init(bn); 730 if (mbedtls_mpi_grow(bn, BITS_TO_LIMBS(size_bits))) { 731 free(bn); 732 return NULL; 733 } 734 735 return (struct bignum *)bn; 736 } 737 738 void crypto_bignum_free(struct bignum *s) 739 { 740 mbedtls_mpi_free((mbedtls_mpi *)s); 741 free(s); 742 } 743 744 void crypto_bignum_clear(struct bignum *s) 745 { 746 mbedtls_mpi *bn = (mbedtls_mpi *)s; 747 748 bn->s = 1; 749 if (bn->p) 750 memset(bn->p, 0, sizeof(*bn->p) * bn->n); 751 } 752