xref: /optee_os/core/lib/libtomcrypt/mpi_desc.c (revision 78887e6053f32b2a050ae5871d9c287f7d51d686)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (c) 2018, Linaro Limited
4  */
5 
6 #include <crypto/crypto.h>
7 #include <kernel/panic.h>
8 #include <mbedtls/bignum.h>
9 #include <mempool.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <tomcrypt.h>
13 #include <tomcrypt_mp.h>
14 #include <util.h>
15 
16 #if defined(CFG_WITH_PAGER)
17 #include <mm/core_mmu.h>
18 #include <mm/tee_pager.h>
19 #endif
20 
21 /* Size needed for xtest to pass reliably on both ARM32 and ARM64 */
22 #define MPI_MEMPOOL_SIZE	(42 * 1024)
23 
24 #if defined(CFG_WITH_PAGER)
25 /* allocate pageable_zi vmem for mp scratch memory pool */
26 static struct mempool *get_mp_scratch_memory_pool(void)
27 {
28 	size_t size;
29 	void *data;
30 
31 	size = ROUNDUP(MPI_MEMPOOL_SIZE, SMALL_PAGE_SIZE);
32 	data = tee_pager_alloc(size, 0);
33 	if (!data)
34 		panic();
35 
36 	return mempool_alloc_pool(data, size, tee_pager_release_phys);
37 }
38 #else /* CFG_WITH_PAGER */
39 static struct mempool *get_mp_scratch_memory_pool(void)
40 {
41 	static uint8_t data[MPI_MEMPOOL_SIZE] __aligned(MEMPOOL_ALIGN);
42 
43 	return mempool_alloc_pool(data, sizeof(data), NULL);
44 }
45 #endif
46 
47 void init_mp_tomcrypt(void)
48 {
49 	struct mempool *p = get_mp_scratch_memory_pool();
50 
51 	if (!p)
52 		panic();
53 	mbedtls_mpi_mempool = p;
54 }
55 
56 static int init(void **a)
57 {
58 	mbedtls_mpi *bn = mempool_alloc(mbedtls_mpi_mempool, sizeof(*bn));
59 
60 	if (!bn)
61 		return CRYPT_MEM;
62 
63 	mbedtls_mpi_init_mempool(bn);
64 	*a = bn;
65 	return CRYPT_OK;
66 }
67 
68 static int init_size(int size_bits __unused, void **a)
69 {
70 	return init(a);
71 }
72 
73 static void deinit(void *a)
74 {
75 	mbedtls_mpi_free((mbedtls_mpi *)a);
76 	mempool_free(mbedtls_mpi_mempool, a);
77 }
78 
79 static int neg(void *a, void *b)
80 {
81 	if (mbedtls_mpi_copy(b, a))
82 		return CRYPT_MEM;
83 	((mbedtls_mpi *)b)->s *= -1;
84 	return CRYPT_OK;
85 }
86 
87 static int copy(void *a, void *b)
88 {
89 	if (mbedtls_mpi_copy(b, a))
90 		return CRYPT_MEM;
91 	return CRYPT_OK;
92 }
93 
94 static int init_copy(void **a, void *b)
95 {
96 	if (init(a) != CRYPT_OK) {
97 		return CRYPT_MEM;
98 	}
99 	return copy(b, *a);
100 }
101 
102 /* ---- trivial ---- */
103 static int set_int(void *a, unsigned long b)
104 {
105 	uint32_t b32 = b;
106 
107 	if (b32 != b)
108 		return CRYPT_INVALID_ARG;
109 
110 	mbedtls_mpi_uint p = b32;
111 	mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p };
112 
113 	if (mbedtls_mpi_copy(a, &bn))
114 		return CRYPT_MEM;
115 	return CRYPT_OK;
116 }
117 
118 static unsigned long get_int(void *a)
119 {
120 	mbedtls_mpi *bn = a;
121 
122 	if (!bn->n)
123 		return 0;
124 
125 	return bn->p[bn->n - 1];
126 }
127 
128 static ltc_mp_digit get_digit(void *a, int n)
129 {
130 	mbedtls_mpi *bn = a;
131 
132 	COMPILE_TIME_ASSERT(sizeof(ltc_mp_digit) >= sizeof(mbedtls_mpi_uint));
133 
134 	if (n < 0 || (size_t)n >= bn->n)
135 		return 0;
136 
137 	return bn->p[n];
138 }
139 
140 static int get_digit_count(void *a)
141 {
142 	return ROUNDUP(mbedtls_mpi_size(a), sizeof(mbedtls_mpi_uint)) /
143 	       sizeof(mbedtls_mpi_uint);
144 }
145 
146 static int compare(void *a, void *b)
147 {
148 	int ret = mbedtls_mpi_cmp_mpi(a, b);
149 
150 	if (ret < 0)
151 		return LTC_MP_LT;
152 
153 	if (ret > 0)
154 		return LTC_MP_GT;
155 
156 	return LTC_MP_EQ;
157 }
158 
159 static int compare_d(void *a, unsigned long b)
160 {
161 	unsigned long v = b;
162 	unsigned int shift = 31;
163 	uint32_t mask = BIT(shift) - 1;
164 	mbedtls_mpi bn;
165 
166 	mbedtls_mpi_init_mempool(&bn);
167 	while (true) {
168 		mbedtls_mpi_add_int(&bn, &bn, v & mask);
169 		v >>= shift;
170 		if (!v)
171 			break;
172 		mbedtls_mpi_shift_l(&bn, shift);
173 	}
174 
175 	int ret = compare(a, &bn);
176 
177 	mbedtls_mpi_free(&bn);
178 
179 	return ret;
180 }
181 
182 static int count_bits(void *a)
183 {
184 	return mbedtls_mpi_bitlen(a);
185 }
186 
187 static int count_lsb_bits(void *a)
188 {
189 	return mbedtls_mpi_lsb(a);
190 }
191 
192 
193 static int twoexpt(void *a, int n)
194 {
195 	if (mbedtls_mpi_set_bit(a, n, 1))
196 		return CRYPT_MEM;
197 
198 	return CRYPT_OK;
199 }
200 
201 /* ---- conversions ---- */
202 
203 /* read ascii string */
204 static int read_radix(void *a, const char *b, int radix)
205 {
206 	int res = mbedtls_mpi_read_string(a, radix, b);
207 
208 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
209 		return CRYPT_MEM;
210 	if (res)
211 		return CRYPT_ERROR;
212 
213 	return CRYPT_OK;
214 }
215 
216 /* write one */
217 static int write_radix(void *a, char *b, int radix)
218 {
219 	size_t ol = SIZE_MAX;
220 	int res = mbedtls_mpi_write_string(a, radix, b, ol, &ol);
221 
222 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
223 		return CRYPT_MEM;
224 	if (res)
225 		return CRYPT_ERROR;
226 
227 	return CRYPT_OK;
228 }
229 
230 /* get size as unsigned char string */
231 static unsigned long unsigned_size(void *a)
232 {
233 	return mbedtls_mpi_size(a);
234 }
235 
236 /* store */
237 static int unsigned_write(void *a, unsigned char *b)
238 {
239 	int res = mbedtls_mpi_write_binary(a, b, unsigned_size(a));
240 
241 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
242 		return CRYPT_MEM;
243 	if (res)
244 		return CRYPT_ERROR;
245 
246 	return CRYPT_OK;
247 }
248 
249 /* read */
250 static int unsigned_read(void *a, unsigned char *b, unsigned long len)
251 {
252 	int res = mbedtls_mpi_read_binary(a, b, len);
253 
254 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
255 		return CRYPT_MEM;
256 	if (res)
257 		return CRYPT_ERROR;
258 
259 	return CRYPT_OK;
260 }
261 
262 /* add */
263 static int add(void *a, void *b, void *c)
264 {
265 	if (mbedtls_mpi_add_mpi(c, a, b))
266 		return CRYPT_MEM;
267 
268 	return CRYPT_OK;
269 }
270 
271 static int addi(void *a, unsigned long b, void *c)
272 {
273 	uint32_t b32 = b;
274 
275 	if (b32 != b)
276 		return CRYPT_INVALID_ARG;
277 
278 	mbedtls_mpi_uint p = b32;
279 	mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p };
280 
281 	return add(a, &bn, c);
282 }
283 
284 /* sub */
285 static int sub(void *a, void *b, void *c)
286 {
287 	if (mbedtls_mpi_sub_mpi(c, a, b))
288 		return CRYPT_MEM;
289 
290 	return CRYPT_OK;
291 }
292 
293 static int subi(void *a, unsigned long b, void *c)
294 {
295 	uint32_t b32 = b;
296 
297 	if (b32 != b)
298 		return CRYPT_INVALID_ARG;
299 
300 	mbedtls_mpi_uint p = b32;
301 	mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p };
302 
303 	return sub(a, &bn, c);
304 }
305 
306 /* mul */
307 static int mul(void *a, void *b, void *c)
308 {
309 	if (mbedtls_mpi_mul_mpi(c, a, b))
310 		return CRYPT_MEM;
311 
312 	return CRYPT_OK;
313 }
314 
315 static int muli(void *a, unsigned long b, void *c)
316 {
317 	if (b > (unsigned long) UINT32_MAX)
318 		return CRYPT_INVALID_ARG;
319 
320 	if (mbedtls_mpi_mul_int(c, a, b))
321 		return CRYPT_MEM;
322 
323 	return CRYPT_OK;
324 }
325 
326 /* sqr */
327 static int sqr(void *a, void *b)
328 {
329 	return mul(a, a, b);
330 }
331 
332 /* div */
333 static int divide(void *a, void *b, void *c, void *d)
334 {
335 	int res = mbedtls_mpi_div_mpi(c, d, a, b);
336 
337 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
338 		return CRYPT_MEM;
339 	if (res)
340 		return CRYPT_ERROR;
341 
342 	return CRYPT_OK;
343 }
344 
345 static int div_2(void *a, void *b)
346 {
347 	if (mbedtls_mpi_copy(b, a))
348 		return CRYPT_MEM;
349 
350 	if (mbedtls_mpi_shift_r(b, 1))
351 		return CRYPT_MEM;
352 
353 	return CRYPT_OK;
354 }
355 
356 /* modi */
357 static int modi(void *a, unsigned long b, unsigned long *c)
358 {
359 	mbedtls_mpi bn_b;
360 	mbedtls_mpi bn_c;
361 	int res = 0;
362 
363 	mbedtls_mpi_init_mempool(&bn_b);
364 	mbedtls_mpi_init_mempool(&bn_c);
365 
366 	res = set_int(&bn_b, b);
367 	if (res)
368 		return res;
369 
370 	res = mbedtls_mpi_mod_mpi(&bn_c, &bn_b, a);
371 	if (!res)
372 		*c = get_int(&bn_c);
373 
374 	mbedtls_mpi_free(&bn_b);
375 	mbedtls_mpi_free(&bn_c);
376 
377 	if (res)
378 		return CRYPT_MEM;
379 
380 	return CRYPT_OK;
381 }
382 
383 /* gcd */
384 static int gcd(void *a, void *b, void *c)
385 {
386 	if (mbedtls_mpi_gcd(c, a, b))
387 		return CRYPT_MEM;
388 
389 	return CRYPT_OK;
390 }
391 
392 /* lcm */
393 static int lcm(void *a, void *b, void *c)
394 {
395 	int res = CRYPT_MEM;
396 	mbedtls_mpi tmp;
397 
398 	mbedtls_mpi_init_mempool(&tmp);
399 	if (mbedtls_mpi_mul_mpi(&tmp, a, b))
400 		goto out;
401 
402 	if (mbedtls_mpi_gcd(c, a, b))
403 		goto out;
404 
405 	/* We use the following equality: gcd(a, b) * lcm(a, b) = a * b */
406 	res = divide(&tmp, c, c, NULL);
407 out:
408 	mbedtls_mpi_free(&tmp);
409 	return res;
410 }
411 
412 static int mod(void *a, void *b, void *c)
413 {
414 	int res = mbedtls_mpi_mod_mpi(c, a, b);
415 
416 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
417 		return CRYPT_MEM;
418 	if (res)
419 		return CRYPT_ERROR;
420 
421 	return CRYPT_OK;
422 }
423 
424 static int mulmod(void *a, void *b, void *c, void *d)
425 {
426 	int res;
427 	mbedtls_mpi ta;
428 	mbedtls_mpi tb;
429 
430 	mbedtls_mpi_init_mempool(&ta);
431 	mbedtls_mpi_init_mempool(&tb);
432 
433 	res = mod(a, c, &ta);
434 	if (res)
435 		goto out;
436 	res = mod(b, c, &tb);
437 	if (res)
438 		goto out;
439 	res = mul(&ta, &tb, d);
440 	if (res)
441 		goto out;
442 	res = mod(d, c, d);
443 out:
444 	mbedtls_mpi_free(&ta);
445 	mbedtls_mpi_free(&tb);
446 	return res;
447 }
448 
449 static int sqrmod(void *a, void *b, void *c)
450 {
451 	return mulmod(a, a, b, c);
452 }
453 
454 /* invmod */
455 static int invmod(void *a, void *b, void *c)
456 {
457 	int res = mbedtls_mpi_inv_mod(c, a, b);
458 
459 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
460 		return CRYPT_MEM;
461 	if (res)
462 		return CRYPT_ERROR;
463 
464 	return CRYPT_OK;
465 }
466 
467 
468 /* setup */
469 static int montgomery_setup(void *a, void **b)
470 {
471 	*b = malloc(sizeof(mbedtls_mpi_uint));
472 	if (!*b)
473 		return CRYPT_MEM;
474 
475 	mbedtls_mpi_montg_init(*b, a);
476 
477 	return CRYPT_OK;
478 }
479 
480 /* get normalization value */
481 static int montgomery_normalization(void *a, void *b)
482 {
483 	size_t c = ROUNDUP(mbedtls_mpi_size(b), sizeof(mbedtls_mpi_uint)) * 8;
484 
485 	if (mbedtls_mpi_lset(a, 1))
486 		return CRYPT_MEM;
487 	if (mbedtls_mpi_shift_l(a, c))
488 		return CRYPT_MEM;
489 	if (mbedtls_mpi_mod_mpi(a, a, b))
490 		return CRYPT_MEM;
491 
492 	return CRYPT_OK;
493 }
494 
495 /* reduce */
496 static int montgomery_reduce(void *a, void *b, void *c)
497 {
498 	mbedtls_mpi A;
499 	mbedtls_mpi *N = b;
500 	mbedtls_mpi_uint *mm = c;
501 	mbedtls_mpi T;
502 	int ret = CRYPT_MEM;
503 
504 	mbedtls_mpi_init_mempool(&T);
505 	mbedtls_mpi_init_mempool(&A);
506 
507 	if (mbedtls_mpi_grow(&T, (N->n + 1) * 2))
508 		goto out;
509 
510 	if (mbedtls_mpi_cmp_mpi(a, N) > 0) {
511 		if (mbedtls_mpi_mod_mpi(&A, a, N))
512 			goto out;
513 	} else {
514 		if (mbedtls_mpi_copy(&A, a))
515 			goto out;
516 	}
517 
518 	if (mbedtls_mpi_grow(&A, N->n + 1))
519 		goto out;
520 
521 	if (mbedtls_mpi_montred(&A, N, *mm, &T))
522 		goto out;
523 
524 	if (mbedtls_mpi_copy(a, &A))
525 		goto out;
526 
527 	ret = CRYPT_OK;
528 out:
529 	mbedtls_mpi_free(&A);
530 	mbedtls_mpi_free(&T);
531 
532 	return ret;
533 }
534 
535 /* clean up */
536 static void montgomery_deinit(void *a)
537 {
538 	free(a);
539 }
540 
541 /*
542  * This function calculates:
543  *  d = a^b mod c
544  *
545  * @a: base
546  * @b: exponent
547  * @c: modulus
548  * @d: destination
549  */
550 static int exptmod(void *a, void *b, void *c, void *d)
551 {
552 	int res;
553 
554 	if (d == a || d == b || d == c) {
555 		mbedtls_mpi dest;
556 
557 		mbedtls_mpi_init_mempool(&dest);
558 		res = mbedtls_mpi_exp_mod(&dest, a, b, c, NULL);
559 		if (!res)
560 			res = mbedtls_mpi_copy(d, &dest);
561 		mbedtls_mpi_free(&dest);
562 	} else {
563 		res = mbedtls_mpi_exp_mod(d, a, b, c, NULL);
564 	}
565 
566 	if (res)
567 		return CRYPT_MEM;
568 	else
569 		return CRYPT_OK;
570 }
571 
572 static int rng_read(void *ignored __unused, unsigned char *buf, size_t blen)
573 {
574 	if (crypto_rng_read(buf, blen))
575 		return MBEDTLS_ERR_MPI_FILE_IO_ERROR;
576 	return 0;
577 }
578 
579 static int isprime(void *a, int b __unused, int *c)
580 {
581 	int res = mbedtls_mpi_is_prime(a, rng_read, NULL);
582 
583 	if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED)
584 		return CRYPT_MEM;
585 
586 	if (res)
587 		*c = LTC_MP_NO;
588 	else
589 		*c = LTC_MP_YES;
590 
591 	return CRYPT_OK;
592 }
593 
594 static int mpa_rand(void *a, int size)
595 {
596 	if (mbedtls_mpi_fill_random(a, size, rng_read, NULL))
597 		return CRYPT_MEM;
598 
599 	return CRYPT_OK;
600 }
601 
602 ltc_math_descriptor ltc_mp = {
603 	.name = "MPI",
604 	.bits_per_digit = sizeof(mbedtls_mpi_uint) * 8,
605 
606 	.init = &init,
607 	.init_size = &init_size,
608 	.init_copy = &init_copy,
609 	.deinit = &deinit,
610 
611 	.neg = &neg,
612 	.copy = &copy,
613 
614 	.set_int = &set_int,
615 	.get_int = &get_int,
616 	.get_digit = &get_digit,
617 	.get_digit_count = &get_digit_count,
618 	.compare = &compare,
619 	.compare_d = &compare_d,
620 	.count_bits = &count_bits,
621 	.count_lsb_bits = &count_lsb_bits,
622 	.twoexpt = &twoexpt,
623 
624 	.read_radix = &read_radix,
625 	.write_radix = &write_radix,
626 	.unsigned_size = &unsigned_size,
627 	.unsigned_write = &unsigned_write,
628 	.unsigned_read = &unsigned_read,
629 
630 	.add = &add,
631 	.addi = &addi,
632 	.sub = &sub,
633 	.subi = &subi,
634 	.mul = &mul,
635 	.muli = &muli,
636 	.sqr = &sqr,
637 	.mpdiv = &divide,
638 	.div_2 = &div_2,
639 	.modi = &modi,
640 	.gcd = &gcd,
641 	.lcm = &lcm,
642 
643 	.mod = &mod,
644 	.mulmod = &mulmod,
645 	.sqrmod = &sqrmod,
646 	.invmod = &invmod,
647 
648 	.montgomery_setup = &montgomery_setup,
649 	.montgomery_normalization = &montgomery_normalization,
650 	.montgomery_reduce = &montgomery_reduce,
651 	.montgomery_deinit = &montgomery_deinit,
652 
653 	.exptmod = &exptmod,
654 	.isprime = &isprime,
655 
656 #ifdef LTC_MECC
657 #ifdef LTC_MECC_FP
658 	.ecc_ptmul = &ltc_ecc_fp_mulmod,
659 #else
660 	.ecc_ptmul = &ltc_ecc_mulmod,
661 #endif /* LTC_MECC_FP */
662 	.ecc_ptadd = &ltc_ecc_projective_add_point,
663 	.ecc_ptdbl = &ltc_ecc_projective_dbl_point,
664 	.ecc_map = &ltc_ecc_map,
665 #ifdef LTC_ECC_SHAMIR
666 #ifdef LTC_MECC_FP
667 	.ecc_mul2add = &ltc_ecc_fp_mul2add,
668 #else
669 	.ecc_mul2add = &ltc_ecc_mul2add,
670 #endif /* LTC_MECC_FP */
671 #endif /* LTC_ECC_SHAMIR */
672 #endif /* LTC_MECC */
673 
674 #ifdef LTC_MRSA
675 	.rsa_keygen = &rsa_make_key,
676 	.rsa_me = &rsa_exptmod,
677 #endif
678 	.rand = &mpa_rand,
679 
680 };
681 
682 size_t crypto_bignum_num_bytes(struct bignum *a)
683 {
684 	return mbedtls_mpi_size((mbedtls_mpi *)a);
685 }
686 
687 size_t crypto_bignum_num_bits(struct bignum *a)
688 {
689 	return mbedtls_mpi_bitlen((mbedtls_mpi *)a);
690 }
691 
692 int32_t crypto_bignum_compare(struct bignum *a, struct bignum *b)
693 {
694 	return mbedtls_mpi_cmp_mpi((mbedtls_mpi *)a, (mbedtls_mpi *)b);
695 }
696 
697 void crypto_bignum_bn2bin(const struct bignum *from, uint8_t *to)
698 {
699 	mbedtls_mpi_write_binary((const mbedtls_mpi *)from, (void *)to,
700 				 mbedtls_mpi_size((const mbedtls_mpi *)from));
701 }
702 
703 TEE_Result crypto_bignum_bin2bn(const uint8_t *from, size_t fromsize,
704 			 struct bignum *to)
705 {
706 	if (mbedtls_mpi_read_binary((mbedtls_mpi *)to, (const void *)from,
707 				    fromsize))
708 		return TEE_ERROR_BAD_PARAMETERS;
709 	return TEE_SUCCESS;
710 }
711 
712 void crypto_bignum_copy(struct bignum *to, const struct bignum *from)
713 {
714 	mbedtls_mpi_copy((mbedtls_mpi *)to, (const mbedtls_mpi *)from);
715 }
716 
717 struct bignum *crypto_bignum_allocate(size_t size_bits __unused)
718 {
719 	mbedtls_mpi *bn = malloc(sizeof(*bn));
720 
721 	if (bn)
722 		mbedtls_mpi_init(bn);
723 
724 	return (struct bignum *)bn;
725 }
726 
727 void crypto_bignum_free(struct bignum *s)
728 {
729 	mbedtls_mpi_free((mbedtls_mpi *)s);
730 	free(s);
731 }
732 
733 void crypto_bignum_clear(struct bignum *s)
734 {
735 	mbedtls_mpi *bn = (mbedtls_mpi *)s;
736 
737 	bn->s = 1;
738 	if (bn->p)
739 		memset(bn->p, 0, sizeof(*bn->p) * bn->n);
740 }
741