xref: /optee_os/core/lib/libtomcrypt/mpi_desc.c (revision e1b46449b6be0e20dc33030df1a31b23b1397d6f)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (c) 2018, Linaro Limited
4  */
5 
6 #include <crypto/crypto.h>
7 #include <kernel/panic.h>
8 #include <mbedtls/bignum.h>
9 #include <mempool.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <tomcrypt_private.h>
13 #include <tomcrypt_mp.h>
14 #include <util.h>
15 
16 #if defined(_CFG_CORE_LTC_PAGER)
17 #include <mm/core_mmu.h>
18 #include <mm/tee_pager.h>
19 #endif
20 
21 /* Size needed for xtest to pass reliably on both ARM32 and ARM64 */
22 #define MPI_MEMPOOL_SIZE	(42 * 1024)
23 
24 /* From mbedtls/library/bignum.c */
25 #define ciL		(sizeof(mbedtls_mpi_uint))	/* chars in limb  */
26 #define biL		(ciL << 3)			/* bits  in limb  */
27 #define BITS_TO_LIMBS(i)	((i) / biL + ((i) % biL != 0))
28 
29 #if defined(_CFG_CORE_LTC_PAGER)
30 /* allocate pageable_zi vmem for mp scratch memory pool */
31 static struct mempool *get_mp_scratch_memory_pool(void)
32 {
33 	size_t size;
34 	void *data;
35 
36 	size = ROUNDUP(MPI_MEMPOOL_SIZE, SMALL_PAGE_SIZE);
37 	data = tee_pager_alloc(size);
38 	if (!data)
39 		panic();
40 
41 	return mempool_alloc_pool(data, size, tee_pager_release_phys);
42 }
43 #else /* _CFG_CORE_LTC_PAGER */
44 static struct mempool *get_mp_scratch_memory_pool(void)
45 {
46 	static uint8_t data[MPI_MEMPOOL_SIZE] __aligned(MEMPOOL_ALIGN);
47 
48 	return mempool_alloc_pool(data, sizeof(data), NULL);
49 }
50 #endif
51 
52 void init_mp_tomcrypt(void)
53 {
54 	struct mempool *p = get_mp_scratch_memory_pool();
55 
56 	if (!p)
57 		panic();
58 	mbedtls_mpi_mempool = p;
59 	assert(!mempool_default);
60 	mempool_default = p;
61 }
62 
63 static int init(void **a)
64 {
65 	mbedtls_mpi *bn = mempool_alloc(mbedtls_mpi_mempool, sizeof(*bn));
66 
67 	if (!bn)
68 		return CRYPT_MEM;
69 
70 	mbedtls_mpi_init_mempool(bn);
71 	*a = bn;
72 	return CRYPT_OK;
73 }
74 
75 static int init_size(int size_bits __unused, void **a)
76 {
77 	return init(a);
78 }
79 
80 static void deinit(void *a)
81 {
82 	mbedtls_mpi_free((mbedtls_mpi *)a);
83 	mempool_free(mbedtls_mpi_mempool, a);
84 }
85 
86 static int neg(void *a, void *b)
87 {
88 	if (mbedtls_mpi_copy(b, a))
89 		return CRYPT_MEM;
90 	((mbedtls_mpi *)b)->s *= -1;
91 	return CRYPT_OK;
92 }
93 
94 static int copy(void *a, void *b)
95 {
96 	if (mbedtls_mpi_copy(b, a))
97 		return CRYPT_MEM;
98 	return CRYPT_OK;
99 }
100 
101 static int init_copy(void **a, void *b)
102 {
103 	if (init(a) != CRYPT_OK) {
104 		return CRYPT_MEM;
105 	}
106 	return copy(b, *a);
107 }
108 
109 /* ---- trivial ---- */
110 static int set_int(void *a, ltc_mp_digit b)
111 {
112 	uint32_t b32 = b;
113 
114 	if (b32 != b)
115 		return CRYPT_INVALID_ARG;
116 
117 	mbedtls_mpi_uint p = b32;
118 	mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p };
119 
120 	if (mbedtls_mpi_copy(a, &bn))
121 		return CRYPT_MEM;
122 	return CRYPT_OK;
123 }
124 
125 static unsigned long get_int(void *a)
126 {
127 	mbedtls_mpi *bn = a;
128 
129 	if (!bn->n)
130 		return 0;
131 
132 	return bn->p[bn->n - 1];
133 }
134 
135 static ltc_mp_digit get_digit(void *a, int n)
136 {
137 	mbedtls_mpi *bn = a;
138 
139 	COMPILE_TIME_ASSERT(sizeof(ltc_mp_digit) >= sizeof(mbedtls_mpi_uint));
140 
141 	if (n < 0 || (size_t)n >= bn->n)
142 		return 0;
143 
144 	return bn->p[n];
145 }
146 
147 static int get_digit_count(void *a)
148 {
149 	return ROUNDUP(mbedtls_mpi_size(a), sizeof(mbedtls_mpi_uint)) /
150 	       sizeof(mbedtls_mpi_uint);
151 }
152 
153 static int compare(void *a, void *b)
154 {
155 	int ret = mbedtls_mpi_cmp_mpi(a, b);
156 
157 	if (ret < 0)
158 		return LTC_MP_LT;
159 
160 	if (ret > 0)
161 		return LTC_MP_GT;
162 
163 	return LTC_MP_EQ;
164 }
165 
166 static int compare_d(void *a, ltc_mp_digit b)
167 {
168 	unsigned long v = b;
169 	unsigned int shift = 31;
170 	uint32_t mask = BIT(shift) - 1;
171 	mbedtls_mpi bn;
172 
173 	mbedtls_mpi_init_mempool(&bn);
174 	while (true) {
175 		mbedtls_mpi_add_int(&bn, &bn, v & mask);
176 		v >>= shift;
177 		if (!v)
178 			break;
179 		mbedtls_mpi_shift_l(&bn, shift);
180 	}
181 
182 	int ret = compare(a, &bn);
183 
184 	mbedtls_mpi_free(&bn);
185 
186 	return ret;
187 }
188 
189 static int count_bits(void *a)
190 {
191 	return mbedtls_mpi_bitlen(a);
192 }
193 
194 static int count_lsb_bits(void *a)
195 {
196 	return mbedtls_mpi_lsb(a);
197 }
198 
199 
200 static int twoexpt(void *a, int n)
201 {
202 	if (mbedtls_mpi_set_bit(a, n, 1))
203 		return CRYPT_MEM;
204 
205 	return CRYPT_OK;
206 }
207 
208 /* ---- conversions ---- */
209 
210 /* read ascii string */
211 static int read_radix(void *a, const char *b, int radix)
212 {
213 	int res = mbedtls_mpi_read_string(a, radix, b);
214 
215 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
216 		return CRYPT_MEM;
217 	if (res)
218 		return CRYPT_ERROR;
219 
220 	return CRYPT_OK;
221 }
222 
223 /* write one */
224 static int write_radix(void *a, char *b, int radix)
225 {
226 	size_t ol = SIZE_MAX;
227 	int res = mbedtls_mpi_write_string(a, radix, b, ol, &ol);
228 
229 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
230 		return CRYPT_MEM;
231 	if (res)
232 		return CRYPT_ERROR;
233 
234 	return CRYPT_OK;
235 }
236 
237 /* get size as unsigned char string */
238 static unsigned long unsigned_size(void *a)
239 {
240 	return mbedtls_mpi_size(a);
241 }
242 
243 /* store */
244 static int unsigned_write(void *a, unsigned char *b)
245 {
246 	int res = mbedtls_mpi_write_binary(a, b, unsigned_size(a));
247 
248 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
249 		return CRYPT_MEM;
250 	if (res)
251 		return CRYPT_ERROR;
252 
253 	return CRYPT_OK;
254 }
255 
256 /* read */
257 static int unsigned_read(void *a, unsigned char *b, unsigned long len)
258 {
259 	int res = mbedtls_mpi_read_binary(a, b, len);
260 
261 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
262 		return CRYPT_MEM;
263 	if (res)
264 		return CRYPT_ERROR;
265 
266 	return CRYPT_OK;
267 }
268 
269 /* add */
270 static int add(void *a, void *b, void *c)
271 {
272 	if (mbedtls_mpi_add_mpi(c, a, b))
273 		return CRYPT_MEM;
274 
275 	return CRYPT_OK;
276 }
277 
278 static int addi(void *a, ltc_mp_digit b, void *c)
279 {
280 	uint32_t b32 = b;
281 
282 	if (b32 != b)
283 		return CRYPT_INVALID_ARG;
284 
285 	mbedtls_mpi_uint p = b32;
286 	mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p };
287 
288 	return add(a, &bn, c);
289 }
290 
291 /* sub */
292 static int sub(void *a, void *b, void *c)
293 {
294 	if (mbedtls_mpi_sub_mpi(c, a, b))
295 		return CRYPT_MEM;
296 
297 	return CRYPT_OK;
298 }
299 
300 static int subi(void *a, ltc_mp_digit b, void *c)
301 {
302 	uint32_t b32 = b;
303 
304 	if (b32 != b)
305 		return CRYPT_INVALID_ARG;
306 
307 	mbedtls_mpi_uint p = b32;
308 	mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p };
309 
310 	return sub(a, &bn, c);
311 }
312 
313 /* mul */
314 static int mul(void *a, void *b, void *c)
315 {
316 	if (mbedtls_mpi_mul_mpi(c, a, b))
317 		return CRYPT_MEM;
318 
319 	return CRYPT_OK;
320 }
321 
322 static int muli(void *a, ltc_mp_digit b, void *c)
323 {
324 	if (b > (unsigned long) UINT32_MAX)
325 		return CRYPT_INVALID_ARG;
326 
327 	if (mbedtls_mpi_mul_int(c, a, b))
328 		return CRYPT_MEM;
329 
330 	return CRYPT_OK;
331 }
332 
333 /* sqr */
334 static int sqr(void *a, void *b)
335 {
336 	return mul(a, a, b);
337 }
338 
339 /* div */
340 static int divide(void *a, void *b, void *c, void *d)
341 {
342 	int res = mbedtls_mpi_div_mpi(c, d, a, b);
343 
344 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
345 		return CRYPT_MEM;
346 	if (res)
347 		return CRYPT_ERROR;
348 
349 	return CRYPT_OK;
350 }
351 
352 static int div_2(void *a, void *b)
353 {
354 	if (mbedtls_mpi_copy(b, a))
355 		return CRYPT_MEM;
356 
357 	if (mbedtls_mpi_shift_r(b, 1))
358 		return CRYPT_MEM;
359 
360 	return CRYPT_OK;
361 }
362 
363 /* modi */
364 static int modi(void *a, ltc_mp_digit b, ltc_mp_digit *c)
365 {
366 	mbedtls_mpi bn_b;
367 	mbedtls_mpi bn_c;
368 	int res = 0;
369 
370 	mbedtls_mpi_init_mempool(&bn_b);
371 	mbedtls_mpi_init_mempool(&bn_c);
372 
373 	res = set_int(&bn_b, b);
374 	if (res)
375 		return res;
376 
377 	res = mbedtls_mpi_mod_mpi(&bn_c, &bn_b, a);
378 	if (!res)
379 		*c = get_int(&bn_c);
380 
381 	mbedtls_mpi_free(&bn_b);
382 	mbedtls_mpi_free(&bn_c);
383 
384 	if (res)
385 		return CRYPT_MEM;
386 
387 	return CRYPT_OK;
388 }
389 
390 /* gcd */
391 static int gcd(void *a, void *b, void *c)
392 {
393 	if (mbedtls_mpi_gcd(c, a, b))
394 		return CRYPT_MEM;
395 
396 	return CRYPT_OK;
397 }
398 
399 /* lcm */
400 static int lcm(void *a, void *b, void *c)
401 {
402 	int res = CRYPT_MEM;
403 	mbedtls_mpi tmp;
404 
405 	mbedtls_mpi_init_mempool(&tmp);
406 	if (mbedtls_mpi_mul_mpi(&tmp, a, b))
407 		goto out;
408 
409 	if (mbedtls_mpi_gcd(c, a, b))
410 		goto out;
411 
412 	/* We use the following equality: gcd(a, b) * lcm(a, b) = a * b */
413 	res = divide(&tmp, c, c, NULL);
414 out:
415 	mbedtls_mpi_free(&tmp);
416 	return res;
417 }
418 
419 static int mod(void *a, void *b, void *c)
420 {
421 	int res = mbedtls_mpi_mod_mpi(c, a, b);
422 
423 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
424 		return CRYPT_MEM;
425 	if (res)
426 		return CRYPT_ERROR;
427 
428 	return CRYPT_OK;
429 }
430 
431 static int addmod(void *a, void *b, void *c, void *d)
432 {
433 	int res = add(a, b, d);
434 
435 	if (res)
436 		return res;
437 
438 	return mod(d, c, d);
439 }
440 
441 static int submod(void *a, void *b, void *c, void *d)
442 {
443 	int res = sub(a, b, d);
444 
445 	if (res)
446 		return res;
447 
448 	return mod(d, c, d);
449 }
450 
451 static int mulmod(void *a, void *b, void *c, void *d)
452 {
453 	int res;
454 	mbedtls_mpi ta;
455 	mbedtls_mpi tb;
456 
457 	mbedtls_mpi_init_mempool(&ta);
458 	mbedtls_mpi_init_mempool(&tb);
459 
460 	res = mod(a, c, &ta);
461 	if (res)
462 		goto out;
463 	res = mod(b, c, &tb);
464 	if (res)
465 		goto out;
466 	res = mul(&ta, &tb, d);
467 	if (res)
468 		goto out;
469 	res = mod(d, c, d);
470 out:
471 	mbedtls_mpi_free(&ta);
472 	mbedtls_mpi_free(&tb);
473 	return res;
474 }
475 
476 static int sqrmod(void *a, void *b, void *c)
477 {
478 	return mulmod(a, a, b, c);
479 }
480 
481 /* invmod */
482 static int invmod(void *a, void *b, void *c)
483 {
484 	int res = mbedtls_mpi_inv_mod(c, a, b);
485 
486 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
487 		return CRYPT_MEM;
488 	if (res)
489 		return CRYPT_ERROR;
490 
491 	return CRYPT_OK;
492 }
493 
494 
495 /* setup */
496 static int montgomery_setup(void *a, void **b)
497 {
498 	*b = malloc(sizeof(mbedtls_mpi_uint));
499 	if (!*b)
500 		return CRYPT_MEM;
501 
502 	mbedtls_mpi_montg_init(*b, a);
503 
504 	return CRYPT_OK;
505 }
506 
507 /* get normalization value */
508 static int montgomery_normalization(void *a, void *b)
509 {
510 	size_t c = ROUNDUP(mbedtls_mpi_size(b), sizeof(mbedtls_mpi_uint)) * 8;
511 
512 	if (mbedtls_mpi_lset(a, 1))
513 		return CRYPT_MEM;
514 	if (mbedtls_mpi_shift_l(a, c))
515 		return CRYPT_MEM;
516 	if (mbedtls_mpi_mod_mpi(a, a, b))
517 		return CRYPT_MEM;
518 
519 	return CRYPT_OK;
520 }
521 
522 /* reduce */
523 static int montgomery_reduce(void *a, void *b, void *c)
524 {
525 	mbedtls_mpi A;
526 	mbedtls_mpi *N = b;
527 	mbedtls_mpi_uint *mm = c;
528 	mbedtls_mpi T;
529 	int ret = CRYPT_MEM;
530 
531 	mbedtls_mpi_init_mempool(&T);
532 	mbedtls_mpi_init_mempool(&A);
533 
534 	if (mbedtls_mpi_grow(&T, (N->n + 1) * 2))
535 		goto out;
536 
537 	if (mbedtls_mpi_cmp_mpi(a, N) > 0) {
538 		if (mbedtls_mpi_mod_mpi(&A, a, N))
539 			goto out;
540 	} else {
541 		if (mbedtls_mpi_copy(&A, a))
542 			goto out;
543 	}
544 
545 	if (mbedtls_mpi_grow(&A, N->n + 1))
546 		goto out;
547 
548 	if (mbedtls_mpi_montred(&A, N, *mm, &T))
549 		goto out;
550 
551 	if (mbedtls_mpi_copy(a, &A))
552 		goto out;
553 
554 	ret = CRYPT_OK;
555 out:
556 	mbedtls_mpi_free(&A);
557 	mbedtls_mpi_free(&T);
558 
559 	return ret;
560 }
561 
562 /* clean up */
563 static void montgomery_deinit(void *a)
564 {
565 	free(a);
566 }
567 
568 /*
569  * This function calculates:
570  *  d = a^b mod c
571  *
572  * @a: base
573  * @b: exponent
574  * @c: modulus
575  * @d: destination
576  */
577 static int exptmod(void *a, void *b, void *c, void *d)
578 {
579 	int res;
580 
581 	if (d == a || d == b || d == c) {
582 		mbedtls_mpi dest;
583 
584 		mbedtls_mpi_init_mempool(&dest);
585 		res = mbedtls_mpi_exp_mod(&dest, a, b, c, NULL);
586 		if (!res)
587 			res = mbedtls_mpi_copy(d, &dest);
588 		mbedtls_mpi_free(&dest);
589 	} else {
590 		res = mbedtls_mpi_exp_mod(d, a, b, c, NULL);
591 	}
592 
593 	if (res)
594 		return CRYPT_MEM;
595 	else
596 		return CRYPT_OK;
597 }
598 
599 static int rng_read(void *ignored __unused, unsigned char *buf, size_t blen)
600 {
601 	if (crypto_rng_read(buf, blen))
602 		return MBEDTLS_ERR_MPI_FILE_IO_ERROR;
603 	return 0;
604 }
605 
606 static int isprime(void *a, int b __unused, int *c)
607 {
608 	int res = mbedtls_mpi_is_prime(a, rng_read, NULL);
609 
610 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
611 		return CRYPT_MEM;
612 
613 	if (res)
614 		*c = LTC_MP_NO;
615 	else
616 		*c = LTC_MP_YES;
617 
618 	return CRYPT_OK;
619 }
620 
621 static int mpi_rand(void *a, int size)
622 {
623 	if (mbedtls_mpi_fill_random(a, size, rng_read, NULL))
624 		return CRYPT_MEM;
625 
626 	return CRYPT_OK;
627 }
628 
629 ltc_math_descriptor ltc_mp = {
630 	.name = "MPI",
631 	.bits_per_digit = sizeof(mbedtls_mpi_uint) * 8,
632 
633 	.init = init,
634 	.init_size = init_size,
635 	.init_copy = init_copy,
636 	.deinit = deinit,
637 
638 	.neg = neg,
639 	.copy = copy,
640 
641 	.set_int = set_int,
642 	.get_int = get_int,
643 	.get_digit = get_digit,
644 	.get_digit_count = get_digit_count,
645 	.compare = compare,
646 	.compare_d = compare_d,
647 	.count_bits = count_bits,
648 	.count_lsb_bits = count_lsb_bits,
649 	.twoexpt = twoexpt,
650 
651 	.read_radix = read_radix,
652 	.write_radix = write_radix,
653 	.unsigned_size = unsigned_size,
654 	.unsigned_write = unsigned_write,
655 	.unsigned_read = unsigned_read,
656 
657 	.add = add,
658 	.addi = addi,
659 	.sub = sub,
660 	.subi = subi,
661 	.mul = mul,
662 	.muli = muli,
663 	.sqr = sqr,
664 	.mpdiv = divide,
665 	.div_2 = div_2,
666 	.modi = modi,
667 	.gcd = gcd,
668 	.lcm = lcm,
669 
670 	.mulmod = mulmod,
671 	.sqrmod = sqrmod,
672 	.invmod = invmod,
673 
674 	.montgomery_setup = montgomery_setup,
675 	.montgomery_normalization = montgomery_normalization,
676 	.montgomery_reduce = montgomery_reduce,
677 	.montgomery_deinit = montgomery_deinit,
678 
679 	.exptmod = exptmod,
680 	.isprime = isprime,
681 
682 #ifdef LTC_MECC
683 #ifdef LTC_MECC_FP
684 	.ecc_ptmul = ltc_ecc_fp_mulmod,
685 #else
686 	.ecc_ptmul = ltc_ecc_mulmod,
687 #endif /* LTC_MECC_FP */
688 	.ecc_ptadd = ltc_ecc_projective_add_point,
689 	.ecc_ptdbl = ltc_ecc_projective_dbl_point,
690 	.ecc_map = ltc_ecc_map,
691 #ifdef LTC_ECC_SHAMIR
692 #ifdef LTC_MECC_FP
693 	.ecc_mul2add = ltc_ecc_fp_mul2add,
694 #else
695 	.ecc_mul2add = ltc_ecc_mul2add,
696 #endif /* LTC_MECC_FP */
697 #endif /* LTC_ECC_SHAMIR */
698 #endif /* LTC_MECC */
699 
700 #ifdef LTC_MRSA
701 	.rsa_keygen = rsa_make_key,
702 	.rsa_me = rsa_exptmod,
703 #endif
704 	.addmod = addmod,
705 	.submod = submod,
706 	.rand = mpi_rand,
707 
708 };
709 
710 size_t crypto_bignum_num_bytes(struct bignum *a)
711 {
712 	return mbedtls_mpi_size((mbedtls_mpi *)a);
713 }
714 
715 size_t crypto_bignum_num_bits(struct bignum *a)
716 {
717 	return mbedtls_mpi_bitlen((mbedtls_mpi *)a);
718 }
719 
720 int32_t crypto_bignum_compare(struct bignum *a, struct bignum *b)
721 {
722 	return mbedtls_mpi_cmp_mpi((mbedtls_mpi *)a, (mbedtls_mpi *)b);
723 }
724 
725 void crypto_bignum_bn2bin(const struct bignum *from, uint8_t *to)
726 {
727 	const mbedtls_mpi *f = (const mbedtls_mpi *)from;
728 	int rc __maybe_unused = 0;
729 
730 	rc = mbedtls_mpi_write_binary(f, (void *)to, mbedtls_mpi_size(f));
731 	assert(!rc);
732 }
733 
734 TEE_Result crypto_bignum_bin2bn(const uint8_t *from, size_t fromsize,
735 			 struct bignum *to)
736 {
737 	if (mbedtls_mpi_read_binary((mbedtls_mpi *)to, (const void *)from,
738 				    fromsize))
739 		return TEE_ERROR_BAD_PARAMETERS;
740 	return TEE_SUCCESS;
741 }
742 
743 void crypto_bignum_copy(struct bignum *to, const struct bignum *from)
744 {
745 	mbedtls_mpi_copy((mbedtls_mpi *)to, (const mbedtls_mpi *)from);
746 }
747 
748 struct bignum *crypto_bignum_allocate(size_t size_bits)
749 {
750 	mbedtls_mpi *bn = malloc(sizeof(*bn));
751 
752 	if (!bn)
753 		return NULL;
754 
755 	mbedtls_mpi_init(bn);
756 	if (mbedtls_mpi_grow(bn, BITS_TO_LIMBS(size_bits))) {
757 		free(bn);
758 		return NULL;
759 	}
760 
761 	return (struct bignum *)bn;
762 }
763 
764 void crypto_bignum_free(struct bignum *s)
765 {
766 	mbedtls_mpi_free((mbedtls_mpi *)s);
767 	free(s);
768 }
769 
770 void crypto_bignum_clear(struct bignum *s)
771 {
772 	mbedtls_mpi *bn = (mbedtls_mpi *)s;
773 
774 	bn->s = 1;
775 	if (bn->p)
776 		memset(bn->p, 0, sizeof(*bn->p) * bn->n);
777 }
778