xref: /optee_os/lib/libutee/tee_api_arith_mpi.c (revision e6e7781f4ba7fe73ee312f4eed640e7ba1d04752)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (c) 2018, Linaro limited
4  */
5 #include <assert.h>
6 #include <mbedtls/bignum.h>
7 #include <mempool.h>
8 #include <stdio.h>
9 #include <string.h>
10 #include <tee_api.h>
11 #include <tee_arith_internal.h>
12 #include <utee_defines.h>
13 #include <utee_syscalls.h>
14 #include <util.h>
15 
16 #define MPI_MEMPOOL_SIZE	(12 * 1024)
17 
18 static void __noreturn api_panic(const char *func, int line, const char *msg)
19 {
20 	printf("Panic function %s, line %d: %s\n", func, line, msg);
21 	TEE_Panic(0xB16127 /*BIGINT*/);
22 	while (1)
23 		; /* Panic will crash the thread */
24 }
25 
26 #define API_PANIC(x) api_panic(__func__, __LINE__, x)
27 
28 static void __noreturn mpi_panic(const char *func, int line, int rc)
29 {
30 	printf("Panic function %s, line %d, code %d\n", func, line, rc);
31 	TEE_Panic(0xB16127 /*BIGINT*/);
32 	while (1)
33 		; /* Panic will crash the thread */
34 }
35 
36 #define MPI_CHECK(x) do { \
37 		int _rc = (x); \
38 		 \
39 		if (_rc) \
40 			mpi_panic(__func__, __LINE__, _rc); \
41 	} while (0)
42 
43 void _TEE_MathAPI_Init(void)
44 {
45 	static uint8_t data[MPI_MEMPOOL_SIZE] __aligned(MEMPOOL_ALIGN);
46 
47 	mbedtls_mpi_mempool = mempool_alloc_pool(data, sizeof(data), NULL);
48 	if (!mbedtls_mpi_mempool)
49 		API_PANIC("Failed to initialize memory pool");
50 }
51 
52 struct bigint_hdr {
53 	int32_t sign;
54 	uint16_t alloc_size;
55 	uint16_t nblimbs;
56 };
57 
58 #define BIGINT_HDR_SIZE_IN_U32	2
59 
60 static TEE_Result copy_mpi_to_bigint(mbedtls_mpi *mpi, TEE_BigInt *bigInt)
61 {
62 	struct bigint_hdr *hdr = (struct bigint_hdr *)bigInt;
63 	size_t n = mpi->n;
64 
65 	/* Trim of eventual insignificant zeroes */
66 	while (n && !mpi->p[n - 1])
67 		n--;
68 
69 	if (hdr->alloc_size < n)
70 		return TEE_ERROR_OVERFLOW;
71 
72 	hdr->nblimbs = n;
73 	hdr->sign = mpi->s;
74 	memcpy(hdr + 1, mpi->p, mpi->n * sizeof(mbedtls_mpi_uint));
75 
76 	return TEE_SUCCESS;
77 }
78 
79 /*
80  * Initializes a MPI.
81  *
82  * A temporary MPI is allocated and if a bigInt is supplied the MPI is
83  * initialized with the value of the bigInt.
84  */
85 static void get_mpi(mbedtls_mpi *mpi, const TEE_BigInt *bigInt)
86 {
87 	/*
88 	 * The way the GP spec is defining the bignums it's
89 	 * difficult/tricky to do it using 64-bit arithmetics given that
90 	 * we'd need 64-bit alignment of the data as well.
91 	 */
92 	COMPILE_TIME_ASSERT(sizeof(mbedtls_mpi_uint) == sizeof(uint32_t));
93 
94 	/*
95 	 * The struct bigint_hdr is the overhead added to the bigint and
96 	 * is required to take exactly 2 uint32_t.
97 	 */
98 	COMPILE_TIME_ASSERT(sizeof(struct bigint_hdr) ==
99 			    sizeof(uint32_t) * BIGINT_HDR_SIZE_IN_U32);
100 
101 	mbedtls_mpi_init_mempool(mpi);
102 
103 	if (bigInt) {
104 		const struct bigint_hdr *hdr = (struct bigint_hdr *)bigInt;
105 		const mbedtls_mpi_uint *p = (const mbedtls_mpi_uint *)(hdr + 1);
106 		size_t n = hdr->nblimbs;
107 
108 		/* Trim of eventual insignificant zeroes */
109 		while (n && !p[n - 1])
110 			n--;
111 
112 		MPI_CHECK(mbedtls_mpi_grow(mpi, n));
113 		mpi->s = hdr->sign;
114 		memcpy(mpi->p, p, n * sizeof(mbedtls_mpi_uint));
115 	}
116 }
117 
118 void TEE_BigIntInit(TEE_BigInt *bigInt, uint32_t len)
119 {
120 	struct bigint_hdr *hdr = (struct bigint_hdr *)bigInt;
121 
122 	memset(bigInt, 0, len * sizeof(uint32_t));
123 	hdr->sign = 1;
124 	if ((len - BIGINT_HDR_SIZE_IN_U32) > MBEDTLS_MPI_MAX_LIMBS)
125 		API_PANIC("Too large bigint");
126 	hdr->alloc_size = len - BIGINT_HDR_SIZE_IN_U32;
127 }
128 
129 TEE_Result TEE_BigIntConvertFromOctetString(TEE_BigInt *dest,
130 					    const uint8_t *buffer,
131 					    uint32_t bufferLen, int32_t sign)
132 {
133 	TEE_Result res;
134 	mbedtls_mpi mpi_dest;
135 
136 	get_mpi(&mpi_dest, NULL);
137 
138 	if (mbedtls_mpi_read_binary(&mpi_dest,  buffer, bufferLen))
139 		res = TEE_ERROR_OVERFLOW;
140 	else
141 		res = TEE_SUCCESS;
142 
143 	if (sign < 0)
144 		mpi_dest.s = -1;
145 
146 	if (!res)
147 		res = copy_mpi_to_bigint(&mpi_dest, dest);
148 
149 	mbedtls_mpi_free(&mpi_dest);
150 
151 	return res;
152 }
153 
154 TEE_Result TEE_BigIntConvertToOctetString(uint8_t *buffer, uint32_t *bufferLen,
155 					  const TEE_BigInt *bigInt)
156 {
157 	TEE_Result res = TEE_SUCCESS;
158 	mbedtls_mpi mpi;
159 	size_t sz;
160 
161 	get_mpi(&mpi, bigInt);
162 
163 	sz = mbedtls_mpi_size(&mpi);
164 	if (sz <= *bufferLen)
165 		MPI_CHECK(mbedtls_mpi_write_binary(&mpi, buffer, sz));
166 	else
167 		res = TEE_ERROR_SHORT_BUFFER;
168 
169 	*bufferLen = sz;
170 
171 	mbedtls_mpi_free(&mpi);
172 
173 	return res;
174 }
175 
176 void TEE_BigIntConvertFromS32(TEE_BigInt *dest, int32_t shortVal)
177 {
178 	mbedtls_mpi mpi;
179 
180 	get_mpi(&mpi, dest);
181 
182 	MPI_CHECK(mbedtls_mpi_lset(&mpi, shortVal));
183 
184 	MPI_CHECK(copy_mpi_to_bigint(&mpi, dest));
185 	mbedtls_mpi_free(&mpi);
186 }
187 
188 TEE_Result TEE_BigIntConvertToS32(int32_t *dest, const TEE_BigInt *src)
189 {
190 	TEE_Result res = TEE_SUCCESS;
191 	mbedtls_mpi mpi;
192 	uint32_t v;
193 
194 	get_mpi(&mpi, src);
195 
196 	if (mbedtls_mpi_write_binary(&mpi, (void *)&v, sizeof(v))) {
197 		res = TEE_ERROR_OVERFLOW;
198 		goto out;
199 	}
200 
201 	if (mpi.s > 0) {
202 		if (ADD_OVERFLOW(0, TEE_U32_FROM_BIG_ENDIAN(v), dest))
203 			res = TEE_ERROR_OVERFLOW;
204 	} else {
205 		if (SUB_OVERFLOW(0, TEE_U32_FROM_BIG_ENDIAN(v), dest))
206 			res = TEE_ERROR_OVERFLOW;
207 	}
208 
209 out:
210 	mbedtls_mpi_free(&mpi);
211 
212 	return res;
213 }
214 
215 int32_t TEE_BigIntCmp(const TEE_BigInt *op1, const TEE_BigInt *op2)
216 {
217 	mbedtls_mpi mpi1;
218 	mbedtls_mpi mpi2;
219 	int32_t rc;
220 
221 	get_mpi(&mpi1, op1);
222 	get_mpi(&mpi2, op2);
223 
224 	rc = mbedtls_mpi_cmp_mpi(&mpi1, &mpi2);
225 
226 	mbedtls_mpi_free(&mpi1);
227 	mbedtls_mpi_free(&mpi2);
228 
229 	return rc;
230 }
231 
232 int32_t TEE_BigIntCmpS32(const TEE_BigInt *op, int32_t shortVal)
233 {
234 	mbedtls_mpi mpi;
235 	int32_t rc;
236 
237 	get_mpi(&mpi, op);
238 
239 	rc = mbedtls_mpi_cmp_int(&mpi, shortVal);
240 
241 	mbedtls_mpi_free(&mpi);
242 
243 	return rc;
244 }
245 
246 void TEE_BigIntShiftRight(TEE_BigInt *dest, const TEE_BigInt *op, size_t bits)
247 {
248 	mbedtls_mpi mpi_dest;
249 	mbedtls_mpi mpi_op;
250 
251 	get_mpi(&mpi_dest, dest);
252 
253 	if (dest == op) {
254 		MPI_CHECK(mbedtls_mpi_shift_r(&mpi_dest, bits));
255 		goto out;
256 	}
257 
258 	get_mpi(&mpi_op, op);
259 
260 	if (mbedtls_mpi_size(&mpi_dest) >= mbedtls_mpi_size(&mpi_op)) {
261 		MPI_CHECK(mbedtls_mpi_copy(&mpi_dest, &mpi_op));
262 		MPI_CHECK(mbedtls_mpi_shift_r(&mpi_dest, bits));
263 	} else {
264 		mbedtls_mpi mpi_t;
265 
266 		get_mpi(&mpi_t, NULL);
267 
268 		/*
269 		 * We're using a temporary buffer to avoid the corner case
270 		 * where destination is unexpectedly overflowed by up to
271 		 * @bits number of bits.
272 		 */
273 		MPI_CHECK(mbedtls_mpi_copy(&mpi_t, &mpi_op));
274 		MPI_CHECK(mbedtls_mpi_shift_r(&mpi_t, bits));
275 		MPI_CHECK(mbedtls_mpi_copy(&mpi_dest, &mpi_t));
276 
277 		mbedtls_mpi_free(&mpi_t);
278 	}
279 
280 	mbedtls_mpi_free(&mpi_op);
281 
282 out:
283 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
284 	mbedtls_mpi_free(&mpi_dest);
285 }
286 
287 bool TEE_BigIntGetBit(const TEE_BigInt *src, uint32_t bitIndex)
288 {
289 	bool rc;
290 	mbedtls_mpi mpi;
291 
292 	get_mpi(&mpi, src);
293 
294 	rc = mbedtls_mpi_get_bit(&mpi, bitIndex);
295 
296 	mbedtls_mpi_free(&mpi);
297 
298 	return rc;
299 }
300 
301 uint32_t TEE_BigIntGetBitCount(const TEE_BigInt *src)
302 {
303 	uint32_t rc;
304 	mbedtls_mpi mpi;
305 
306 	get_mpi(&mpi, src);
307 
308 	rc = mbedtls_mpi_bitlen(&mpi);
309 
310 	mbedtls_mpi_free(&mpi);
311 
312 	return rc;
313 }
314 
315 static void bigint_binary(TEE_BigInt *dest, const TEE_BigInt *op1,
316 			  const TEE_BigInt *op2,
317 			  int (*func)(mbedtls_mpi *X, const mbedtls_mpi *A,
318 				      const mbedtls_mpi *B))
319 {
320 	mbedtls_mpi mpi_dest;
321 	mbedtls_mpi mpi_op1;
322 	mbedtls_mpi mpi_op2;
323 	mbedtls_mpi *pop1 = &mpi_op1;
324 	mbedtls_mpi *pop2 = &mpi_op2;
325 
326 	get_mpi(&mpi_dest, dest);
327 
328 	if (op1 == dest)
329 		pop1 = &mpi_dest;
330 	else
331 		get_mpi(&mpi_op1, op1);
332 
333 	if (op2 == dest)
334 		pop2 = &mpi_dest;
335 	else if (op2 == op1)
336 		pop2 = pop1;
337 	else
338 		get_mpi(&mpi_op2, op2);
339 
340 	MPI_CHECK(func(&mpi_dest, pop1, pop2));
341 
342 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
343 	mbedtls_mpi_free(&mpi_dest);
344 	if (pop1 == &mpi_op1)
345 		mbedtls_mpi_free(&mpi_op1);
346 	if (pop2 == &mpi_op2)
347 		mbedtls_mpi_free(&mpi_op2);
348 }
349 
350 static void bigint_binary_mod(TEE_BigInt *dest, const TEE_BigInt *op1,
351 			      const TEE_BigInt *op2, const TEE_BigInt *n,
352 			      int (*func)(mbedtls_mpi *X, const mbedtls_mpi *A,
353 					  const mbedtls_mpi *B))
354 {
355 	if (TEE_BigIntCmpS32(n, 2) < 0)
356 		API_PANIC("Modulus is too short");
357 
358 	mbedtls_mpi mpi_dest;
359 	mbedtls_mpi mpi_op1;
360 	mbedtls_mpi mpi_op2;
361 	mbedtls_mpi mpi_n;
362 	mbedtls_mpi *pop1 = &mpi_op1;
363 	mbedtls_mpi *pop2 = &mpi_op2;
364 	mbedtls_mpi mpi_t;
365 
366 	get_mpi(&mpi_dest, dest);
367 	get_mpi(&mpi_n, n);
368 
369 	if (op1 == dest)
370 		pop1 = &mpi_dest;
371 	else
372 		get_mpi(&mpi_op1, op1);
373 
374 	if (op2 == dest)
375 		pop2 = &mpi_dest;
376 	else if (op2 == op1)
377 		pop2 = pop1;
378 	else
379 		get_mpi(&mpi_op2, op2);
380 
381 	get_mpi(&mpi_t, NULL);
382 
383 	MPI_CHECK(func(&mpi_t, pop1, pop2));
384 	MPI_CHECK(mbedtls_mpi_mod_mpi(&mpi_dest, &mpi_t, &mpi_n));
385 
386 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
387 	mbedtls_mpi_free(&mpi_dest);
388 	if (pop1 == &mpi_op1)
389 		mbedtls_mpi_free(&mpi_op1);
390 	if (pop2 == &mpi_op2)
391 		mbedtls_mpi_free(&mpi_op2);
392 	mbedtls_mpi_free(&mpi_t);
393 }
394 
395 void TEE_BigIntAdd(TEE_BigInt *dest, const TEE_BigInt *op1,
396 		   const TEE_BigInt *op2)
397 {
398 	bigint_binary(dest, op1, op2, mbedtls_mpi_add_mpi);
399 }
400 
401 void TEE_BigIntSub(TEE_BigInt *dest, const TEE_BigInt *op1,
402 		   const TEE_BigInt *op2)
403 {
404 	bigint_binary(dest, op1, op2, mbedtls_mpi_sub_mpi);
405 }
406 
407 void TEE_BigIntNeg(TEE_BigInt *dest, const TEE_BigInt *src)
408 {
409 	mbedtls_mpi mpi_dest;
410 
411 	get_mpi(&mpi_dest, dest);
412 
413 	if (dest != src) {
414 		mbedtls_mpi mpi_src;
415 
416 		get_mpi(&mpi_src, src);
417 
418 		MPI_CHECK(mbedtls_mpi_copy(&mpi_dest, &mpi_src));
419 
420 		mbedtls_mpi_free(&mpi_src);
421 	}
422 
423 	mpi_dest.s *= -1;
424 
425 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
426 	mbedtls_mpi_free(&mpi_dest);
427 }
428 
429 void TEE_BigIntMul(TEE_BigInt *dest, const TEE_BigInt *op1,
430 		   const TEE_BigInt *op2)
431 {
432 	size_t bs1 = TEE_BigIntGetBitCount(op1);
433 	size_t bs2 = TEE_BigIntGetBitCount(op2);
434 	size_t s = TEE_BigIntSizeInU32(bs1) + TEE_BigIntSizeInU32(bs2);
435 	TEE_BigInt zero[TEE_BigIntSizeInU32(1)] = { 0 };
436 	TEE_BigInt *tmp = NULL;
437 
438 	tmp = mempool_alloc(mbedtls_mpi_mempool, sizeof(uint32_t) * s);
439 	if (!tmp)
440 		TEE_Panic(TEE_ERROR_OUT_OF_MEMORY);
441 
442 	TEE_BigIntInit(tmp, s);
443 	TEE_BigIntInit(zero, TEE_BigIntSizeInU32(1));
444 
445 	bigint_binary(tmp, op1, op2, mbedtls_mpi_mul_mpi);
446 
447 	TEE_BigIntAdd(dest, tmp, zero);
448 
449 	mempool_free(mbedtls_mpi_mempool, tmp);
450 }
451 
452 void TEE_BigIntSquare(TEE_BigInt *dest, const TEE_BigInt *op)
453 {
454 	TEE_BigIntMul(dest, op, op);
455 }
456 
457 void TEE_BigIntDiv(TEE_BigInt *dest_q, TEE_BigInt *dest_r,
458 		   const TEE_BigInt *op1, const TEE_BigInt *op2)
459 {
460 	mbedtls_mpi mpi_dest_q;
461 	mbedtls_mpi mpi_dest_r;
462 	mbedtls_mpi mpi_op1;
463 	mbedtls_mpi mpi_op2;
464 	mbedtls_mpi *pop1 = &mpi_op1;
465 	mbedtls_mpi *pop2 = &mpi_op2;
466 
467 	get_mpi(&mpi_dest_q, dest_q);
468 	get_mpi(&mpi_dest_r, dest_r);
469 
470 	if (op1 == dest_q)
471 		pop1 = &mpi_dest_q;
472 	else if (op1 == dest_r)
473 		pop1 = &mpi_dest_r;
474 	else
475 		get_mpi(&mpi_op1, op1);
476 
477 	if (op2 == dest_q)
478 		pop2 = &mpi_dest_q;
479 	else if (op2 == dest_r)
480 		pop2 = &mpi_dest_r;
481 	else if (op2 == op1)
482 		pop2 = pop1;
483 	else
484 		get_mpi(&mpi_op2, op2);
485 
486 	MPI_CHECK(mbedtls_mpi_div_mpi(&mpi_dest_q, &mpi_dest_r, pop1, pop2));
487 
488 	if (dest_q)
489 		MPI_CHECK(copy_mpi_to_bigint(&mpi_dest_q, dest_q));
490 	if (dest_r)
491 		MPI_CHECK(copy_mpi_to_bigint(&mpi_dest_r, dest_r));
492 	mbedtls_mpi_free(&mpi_dest_q);
493 	mbedtls_mpi_free(&mpi_dest_r);
494 	if (pop1 == &mpi_op1)
495 		mbedtls_mpi_free(&mpi_op1);
496 	if (pop2 == &mpi_op2)
497 		mbedtls_mpi_free(&mpi_op2);
498 }
499 
500 void TEE_BigIntMod(TEE_BigInt *dest, const TEE_BigInt *op, const TEE_BigInt *n)
501 {
502 	if (TEE_BigIntCmpS32(n, 2) < 0)
503 		API_PANIC("Modulus is too short");
504 
505 	bigint_binary(dest, op, n, mbedtls_mpi_mod_mpi);
506 }
507 
508 void TEE_BigIntAddMod(TEE_BigInt *dest, const TEE_BigInt *op1,
509 		      const TEE_BigInt *op2, const TEE_BigInt *n)
510 {
511 	bigint_binary_mod(dest, op1, op2, n, mbedtls_mpi_add_mpi);
512 }
513 
514 void TEE_BigIntSubMod(TEE_BigInt *dest, const TEE_BigInt *op1,
515 		      const TEE_BigInt *op2, const TEE_BigInt *n)
516 {
517 	bigint_binary_mod(dest, op1, op2, n, mbedtls_mpi_sub_mpi);
518 }
519 
520 void TEE_BigIntMulMod(TEE_BigInt *dest, const TEE_BigInt *op1,
521 		      const TEE_BigInt *op2, const TEE_BigInt *n)
522 {
523 	bigint_binary_mod(dest, op1, op2, n, mbedtls_mpi_mul_mpi);
524 }
525 
526 void TEE_BigIntSquareMod(TEE_BigInt *dest, const TEE_BigInt *op,
527 			 const TEE_BigInt *n)
528 {
529 	TEE_BigIntMulMod(dest, op, op, n);
530 }
531 
532 void TEE_BigIntInvMod(TEE_BigInt *dest, const TEE_BigInt *op,
533 		      const TEE_BigInt *n)
534 {
535 	if (TEE_BigIntCmpS32(n, 2) < 0 || TEE_BigIntCmpS32(op, 0) == 0)
536 		API_PANIC("too small modulus or trying to invert zero");
537 
538 	mbedtls_mpi mpi_dest;
539 	mbedtls_mpi mpi_op;
540 	mbedtls_mpi mpi_n;
541 	mbedtls_mpi *pop = &mpi_op;
542 
543 	get_mpi(&mpi_dest, dest);
544 	get_mpi(&mpi_n, n);
545 
546 	if (op == dest)
547 		pop = &mpi_dest;
548 	else
549 		get_mpi(&mpi_op, op);
550 
551 	MPI_CHECK(mbedtls_mpi_inv_mod(&mpi_dest, pop, &mpi_n));
552 
553 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
554 	mbedtls_mpi_free(&mpi_dest);
555 	mbedtls_mpi_free(&mpi_n);
556 	if (pop == &mpi_op)
557 		mbedtls_mpi_free(&mpi_op);
558 }
559 
560 bool TEE_BigIntRelativePrime(const TEE_BigInt *op1, const TEE_BigInt *op2)
561 {
562 	bool rc;
563 	mbedtls_mpi mpi_op1;
564 	mbedtls_mpi mpi_op2;
565 	mbedtls_mpi *pop2 = &mpi_op2;
566 	mbedtls_mpi gcd;
567 
568 	get_mpi(&mpi_op1, op1);
569 
570 	if (op2 == op1)
571 		pop2 = &mpi_op1;
572 	else
573 		get_mpi(&mpi_op2, op2);
574 
575 	get_mpi(&gcd, NULL);
576 
577 	MPI_CHECK(mbedtls_mpi_gcd(&gcd, &mpi_op1, &mpi_op2));
578 
579 	rc = !mbedtls_mpi_cmp_int(&gcd, 1);
580 
581 	mbedtls_mpi_free(&gcd);
582 	mbedtls_mpi_free(&mpi_op1);
583 	if (pop2 == &mpi_op2)
584 		mbedtls_mpi_free(&mpi_op2);
585 
586 	return rc;
587 }
588 
589 static bool mpi_is_odd(mbedtls_mpi *x)
590 {
591 	return mbedtls_mpi_get_bit(x, 0);
592 }
593 
594 static bool mpi_is_even(mbedtls_mpi *x)
595 {
596 	return !mpi_is_odd(x);
597 }
598 
599 /*
600  * Based on libmpa implementation __mpa_egcd(), modified to work with MPI
601  * instead.
602  */
603 static void mpi_egcd(mbedtls_mpi *gcd, mbedtls_mpi *a, mbedtls_mpi *b,
604 		     mbedtls_mpi *x_in, mbedtls_mpi *y_in)
605 {
606 	mbedtls_mpi_uint k;
607 	mbedtls_mpi A;
608 	mbedtls_mpi B;
609 	mbedtls_mpi C;
610 	mbedtls_mpi D;
611 	mbedtls_mpi x;
612 	mbedtls_mpi y;
613 	mbedtls_mpi u;
614 
615 	get_mpi(&A, NULL);
616 	get_mpi(&B, NULL);
617 	get_mpi(&C, NULL);
618 	get_mpi(&D, NULL);
619 	get_mpi(&x, NULL);
620 	get_mpi(&y, NULL);
621 	get_mpi(&u, NULL);
622 
623 	/* have y < x from assumption */
624 	if (!mbedtls_mpi_cmp_int(y_in, 0)) {
625 		MPI_CHECK(mbedtls_mpi_lset(a, 1));
626 		MPI_CHECK(mbedtls_mpi_lset(b, 0));
627 		MPI_CHECK(mbedtls_mpi_copy(gcd, x_in));
628 		goto out;
629 	}
630 
631 	MPI_CHECK(mbedtls_mpi_copy(&x, x_in));
632 	MPI_CHECK(mbedtls_mpi_copy(&y, y_in));
633 
634 	k = 0;
635 	while (mpi_is_even(&x) && mpi_is_even(&y)) {
636 		k++;
637 		MPI_CHECK(mbedtls_mpi_shift_r(&x, 1));
638 		MPI_CHECK(mbedtls_mpi_shift_r(&y, 1));
639 	}
640 
641 	MPI_CHECK(mbedtls_mpi_copy(&u, &x));
642 	MPI_CHECK(mbedtls_mpi_copy(gcd, &y));
643 	MPI_CHECK(mbedtls_mpi_lset(&A, 1));
644 	MPI_CHECK(mbedtls_mpi_lset(&B, 0));
645 	MPI_CHECK(mbedtls_mpi_lset(&C, 0));
646 	MPI_CHECK(mbedtls_mpi_lset(&D, 1));
647 
648 	while (mbedtls_mpi_cmp_int(&u, 0)) {
649 		while (mpi_is_even(&u)) {
650 			MPI_CHECK(mbedtls_mpi_shift_r(&u, 1));
651 			if (mpi_is_odd(&A) || mpi_is_odd(&B)) {
652 				MPI_CHECK(mbedtls_mpi_add_mpi(&A, &A, &y));
653 				MPI_CHECK(mbedtls_mpi_sub_mpi(&B, &B, &x));
654 			}
655 			MPI_CHECK(mbedtls_mpi_shift_r(&A, 1));
656 			MPI_CHECK(mbedtls_mpi_shift_r(&B, 1));
657 		}
658 
659 		while (mpi_is_even(gcd)) {
660 			MPI_CHECK(mbedtls_mpi_shift_r(gcd, 1));
661 			if (mpi_is_odd(&C) || mpi_is_odd(&D)) {
662 				MPI_CHECK(mbedtls_mpi_add_mpi(&C, &C, &y));
663 				MPI_CHECK(mbedtls_mpi_sub_mpi(&D, &D, &x));
664 			}
665 			MPI_CHECK(mbedtls_mpi_shift_r(&C, 1));
666 			MPI_CHECK(mbedtls_mpi_shift_r(&D, 1));
667 
668 		}
669 
670 		if (mbedtls_mpi_cmp_mpi(&u, gcd) >= 0) {
671 			MPI_CHECK(mbedtls_mpi_sub_mpi(&u, &u, gcd));
672 			MPI_CHECK(mbedtls_mpi_sub_mpi(&A, &A, &C));
673 			MPI_CHECK(mbedtls_mpi_sub_mpi(&B, &B, &D));
674 		} else {
675 			MPI_CHECK(mbedtls_mpi_sub_mpi(gcd, gcd, &u));
676 			MPI_CHECK(mbedtls_mpi_sub_mpi(&C, &C, &A));
677 			MPI_CHECK(mbedtls_mpi_sub_mpi(&D, &D, &B));
678 		}
679 	}
680 
681 	MPI_CHECK(mbedtls_mpi_copy(a, &C));
682 	MPI_CHECK(mbedtls_mpi_copy(b, &D));
683 	MPI_CHECK(mbedtls_mpi_shift_l(gcd, k));
684 
685 out:
686 	mbedtls_mpi_free(&A);
687 	mbedtls_mpi_free(&B);
688 	mbedtls_mpi_free(&C);
689 	mbedtls_mpi_free(&D);
690 	mbedtls_mpi_free(&x);
691 	mbedtls_mpi_free(&y);
692 	mbedtls_mpi_free(&u);
693 }
694 
695 void TEE_BigIntComputeExtendedGcd(TEE_BigInt *gcd, TEE_BigInt *u,
696 				  TEE_BigInt *v, const TEE_BigInt *op1,
697 				  const TEE_BigInt *op2)
698 {
699 	mbedtls_mpi mpi_gcd_res;
700 	mbedtls_mpi mpi_op1;
701 	mbedtls_mpi mpi_op2;
702 	mbedtls_mpi *pop2 = &mpi_op2;
703 
704 	get_mpi(&mpi_gcd_res, gcd);
705 	get_mpi(&mpi_op1, op1);
706 
707 	if (op2 == op1)
708 		pop2 = &mpi_op1;
709 	else
710 		get_mpi(&mpi_op2, op2);
711 
712 	if (!u && !v) {
713 		if (gcd)
714 			MPI_CHECK(mbedtls_mpi_gcd(&mpi_gcd_res, &mpi_op1,
715 						  pop2));
716 	} else {
717 		mbedtls_mpi mpi_u;
718 		mbedtls_mpi mpi_v;
719 		int8_t s1 = mpi_op1.s;
720 		int8_t s2 = pop2->s;
721 		int cmp;
722 
723 		mpi_op1.s = 1;
724 		pop2->s = 1;
725 
726 		get_mpi(&mpi_u, u);
727 		get_mpi(&mpi_v, v);
728 
729 		cmp = mbedtls_mpi_cmp_abs(&mpi_op1, pop2);
730 		if (cmp == 0) {
731 			MPI_CHECK(mbedtls_mpi_copy(&mpi_gcd_res, &mpi_op1));
732 			MPI_CHECK(mbedtls_mpi_lset(&mpi_u, 1));
733 			MPI_CHECK(mbedtls_mpi_lset(&mpi_v, 0));
734 		} else if (cmp > 0) {
735 			mpi_egcd(&mpi_gcd_res, &mpi_u, &mpi_v, &mpi_op1, pop2);
736 		} else {
737 			mpi_egcd(&mpi_gcd_res, &mpi_v, &mpi_u, pop2, &mpi_op1);
738 		}
739 
740 		mpi_u.s *= s1;
741 		mpi_v.s *= s2;
742 
743 		MPI_CHECK(copy_mpi_to_bigint(&mpi_u, u));
744 		MPI_CHECK(copy_mpi_to_bigint(&mpi_v, v));
745 		mbedtls_mpi_free(&mpi_u);
746 		mbedtls_mpi_free(&mpi_v);
747 	}
748 
749 	MPI_CHECK(copy_mpi_to_bigint(&mpi_gcd_res, gcd));
750 	mbedtls_mpi_free(&mpi_gcd_res);
751 	mbedtls_mpi_free(&mpi_op1);
752 	if (pop2 == &mpi_op2)
753 		mbedtls_mpi_free(&mpi_op2);
754 }
755 
756 static int rng_read(void *ignored __unused, unsigned char *buf, size_t blen)
757 {
758 	if (_utee_cryp_random_number_generate(buf, blen))
759 		return MBEDTLS_ERR_MPI_FILE_IO_ERROR;
760 	return 0;
761 }
762 
763 int32_t TEE_BigIntIsProbablePrime(const TEE_BigInt *op,
764 				  uint32_t confidenceLevel __unused)
765 {
766 	int rc;
767 	mbedtls_mpi mpi_op;
768 
769 	get_mpi(&mpi_op, op);
770 
771 	rc = mbedtls_mpi_is_prime(&mpi_op, rng_read, NULL);
772 
773 	mbedtls_mpi_free(&mpi_op);
774 
775 	if (rc)
776 		return 0;
777 
778 	return 1;
779 }
780 
781 /*
782  * Not so fast FMM implementation based on the normal big int functions.
783  *
784  * Note that these functions (along with all the other functions in this
785  * file) only are used directly by the TA doing bigint arithmetics on its
786  * own. Performance of RSA operations in TEE Internal API are not affected
787  * by this.
788  */
789 void TEE_BigIntInitFMM(TEE_BigIntFMM *bigIntFMM, uint32_t len)
790 {
791 	TEE_BigIntInit(bigIntFMM, len);
792 }
793 
794 void TEE_BigIntInitFMMContext(TEE_BigIntFMMContext *context __unused,
795 			      uint32_t len __unused,
796 			      const TEE_BigInt *modulus __unused)
797 {
798 }
799 
800 uint32_t TEE_BigIntFMMSizeInU32(uint32_t modulusSizeInBits)
801 {
802 	return TEE_BigIntSizeInU32(modulusSizeInBits);
803 }
804 
805 uint32_t TEE_BigIntFMMContextSizeInU32(uint32_t modulusSizeInBits __unused)
806 {
807 	/* Return something larger than 0 to keep malloc() and friends happy */
808 	return 1;
809 }
810 
811 void TEE_BigIntConvertToFMM(TEE_BigIntFMM *dest, const TEE_BigInt *src,
812 			    const TEE_BigInt *n,
813 			    const TEE_BigIntFMMContext *context __unused)
814 {
815 	TEE_BigIntMod(dest, src, n);
816 }
817 
818 void TEE_BigIntConvertFromFMM(TEE_BigInt *dest, const TEE_BigIntFMM *src,
819 			      const TEE_BigInt *n __unused,
820 			      const TEE_BigIntFMMContext *context __unused)
821 {
822 	mbedtls_mpi mpi_dst;
823 	mbedtls_mpi mpi_src;
824 
825 	get_mpi(&mpi_dst, dest);
826 	get_mpi(&mpi_src, src);
827 
828 	MPI_CHECK(mbedtls_mpi_copy(&mpi_dst, &mpi_src));
829 
830 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dst, dest));
831 	mbedtls_mpi_free(&mpi_dst);
832 	mbedtls_mpi_free(&mpi_src);
833 }
834 
835 void TEE_BigIntComputeFMM(TEE_BigIntFMM *dest, const TEE_BigIntFMM *op1,
836 			  const TEE_BigIntFMM *op2, const TEE_BigInt *n,
837 			  const TEE_BigIntFMMContext *context __unused)
838 {
839 	mbedtls_mpi mpi_dst;
840 	mbedtls_mpi mpi_op1;
841 	mbedtls_mpi mpi_op2;
842 	mbedtls_mpi mpi_n;
843 	mbedtls_mpi mpi_t;
844 
845 	get_mpi(&mpi_dst, dest);
846 	get_mpi(&mpi_op1, op1);
847 	get_mpi(&mpi_op2, op2);
848 	get_mpi(&mpi_n, n);
849 	get_mpi(&mpi_t, NULL);
850 
851 	MPI_CHECK(mbedtls_mpi_mul_mpi(&mpi_t, &mpi_op1, &mpi_op2));
852 	MPI_CHECK(mbedtls_mpi_mod_mpi(&mpi_dst, &mpi_t, &mpi_n));
853 
854 	mbedtls_mpi_free(&mpi_t);
855 	mbedtls_mpi_free(&mpi_n);
856 	mbedtls_mpi_free(&mpi_op2);
857 	mbedtls_mpi_free(&mpi_op1);
858 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dst, dest));
859 	mbedtls_mpi_free(&mpi_dst);
860 }
861