xref: /optee_os/core/lib/libtomcrypt/mpi_desc.c (revision 4b5c81cc18db44f317d1b67646c3efb32153133c)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (c) 2018, Linaro Limited
4  */
5 
6 #include <crypto/crypto.h>
7 #include <kernel/panic.h>
8 #include <mbedtls/bignum.h>
9 #include <mempool.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <tomcrypt.h>
13 #include <tomcrypt_mp.h>
14 #include <util.h>
15 
16 #if defined(_CFG_CORE_LTC_PAGER)
17 #include <mm/core_mmu.h>
18 #include <mm/tee_pager.h>
19 #endif
20 
21 /* Size needed for xtest to pass reliably on both ARM32 and ARM64 */
22 #define MPI_MEMPOOL_SIZE	(42 * 1024)
23 
24 /* From mbedtls/library/bignum.c */
25 #define ciL		(sizeof(mbedtls_mpi_uint))	/* chars in limb  */
26 #define biL		(ciL << 3)			/* bits  in limb  */
27 #define BITS_TO_LIMBS(i)	((i) / biL + ((i) % biL != 0))
28 
29 #if defined(_CFG_CORE_LTC_PAGER)
30 /* allocate pageable_zi vmem for mp scratch memory pool */
31 static struct mempool *get_mp_scratch_memory_pool(void)
32 {
33 	size_t size;
34 	void *data;
35 
36 	size = ROUNDUP(MPI_MEMPOOL_SIZE, SMALL_PAGE_SIZE);
37 	data = tee_pager_alloc(size, 0);
38 	if (!data)
39 		panic();
40 
41 	return mempool_alloc_pool(data, size, tee_pager_release_phys);
42 }
43 #else /* _CFG_CORE_LTC_PAGER */
44 static struct mempool *get_mp_scratch_memory_pool(void)
45 {
46 	static uint8_t data[MPI_MEMPOOL_SIZE] __aligned(MEMPOOL_ALIGN);
47 
48 	return mempool_alloc_pool(data, sizeof(data), NULL);
49 }
50 #endif
51 
52 void init_mp_tomcrypt(void)
53 {
54 	struct mempool *p = get_mp_scratch_memory_pool();
55 
56 	if (!p)
57 		panic();
58 	mbedtls_mpi_mempool = p;
59 }
60 
61 static int init(void **a)
62 {
63 	mbedtls_mpi *bn = mempool_alloc(mbedtls_mpi_mempool, sizeof(*bn));
64 
65 	if (!bn)
66 		return CRYPT_MEM;
67 
68 	mbedtls_mpi_init_mempool(bn);
69 	*a = bn;
70 	return CRYPT_OK;
71 }
72 
73 static int init_size(int size_bits __unused, void **a)
74 {
75 	return init(a);
76 }
77 
78 static void deinit(void *a)
79 {
80 	mbedtls_mpi_free((mbedtls_mpi *)a);
81 	mempool_free(mbedtls_mpi_mempool, a);
82 }
83 
84 static int neg(void *a, void *b)
85 {
86 	if (mbedtls_mpi_copy(b, a))
87 		return CRYPT_MEM;
88 	((mbedtls_mpi *)b)->s *= -1;
89 	return CRYPT_OK;
90 }
91 
92 static int copy(void *a, void *b)
93 {
94 	if (mbedtls_mpi_copy(b, a))
95 		return CRYPT_MEM;
96 	return CRYPT_OK;
97 }
98 
99 static int init_copy(void **a, void *b)
100 {
101 	if (init(a) != CRYPT_OK) {
102 		return CRYPT_MEM;
103 	}
104 	return copy(b, *a);
105 }
106 
107 /* ---- trivial ---- */
108 static int set_int(void *a, unsigned long b)
109 {
110 	uint32_t b32 = b;
111 
112 	if (b32 != b)
113 		return CRYPT_INVALID_ARG;
114 
115 	mbedtls_mpi_uint p = b32;
116 	mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p };
117 
118 	if (mbedtls_mpi_copy(a, &bn))
119 		return CRYPT_MEM;
120 	return CRYPT_OK;
121 }
122 
123 static unsigned long get_int(void *a)
124 {
125 	mbedtls_mpi *bn = a;
126 
127 	if (!bn->n)
128 		return 0;
129 
130 	return bn->p[bn->n - 1];
131 }
132 
133 static ltc_mp_digit get_digit(void *a, int n)
134 {
135 	mbedtls_mpi *bn = a;
136 
137 	COMPILE_TIME_ASSERT(sizeof(ltc_mp_digit) >= sizeof(mbedtls_mpi_uint));
138 
139 	if (n < 0 || (size_t)n >= bn->n)
140 		return 0;
141 
142 	return bn->p[n];
143 }
144 
145 static int get_digit_count(void *a)
146 {
147 	return ROUNDUP(mbedtls_mpi_size(a), sizeof(mbedtls_mpi_uint)) /
148 	       sizeof(mbedtls_mpi_uint);
149 }
150 
151 static int compare(void *a, void *b)
152 {
153 	int ret = mbedtls_mpi_cmp_mpi(a, b);
154 
155 	if (ret < 0)
156 		return LTC_MP_LT;
157 
158 	if (ret > 0)
159 		return LTC_MP_GT;
160 
161 	return LTC_MP_EQ;
162 }
163 
164 static int compare_d(void *a, unsigned long b)
165 {
166 	unsigned long v = b;
167 	unsigned int shift = 31;
168 	uint32_t mask = BIT(shift) - 1;
169 	mbedtls_mpi bn;
170 
171 	mbedtls_mpi_init_mempool(&bn);
172 	while (true) {
173 		mbedtls_mpi_add_int(&bn, &bn, v & mask);
174 		v >>= shift;
175 		if (!v)
176 			break;
177 		mbedtls_mpi_shift_l(&bn, shift);
178 	}
179 
180 	int ret = compare(a, &bn);
181 
182 	mbedtls_mpi_free(&bn);
183 
184 	return ret;
185 }
186 
187 static int count_bits(void *a)
188 {
189 	return mbedtls_mpi_bitlen(a);
190 }
191 
192 static int count_lsb_bits(void *a)
193 {
194 	return mbedtls_mpi_lsb(a);
195 }
196 
197 
198 static int twoexpt(void *a, int n)
199 {
200 	if (mbedtls_mpi_set_bit(a, n, 1))
201 		return CRYPT_MEM;
202 
203 	return CRYPT_OK;
204 }
205 
206 /* ---- conversions ---- */
207 
208 /* read ascii string */
209 static int read_radix(void *a, const char *b, int radix)
210 {
211 	int res = mbedtls_mpi_read_string(a, radix, b);
212 
213 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
214 		return CRYPT_MEM;
215 	if (res)
216 		return CRYPT_ERROR;
217 
218 	return CRYPT_OK;
219 }
220 
221 /* write one */
222 static int write_radix(void *a, char *b, int radix)
223 {
224 	size_t ol = SIZE_MAX;
225 	int res = mbedtls_mpi_write_string(a, radix, b, ol, &ol);
226 
227 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
228 		return CRYPT_MEM;
229 	if (res)
230 		return CRYPT_ERROR;
231 
232 	return CRYPT_OK;
233 }
234 
235 /* get size as unsigned char string */
236 static unsigned long unsigned_size(void *a)
237 {
238 	return mbedtls_mpi_size(a);
239 }
240 
241 /* store */
242 static int unsigned_write(void *a, unsigned char *b)
243 {
244 	int res = mbedtls_mpi_write_binary(a, b, unsigned_size(a));
245 
246 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
247 		return CRYPT_MEM;
248 	if (res)
249 		return CRYPT_ERROR;
250 
251 	return CRYPT_OK;
252 }
253 
254 /* read */
255 static int unsigned_read(void *a, unsigned char *b, unsigned long len)
256 {
257 	int res = mbedtls_mpi_read_binary(a, b, len);
258 
259 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
260 		return CRYPT_MEM;
261 	if (res)
262 		return CRYPT_ERROR;
263 
264 	return CRYPT_OK;
265 }
266 
267 /* add */
268 static int add(void *a, void *b, void *c)
269 {
270 	if (mbedtls_mpi_add_mpi(c, a, b))
271 		return CRYPT_MEM;
272 
273 	return CRYPT_OK;
274 }
275 
276 static int addi(void *a, unsigned long b, void *c)
277 {
278 	uint32_t b32 = b;
279 
280 	if (b32 != b)
281 		return CRYPT_INVALID_ARG;
282 
283 	mbedtls_mpi_uint p = b32;
284 	mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p };
285 
286 	return add(a, &bn, c);
287 }
288 
289 /* sub */
290 static int sub(void *a, void *b, void *c)
291 {
292 	if (mbedtls_mpi_sub_mpi(c, a, b))
293 		return CRYPT_MEM;
294 
295 	return CRYPT_OK;
296 }
297 
298 static int subi(void *a, unsigned long b, void *c)
299 {
300 	uint32_t b32 = b;
301 
302 	if (b32 != b)
303 		return CRYPT_INVALID_ARG;
304 
305 	mbedtls_mpi_uint p = b32;
306 	mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p };
307 
308 	return sub(a, &bn, c);
309 }
310 
311 /* mul */
312 static int mul(void *a, void *b, void *c)
313 {
314 	if (mbedtls_mpi_mul_mpi(c, a, b))
315 		return CRYPT_MEM;
316 
317 	return CRYPT_OK;
318 }
319 
320 static int muli(void *a, unsigned long b, void *c)
321 {
322 	if (b > (unsigned long) UINT32_MAX)
323 		return CRYPT_INVALID_ARG;
324 
325 	if (mbedtls_mpi_mul_int(c, a, b))
326 		return CRYPT_MEM;
327 
328 	return CRYPT_OK;
329 }
330 
331 /* sqr */
332 static int sqr(void *a, void *b)
333 {
334 	return mul(a, a, b);
335 }
336 
337 /* div */
338 static int divide(void *a, void *b, void *c, void *d)
339 {
340 	int res = mbedtls_mpi_div_mpi(c, d, a, b);
341 
342 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
343 		return CRYPT_MEM;
344 	if (res)
345 		return CRYPT_ERROR;
346 
347 	return CRYPT_OK;
348 }
349 
350 static int div_2(void *a, void *b)
351 {
352 	if (mbedtls_mpi_copy(b, a))
353 		return CRYPT_MEM;
354 
355 	if (mbedtls_mpi_shift_r(b, 1))
356 		return CRYPT_MEM;
357 
358 	return CRYPT_OK;
359 }
360 
361 /* modi */
362 static int modi(void *a, unsigned long b, unsigned long *c)
363 {
364 	mbedtls_mpi bn_b;
365 	mbedtls_mpi bn_c;
366 	int res = 0;
367 
368 	mbedtls_mpi_init_mempool(&bn_b);
369 	mbedtls_mpi_init_mempool(&bn_c);
370 
371 	res = set_int(&bn_b, b);
372 	if (res)
373 		return res;
374 
375 	res = mbedtls_mpi_mod_mpi(&bn_c, &bn_b, a);
376 	if (!res)
377 		*c = get_int(&bn_c);
378 
379 	mbedtls_mpi_free(&bn_b);
380 	mbedtls_mpi_free(&bn_c);
381 
382 	if (res)
383 		return CRYPT_MEM;
384 
385 	return CRYPT_OK;
386 }
387 
388 /* gcd */
389 static int gcd(void *a, void *b, void *c)
390 {
391 	if (mbedtls_mpi_gcd(c, a, b))
392 		return CRYPT_MEM;
393 
394 	return CRYPT_OK;
395 }
396 
397 /* lcm */
398 static int lcm(void *a, void *b, void *c)
399 {
400 	int res = CRYPT_MEM;
401 	mbedtls_mpi tmp;
402 
403 	mbedtls_mpi_init_mempool(&tmp);
404 	if (mbedtls_mpi_mul_mpi(&tmp, a, b))
405 		goto out;
406 
407 	if (mbedtls_mpi_gcd(c, a, b))
408 		goto out;
409 
410 	/* We use the following equality: gcd(a, b) * lcm(a, b) = a * b */
411 	res = divide(&tmp, c, c, NULL);
412 out:
413 	mbedtls_mpi_free(&tmp);
414 	return res;
415 }
416 
417 static int mod(void *a, void *b, void *c)
418 {
419 	int res = mbedtls_mpi_mod_mpi(c, a, b);
420 
421 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
422 		return CRYPT_MEM;
423 	if (res)
424 		return CRYPT_ERROR;
425 
426 	return CRYPT_OK;
427 }
428 
429 static int mulmod(void *a, void *b, void *c, void *d)
430 {
431 	int res;
432 	mbedtls_mpi ta;
433 	mbedtls_mpi tb;
434 
435 	mbedtls_mpi_init_mempool(&ta);
436 	mbedtls_mpi_init_mempool(&tb);
437 
438 	res = mod(a, c, &ta);
439 	if (res)
440 		goto out;
441 	res = mod(b, c, &tb);
442 	if (res)
443 		goto out;
444 	res = mul(&ta, &tb, d);
445 	if (res)
446 		goto out;
447 	res = mod(d, c, d);
448 out:
449 	mbedtls_mpi_free(&ta);
450 	mbedtls_mpi_free(&tb);
451 	return res;
452 }
453 
454 static int sqrmod(void *a, void *b, void *c)
455 {
456 	return mulmod(a, a, b, c);
457 }
458 
459 /* invmod */
460 static int invmod(void *a, void *b, void *c)
461 {
462 	int res = mbedtls_mpi_inv_mod(c, a, b);
463 
464 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
465 		return CRYPT_MEM;
466 	if (res)
467 		return CRYPT_ERROR;
468 
469 	return CRYPT_OK;
470 }
471 
472 
473 /* setup */
474 static int montgomery_setup(void *a, void **b)
475 {
476 	*b = malloc(sizeof(mbedtls_mpi_uint));
477 	if (!*b)
478 		return CRYPT_MEM;
479 
480 	mbedtls_mpi_montg_init(*b, a);
481 
482 	return CRYPT_OK;
483 }
484 
485 /* get normalization value */
486 static int montgomery_normalization(void *a, void *b)
487 {
488 	size_t c = ROUNDUP(mbedtls_mpi_size(b), sizeof(mbedtls_mpi_uint)) * 8;
489 
490 	if (mbedtls_mpi_lset(a, 1))
491 		return CRYPT_MEM;
492 	if (mbedtls_mpi_shift_l(a, c))
493 		return CRYPT_MEM;
494 	if (mbedtls_mpi_mod_mpi(a, a, b))
495 		return CRYPT_MEM;
496 
497 	return CRYPT_OK;
498 }
499 
500 /* reduce */
501 static int montgomery_reduce(void *a, void *b, void *c)
502 {
503 	mbedtls_mpi A;
504 	mbedtls_mpi *N = b;
505 	mbedtls_mpi_uint *mm = c;
506 	mbedtls_mpi T;
507 	int ret = CRYPT_MEM;
508 
509 	mbedtls_mpi_init_mempool(&T);
510 	mbedtls_mpi_init_mempool(&A);
511 
512 	if (mbedtls_mpi_grow(&T, (N->n + 1) * 2))
513 		goto out;
514 
515 	if (mbedtls_mpi_cmp_mpi(a, N) > 0) {
516 		if (mbedtls_mpi_mod_mpi(&A, a, N))
517 			goto out;
518 	} else {
519 		if (mbedtls_mpi_copy(&A, a))
520 			goto out;
521 	}
522 
523 	if (mbedtls_mpi_grow(&A, N->n + 1))
524 		goto out;
525 
526 	if (mbedtls_mpi_montred(&A, N, *mm, &T))
527 		goto out;
528 
529 	if (mbedtls_mpi_copy(a, &A))
530 		goto out;
531 
532 	ret = CRYPT_OK;
533 out:
534 	mbedtls_mpi_free(&A);
535 	mbedtls_mpi_free(&T);
536 
537 	return ret;
538 }
539 
540 /* clean up */
541 static void montgomery_deinit(void *a)
542 {
543 	free(a);
544 }
545 
546 /*
547  * This function calculates:
548  *  d = a^b mod c
549  *
550  * @a: base
551  * @b: exponent
552  * @c: modulus
553  * @d: destination
554  */
555 static int exptmod(void *a, void *b, void *c, void *d)
556 {
557 	int res;
558 
559 	if (d == a || d == b || d == c) {
560 		mbedtls_mpi dest;
561 
562 		mbedtls_mpi_init_mempool(&dest);
563 		res = mbedtls_mpi_exp_mod(&dest, a, b, c, NULL);
564 		if (!res)
565 			res = mbedtls_mpi_copy(d, &dest);
566 		mbedtls_mpi_free(&dest);
567 	} else {
568 		res = mbedtls_mpi_exp_mod(d, a, b, c, NULL);
569 	}
570 
571 	if (res)
572 		return CRYPT_MEM;
573 	else
574 		return CRYPT_OK;
575 }
576 
577 static int rng_read(void *ignored __unused, unsigned char *buf, size_t blen)
578 {
579 	if (crypto_rng_read(buf, blen))
580 		return MBEDTLS_ERR_MPI_FILE_IO_ERROR;
581 	return 0;
582 }
583 
584 static int isprime(void *a, int b __unused, int *c)
585 {
586 	int res = mbedtls_mpi_is_prime(a, rng_read, NULL);
587 
588 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
589 		return CRYPT_MEM;
590 
591 	if (res)
592 		*c = LTC_MP_NO;
593 	else
594 		*c = LTC_MP_YES;
595 
596 	return CRYPT_OK;
597 }
598 
599 static int mpa_rand(void *a, int size)
600 {
601 	if (mbedtls_mpi_fill_random(a, size, rng_read, NULL))
602 		return CRYPT_MEM;
603 
604 	return CRYPT_OK;
605 }
606 
607 ltc_math_descriptor ltc_mp = {
608 	.name = "MPI",
609 	.bits_per_digit = sizeof(mbedtls_mpi_uint) * 8,
610 
611 	.init = &init,
612 	.init_size = &init_size,
613 	.init_copy = &init_copy,
614 	.deinit = &deinit,
615 
616 	.neg = &neg,
617 	.copy = &copy,
618 
619 	.set_int = &set_int,
620 	.get_int = &get_int,
621 	.get_digit = &get_digit,
622 	.get_digit_count = &get_digit_count,
623 	.compare = &compare,
624 	.compare_d = &compare_d,
625 	.count_bits = &count_bits,
626 	.count_lsb_bits = &count_lsb_bits,
627 	.twoexpt = &twoexpt,
628 
629 	.read_radix = &read_radix,
630 	.write_radix = &write_radix,
631 	.unsigned_size = &unsigned_size,
632 	.unsigned_write = &unsigned_write,
633 	.unsigned_read = &unsigned_read,
634 
635 	.add = &add,
636 	.addi = &addi,
637 	.sub = &sub,
638 	.subi = &subi,
639 	.mul = &mul,
640 	.muli = &muli,
641 	.sqr = &sqr,
642 	.mpdiv = &divide,
643 	.div_2 = &div_2,
644 	.modi = &modi,
645 	.gcd = &gcd,
646 	.lcm = &lcm,
647 
648 	.mod = &mod,
649 	.mulmod = &mulmod,
650 	.sqrmod = &sqrmod,
651 	.invmod = &invmod,
652 
653 	.montgomery_setup = &montgomery_setup,
654 	.montgomery_normalization = &montgomery_normalization,
655 	.montgomery_reduce = &montgomery_reduce,
656 	.montgomery_deinit = &montgomery_deinit,
657 
658 	.exptmod = &exptmod,
659 	.isprime = &isprime,
660 
661 #ifdef LTC_MECC
662 #ifdef LTC_MECC_FP
663 	.ecc_ptmul = &ltc_ecc_fp_mulmod,
664 #else
665 	.ecc_ptmul = &ltc_ecc_mulmod,
666 #endif /* LTC_MECC_FP */
667 	.ecc_ptadd = &ltc_ecc_projective_add_point,
668 	.ecc_ptdbl = &ltc_ecc_projective_dbl_point,
669 	.ecc_map = &ltc_ecc_map,
670 #ifdef LTC_ECC_SHAMIR
671 #ifdef LTC_MECC_FP
672 	.ecc_mul2add = &ltc_ecc_fp_mul2add,
673 #else
674 	.ecc_mul2add = &ltc_ecc_mul2add,
675 #endif /* LTC_MECC_FP */
676 #endif /* LTC_ECC_SHAMIR */
677 #endif /* LTC_MECC */
678 
679 #ifdef LTC_MRSA
680 	.rsa_keygen = &rsa_make_key,
681 	.rsa_me = &rsa_exptmod,
682 #endif
683 	.rand = &mpa_rand,
684 
685 };
686 
687 size_t crypto_bignum_num_bytes(struct bignum *a)
688 {
689 	return mbedtls_mpi_size((mbedtls_mpi *)a);
690 }
691 
692 size_t crypto_bignum_num_bits(struct bignum *a)
693 {
694 	return mbedtls_mpi_bitlen((mbedtls_mpi *)a);
695 }
696 
697 int32_t crypto_bignum_compare(struct bignum *a, struct bignum *b)
698 {
699 	return mbedtls_mpi_cmp_mpi((mbedtls_mpi *)a, (mbedtls_mpi *)b);
700 }
701 
702 void crypto_bignum_bn2bin(const struct bignum *from, uint8_t *to)
703 {
704 	mbedtls_mpi_write_binary((const mbedtls_mpi *)from, (void *)to,
705 				 mbedtls_mpi_size((const mbedtls_mpi *)from));
706 }
707 
708 TEE_Result crypto_bignum_bin2bn(const uint8_t *from, size_t fromsize,
709 			 struct bignum *to)
710 {
711 	if (mbedtls_mpi_read_binary((mbedtls_mpi *)to, (const void *)from,
712 				    fromsize))
713 		return TEE_ERROR_BAD_PARAMETERS;
714 	return TEE_SUCCESS;
715 }
716 
717 void crypto_bignum_copy(struct bignum *to, const struct bignum *from)
718 {
719 	mbedtls_mpi_copy((mbedtls_mpi *)to, (const mbedtls_mpi *)from);
720 }
721 
722 struct bignum *crypto_bignum_allocate(size_t size_bits)
723 {
724 	mbedtls_mpi *bn = malloc(sizeof(*bn));
725 
726 	if (!bn)
727 		return NULL;
728 
729 	mbedtls_mpi_init(bn);
730 	if (mbedtls_mpi_grow(bn, BITS_TO_LIMBS(size_bits))) {
731 		free(bn);
732 		return NULL;
733 	}
734 
735 	return (struct bignum *)bn;
736 }
737 
738 void crypto_bignum_free(struct bignum *s)
739 {
740 	mbedtls_mpi_free((mbedtls_mpi *)s);
741 	free(s);
742 }
743 
744 void crypto_bignum_clear(struct bignum *s)
745 {
746 	mbedtls_mpi *bn = (mbedtls_mpi *)s;
747 
748 	bn->s = 1;
749 	if (bn->p)
750 		memset(bn->p, 0, sizeof(*bn->p) * bn->n);
751 }
752