xref: /optee_os/lib/libutee/tee_api_arith_mpi.c (revision 420232951deba3e79232a90b8f0fa03da68fc3db)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (c) 2018, Linaro limited
4  */
5 #include <assert.h>
6 #include <mbedtls/bignum.h>
7 #include <mempool.h>
8 #include <stdio.h>
9 #include <string.h>
10 #include <tee_api.h>
11 #include <tee_arith_internal.h>
12 #include <utee_defines.h>
13 #include <utee_syscalls.h>
14 #include <util.h>
15 
16 #define MPI_MEMPOOL_SIZE	(12 * 1024)
17 
18 static void __noreturn api_panic(const char *func, int line, const char *msg)
19 {
20 	printf("Panic function %s, line %d: %s\n", func, line, msg);
21 	TEE_Panic(0xB16127 /*BIGINT*/);
22 	while (1)
23 		; /* Panic will crash the thread */
24 }
25 
26 #define API_PANIC(x) api_panic(__func__, __LINE__, x)
27 
28 static void __noreturn mpi_panic(const char *func, int line, int rc)
29 {
30 	printf("Panic function %s, line %d, code %d\n", func, line, rc);
31 	TEE_Panic(0xB16127 /*BIGINT*/);
32 	while (1)
33 		; /* Panic will crash the thread */
34 }
35 
36 #define MPI_CHECK(x) do { \
37 		int _rc = (x); \
38 		 \
39 		if (_rc) \
40 			mpi_panic(__func__, __LINE__, _rc); \
41 	} while (0)
42 
43 void _TEE_MathAPI_Init(void)
44 {
45 	static uint8_t data[MPI_MEMPOOL_SIZE] __aligned(MEMPOOL_ALIGN);
46 
47 	mbedtls_mpi_mempool = mempool_alloc_pool(data, sizeof(data), NULL);
48 	if (!mbedtls_mpi_mempool)
49 		API_PANIC("Failed to initialize memory pool");
50 }
51 
52 struct bigint_hdr {
53 	int32_t sign;
54 	uint16_t alloc_size;
55 	uint16_t nblimbs;
56 };
57 
58 #define BIGINT_HDR_SIZE_IN_U32	2
59 
60 static TEE_Result copy_mpi_to_bigint(mbedtls_mpi *mpi, TEE_BigInt *bigInt)
61 {
62 	struct bigint_hdr *hdr = (struct bigint_hdr *)bigInt;
63 	size_t n = mpi->n;
64 
65 	/* Trim of eventual insignificant zeroes */
66 	while (n && !mpi->p[n - 1])
67 		n--;
68 
69 	if (hdr->alloc_size < n)
70 		return TEE_ERROR_OVERFLOW;
71 
72 	hdr->nblimbs = n;
73 	hdr->sign = mpi->s;
74 	memcpy(hdr + 1, mpi->p, mpi->n * sizeof(mbedtls_mpi_uint));
75 
76 	return TEE_SUCCESS;
77 }
78 
79 /*
80  * Initializes a MPI.
81  *
82  * A temporary MPI is allocated and if a bigInt is supplied the MPI is
83  * initialized with the value of the bigInt.
84  */
85 static void get_mpi(mbedtls_mpi *mpi, const TEE_BigInt *bigInt)
86 {
87 	/*
88 	 * The way the GP spec is defining the bignums it's
89 	 * difficult/tricky to do it using 64-bit arithmetics given that
90 	 * we'd need 64-bit alignment of the data as well.
91 	 */
92 	COMPILE_TIME_ASSERT(sizeof(mbedtls_mpi_uint) == sizeof(uint32_t));
93 
94 	/*
95 	 * The struct bigint_hdr is the overhead added to the bigint and
96 	 * is required to take exactly 2 uint32_t.
97 	 */
98 	COMPILE_TIME_ASSERT(sizeof(struct bigint_hdr) ==
99 			    sizeof(uint32_t) * BIGINT_HDR_SIZE_IN_U32);
100 
101 	mbedtls_mpi_init_mempool(mpi);
102 
103 	if (bigInt) {
104 		const struct bigint_hdr *hdr = (struct bigint_hdr *)bigInt;
105 		const mbedtls_mpi_uint *p = (const mbedtls_mpi_uint *)(hdr + 1);
106 		size_t n = hdr->nblimbs;
107 
108 		/* Trim of eventual insignificant zeroes */
109 		while (n && !p[n - 1])
110 			n--;
111 
112 		MPI_CHECK(mbedtls_mpi_grow(mpi, n));
113 		mpi->s = hdr->sign;
114 		memcpy(mpi->p, p, n * sizeof(mbedtls_mpi_uint));
115 	}
116 }
117 
118 void TEE_BigIntInit(TEE_BigInt *bigInt, size_t len)
119 {
120 	struct bigint_hdr *hdr = (struct bigint_hdr *)bigInt;
121 
122 	static_assert(MBEDTLS_MPI_MAX_LIMBS + BIGINT_HDR_SIZE_IN_U32 >=
123 		      CFG_TA_BIGNUM_MAX_BITS / 32);
124 
125 	memset(bigInt, 0, len * sizeof(uint32_t));
126 	hdr->sign = 1;
127 
128 	/* "gpd.tee.arith.maxBigIntSize" is assigned CFG_TA_BIGNUM_MAX_BITS */
129 	if (len > CFG_TA_BIGNUM_MAX_BITS / 4)
130 		API_PANIC("Too large bigint");
131 	hdr->alloc_size = len - BIGINT_HDR_SIZE_IN_U32;
132 }
133 
134 void __GP11_TEE_BigIntInit(TEE_BigInt *bigInt, uint32_t len)
135 {
136 	TEE_BigIntInit(bigInt, len);
137 }
138 
139 TEE_Result TEE_BigIntConvertFromOctetString(TEE_BigInt *dest,
140 					    const uint8_t *buffer,
141 					    size_t bufferLen, int32_t sign)
142 {
143 	TEE_Result res;
144 	mbedtls_mpi mpi_dest;
145 
146 	get_mpi(&mpi_dest, NULL);
147 
148 	if (mbedtls_mpi_read_binary(&mpi_dest,  buffer, bufferLen))
149 		res = TEE_ERROR_OVERFLOW;
150 	else
151 		res = TEE_SUCCESS;
152 
153 	if (sign < 0)
154 		mpi_dest.s = -1;
155 
156 	if (!res)
157 		res = copy_mpi_to_bigint(&mpi_dest, dest);
158 
159 	mbedtls_mpi_free(&mpi_dest);
160 
161 	return res;
162 }
163 
164 TEE_Result __GP11_TEE_BigIntConvertFromOctetString(TEE_BigInt *dest,
165 						   const uint8_t *buffer,
166 						   uint32_t bufferLen,
167 						   int32_t sign)
168 {
169 	return TEE_BigIntConvertFromOctetString(dest, buffer, bufferLen, sign);
170 }
171 
172 TEE_Result TEE_BigIntConvertToOctetString(uint8_t *buffer, size_t *bufferLen,
173 					  const TEE_BigInt *bigInt)
174 {
175 	TEE_Result res = TEE_SUCCESS;
176 	mbedtls_mpi mpi;
177 	size_t sz;
178 
179 	get_mpi(&mpi, bigInt);
180 
181 	sz = mbedtls_mpi_size(&mpi);
182 	if (sz <= *bufferLen)
183 		MPI_CHECK(mbedtls_mpi_write_binary(&mpi, buffer, sz));
184 	else
185 		res = TEE_ERROR_SHORT_BUFFER;
186 
187 	*bufferLen = sz;
188 
189 	mbedtls_mpi_free(&mpi);
190 
191 	return res;
192 }
193 
194 TEE_Result __GP11_TEE_BigIntConvertToOctetString(uint8_t *buffer,
195 						 uint32_t *bufferLen,
196 						 const TEE_BigInt *bigInt)
197 {
198 	TEE_Result res = TEE_SUCCESS;
199 	size_t l = *bufferLen;
200 
201 	res = TEE_BigIntConvertToOctetString(buffer, &l, bigInt);
202 	*bufferLen = l;
203 	return res;
204 }
205 
206 void TEE_BigIntConvertFromS32(TEE_BigInt *dest, int32_t shortVal)
207 {
208 	mbedtls_mpi mpi;
209 
210 	get_mpi(&mpi, dest);
211 
212 	MPI_CHECK(mbedtls_mpi_lset(&mpi, shortVal));
213 
214 	MPI_CHECK(copy_mpi_to_bigint(&mpi, dest));
215 	mbedtls_mpi_free(&mpi);
216 }
217 
218 TEE_Result TEE_BigIntConvertToS32(int32_t *dest, const TEE_BigInt *src)
219 {
220 	TEE_Result res = TEE_SUCCESS;
221 	mbedtls_mpi mpi;
222 	uint32_t v;
223 
224 	get_mpi(&mpi, src);
225 
226 	if (mbedtls_mpi_write_binary(&mpi, (void *)&v, sizeof(v))) {
227 		res = TEE_ERROR_OVERFLOW;
228 		goto out;
229 	}
230 
231 	if (mpi.s > 0) {
232 		if (ADD_OVERFLOW(0, TEE_U32_FROM_BIG_ENDIAN(v), dest))
233 			res = TEE_ERROR_OVERFLOW;
234 	} else {
235 		if (SUB_OVERFLOW(0, TEE_U32_FROM_BIG_ENDIAN(v), dest))
236 			res = TEE_ERROR_OVERFLOW;
237 	}
238 
239 out:
240 	mbedtls_mpi_free(&mpi);
241 
242 	return res;
243 }
244 
245 int32_t TEE_BigIntCmp(const TEE_BigInt *op1, const TEE_BigInt *op2)
246 {
247 	mbedtls_mpi mpi1;
248 	mbedtls_mpi mpi2;
249 	int32_t rc;
250 
251 	get_mpi(&mpi1, op1);
252 	get_mpi(&mpi2, op2);
253 
254 	rc = mbedtls_mpi_cmp_mpi(&mpi1, &mpi2);
255 
256 	mbedtls_mpi_free(&mpi1);
257 	mbedtls_mpi_free(&mpi2);
258 
259 	return rc;
260 }
261 
262 int32_t TEE_BigIntCmpS32(const TEE_BigInt *op, int32_t shortVal)
263 {
264 	mbedtls_mpi mpi;
265 	int32_t rc;
266 
267 	get_mpi(&mpi, op);
268 
269 	rc = mbedtls_mpi_cmp_int(&mpi, shortVal);
270 
271 	mbedtls_mpi_free(&mpi);
272 
273 	return rc;
274 }
275 
276 void TEE_BigIntShiftRight(TEE_BigInt *dest, const TEE_BigInt *op, size_t bits)
277 {
278 	mbedtls_mpi mpi_dest;
279 	mbedtls_mpi mpi_op;
280 
281 	get_mpi(&mpi_dest, dest);
282 
283 	if (dest == op) {
284 		MPI_CHECK(mbedtls_mpi_shift_r(&mpi_dest, bits));
285 		goto out;
286 	}
287 
288 	get_mpi(&mpi_op, op);
289 
290 	if (mbedtls_mpi_size(&mpi_dest) >= mbedtls_mpi_size(&mpi_op)) {
291 		MPI_CHECK(mbedtls_mpi_copy(&mpi_dest, &mpi_op));
292 		MPI_CHECK(mbedtls_mpi_shift_r(&mpi_dest, bits));
293 	} else {
294 		mbedtls_mpi mpi_t;
295 
296 		get_mpi(&mpi_t, NULL);
297 
298 		/*
299 		 * We're using a temporary buffer to avoid the corner case
300 		 * where destination is unexpectedly overflowed by up to
301 		 * @bits number of bits.
302 		 */
303 		MPI_CHECK(mbedtls_mpi_copy(&mpi_t, &mpi_op));
304 		MPI_CHECK(mbedtls_mpi_shift_r(&mpi_t, bits));
305 		MPI_CHECK(mbedtls_mpi_copy(&mpi_dest, &mpi_t));
306 
307 		mbedtls_mpi_free(&mpi_t);
308 	}
309 
310 	mbedtls_mpi_free(&mpi_op);
311 
312 out:
313 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
314 	mbedtls_mpi_free(&mpi_dest);
315 }
316 
317 void __GP11_TEE_BigIntShiftRight(TEE_BigInt *dest, const TEE_BigInt *op,
318 				 uint32_t bits)
319 {
320 	TEE_BigIntShiftRight(dest, op, bits);
321 }
322 
323 bool TEE_BigIntGetBit(const TEE_BigInt *src, uint32_t bitIndex)
324 {
325 	bool rc;
326 	mbedtls_mpi mpi;
327 
328 	get_mpi(&mpi, src);
329 
330 	rc = mbedtls_mpi_get_bit(&mpi, bitIndex);
331 
332 	mbedtls_mpi_free(&mpi);
333 
334 	return rc;
335 }
336 
337 uint32_t TEE_BigIntGetBitCount(const TEE_BigInt *src)
338 {
339 	uint32_t rc;
340 	mbedtls_mpi mpi;
341 
342 	get_mpi(&mpi, src);
343 
344 	rc = mbedtls_mpi_bitlen(&mpi);
345 
346 	mbedtls_mpi_free(&mpi);
347 
348 	return rc;
349 }
350 
351 TEE_Result TEE_BigIntSetBit(TEE_BigInt *op, uint32_t bitIndex, bool value)
352 {
353 	TEE_Result res = TEE_SUCCESS;
354 	mbedtls_mpi mpi = { };
355 	int rc = 0;
356 
357 	get_mpi(&mpi, op);
358 
359 	rc = mbedtls_mpi_set_bit(&mpi, bitIndex, value);
360 	if (rc)
361 		res = TEE_ERROR_OVERFLOW;
362 	else
363 		res = copy_mpi_to_bigint(&mpi, op);
364 
365 	mbedtls_mpi_free(&mpi);
366 
367 	return res;
368 }
369 
370 TEE_Result TEE_BigIntAssign(TEE_BigInt *dest, const TEE_BigInt *src)
371 {
372 	const struct bigint_hdr *src_hdr = (struct bigint_hdr *)src;
373 	struct bigint_hdr *dst_hdr = (struct bigint_hdr *)dest;
374 
375 	if (dst_hdr == src_hdr)
376 		return TEE_SUCCESS;
377 
378 	if (dst_hdr->alloc_size < src_hdr->nblimbs)
379 		return TEE_ERROR_OVERFLOW;
380 
381 	dst_hdr->nblimbs = src_hdr->nblimbs;
382 	dst_hdr->sign = src_hdr->sign;
383 	memcpy(dst_hdr + 1, src_hdr + 1, src_hdr->nblimbs * sizeof(uint32_t));
384 
385 	return TEE_SUCCESS;
386 }
387 
388 static void bigint_binary(TEE_BigInt *dest, const TEE_BigInt *op1,
389 			  const TEE_BigInt *op2,
390 			  int (*func)(mbedtls_mpi *X, const mbedtls_mpi *A,
391 				      const mbedtls_mpi *B))
392 {
393 	mbedtls_mpi mpi_dest;
394 	mbedtls_mpi mpi_op1;
395 	mbedtls_mpi mpi_op2;
396 	mbedtls_mpi *pop1 = &mpi_op1;
397 	mbedtls_mpi *pop2 = &mpi_op2;
398 
399 	get_mpi(&mpi_dest, dest);
400 
401 	if (op1 == dest)
402 		pop1 = &mpi_dest;
403 	else
404 		get_mpi(&mpi_op1, op1);
405 
406 	if (op2 == dest)
407 		pop2 = &mpi_dest;
408 	else if (op2 == op1)
409 		pop2 = pop1;
410 	else
411 		get_mpi(&mpi_op2, op2);
412 
413 	MPI_CHECK(func(&mpi_dest, pop1, pop2));
414 
415 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
416 	mbedtls_mpi_free(&mpi_dest);
417 	if (pop1 == &mpi_op1)
418 		mbedtls_mpi_free(&mpi_op1);
419 	if (pop2 == &mpi_op2)
420 		mbedtls_mpi_free(&mpi_op2);
421 }
422 
423 static void bigint_binary_mod(TEE_BigInt *dest, const TEE_BigInt *op1,
424 			      const TEE_BigInt *op2, const TEE_BigInt *n,
425 			      int (*func)(mbedtls_mpi *X, const mbedtls_mpi *A,
426 					  const mbedtls_mpi *B))
427 {
428 	mbedtls_mpi mpi_dest;
429 	mbedtls_mpi mpi_op1;
430 	mbedtls_mpi mpi_op2;
431 	mbedtls_mpi mpi_n;
432 	mbedtls_mpi *pop1 = &mpi_op1;
433 	mbedtls_mpi *pop2 = &mpi_op2;
434 	mbedtls_mpi mpi_t;
435 
436 	if (TEE_BigIntCmpS32(n, 2) < 0)
437 		API_PANIC("Modulus is too short");
438 
439 	get_mpi(&mpi_dest, dest);
440 	get_mpi(&mpi_n, n);
441 
442 	if (op1 == dest)
443 		pop1 = &mpi_dest;
444 	else
445 		get_mpi(&mpi_op1, op1);
446 
447 	if (op2 == dest)
448 		pop2 = &mpi_dest;
449 	else if (op2 == op1)
450 		pop2 = pop1;
451 	else
452 		get_mpi(&mpi_op2, op2);
453 
454 	get_mpi(&mpi_t, NULL);
455 
456 	MPI_CHECK(func(&mpi_t, pop1, pop2));
457 	MPI_CHECK(mbedtls_mpi_mod_mpi(&mpi_dest, &mpi_t, &mpi_n));
458 
459 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
460 	mbedtls_mpi_free(&mpi_dest);
461 	if (pop1 == &mpi_op1)
462 		mbedtls_mpi_free(&mpi_op1);
463 	if (pop2 == &mpi_op2)
464 		mbedtls_mpi_free(&mpi_op2);
465 	mbedtls_mpi_free(&mpi_t);
466 	mbedtls_mpi_free(&mpi_n);
467 }
468 
469 void TEE_BigIntAdd(TEE_BigInt *dest, const TEE_BigInt *op1,
470 		   const TEE_BigInt *op2)
471 {
472 	bigint_binary(dest, op1, op2, mbedtls_mpi_add_mpi);
473 }
474 
475 void TEE_BigIntSub(TEE_BigInt *dest, const TEE_BigInt *op1,
476 		   const TEE_BigInt *op2)
477 {
478 	bigint_binary(dest, op1, op2, mbedtls_mpi_sub_mpi);
479 }
480 
481 void TEE_BigIntNeg(TEE_BigInt *dest, const TEE_BigInt *src)
482 {
483 	mbedtls_mpi mpi_dest;
484 
485 	get_mpi(&mpi_dest, dest);
486 
487 	if (dest != src) {
488 		mbedtls_mpi mpi_src;
489 
490 		get_mpi(&mpi_src, src);
491 
492 		MPI_CHECK(mbedtls_mpi_copy(&mpi_dest, &mpi_src));
493 
494 		mbedtls_mpi_free(&mpi_src);
495 	}
496 
497 	mpi_dest.s *= -1;
498 
499 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
500 	mbedtls_mpi_free(&mpi_dest);
501 }
502 
503 void TEE_BigIntMul(TEE_BigInt *dest, const TEE_BigInt *op1,
504 		   const TEE_BigInt *op2)
505 {
506 	size_t bs1 = TEE_BigIntGetBitCount(op1);
507 	size_t bs2 = TEE_BigIntGetBitCount(op2);
508 	size_t s = TEE_BigIntSizeInU32(bs1) + TEE_BigIntSizeInU32(bs2);
509 	TEE_BigInt zero[TEE_BigIntSizeInU32(1)] = { 0 };
510 	TEE_BigInt *tmp = NULL;
511 
512 	tmp = mempool_alloc(mbedtls_mpi_mempool, sizeof(uint32_t) * s);
513 	if (!tmp)
514 		TEE_Panic(TEE_ERROR_OUT_OF_MEMORY);
515 
516 	TEE_BigIntInit(tmp, s);
517 	TEE_BigIntInit(zero, TEE_BigIntSizeInU32(1));
518 
519 	bigint_binary(tmp, op1, op2, mbedtls_mpi_mul_mpi);
520 
521 	TEE_BigIntAdd(dest, tmp, zero);
522 
523 	mempool_free(mbedtls_mpi_mempool, tmp);
524 }
525 
526 void TEE_BigIntSquare(TEE_BigInt *dest, const TEE_BigInt *op)
527 {
528 	TEE_BigIntMul(dest, op, op);
529 }
530 
531 void TEE_BigIntDiv(TEE_BigInt *dest_q, TEE_BigInt *dest_r,
532 		   const TEE_BigInt *op1, const TEE_BigInt *op2)
533 {
534 	mbedtls_mpi mpi_dest_q;
535 	mbedtls_mpi mpi_dest_r;
536 	mbedtls_mpi mpi_op1;
537 	mbedtls_mpi mpi_op2;
538 	mbedtls_mpi *pop1 = &mpi_op1;
539 	mbedtls_mpi *pop2 = &mpi_op2;
540 
541 	get_mpi(&mpi_dest_q, dest_q);
542 	get_mpi(&mpi_dest_r, dest_r);
543 
544 	if (op1 == dest_q)
545 		pop1 = &mpi_dest_q;
546 	else if (op1 == dest_r)
547 		pop1 = &mpi_dest_r;
548 	else
549 		get_mpi(&mpi_op1, op1);
550 
551 	if (op2 == dest_q)
552 		pop2 = &mpi_dest_q;
553 	else if (op2 == dest_r)
554 		pop2 = &mpi_dest_r;
555 	else if (op2 == op1)
556 		pop2 = pop1;
557 	else
558 		get_mpi(&mpi_op2, op2);
559 
560 	MPI_CHECK(mbedtls_mpi_div_mpi(&mpi_dest_q, &mpi_dest_r, pop1, pop2));
561 
562 	if (dest_q)
563 		MPI_CHECK(copy_mpi_to_bigint(&mpi_dest_q, dest_q));
564 	if (dest_r)
565 		MPI_CHECK(copy_mpi_to_bigint(&mpi_dest_r, dest_r));
566 	mbedtls_mpi_free(&mpi_dest_q);
567 	mbedtls_mpi_free(&mpi_dest_r);
568 	if (pop1 == &mpi_op1)
569 		mbedtls_mpi_free(&mpi_op1);
570 	if (pop2 == &mpi_op2)
571 		mbedtls_mpi_free(&mpi_op2);
572 }
573 
574 void TEE_BigIntMod(TEE_BigInt *dest, const TEE_BigInt *op, const TEE_BigInt *n)
575 {
576 	if (TEE_BigIntCmpS32(n, 2) < 0)
577 		API_PANIC("Modulus is too short");
578 
579 	bigint_binary(dest, op, n, mbedtls_mpi_mod_mpi);
580 }
581 
582 void TEE_BigIntAddMod(TEE_BigInt *dest, const TEE_BigInt *op1,
583 		      const TEE_BigInt *op2, const TEE_BigInt *n)
584 {
585 	bigint_binary_mod(dest, op1, op2, n, mbedtls_mpi_add_mpi);
586 }
587 
588 void TEE_BigIntSubMod(TEE_BigInt *dest, const TEE_BigInt *op1,
589 		      const TEE_BigInt *op2, const TEE_BigInt *n)
590 {
591 	bigint_binary_mod(dest, op1, op2, n, mbedtls_mpi_sub_mpi);
592 }
593 
594 void TEE_BigIntMulMod(TEE_BigInt *dest, const TEE_BigInt *op1,
595 		      const TEE_BigInt *op2, const TEE_BigInt *n)
596 {
597 	bigint_binary_mod(dest, op1, op2, n, mbedtls_mpi_mul_mpi);
598 }
599 
600 void TEE_BigIntSquareMod(TEE_BigInt *dest, const TEE_BigInt *op,
601 			 const TEE_BigInt *n)
602 {
603 	TEE_BigIntMulMod(dest, op, op, n);
604 }
605 
606 void TEE_BigIntInvMod(TEE_BigInt *dest, const TEE_BigInt *op,
607 		      const TEE_BigInt *n)
608 {
609 	mbedtls_mpi mpi_dest;
610 	mbedtls_mpi mpi_op;
611 	mbedtls_mpi mpi_n;
612 	mbedtls_mpi *pop = &mpi_op;
613 
614 	if (TEE_BigIntCmpS32(n, 2) < 0 || TEE_BigIntCmpS32(op, 0) == 0)
615 		API_PANIC("too small modulus or trying to invert zero");
616 
617 	get_mpi(&mpi_dest, dest);
618 	get_mpi(&mpi_n, n);
619 
620 	if (op == dest)
621 		pop = &mpi_dest;
622 	else
623 		get_mpi(&mpi_op, op);
624 
625 	MPI_CHECK(mbedtls_mpi_inv_mod(&mpi_dest, pop, &mpi_n));
626 
627 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
628 	mbedtls_mpi_free(&mpi_dest);
629 	mbedtls_mpi_free(&mpi_n);
630 	if (pop == &mpi_op)
631 		mbedtls_mpi_free(&mpi_op);
632 }
633 
634 bool TEE_BigIntRelativePrime(const TEE_BigInt *op1, const TEE_BigInt *op2)
635 {
636 	bool rc;
637 	mbedtls_mpi mpi_op1;
638 	mbedtls_mpi mpi_op2;
639 	mbedtls_mpi *pop2 = &mpi_op2;
640 	mbedtls_mpi gcd;
641 
642 	get_mpi(&mpi_op1, op1);
643 
644 	if (op2 == op1)
645 		pop2 = &mpi_op1;
646 	else
647 		get_mpi(&mpi_op2, op2);
648 
649 	get_mpi(&gcd, NULL);
650 
651 	MPI_CHECK(mbedtls_mpi_gcd(&gcd, &mpi_op1, &mpi_op2));
652 
653 	rc = !mbedtls_mpi_cmp_int(&gcd, 1);
654 
655 	mbedtls_mpi_free(&gcd);
656 	mbedtls_mpi_free(&mpi_op1);
657 	if (pop2 == &mpi_op2)
658 		mbedtls_mpi_free(&mpi_op2);
659 
660 	return rc;
661 }
662 
663 static bool mpi_is_odd(mbedtls_mpi *x)
664 {
665 	return mbedtls_mpi_get_bit(x, 0);
666 }
667 
668 static bool mpi_is_even(mbedtls_mpi *x)
669 {
670 	return !mpi_is_odd(x);
671 }
672 
673 /*
674  * Based on libmpa implementation __mpa_egcd(), modified to work with MPI
675  * instead.
676  */
677 static void mpi_egcd(mbedtls_mpi *gcd, mbedtls_mpi *a, mbedtls_mpi *b,
678 		     mbedtls_mpi *x_in, mbedtls_mpi *y_in)
679 {
680 	mbedtls_mpi_uint k;
681 	mbedtls_mpi A;
682 	mbedtls_mpi B;
683 	mbedtls_mpi C;
684 	mbedtls_mpi D;
685 	mbedtls_mpi x;
686 	mbedtls_mpi y;
687 	mbedtls_mpi u;
688 
689 	get_mpi(&A, NULL);
690 	get_mpi(&B, NULL);
691 	get_mpi(&C, NULL);
692 	get_mpi(&D, NULL);
693 	get_mpi(&x, NULL);
694 	get_mpi(&y, NULL);
695 	get_mpi(&u, NULL);
696 
697 	/* have y < x from assumption */
698 	if (!mbedtls_mpi_cmp_int(y_in, 0)) {
699 		MPI_CHECK(mbedtls_mpi_lset(a, 1));
700 		MPI_CHECK(mbedtls_mpi_lset(b, 0));
701 		MPI_CHECK(mbedtls_mpi_copy(gcd, x_in));
702 		goto out;
703 	}
704 
705 	MPI_CHECK(mbedtls_mpi_copy(&x, x_in));
706 	MPI_CHECK(mbedtls_mpi_copy(&y, y_in));
707 
708 	k = 0;
709 	while (mpi_is_even(&x) && mpi_is_even(&y)) {
710 		k++;
711 		MPI_CHECK(mbedtls_mpi_shift_r(&x, 1));
712 		MPI_CHECK(mbedtls_mpi_shift_r(&y, 1));
713 	}
714 
715 	MPI_CHECK(mbedtls_mpi_copy(&u, &x));
716 	MPI_CHECK(mbedtls_mpi_copy(gcd, &y));
717 	MPI_CHECK(mbedtls_mpi_lset(&A, 1));
718 	MPI_CHECK(mbedtls_mpi_lset(&B, 0));
719 	MPI_CHECK(mbedtls_mpi_lset(&C, 0));
720 	MPI_CHECK(mbedtls_mpi_lset(&D, 1));
721 
722 	while (mbedtls_mpi_cmp_int(&u, 0)) {
723 		while (mpi_is_even(&u)) {
724 			MPI_CHECK(mbedtls_mpi_shift_r(&u, 1));
725 			if (mpi_is_odd(&A) || mpi_is_odd(&B)) {
726 				MPI_CHECK(mbedtls_mpi_add_mpi(&A, &A, &y));
727 				MPI_CHECK(mbedtls_mpi_sub_mpi(&B, &B, &x));
728 			}
729 			MPI_CHECK(mbedtls_mpi_shift_r(&A, 1));
730 			MPI_CHECK(mbedtls_mpi_shift_r(&B, 1));
731 		}
732 
733 		while (mpi_is_even(gcd)) {
734 			MPI_CHECK(mbedtls_mpi_shift_r(gcd, 1));
735 			if (mpi_is_odd(&C) || mpi_is_odd(&D)) {
736 				MPI_CHECK(mbedtls_mpi_add_mpi(&C, &C, &y));
737 				MPI_CHECK(mbedtls_mpi_sub_mpi(&D, &D, &x));
738 			}
739 			MPI_CHECK(mbedtls_mpi_shift_r(&C, 1));
740 			MPI_CHECK(mbedtls_mpi_shift_r(&D, 1));
741 
742 		}
743 
744 		if (mbedtls_mpi_cmp_mpi(&u, gcd) >= 0) {
745 			MPI_CHECK(mbedtls_mpi_sub_mpi(&u, &u, gcd));
746 			MPI_CHECK(mbedtls_mpi_sub_mpi(&A, &A, &C));
747 			MPI_CHECK(mbedtls_mpi_sub_mpi(&B, &B, &D));
748 		} else {
749 			MPI_CHECK(mbedtls_mpi_sub_mpi(gcd, gcd, &u));
750 			MPI_CHECK(mbedtls_mpi_sub_mpi(&C, &C, &A));
751 			MPI_CHECK(mbedtls_mpi_sub_mpi(&D, &D, &B));
752 		}
753 	}
754 
755 	MPI_CHECK(mbedtls_mpi_copy(a, &C));
756 	MPI_CHECK(mbedtls_mpi_copy(b, &D));
757 	MPI_CHECK(mbedtls_mpi_shift_l(gcd, k));
758 
759 out:
760 	mbedtls_mpi_free(&A);
761 	mbedtls_mpi_free(&B);
762 	mbedtls_mpi_free(&C);
763 	mbedtls_mpi_free(&D);
764 	mbedtls_mpi_free(&x);
765 	mbedtls_mpi_free(&y);
766 	mbedtls_mpi_free(&u);
767 }
768 
769 void TEE_BigIntComputeExtendedGcd(TEE_BigInt *gcd, TEE_BigInt *u,
770 				  TEE_BigInt *v, const TEE_BigInt *op1,
771 				  const TEE_BigInt *op2)
772 {
773 	mbedtls_mpi mpi_gcd_res;
774 	mbedtls_mpi mpi_op1;
775 	mbedtls_mpi mpi_op2;
776 	mbedtls_mpi *pop2 = &mpi_op2;
777 
778 	get_mpi(&mpi_gcd_res, gcd);
779 	get_mpi(&mpi_op1, op1);
780 
781 	if (op2 == op1)
782 		pop2 = &mpi_op1;
783 	else
784 		get_mpi(&mpi_op2, op2);
785 
786 	if (!u && !v) {
787 		MPI_CHECK(mbedtls_mpi_gcd(&mpi_gcd_res, &mpi_op1, pop2));
788 	} else {
789 		mbedtls_mpi mpi_u;
790 		mbedtls_mpi mpi_v;
791 		int8_t s1 = mpi_op1.s;
792 		int8_t s2 = pop2->s;
793 		int cmp;
794 
795 		mpi_op1.s = 1;
796 		pop2->s = 1;
797 
798 		get_mpi(&mpi_u, u);
799 		get_mpi(&mpi_v, v);
800 
801 		cmp = mbedtls_mpi_cmp_abs(&mpi_op1, pop2);
802 		if (cmp == 0) {
803 			MPI_CHECK(mbedtls_mpi_copy(&mpi_gcd_res, &mpi_op1));
804 			MPI_CHECK(mbedtls_mpi_lset(&mpi_u, 1));
805 			MPI_CHECK(mbedtls_mpi_lset(&mpi_v, 0));
806 		} else if (cmp > 0) {
807 			mpi_egcd(&mpi_gcd_res, &mpi_u, &mpi_v, &mpi_op1, pop2);
808 		} else {
809 			mpi_egcd(&mpi_gcd_res, &mpi_v, &mpi_u, pop2, &mpi_op1);
810 		}
811 
812 		mpi_u.s *= s1;
813 		mpi_v.s *= s2;
814 
815 		MPI_CHECK(copy_mpi_to_bigint(&mpi_u, u));
816 		MPI_CHECK(copy_mpi_to_bigint(&mpi_v, v));
817 		mbedtls_mpi_free(&mpi_u);
818 		mbedtls_mpi_free(&mpi_v);
819 	}
820 
821 	MPI_CHECK(copy_mpi_to_bigint(&mpi_gcd_res, gcd));
822 	mbedtls_mpi_free(&mpi_gcd_res);
823 	mbedtls_mpi_free(&mpi_op1);
824 	if (pop2 == &mpi_op2)
825 		mbedtls_mpi_free(&mpi_op2);
826 }
827 
828 static int rng_read(void *ignored __unused, unsigned char *buf, size_t blen)
829 {
830 	if (_utee_cryp_random_number_generate(buf, blen))
831 		return MBEDTLS_ERR_MPI_FILE_IO_ERROR;
832 	return 0;
833 }
834 
835 int32_t TEE_BigIntIsProbablePrime(const TEE_BigInt *op,
836 				  uint32_t confidenceLevel __unused)
837 {
838 	int rc;
839 	mbedtls_mpi mpi_op;
840 
841 	get_mpi(&mpi_op, op);
842 
843 	rc = mbedtls_mpi_is_prime(&mpi_op, rng_read, NULL);
844 
845 	mbedtls_mpi_free(&mpi_op);
846 
847 	if (rc)
848 		return 0;
849 
850 	return 1;
851 }
852 
853 /*
854  * Not so fast FMM implementation based on the normal big int functions.
855  *
856  * Note that these functions (along with all the other functions in this
857  * file) only are used directly by the TA doing bigint arithmetics on its
858  * own. Performance of RSA operations in TEE Internal API are not affected
859  * by this.
860  */
861 void TEE_BigIntInitFMM(TEE_BigIntFMM *bigIntFMM, size_t len)
862 {
863 	TEE_BigIntInit(bigIntFMM, len);
864 }
865 
866 void __GP11_TEE_BigIntInitFMM(TEE_BigIntFMM *bigIntFMM, uint32_t len)
867 {
868 	TEE_BigIntInitFMM(bigIntFMM, len);
869 }
870 
871 void TEE_BigIntInitFMMContext(TEE_BigIntFMMContext *context __unused,
872 			      size_t len __unused,
873 			      const TEE_BigInt *modulus __unused)
874 {
875 }
876 
877 void __GP11_TEE_BigIntInitFMMContext(TEE_BigIntFMMContext *context,
878 				     uint32_t len, const TEE_BigInt *modulus)
879 {
880 	TEE_BigIntInitFMMContext(context, len, modulus);
881 }
882 
883 TEE_Result TEE_BigIntInitFMMContext1(TEE_BigIntFMMContext *context __unused,
884 				     size_t len __unused,
885 				     const TEE_BigInt *modulus __unused)
886 {
887 	return TEE_SUCCESS;
888 }
889 
890 size_t TEE_BigIntFMMSizeInU32(size_t modulusSizeInBits)
891 {
892 	return TEE_BigIntSizeInU32(modulusSizeInBits);
893 }
894 
895 uint32_t __GP11_TEE_BigIntFMMSizeInU32(uint32_t modulusSizeInBits)
896 {
897 	return TEE_BigIntFMMSizeInU32(modulusSizeInBits);
898 }
899 
900 size_t TEE_BigIntFMMContextSizeInU32(size_t modulusSizeInBits __unused)
901 {
902 	/* Return something larger than 0 to keep malloc() and friends happy */
903 	return 1;
904 }
905 
906 uint32_t __GP11_TEE_BigIntFMMContextSizeInU32(uint32_t modulusSizeInBits)
907 {
908 	return TEE_BigIntFMMContextSizeInU32(modulusSizeInBits);
909 }
910 
911 void TEE_BigIntConvertToFMM(TEE_BigIntFMM *dest, const TEE_BigInt *src,
912 			    const TEE_BigInt *n,
913 			    const TEE_BigIntFMMContext *context __unused)
914 {
915 	TEE_BigIntMod(dest, src, n);
916 }
917 
918 void TEE_BigIntConvertFromFMM(TEE_BigInt *dest, const TEE_BigIntFMM *src,
919 			      const TEE_BigInt *n __unused,
920 			      const TEE_BigIntFMMContext *context __unused)
921 {
922 	mbedtls_mpi mpi_dst;
923 	mbedtls_mpi mpi_src;
924 
925 	get_mpi(&mpi_dst, dest);
926 	get_mpi(&mpi_src, src);
927 
928 	MPI_CHECK(mbedtls_mpi_copy(&mpi_dst, &mpi_src));
929 
930 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dst, dest));
931 	mbedtls_mpi_free(&mpi_dst);
932 	mbedtls_mpi_free(&mpi_src);
933 }
934 
935 void TEE_BigIntComputeFMM(TEE_BigIntFMM *dest, const TEE_BigIntFMM *op1,
936 			  const TEE_BigIntFMM *op2, const TEE_BigInt *n,
937 			  const TEE_BigIntFMMContext *context __unused)
938 {
939 	mbedtls_mpi mpi_dst;
940 	mbedtls_mpi mpi_op1;
941 	mbedtls_mpi mpi_op2;
942 	mbedtls_mpi mpi_n;
943 	mbedtls_mpi mpi_t;
944 
945 	get_mpi(&mpi_dst, dest);
946 	get_mpi(&mpi_op1, op1);
947 	get_mpi(&mpi_op2, op2);
948 	get_mpi(&mpi_n, n);
949 	get_mpi(&mpi_t, NULL);
950 
951 	MPI_CHECK(mbedtls_mpi_mul_mpi(&mpi_t, &mpi_op1, &mpi_op2));
952 	MPI_CHECK(mbedtls_mpi_mod_mpi(&mpi_dst, &mpi_t, &mpi_n));
953 
954 	mbedtls_mpi_free(&mpi_t);
955 	mbedtls_mpi_free(&mpi_n);
956 	mbedtls_mpi_free(&mpi_op2);
957 	mbedtls_mpi_free(&mpi_op1);
958 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dst, dest));
959 	mbedtls_mpi_free(&mpi_dst);
960 }
961