1 // SPDX-License-Identifier: BSD-2-Clause 2 /* 3 * Copyright (c) 2018, Linaro Limited 4 */ 5 6 #include <crypto/crypto.h> 7 #include <kernel/panic.h> 8 #include <mbedtls/bignum.h> 9 #include <mempool.h> 10 #include <stdlib.h> 11 #include <string.h> 12 #include <tomcrypt.h> 13 #include <tomcrypt_mp.h> 14 #include <util.h> 15 16 #if defined(_CFG_CORE_LTC_PAGER) 17 #include <mm/core_mmu.h> 18 #include <mm/tee_pager.h> 19 #endif 20 21 /* Size needed for xtest to pass reliably on both ARM32 and ARM64 */ 22 #define MPI_MEMPOOL_SIZE (42 * 1024) 23 24 #if defined(_CFG_CORE_LTC_PAGER) 25 /* allocate pageable_zi vmem for mp scratch memory pool */ 26 static struct mempool *get_mp_scratch_memory_pool(void) 27 { 28 size_t size; 29 void *data; 30 31 size = ROUNDUP(MPI_MEMPOOL_SIZE, SMALL_PAGE_SIZE); 32 data = tee_pager_alloc(size, 0); 33 if (!data) 34 panic(); 35 36 return mempool_alloc_pool(data, size, tee_pager_release_phys); 37 } 38 #else /* _CFG_CORE_LTC_PAGER */ 39 static struct mempool *get_mp_scratch_memory_pool(void) 40 { 41 static uint8_t data[MPI_MEMPOOL_SIZE] __aligned(MEMPOOL_ALIGN); 42 43 return mempool_alloc_pool(data, sizeof(data), NULL); 44 } 45 #endif 46 47 void init_mp_tomcrypt(void) 48 { 49 struct mempool *p = get_mp_scratch_memory_pool(); 50 51 if (!p) 52 panic(); 53 mbedtls_mpi_mempool = p; 54 } 55 56 static int init(void **a) 57 { 58 mbedtls_mpi *bn = mempool_alloc(mbedtls_mpi_mempool, sizeof(*bn)); 59 60 if (!bn) 61 return CRYPT_MEM; 62 63 mbedtls_mpi_init_mempool(bn); 64 *a = bn; 65 return CRYPT_OK; 66 } 67 68 static int init_size(int size_bits __unused, void **a) 69 { 70 return init(a); 71 } 72 73 static void deinit(void *a) 74 { 75 mbedtls_mpi_free((mbedtls_mpi *)a); 76 mempool_free(mbedtls_mpi_mempool, a); 77 } 78 79 static int neg(void *a, void *b) 80 { 81 if (mbedtls_mpi_copy(b, a)) 82 return CRYPT_MEM; 83 ((mbedtls_mpi *)b)->s *= -1; 84 return CRYPT_OK; 85 } 86 87 static int copy(void *a, void *b) 88 { 89 if (mbedtls_mpi_copy(b, a)) 90 return CRYPT_MEM; 91 return CRYPT_OK; 92 } 93 94 static int init_copy(void **a, void *b) 95 { 96 if (init(a) != CRYPT_OK) { 97 return CRYPT_MEM; 98 } 99 return copy(b, *a); 100 } 101 102 /* ---- trivial ---- */ 103 static int set_int(void *a, unsigned long b) 104 { 105 uint32_t b32 = b; 106 107 if (b32 != b) 108 return CRYPT_INVALID_ARG; 109 110 mbedtls_mpi_uint p = b32; 111 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 112 113 if (mbedtls_mpi_copy(a, &bn)) 114 return CRYPT_MEM; 115 return CRYPT_OK; 116 } 117 118 static unsigned long get_int(void *a) 119 { 120 mbedtls_mpi *bn = a; 121 122 if (!bn->n) 123 return 0; 124 125 return bn->p[bn->n - 1]; 126 } 127 128 static ltc_mp_digit get_digit(void *a, int n) 129 { 130 mbedtls_mpi *bn = a; 131 132 COMPILE_TIME_ASSERT(sizeof(ltc_mp_digit) >= sizeof(mbedtls_mpi_uint)); 133 134 if (n < 0 || (size_t)n >= bn->n) 135 return 0; 136 137 return bn->p[n]; 138 } 139 140 static int get_digit_count(void *a) 141 { 142 return ROUNDUP(mbedtls_mpi_size(a), sizeof(mbedtls_mpi_uint)) / 143 sizeof(mbedtls_mpi_uint); 144 } 145 146 static int compare(void *a, void *b) 147 { 148 int ret = mbedtls_mpi_cmp_mpi(a, b); 149 150 if (ret < 0) 151 return LTC_MP_LT; 152 153 if (ret > 0) 154 return LTC_MP_GT; 155 156 return LTC_MP_EQ; 157 } 158 159 static int compare_d(void *a, unsigned long b) 160 { 161 unsigned long v = b; 162 unsigned int shift = 31; 163 uint32_t mask = BIT(shift) - 1; 164 mbedtls_mpi bn; 165 166 mbedtls_mpi_init_mempool(&bn); 167 while (true) { 168 mbedtls_mpi_add_int(&bn, &bn, v & mask); 169 v >>= shift; 170 if (!v) 171 break; 172 mbedtls_mpi_shift_l(&bn, shift); 173 } 174 175 int ret = compare(a, &bn); 176 177 mbedtls_mpi_free(&bn); 178 179 return ret; 180 } 181 182 static int count_bits(void *a) 183 { 184 return mbedtls_mpi_bitlen(a); 185 } 186 187 static int count_lsb_bits(void *a) 188 { 189 return mbedtls_mpi_lsb(a); 190 } 191 192 193 static int twoexpt(void *a, int n) 194 { 195 if (mbedtls_mpi_set_bit(a, n, 1)) 196 return CRYPT_MEM; 197 198 return CRYPT_OK; 199 } 200 201 /* ---- conversions ---- */ 202 203 /* read ascii string */ 204 static int read_radix(void *a, const char *b, int radix) 205 { 206 int res = mbedtls_mpi_read_string(a, radix, b); 207 208 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 209 return CRYPT_MEM; 210 if (res) 211 return CRYPT_ERROR; 212 213 return CRYPT_OK; 214 } 215 216 /* write one */ 217 static int write_radix(void *a, char *b, int radix) 218 { 219 size_t ol = SIZE_MAX; 220 int res = mbedtls_mpi_write_string(a, radix, b, ol, &ol); 221 222 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 223 return CRYPT_MEM; 224 if (res) 225 return CRYPT_ERROR; 226 227 return CRYPT_OK; 228 } 229 230 /* get size as unsigned char string */ 231 static unsigned long unsigned_size(void *a) 232 { 233 return mbedtls_mpi_size(a); 234 } 235 236 /* store */ 237 static int unsigned_write(void *a, unsigned char *b) 238 { 239 int res = mbedtls_mpi_write_binary(a, b, unsigned_size(a)); 240 241 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 242 return CRYPT_MEM; 243 if (res) 244 return CRYPT_ERROR; 245 246 return CRYPT_OK; 247 } 248 249 /* read */ 250 static int unsigned_read(void *a, unsigned char *b, unsigned long len) 251 { 252 int res = mbedtls_mpi_read_binary(a, b, len); 253 254 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 255 return CRYPT_MEM; 256 if (res) 257 return CRYPT_ERROR; 258 259 return CRYPT_OK; 260 } 261 262 /* add */ 263 static int add(void *a, void *b, void *c) 264 { 265 if (mbedtls_mpi_add_mpi(c, a, b)) 266 return CRYPT_MEM; 267 268 return CRYPT_OK; 269 } 270 271 static int addi(void *a, unsigned long b, void *c) 272 { 273 uint32_t b32 = b; 274 275 if (b32 != b) 276 return CRYPT_INVALID_ARG; 277 278 mbedtls_mpi_uint p = b32; 279 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 280 281 return add(a, &bn, c); 282 } 283 284 /* sub */ 285 static int sub(void *a, void *b, void *c) 286 { 287 if (mbedtls_mpi_sub_mpi(c, a, b)) 288 return CRYPT_MEM; 289 290 return CRYPT_OK; 291 } 292 293 static int subi(void *a, unsigned long b, void *c) 294 { 295 uint32_t b32 = b; 296 297 if (b32 != b) 298 return CRYPT_INVALID_ARG; 299 300 mbedtls_mpi_uint p = b32; 301 mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; 302 303 return sub(a, &bn, c); 304 } 305 306 /* mul */ 307 static int mul(void *a, void *b, void *c) 308 { 309 if (mbedtls_mpi_mul_mpi(c, a, b)) 310 return CRYPT_MEM; 311 312 return CRYPT_OK; 313 } 314 315 static int muli(void *a, unsigned long b, void *c) 316 { 317 if (b > (unsigned long) UINT32_MAX) 318 return CRYPT_INVALID_ARG; 319 320 if (mbedtls_mpi_mul_int(c, a, b)) 321 return CRYPT_MEM; 322 323 return CRYPT_OK; 324 } 325 326 /* sqr */ 327 static int sqr(void *a, void *b) 328 { 329 return mul(a, a, b); 330 } 331 332 /* div */ 333 static int divide(void *a, void *b, void *c, void *d) 334 { 335 int res = mbedtls_mpi_div_mpi(c, d, a, b); 336 337 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 338 return CRYPT_MEM; 339 if (res) 340 return CRYPT_ERROR; 341 342 return CRYPT_OK; 343 } 344 345 static int div_2(void *a, void *b) 346 { 347 if (mbedtls_mpi_copy(b, a)) 348 return CRYPT_MEM; 349 350 if (mbedtls_mpi_shift_r(b, 1)) 351 return CRYPT_MEM; 352 353 return CRYPT_OK; 354 } 355 356 /* modi */ 357 static int modi(void *a, unsigned long b, unsigned long *c) 358 { 359 mbedtls_mpi bn_b; 360 mbedtls_mpi bn_c; 361 int res = 0; 362 363 mbedtls_mpi_init_mempool(&bn_b); 364 mbedtls_mpi_init_mempool(&bn_c); 365 366 res = set_int(&bn_b, b); 367 if (res) 368 return res; 369 370 res = mbedtls_mpi_mod_mpi(&bn_c, &bn_b, a); 371 if (!res) 372 *c = get_int(&bn_c); 373 374 mbedtls_mpi_free(&bn_b); 375 mbedtls_mpi_free(&bn_c); 376 377 if (res) 378 return CRYPT_MEM; 379 380 return CRYPT_OK; 381 } 382 383 /* gcd */ 384 static int gcd(void *a, void *b, void *c) 385 { 386 if (mbedtls_mpi_gcd(c, a, b)) 387 return CRYPT_MEM; 388 389 return CRYPT_OK; 390 } 391 392 /* lcm */ 393 static int lcm(void *a, void *b, void *c) 394 { 395 int res = CRYPT_MEM; 396 mbedtls_mpi tmp; 397 398 mbedtls_mpi_init_mempool(&tmp); 399 if (mbedtls_mpi_mul_mpi(&tmp, a, b)) 400 goto out; 401 402 if (mbedtls_mpi_gcd(c, a, b)) 403 goto out; 404 405 /* We use the following equality: gcd(a, b) * lcm(a, b) = a * b */ 406 res = divide(&tmp, c, c, NULL); 407 out: 408 mbedtls_mpi_free(&tmp); 409 return res; 410 } 411 412 static int mod(void *a, void *b, void *c) 413 { 414 int res = mbedtls_mpi_mod_mpi(c, a, b); 415 416 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 417 return CRYPT_MEM; 418 if (res) 419 return CRYPT_ERROR; 420 421 return CRYPT_OK; 422 } 423 424 static int mulmod(void *a, void *b, void *c, void *d) 425 { 426 int res; 427 mbedtls_mpi ta; 428 mbedtls_mpi tb; 429 430 mbedtls_mpi_init_mempool(&ta); 431 mbedtls_mpi_init_mempool(&tb); 432 433 res = mod(a, c, &ta); 434 if (res) 435 goto out; 436 res = mod(b, c, &tb); 437 if (res) 438 goto out; 439 res = mul(&ta, &tb, d); 440 if (res) 441 goto out; 442 res = mod(d, c, d); 443 out: 444 mbedtls_mpi_free(&ta); 445 mbedtls_mpi_free(&tb); 446 return res; 447 } 448 449 static int sqrmod(void *a, void *b, void *c) 450 { 451 return mulmod(a, a, b, c); 452 } 453 454 /* invmod */ 455 static int invmod(void *a, void *b, void *c) 456 { 457 int res = mbedtls_mpi_inv_mod(c, a, b); 458 459 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 460 return CRYPT_MEM; 461 if (res) 462 return CRYPT_ERROR; 463 464 return CRYPT_OK; 465 } 466 467 468 /* setup */ 469 static int montgomery_setup(void *a, void **b) 470 { 471 *b = malloc(sizeof(mbedtls_mpi_uint)); 472 if (!*b) 473 return CRYPT_MEM; 474 475 mbedtls_mpi_montg_init(*b, a); 476 477 return CRYPT_OK; 478 } 479 480 /* get normalization value */ 481 static int montgomery_normalization(void *a, void *b) 482 { 483 size_t c = ROUNDUP(mbedtls_mpi_size(b), sizeof(mbedtls_mpi_uint)) * 8; 484 485 if (mbedtls_mpi_lset(a, 1)) 486 return CRYPT_MEM; 487 if (mbedtls_mpi_shift_l(a, c)) 488 return CRYPT_MEM; 489 if (mbedtls_mpi_mod_mpi(a, a, b)) 490 return CRYPT_MEM; 491 492 return CRYPT_OK; 493 } 494 495 /* reduce */ 496 static int montgomery_reduce(void *a, void *b, void *c) 497 { 498 mbedtls_mpi A; 499 mbedtls_mpi *N = b; 500 mbedtls_mpi_uint *mm = c; 501 mbedtls_mpi T; 502 int ret = CRYPT_MEM; 503 504 mbedtls_mpi_init_mempool(&T); 505 mbedtls_mpi_init_mempool(&A); 506 507 if (mbedtls_mpi_grow(&T, (N->n + 1) * 2)) 508 goto out; 509 510 if (mbedtls_mpi_cmp_mpi(a, N) > 0) { 511 if (mbedtls_mpi_mod_mpi(&A, a, N)) 512 goto out; 513 } else { 514 if (mbedtls_mpi_copy(&A, a)) 515 goto out; 516 } 517 518 if (mbedtls_mpi_grow(&A, N->n + 1)) 519 goto out; 520 521 if (mbedtls_mpi_montred(&A, N, *mm, &T)) 522 goto out; 523 524 if (mbedtls_mpi_copy(a, &A)) 525 goto out; 526 527 ret = CRYPT_OK; 528 out: 529 mbedtls_mpi_free(&A); 530 mbedtls_mpi_free(&T); 531 532 return ret; 533 } 534 535 /* clean up */ 536 static void montgomery_deinit(void *a) 537 { 538 free(a); 539 } 540 541 /* 542 * This function calculates: 543 * d = a^b mod c 544 * 545 * @a: base 546 * @b: exponent 547 * @c: modulus 548 * @d: destination 549 */ 550 static int exptmod(void *a, void *b, void *c, void *d) 551 { 552 int res; 553 554 if (d == a || d == b || d == c) { 555 mbedtls_mpi dest; 556 557 mbedtls_mpi_init_mempool(&dest); 558 res = mbedtls_mpi_exp_mod(&dest, a, b, c, NULL); 559 if (!res) 560 res = mbedtls_mpi_copy(d, &dest); 561 mbedtls_mpi_free(&dest); 562 } else { 563 res = mbedtls_mpi_exp_mod(d, a, b, c, NULL); 564 } 565 566 if (res) 567 return CRYPT_MEM; 568 else 569 return CRYPT_OK; 570 } 571 572 static int rng_read(void *ignored __unused, unsigned char *buf, size_t blen) 573 { 574 if (crypto_rng_read(buf, blen)) 575 return MBEDTLS_ERR_MPI_FILE_IO_ERROR; 576 return 0; 577 } 578 579 static int isprime(void *a, int b __unused, int *c) 580 { 581 int res = mbedtls_mpi_is_prime(a, rng_read, NULL); 582 583 if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) 584 return CRYPT_MEM; 585 586 if (res) 587 *c = LTC_MP_NO; 588 else 589 *c = LTC_MP_YES; 590 591 return CRYPT_OK; 592 } 593 594 static int mpa_rand(void *a, int size) 595 { 596 if (mbedtls_mpi_fill_random(a, size, rng_read, NULL)) 597 return CRYPT_MEM; 598 599 return CRYPT_OK; 600 } 601 602 ltc_math_descriptor ltc_mp = { 603 .name = "MPI", 604 .bits_per_digit = sizeof(mbedtls_mpi_uint) * 8, 605 606 .init = &init, 607 .init_size = &init_size, 608 .init_copy = &init_copy, 609 .deinit = &deinit, 610 611 .neg = &neg, 612 .copy = ©, 613 614 .set_int = &set_int, 615 .get_int = &get_int, 616 .get_digit = &get_digit, 617 .get_digit_count = &get_digit_count, 618 .compare = &compare, 619 .compare_d = &compare_d, 620 .count_bits = &count_bits, 621 .count_lsb_bits = &count_lsb_bits, 622 .twoexpt = &twoexpt, 623 624 .read_radix = &read_radix, 625 .write_radix = &write_radix, 626 .unsigned_size = &unsigned_size, 627 .unsigned_write = &unsigned_write, 628 .unsigned_read = &unsigned_read, 629 630 .add = &add, 631 .addi = &addi, 632 .sub = &sub, 633 .subi = &subi, 634 .mul = &mul, 635 .muli = &muli, 636 .sqr = &sqr, 637 .mpdiv = ÷, 638 .div_2 = &div_2, 639 .modi = &modi, 640 .gcd = &gcd, 641 .lcm = &lcm, 642 643 .mod = &mod, 644 .mulmod = &mulmod, 645 .sqrmod = &sqrmod, 646 .invmod = &invmod, 647 648 .montgomery_setup = &montgomery_setup, 649 .montgomery_normalization = &montgomery_normalization, 650 .montgomery_reduce = &montgomery_reduce, 651 .montgomery_deinit = &montgomery_deinit, 652 653 .exptmod = &exptmod, 654 .isprime = &isprime, 655 656 #ifdef LTC_MECC 657 #ifdef LTC_MECC_FP 658 .ecc_ptmul = <c_ecc_fp_mulmod, 659 #else 660 .ecc_ptmul = <c_ecc_mulmod, 661 #endif /* LTC_MECC_FP */ 662 .ecc_ptadd = <c_ecc_projective_add_point, 663 .ecc_ptdbl = <c_ecc_projective_dbl_point, 664 .ecc_map = <c_ecc_map, 665 #ifdef LTC_ECC_SHAMIR 666 #ifdef LTC_MECC_FP 667 .ecc_mul2add = <c_ecc_fp_mul2add, 668 #else 669 .ecc_mul2add = <c_ecc_mul2add, 670 #endif /* LTC_MECC_FP */ 671 #endif /* LTC_ECC_SHAMIR */ 672 #endif /* LTC_MECC */ 673 674 #ifdef LTC_MRSA 675 .rsa_keygen = &rsa_make_key, 676 .rsa_me = &rsa_exptmod, 677 #endif 678 .rand = &mpa_rand, 679 680 }; 681 682 size_t crypto_bignum_num_bytes(struct bignum *a) 683 { 684 return mbedtls_mpi_size((mbedtls_mpi *)a); 685 } 686 687 size_t crypto_bignum_num_bits(struct bignum *a) 688 { 689 return mbedtls_mpi_bitlen((mbedtls_mpi *)a); 690 } 691 692 int32_t crypto_bignum_compare(struct bignum *a, struct bignum *b) 693 { 694 return mbedtls_mpi_cmp_mpi((mbedtls_mpi *)a, (mbedtls_mpi *)b); 695 } 696 697 void crypto_bignum_bn2bin(const struct bignum *from, uint8_t *to) 698 { 699 mbedtls_mpi_write_binary((const mbedtls_mpi *)from, (void *)to, 700 mbedtls_mpi_size((const mbedtls_mpi *)from)); 701 } 702 703 TEE_Result crypto_bignum_bin2bn(const uint8_t *from, size_t fromsize, 704 struct bignum *to) 705 { 706 if (mbedtls_mpi_read_binary((mbedtls_mpi *)to, (const void *)from, 707 fromsize)) 708 return TEE_ERROR_BAD_PARAMETERS; 709 return TEE_SUCCESS; 710 } 711 712 void crypto_bignum_copy(struct bignum *to, const struct bignum *from) 713 { 714 mbedtls_mpi_copy((mbedtls_mpi *)to, (const mbedtls_mpi *)from); 715 } 716 717 struct bignum *crypto_bignum_allocate(size_t size_bits __unused) 718 { 719 mbedtls_mpi *bn = malloc(sizeof(*bn)); 720 721 if (bn) 722 mbedtls_mpi_init(bn); 723 724 return (struct bignum *)bn; 725 } 726 727 void crypto_bignum_free(struct bignum *s) 728 { 729 mbedtls_mpi_free((mbedtls_mpi *)s); 730 free(s); 731 } 732 733 void crypto_bignum_clear(struct bignum *s) 734 { 735 mbedtls_mpi *bn = (mbedtls_mpi *)s; 736 737 bn->s = 1; 738 if (bn->p) 739 memset(bn->p, 0, sizeof(*bn->p) * bn->n); 740 } 741