xref: /optee_os/lib/libutee/tee_api_arith_mpi.c (revision a874dbbdde1f95684861301717d07dadb095172f)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (c) 2018, Linaro limited
4  */
5 #include <assert.h>
6 #include <mbedtls/bignum.h>
7 #include <mempool.h>
8 #include <stdio.h>
9 #include <string.h>
10 #include <tee_api.h>
11 #include <tee_arith_internal.h>
12 #include <utee_defines.h>
13 #include <utee_syscalls.h>
14 #include <util.h>
15 
16 #define MPI_MEMPOOL_SIZE	(12 * 1024)
17 
18 static void __noreturn api_panic(const char *func, int line, const char *msg)
19 {
20 	printf("Panic function %s, line %d: %s\n", func, line, msg);
21 	TEE_Panic(0xB16127 /*BIGINT*/);
22 	while (1)
23 		; /* Panic will crash the thread */
24 }
25 
26 #define API_PANIC(x) api_panic(__func__, __LINE__, x)
27 
28 static void __noreturn mpi_panic(const char *func, int line, int rc)
29 {
30 	printf("Panic function %s, line %d, code %d\n", func, line, rc);
31 	TEE_Panic(0xB16127 /*BIGINT*/);
32 	while (1)
33 		; /* Panic will crash the thread */
34 }
35 
36 #define MPI_CHECK(x) do { \
37 		int _rc = (x); \
38 		 \
39 		if (_rc) \
40 			mpi_panic(__func__, __LINE__, _rc); \
41 	} while (0)
42 
43 void _TEE_MathAPI_Init(void)
44 {
45 	static uint8_t data[MPI_MEMPOOL_SIZE] __aligned(MEMPOOL_ALIGN);
46 
47 	mbedtls_mpi_mempool = mempool_alloc_pool(data, sizeof(data), NULL);
48 	if (!mbedtls_mpi_mempool)
49 		API_PANIC("Failed to initialize memory pool");
50 }
51 
52 struct bigint_hdr {
53 	int32_t sign;
54 	uint16_t alloc_size;
55 	uint16_t nblimbs;
56 };
57 
58 #define BIGINT_HDR_SIZE_IN_U32	2
59 
60 static TEE_Result copy_mpi_to_bigint(mbedtls_mpi *mpi, TEE_BigInt *bigInt)
61 {
62 	struct bigint_hdr *hdr = (struct bigint_hdr *)bigInt;
63 	size_t n = mpi->n;
64 
65 	/* Trim of eventual insignificant zeroes */
66 	while (n && !mpi->p[n - 1])
67 		n--;
68 
69 	if (hdr->alloc_size < n)
70 		return TEE_ERROR_OVERFLOW;
71 
72 	hdr->nblimbs = n;
73 	hdr->sign = mpi->s;
74 	memcpy(hdr + 1, mpi->p, mpi->n * sizeof(mbedtls_mpi_uint));
75 
76 	return TEE_SUCCESS;
77 }
78 
79 /*
80  * Initializes a MPI.
81  *
82  * A temporary MPI is allocated and if a bigInt is supplied the MPI is
83  * initialized with the value of the bigInt.
84  */
85 static void get_mpi(mbedtls_mpi *mpi, const TEE_BigInt *bigInt)
86 {
87 	/*
88 	 * The way the GP spec is defining the bignums it's
89 	 * difficult/tricky to do it using 64-bit arithmetics given that
90 	 * we'd need 64-bit alignment of the data as well.
91 	 */
92 	COMPILE_TIME_ASSERT(sizeof(mbedtls_mpi_uint) == sizeof(uint32_t));
93 
94 	/*
95 	 * The struct bigint_hdr is the overhead added to the bigint and
96 	 * is required to take exactly 2 uint32_t.
97 	 */
98 	COMPILE_TIME_ASSERT(sizeof(struct bigint_hdr) ==
99 			    sizeof(uint32_t) * BIGINT_HDR_SIZE_IN_U32);
100 
101 	mbedtls_mpi_init_mempool(mpi);
102 
103 	if (bigInt) {
104 		const struct bigint_hdr *hdr = (struct bigint_hdr *)bigInt;
105 		const mbedtls_mpi_uint *p = (const mbedtls_mpi_uint *)(hdr + 1);
106 		size_t n = hdr->nblimbs;
107 
108 		/* Trim of eventual insignificant zeroes */
109 		while (n && !p[n - 1])
110 			n--;
111 
112 		MPI_CHECK(mbedtls_mpi_grow(mpi, n));
113 		mpi->s = hdr->sign;
114 		memcpy(mpi->p, p, n * sizeof(mbedtls_mpi_uint));
115 	}
116 }
117 
118 void TEE_BigIntInit(TEE_BigInt *bigInt, uint32_t len)
119 {
120 	memset(bigInt, 0, TEE_BigIntSizeInU32(len) * sizeof(uint32_t));
121 
122 	struct bigint_hdr *hdr = (struct bigint_hdr *)bigInt;
123 
124 
125 	hdr->sign = 1;
126 	if ((len - BIGINT_HDR_SIZE_IN_U32) > MBEDTLS_MPI_MAX_LIMBS)
127 		API_PANIC("Too large bigint");
128 	hdr->alloc_size = len - BIGINT_HDR_SIZE_IN_U32;
129 }
130 
131 TEE_Result TEE_BigIntConvertFromOctetString(TEE_BigInt *dest,
132 					    const uint8_t *buffer,
133 					    uint32_t bufferLen, int32_t sign)
134 {
135 	TEE_Result res;
136 	mbedtls_mpi mpi_dest;
137 
138 	get_mpi(&mpi_dest, NULL);
139 
140 	if (mbedtls_mpi_read_binary(&mpi_dest,  buffer, bufferLen))
141 		res = TEE_ERROR_OVERFLOW;
142 	else
143 		res = TEE_SUCCESS;
144 
145 	if (sign < 0)
146 		mpi_dest.s = -1;
147 
148 	if (!res)
149 		res = copy_mpi_to_bigint(&mpi_dest, dest);
150 
151 	mbedtls_mpi_free(&mpi_dest);
152 
153 	return res;
154 }
155 
156 TEE_Result TEE_BigIntConvertToOctetString(uint8_t *buffer, uint32_t *bufferLen,
157 					  const TEE_BigInt *bigInt)
158 {
159 	TEE_Result res = TEE_SUCCESS;
160 	mbedtls_mpi mpi;
161 	size_t sz;
162 
163 	get_mpi(&mpi, bigInt);
164 
165 	sz = mbedtls_mpi_size(&mpi);
166 	if (sz <= *bufferLen)
167 		MPI_CHECK(mbedtls_mpi_write_binary(&mpi, buffer, sz));
168 	else
169 		res = TEE_ERROR_SHORT_BUFFER;
170 
171 	*bufferLen = sz;
172 
173 	mbedtls_mpi_free(&mpi);
174 
175 	return res;
176 }
177 
178 void TEE_BigIntConvertFromS32(TEE_BigInt *dest, int32_t shortVal)
179 {
180 	mbedtls_mpi mpi;
181 
182 	get_mpi(&mpi, dest);
183 
184 	MPI_CHECK(mbedtls_mpi_lset(&mpi, shortVal));
185 
186 	MPI_CHECK(copy_mpi_to_bigint(&mpi, dest));
187 	mbedtls_mpi_free(&mpi);
188 }
189 
190 TEE_Result TEE_BigIntConvertToS32(int32_t *dest, const TEE_BigInt *src)
191 {
192 	TEE_Result res = TEE_SUCCESS;
193 	mbedtls_mpi mpi;
194 	uint32_t v;
195 
196 	get_mpi(&mpi, src);
197 
198 	if (mbedtls_mpi_write_binary(&mpi, (void *)&v, sizeof(v))) {
199 		res = TEE_ERROR_OVERFLOW;
200 		goto out;
201 	}
202 
203 	if (mpi.s > 0) {
204 		if (ADD_OVERFLOW(0, TEE_U32_FROM_BIG_ENDIAN(v), dest))
205 			res = TEE_ERROR_OVERFLOW;
206 	} else {
207 		if (SUB_OVERFLOW(0, TEE_U32_FROM_BIG_ENDIAN(v), dest))
208 			res = TEE_ERROR_OVERFLOW;
209 	}
210 
211 out:
212 	mbedtls_mpi_free(&mpi);
213 
214 	return res;
215 }
216 
217 int32_t TEE_BigIntCmp(const TEE_BigInt *op1, const TEE_BigInt *op2)
218 {
219 	mbedtls_mpi mpi1;
220 	mbedtls_mpi mpi2;
221 	int32_t rc;
222 
223 	get_mpi(&mpi1, op1);
224 	get_mpi(&mpi2, op2);
225 
226 	rc = mbedtls_mpi_cmp_mpi(&mpi1, &mpi2);
227 
228 	mbedtls_mpi_free(&mpi1);
229 	mbedtls_mpi_free(&mpi2);
230 
231 	return rc;
232 }
233 
234 int32_t TEE_BigIntCmpS32(const TEE_BigInt *op, int32_t shortVal)
235 {
236 	mbedtls_mpi mpi;
237 	int32_t rc;
238 
239 	get_mpi(&mpi, op);
240 
241 	rc = mbedtls_mpi_cmp_int(&mpi, shortVal);
242 
243 	mbedtls_mpi_free(&mpi);
244 
245 	return rc;
246 }
247 
248 void TEE_BigIntShiftRight(TEE_BigInt *dest, const TEE_BigInt *op, size_t bits)
249 {
250 	mbedtls_mpi mpi_dest;
251 	mbedtls_mpi mpi_op;
252 
253 	get_mpi(&mpi_dest, dest);
254 
255 	if (dest == op) {
256 		MPI_CHECK(mbedtls_mpi_shift_r(&mpi_dest, bits));
257 		goto out;
258 	}
259 
260 	get_mpi(&mpi_op, op);
261 
262 	if (mbedtls_mpi_size(&mpi_dest) >= mbedtls_mpi_size(&mpi_op)) {
263 		MPI_CHECK(mbedtls_mpi_copy(&mpi_dest, &mpi_op));
264 		MPI_CHECK(mbedtls_mpi_shift_r(&mpi_dest, bits));
265 	} else {
266 		mbedtls_mpi mpi_t;
267 
268 		get_mpi(&mpi_t, NULL);
269 
270 		/*
271 		 * We're using a temporary buffer to avoid the corner case
272 		 * where destination is unexpectedly overflowed by up to
273 		 * @bits number of bits.
274 		 */
275 		MPI_CHECK(mbedtls_mpi_copy(&mpi_t, &mpi_op));
276 		MPI_CHECK(mbedtls_mpi_shift_r(&mpi_t, bits));
277 		MPI_CHECK(mbedtls_mpi_copy(&mpi_dest, &mpi_t));
278 
279 		mbedtls_mpi_free(&mpi_t);
280 	}
281 
282 	mbedtls_mpi_free(&mpi_op);
283 
284 out:
285 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
286 	mbedtls_mpi_free(&mpi_dest);
287 }
288 
289 bool TEE_BigIntGetBit(const TEE_BigInt *src, uint32_t bitIndex)
290 {
291 	bool rc;
292 	mbedtls_mpi mpi;
293 
294 	get_mpi(&mpi, src);
295 
296 	rc = mbedtls_mpi_get_bit(&mpi, bitIndex);
297 
298 	mbedtls_mpi_free(&mpi);
299 
300 	return rc;
301 }
302 
303 uint32_t TEE_BigIntGetBitCount(const TEE_BigInt *src)
304 {
305 	uint32_t rc;
306 	mbedtls_mpi mpi;
307 
308 	get_mpi(&mpi, src);
309 
310 	rc = mbedtls_mpi_bitlen(&mpi);
311 
312 	mbedtls_mpi_free(&mpi);
313 
314 	return rc;
315 }
316 
317 static void bigint_binary(TEE_BigInt *dest, const TEE_BigInt *op1,
318 			  const TEE_BigInt *op2,
319 			  int (*func)(mbedtls_mpi *X, const mbedtls_mpi *A,
320 				      const mbedtls_mpi *B))
321 {
322 	mbedtls_mpi mpi_dest;
323 	mbedtls_mpi mpi_op1;
324 	mbedtls_mpi mpi_op2;
325 	mbedtls_mpi *pop1 = &mpi_op1;
326 	mbedtls_mpi *pop2 = &mpi_op2;
327 
328 	get_mpi(&mpi_dest, dest);
329 
330 	if (op1 == dest)
331 		pop1 = &mpi_dest;
332 	else
333 		get_mpi(&mpi_op1, op1);
334 
335 	if (op2 == dest)
336 		pop2 = &mpi_dest;
337 	else if (op2 == op1)
338 		pop2 = pop1;
339 	else
340 		get_mpi(&mpi_op2, op2);
341 
342 	MPI_CHECK(func(&mpi_dest, pop1, pop2));
343 
344 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
345 	mbedtls_mpi_free(&mpi_dest);
346 	if (pop1 == &mpi_op1)
347 		mbedtls_mpi_free(&mpi_op1);
348 	if (pop2 == &mpi_op2)
349 		mbedtls_mpi_free(&mpi_op2);
350 }
351 
352 static void bigint_binary_mod(TEE_BigInt *dest, const TEE_BigInt *op1,
353 			      const TEE_BigInt *op2, const TEE_BigInt *n,
354 			      int (*func)(mbedtls_mpi *X, const mbedtls_mpi *A,
355 					  const mbedtls_mpi *B))
356 {
357 	if (TEE_BigIntCmpS32(n, 2) < 0)
358 		API_PANIC("Modulus is too short");
359 
360 	mbedtls_mpi mpi_dest;
361 	mbedtls_mpi mpi_op1;
362 	mbedtls_mpi mpi_op2;
363 	mbedtls_mpi mpi_n;
364 	mbedtls_mpi *pop1 = &mpi_op1;
365 	mbedtls_mpi *pop2 = &mpi_op2;
366 	mbedtls_mpi mpi_t;
367 
368 	get_mpi(&mpi_dest, dest);
369 	get_mpi(&mpi_n, n);
370 
371 	if (op1 == dest)
372 		pop1 = &mpi_dest;
373 	else
374 		get_mpi(&mpi_op1, op1);
375 
376 	if (op2 == dest)
377 		pop2 = &mpi_dest;
378 	else if (op2 == op1)
379 		pop2 = pop1;
380 	else
381 		get_mpi(&mpi_op2, op2);
382 
383 	get_mpi(&mpi_t, NULL);
384 
385 	MPI_CHECK(func(&mpi_t, pop1, pop2));
386 	MPI_CHECK(mbedtls_mpi_mod_mpi(&mpi_dest, &mpi_t, &mpi_n));
387 
388 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
389 	mbedtls_mpi_free(&mpi_dest);
390 	if (pop1 == &mpi_op1)
391 		mbedtls_mpi_free(&mpi_op1);
392 	if (pop2 == &mpi_op2)
393 		mbedtls_mpi_free(&mpi_op2);
394 	mbedtls_mpi_free(&mpi_t);
395 }
396 
397 void TEE_BigIntAdd(TEE_BigInt *dest, const TEE_BigInt *op1,
398 		   const TEE_BigInt *op2)
399 {
400 	bigint_binary(dest, op1, op2, mbedtls_mpi_add_mpi);
401 }
402 
403 void TEE_BigIntSub(TEE_BigInt *dest, const TEE_BigInt *op1,
404 		   const TEE_BigInt *op2)
405 {
406 	bigint_binary(dest, op1, op2, mbedtls_mpi_sub_mpi);
407 }
408 
409 void TEE_BigIntNeg(TEE_BigInt *dest, const TEE_BigInt *src)
410 {
411 	mbedtls_mpi mpi_dest;
412 
413 	get_mpi(&mpi_dest, dest);
414 
415 	if (dest != src) {
416 		mbedtls_mpi mpi_src;
417 
418 		get_mpi(&mpi_src, src);
419 
420 		MPI_CHECK(mbedtls_mpi_copy(&mpi_dest, &mpi_src));
421 
422 		mbedtls_mpi_free(&mpi_src);
423 	}
424 
425 	mpi_dest.s *= -1;
426 
427 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
428 	mbedtls_mpi_free(&mpi_dest);
429 }
430 
431 void TEE_BigIntMul(TEE_BigInt *dest, const TEE_BigInt *op1,
432 		   const TEE_BigInt *op2)
433 {
434 	size_t bs1 = TEE_BigIntGetBitCount(op1);
435 	size_t bs2 = TEE_BigIntGetBitCount(op2);
436 	size_t s = TEE_BigIntSizeInU32(bs1) + TEE_BigIntSizeInU32(bs2);
437 	TEE_BigInt zero[TEE_BigIntSizeInU32(1)] = { 0 };
438 	TEE_BigInt *tmp = NULL;
439 
440 	tmp = mempool_alloc(mbedtls_mpi_mempool, sizeof(uint32_t) * s);
441 	if (!tmp)
442 		TEE_Panic(TEE_ERROR_OUT_OF_MEMORY);
443 
444 	TEE_BigIntInit(tmp, s);
445 	TEE_BigIntInit(zero, TEE_BigIntSizeInU32(1));
446 
447 	bigint_binary(tmp, op1, op2, mbedtls_mpi_mul_mpi);
448 
449 	TEE_BigIntAdd(dest, tmp, zero);
450 
451 	mempool_free(mbedtls_mpi_mempool, tmp);
452 }
453 
454 void TEE_BigIntSquare(TEE_BigInt *dest, const TEE_BigInt *op)
455 {
456 	TEE_BigIntMul(dest, op, op);
457 }
458 
459 void TEE_BigIntDiv(TEE_BigInt *dest_q, TEE_BigInt *dest_r,
460 		   const TEE_BigInt *op1, const TEE_BigInt *op2)
461 {
462 	mbedtls_mpi mpi_dest_q;
463 	mbedtls_mpi mpi_dest_r;
464 	mbedtls_mpi mpi_op1;
465 	mbedtls_mpi mpi_op2;
466 	mbedtls_mpi *pop1 = &mpi_op1;
467 	mbedtls_mpi *pop2 = &mpi_op2;
468 
469 	get_mpi(&mpi_dest_q, dest_q);
470 	get_mpi(&mpi_dest_r, dest_r);
471 
472 	if (op1 == dest_q)
473 		pop1 = &mpi_dest_q;
474 	else if (op1 == dest_r)
475 		pop1 = &mpi_dest_r;
476 	else
477 		get_mpi(&mpi_op1, op1);
478 
479 	if (op2 == dest_q)
480 		pop2 = &mpi_dest_q;
481 	else if (op2 == dest_r)
482 		pop2 = &mpi_dest_r;
483 	else if (op2 == op1)
484 		pop2 = pop1;
485 	else
486 		get_mpi(&mpi_op2, op2);
487 
488 	MPI_CHECK(mbedtls_mpi_div_mpi(&mpi_dest_q, &mpi_dest_r, pop1, pop2));
489 
490 	if (dest_q)
491 		MPI_CHECK(copy_mpi_to_bigint(&mpi_dest_q, dest_q));
492 	if (dest_r)
493 		MPI_CHECK(copy_mpi_to_bigint(&mpi_dest_r, dest_r));
494 	mbedtls_mpi_free(&mpi_dest_q);
495 	mbedtls_mpi_free(&mpi_dest_r);
496 	if (pop1 == &mpi_op1)
497 		mbedtls_mpi_free(&mpi_op1);
498 	if (pop2 == &mpi_op2)
499 		mbedtls_mpi_free(&mpi_op2);
500 }
501 
502 void TEE_BigIntMod(TEE_BigInt *dest, const TEE_BigInt *op, const TEE_BigInt *n)
503 {
504 	if (TEE_BigIntCmpS32(n, 2) < 0)
505 		API_PANIC("Modulus is too short");
506 
507 	bigint_binary(dest, op, n, mbedtls_mpi_mod_mpi);
508 }
509 
510 void TEE_BigIntAddMod(TEE_BigInt *dest, const TEE_BigInt *op1,
511 		      const TEE_BigInt *op2, const TEE_BigInt *n)
512 {
513 	bigint_binary_mod(dest, op1, op2, n, mbedtls_mpi_add_mpi);
514 }
515 
516 void TEE_BigIntSubMod(TEE_BigInt *dest, const TEE_BigInt *op1,
517 		      const TEE_BigInt *op2, const TEE_BigInt *n)
518 {
519 	bigint_binary_mod(dest, op1, op2, n, mbedtls_mpi_sub_mpi);
520 }
521 
522 void TEE_BigIntMulMod(TEE_BigInt *dest, const TEE_BigInt *op1,
523 		      const TEE_BigInt *op2, const TEE_BigInt *n)
524 {
525 	bigint_binary_mod(dest, op1, op2, n, mbedtls_mpi_mul_mpi);
526 }
527 
528 void TEE_BigIntSquareMod(TEE_BigInt *dest, const TEE_BigInt *op,
529 			 const TEE_BigInt *n)
530 {
531 	TEE_BigIntMulMod(dest, op, op, n);
532 }
533 
534 void TEE_BigIntInvMod(TEE_BigInt *dest, const TEE_BigInt *op,
535 		      const TEE_BigInt *n)
536 {
537 	if (TEE_BigIntCmpS32(n, 2) < 0 || TEE_BigIntCmpS32(op, 0) == 0)
538 		API_PANIC("too small modulus or trying to invert zero");
539 
540 	mbedtls_mpi mpi_dest;
541 	mbedtls_mpi mpi_op;
542 	mbedtls_mpi mpi_n;
543 	mbedtls_mpi *pop = &mpi_op;
544 
545 	get_mpi(&mpi_dest, dest);
546 	get_mpi(&mpi_n, n);
547 
548 	if (op == dest)
549 		pop = &mpi_dest;
550 	else
551 		get_mpi(&mpi_op, op);
552 
553 	MPI_CHECK(mbedtls_mpi_inv_mod(&mpi_dest, pop, &mpi_n));
554 
555 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dest, dest));
556 	mbedtls_mpi_free(&mpi_dest);
557 	mbedtls_mpi_free(&mpi_n);
558 	if (pop == &mpi_op)
559 		mbedtls_mpi_free(&mpi_op);
560 }
561 
562 bool TEE_BigIntRelativePrime(const TEE_BigInt *op1, const TEE_BigInt *op2)
563 {
564 	bool rc;
565 	mbedtls_mpi mpi_op1;
566 	mbedtls_mpi mpi_op2;
567 	mbedtls_mpi *pop2 = &mpi_op2;
568 	mbedtls_mpi gcd;
569 
570 	get_mpi(&mpi_op1, op1);
571 
572 	if (op2 == op1)
573 		pop2 = &mpi_op1;
574 	else
575 		get_mpi(&mpi_op2, op2);
576 
577 	get_mpi(&gcd, NULL);
578 
579 	MPI_CHECK(mbedtls_mpi_gcd(&gcd, &mpi_op1, &mpi_op2));
580 
581 	rc = !mbedtls_mpi_cmp_int(&gcd, 1);
582 
583 	mbedtls_mpi_free(&gcd);
584 	mbedtls_mpi_free(&mpi_op1);
585 	if (pop2 == &mpi_op2)
586 		mbedtls_mpi_free(&mpi_op2);
587 
588 	return rc;
589 }
590 
591 static bool mpi_is_odd(mbedtls_mpi *x)
592 {
593 	return mbedtls_mpi_get_bit(x, 0);
594 }
595 
596 static bool mpi_is_even(mbedtls_mpi *x)
597 {
598 	return !mpi_is_odd(x);
599 }
600 
601 /*
602  * Based on libmpa implementation __mpa_egcd(), modified to work with MPI
603  * instead.
604  */
605 static void mpi_egcd(mbedtls_mpi *gcd, mbedtls_mpi *a, mbedtls_mpi *b,
606 		     mbedtls_mpi *x_in, mbedtls_mpi *y_in)
607 {
608 	mbedtls_mpi_uint k;
609 	mbedtls_mpi A;
610 	mbedtls_mpi B;
611 	mbedtls_mpi C;
612 	mbedtls_mpi D;
613 	mbedtls_mpi x;
614 	mbedtls_mpi y;
615 	mbedtls_mpi u;
616 
617 	get_mpi(&A, NULL);
618 	get_mpi(&B, NULL);
619 	get_mpi(&C, NULL);
620 	get_mpi(&D, NULL);
621 	get_mpi(&x, NULL);
622 	get_mpi(&y, NULL);
623 	get_mpi(&u, NULL);
624 
625 	/* have y < x from assumption */
626 	if (!mbedtls_mpi_cmp_int(y_in, 0)) {
627 		MPI_CHECK(mbedtls_mpi_lset(a, 1));
628 		MPI_CHECK(mbedtls_mpi_lset(b, 0));
629 		MPI_CHECK(mbedtls_mpi_copy(gcd, x_in));
630 		goto out;
631 	}
632 
633 	MPI_CHECK(mbedtls_mpi_copy(&x, x_in));
634 	MPI_CHECK(mbedtls_mpi_copy(&y, y_in));
635 
636 	k = 0;
637 	while (mpi_is_even(&x) && mpi_is_even(&y)) {
638 		k++;
639 		MPI_CHECK(mbedtls_mpi_shift_r(&x, 1));
640 		MPI_CHECK(mbedtls_mpi_shift_r(&y, 1));
641 	}
642 
643 	MPI_CHECK(mbedtls_mpi_copy(&u, &x));
644 	MPI_CHECK(mbedtls_mpi_copy(gcd, &y));
645 	MPI_CHECK(mbedtls_mpi_lset(&A, 1));
646 	MPI_CHECK(mbedtls_mpi_lset(&B, 0));
647 	MPI_CHECK(mbedtls_mpi_lset(&C, 0));
648 	MPI_CHECK(mbedtls_mpi_lset(&D, 1));
649 
650 	while (mbedtls_mpi_cmp_int(&u, 0)) {
651 		while (mpi_is_even(&u)) {
652 			MPI_CHECK(mbedtls_mpi_shift_r(&u, 1));
653 			if (mpi_is_odd(&A) || mpi_is_odd(&B)) {
654 				MPI_CHECK(mbedtls_mpi_add_mpi(&A, &A, &y));
655 				MPI_CHECK(mbedtls_mpi_sub_mpi(&B, &B, &x));
656 			}
657 			MPI_CHECK(mbedtls_mpi_shift_r(&A, 1));
658 			MPI_CHECK(mbedtls_mpi_shift_r(&B, 1));
659 		}
660 
661 		while (mpi_is_even(gcd)) {
662 			MPI_CHECK(mbedtls_mpi_shift_r(gcd, 1));
663 			if (mpi_is_odd(&C) || mpi_is_odd(&D)) {
664 				MPI_CHECK(mbedtls_mpi_add_mpi(&C, &C, &y));
665 				MPI_CHECK(mbedtls_mpi_sub_mpi(&D, &D, &x));
666 			}
667 			MPI_CHECK(mbedtls_mpi_shift_r(&C, 1));
668 			MPI_CHECK(mbedtls_mpi_shift_r(&D, 1));
669 
670 		}
671 
672 		if (mbedtls_mpi_cmp_mpi(&u, gcd) >= 0) {
673 			MPI_CHECK(mbedtls_mpi_sub_mpi(&u, &u, gcd));
674 			MPI_CHECK(mbedtls_mpi_sub_mpi(&A, &A, &C));
675 			MPI_CHECK(mbedtls_mpi_sub_mpi(&B, &B, &D));
676 		} else {
677 			MPI_CHECK(mbedtls_mpi_sub_mpi(gcd, gcd, &u));
678 			MPI_CHECK(mbedtls_mpi_sub_mpi(&C, &C, &A));
679 			MPI_CHECK(mbedtls_mpi_sub_mpi(&D, &D, &B));
680 		}
681 	}
682 
683 	MPI_CHECK(mbedtls_mpi_copy(a, &C));
684 	MPI_CHECK(mbedtls_mpi_copy(b, &D));
685 	MPI_CHECK(mbedtls_mpi_shift_l(gcd, k));
686 
687 out:
688 	mbedtls_mpi_free(&A);
689 	mbedtls_mpi_free(&B);
690 	mbedtls_mpi_free(&C);
691 	mbedtls_mpi_free(&D);
692 	mbedtls_mpi_free(&x);
693 	mbedtls_mpi_free(&y);
694 	mbedtls_mpi_free(&u);
695 }
696 
697 void TEE_BigIntComputeExtendedGcd(TEE_BigInt *gcd, TEE_BigInt *u,
698 				  TEE_BigInt *v, const TEE_BigInt *op1,
699 				  const TEE_BigInt *op2)
700 {
701 	mbedtls_mpi mpi_gcd_res;
702 	mbedtls_mpi mpi_op1;
703 	mbedtls_mpi mpi_op2;
704 	mbedtls_mpi *pop2 = &mpi_op2;
705 
706 	get_mpi(&mpi_gcd_res, gcd);
707 	get_mpi(&mpi_op1, op1);
708 
709 	if (op2 == op1)
710 		pop2 = &mpi_op1;
711 	else
712 		get_mpi(&mpi_op2, op2);
713 
714 	if (!u && !v) {
715 		if (gcd)
716 			MPI_CHECK(mbedtls_mpi_gcd(&mpi_gcd_res, &mpi_op1,
717 						  pop2));
718 	} else {
719 		mbedtls_mpi mpi_u;
720 		mbedtls_mpi mpi_v;
721 		int8_t s1 = mpi_op1.s;
722 		int8_t s2 = pop2->s;
723 		int cmp;
724 
725 		mpi_op1.s = 1;
726 		pop2->s = 1;
727 
728 		get_mpi(&mpi_u, u);
729 		get_mpi(&mpi_v, v);
730 
731 		cmp = mbedtls_mpi_cmp_abs(&mpi_op1, pop2);
732 		if (cmp == 0) {
733 			MPI_CHECK(mbedtls_mpi_copy(&mpi_gcd_res, &mpi_op1));
734 			MPI_CHECK(mbedtls_mpi_lset(&mpi_u, 1));
735 			MPI_CHECK(mbedtls_mpi_lset(&mpi_v, 0));
736 		} else if (cmp > 0) {
737 			mpi_egcd(&mpi_gcd_res, &mpi_u, &mpi_v, &mpi_op1, pop2);
738 		} else {
739 			mpi_egcd(&mpi_gcd_res, &mpi_v, &mpi_u, pop2, &mpi_op1);
740 		}
741 
742 		mpi_u.s *= s1;
743 		mpi_v.s *= s2;
744 
745 		MPI_CHECK(copy_mpi_to_bigint(&mpi_u, u));
746 		MPI_CHECK(copy_mpi_to_bigint(&mpi_v, v));
747 		mbedtls_mpi_free(&mpi_u);
748 		mbedtls_mpi_free(&mpi_v);
749 	}
750 
751 	MPI_CHECK(copy_mpi_to_bigint(&mpi_gcd_res, gcd));
752 	mbedtls_mpi_free(&mpi_gcd_res);
753 	mbedtls_mpi_free(&mpi_op1);
754 	if (pop2 == &mpi_op2)
755 		mbedtls_mpi_free(&mpi_op2);
756 }
757 
758 static int rng_read(void *ignored __unused, unsigned char *buf, size_t blen)
759 {
760 	if (_utee_cryp_random_number_generate(buf, blen))
761 		return MBEDTLS_ERR_MPI_FILE_IO_ERROR;
762 	return 0;
763 }
764 
765 int32_t TEE_BigIntIsProbablePrime(const TEE_BigInt *op,
766 				  uint32_t confidenceLevel __unused)
767 {
768 	int rc;
769 	mbedtls_mpi mpi_op;
770 
771 	get_mpi(&mpi_op, op);
772 
773 	rc = mbedtls_mpi_is_prime(&mpi_op, rng_read, NULL);
774 
775 	mbedtls_mpi_free(&mpi_op);
776 
777 	if (rc)
778 		return 0;
779 
780 	return 1;
781 }
782 
783 /*
784  * Not so fast FMM implementation based on the normal big int functions.
785  *
786  * Note that these functions (along with all the other functions in this
787  * file) only are used directly by the TA doing bigint arithmetics on its
788  * own. Performance of RSA operations in TEE Internal API are not affected
789  * by this.
790  */
791 void TEE_BigIntInitFMM(TEE_BigIntFMM *bigIntFMM, uint32_t len)
792 {
793 	TEE_BigIntInit(bigIntFMM, len);
794 }
795 
796 void TEE_BigIntInitFMMContext(TEE_BigIntFMMContext *context __unused,
797 			      uint32_t len __unused,
798 			      const TEE_BigInt *modulus __unused)
799 {
800 }
801 
802 uint32_t TEE_BigIntFMMSizeInU32(uint32_t modulusSizeInBits)
803 {
804 	return TEE_BigIntSizeInU32(modulusSizeInBits);
805 }
806 
807 uint32_t TEE_BigIntFMMContextSizeInU32(uint32_t modulusSizeInBits __unused)
808 {
809 	/* Return something larger than 0 to keep malloc() and friends happy */
810 	return 1;
811 }
812 
813 void TEE_BigIntConvertToFMM(TEE_BigIntFMM *dest, const TEE_BigInt *src,
814 			    const TEE_BigInt *n,
815 			    const TEE_BigIntFMMContext *context __unused)
816 {
817 	TEE_BigIntMod(dest, src, n);
818 }
819 
820 void TEE_BigIntConvertFromFMM(TEE_BigInt *dest, const TEE_BigIntFMM *src,
821 			      const TEE_BigInt *n __unused,
822 			      const TEE_BigIntFMMContext *context __unused)
823 {
824 	mbedtls_mpi mpi_dst;
825 	mbedtls_mpi mpi_src;
826 
827 	get_mpi(&mpi_dst, dest);
828 	get_mpi(&mpi_src, src);
829 
830 	MPI_CHECK(mbedtls_mpi_copy(&mpi_dst, &mpi_src));
831 
832 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dst, dest));
833 	mbedtls_mpi_free(&mpi_dst);
834 	mbedtls_mpi_free(&mpi_src);
835 }
836 
837 void TEE_BigIntComputeFMM(TEE_BigIntFMM *dest, const TEE_BigIntFMM *op1,
838 			  const TEE_BigIntFMM *op2, const TEE_BigInt *n,
839 			  const TEE_BigIntFMMContext *context __unused)
840 {
841 	mbedtls_mpi mpi_dst;
842 	mbedtls_mpi mpi_op1;
843 	mbedtls_mpi mpi_op2;
844 	mbedtls_mpi mpi_n;
845 	mbedtls_mpi mpi_t;
846 
847 	get_mpi(&mpi_dst, dest);
848 	get_mpi(&mpi_op1, op1);
849 	get_mpi(&mpi_op2, op2);
850 	get_mpi(&mpi_n, n);
851 	get_mpi(&mpi_t, NULL);
852 
853 	MPI_CHECK(mbedtls_mpi_mul_mpi(&mpi_t, &mpi_op1, &mpi_op2));
854 	MPI_CHECK(mbedtls_mpi_mod_mpi(&mpi_dst, &mpi_t, &mpi_n));
855 
856 	mbedtls_mpi_free(&mpi_t);
857 	mbedtls_mpi_free(&mpi_n);
858 	mbedtls_mpi_free(&mpi_op2);
859 	mbedtls_mpi_free(&mpi_op1);
860 	MPI_CHECK(copy_mpi_to_bigint(&mpi_dst, dest));
861 	mbedtls_mpi_free(&mpi_dst);
862 }
863