xref: /optee_os/lib/libmbedtls/core/rsa.c (revision f5c3d85a579c9594ee7592af0c0783891c21d9e0)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (C) 2018, ARM Limited
4  * Copyright (C) 2019, Linaro Limited
5  */
6 
7 #include <assert.h>
8 #include <crypto/crypto.h>
9 #include <crypto/crypto_impl.h>
10 #include <mbedtls/ctr_drbg.h>
11 #include <mbedtls/entropy.h>
12 #include <mbedtls/pk.h>
13 #include <mbedtls/pk_internal.h>
14 #include <stdlib.h>
15 #include <string.h>
16 #include <tee/tee_cryp_utl.h>
17 #include <utee_defines.h>
18 #include <fault_mitigation.h>
19 
20 #include "mbed_helpers.h"
21 
22 static TEE_Result get_tee_result(int lmd_res)
23 {
24 	switch (lmd_res) {
25 	case 0:
26 		return TEE_SUCCESS;
27 	case MBEDTLS_ERR_RSA_PRIVATE_FAILED +
28 		MBEDTLS_ERR_MPI_BAD_INPUT_DATA:
29 	case MBEDTLS_ERR_RSA_BAD_INPUT_DATA:
30 	case MBEDTLS_ERR_RSA_INVALID_PADDING:
31 	case MBEDTLS_ERR_PK_TYPE_MISMATCH:
32 		return TEE_ERROR_BAD_PARAMETERS;
33 	case MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE:
34 		return TEE_ERROR_SHORT_BUFFER;
35 	default:
36 		return TEE_ERROR_BAD_STATE;
37 	}
38 }
39 
40 static uint32_t tee_algo_to_mbedtls_hash_algo(uint32_t algo)
41 {
42 	switch (algo) {
43 #if defined(CFG_CRYPTO_SHA1)
44 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
45 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
46 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA1:
47 	case TEE_ALG_SHA1:
48 	case TEE_ALG_DSA_SHA1:
49 	case TEE_ALG_HMAC_SHA1:
50 		return MBEDTLS_MD_SHA1;
51 #endif
52 #if defined(CFG_CRYPTO_MD5)
53 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
54 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_MD5:
55 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_MD5:
56 	case TEE_ALG_MD5:
57 	case TEE_ALG_HMAC_MD5:
58 		return MBEDTLS_MD_MD5;
59 #endif
60 #if defined(CFG_CRYPTO_SHA224)
61 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
62 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
63 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA224:
64 	case TEE_ALG_SHA224:
65 	case TEE_ALG_DSA_SHA224:
66 	case TEE_ALG_HMAC_SHA224:
67 		return MBEDTLS_MD_SHA224;
68 #endif
69 #if defined(CFG_CRYPTO_SHA256)
70 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
71 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
72 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA256:
73 	case TEE_ALG_SHA256:
74 	case TEE_ALG_DSA_SHA256:
75 	case TEE_ALG_HMAC_SHA256:
76 		return MBEDTLS_MD_SHA256;
77 #endif
78 #if defined(CFG_CRYPTO_SHA384)
79 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
80 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
81 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA384:
82 	case TEE_ALG_SHA384:
83 	case TEE_ALG_HMAC_SHA384:
84 		return MBEDTLS_MD_SHA384;
85 #endif
86 #if defined(CFG_CRYPTO_SHA512)
87 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
88 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
89 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA512:
90 	case TEE_ALG_SHA512:
91 	case TEE_ALG_HMAC_SHA512:
92 		return MBEDTLS_MD_SHA512;
93 #endif
94 	default:
95 		return MBEDTLS_MD_NONE;
96 	}
97 }
98 
99 static void rsa_init_from_key_pair(mbedtls_rsa_context *rsa,
100 				struct rsa_keypair *key)
101 {
102 	mbedtls_rsa_init(rsa, 0, 0);
103 
104 	rsa->E = *(mbedtls_mpi *)key->e;
105 	rsa->N = *(mbedtls_mpi *)key->n;
106 	rsa->D = *(mbedtls_mpi *)key->d;
107 	if (key->p && crypto_bignum_num_bytes(key->p)) {
108 		rsa->P = *(mbedtls_mpi *)key->p;
109 		rsa->Q = *(mbedtls_mpi *)key->q;
110 		rsa->QP = *(mbedtls_mpi *)key->qp;
111 		rsa->DP = *(mbedtls_mpi *)key->dp;
112 		rsa->DQ = *(mbedtls_mpi *)key->dq;
113 	}
114 	rsa->len = mbedtls_mpi_size(&rsa->N);
115 }
116 
117 static void mbd_rsa_free(mbedtls_rsa_context *rsa)
118 {
119 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
120 	mbedtls_mpi_init(&rsa->E);
121 	mbedtls_mpi_init(&rsa->N);
122 	mbedtls_mpi_init(&rsa->D);
123 	if (mbedtls_mpi_size(&rsa->P)) {
124 		mbedtls_mpi_init(&rsa->P);
125 		mbedtls_mpi_init(&rsa->Q);
126 		mbedtls_mpi_init(&rsa->QP);
127 		mbedtls_mpi_init(&rsa->DP);
128 		mbedtls_mpi_init(&rsa->DQ);
129 	}
130 	mbedtls_rsa_free(rsa);
131 }
132 
133 TEE_Result crypto_acipher_alloc_rsa_keypair(struct rsa_keypair *s,
134 					    size_t key_size_bits)
135 __weak __alias("sw_crypto_acipher_alloc_rsa_keypair");
136 
137 TEE_Result sw_crypto_acipher_alloc_rsa_keypair(struct rsa_keypair *s,
138 					       size_t key_size_bits)
139 {
140 	memset(s, 0, sizeof(*s));
141 	s->e = crypto_bignum_allocate(key_size_bits);
142 	if (!s->e)
143 		goto err;
144 	s->d = crypto_bignum_allocate(key_size_bits);
145 	if (!s->d)
146 		goto err;
147 	s->n = crypto_bignum_allocate(key_size_bits);
148 	if (!s->n)
149 		goto err;
150 	s->p = crypto_bignum_allocate(key_size_bits);
151 	if (!s->p)
152 		goto err;
153 	s->q = crypto_bignum_allocate(key_size_bits);
154 	if (!s->q)
155 		goto err;
156 	s->qp = crypto_bignum_allocate(key_size_bits);
157 	if (!s->qp)
158 		goto err;
159 	s->dp = crypto_bignum_allocate(key_size_bits);
160 	if (!s->dp)
161 		goto err;
162 	s->dq = crypto_bignum_allocate(key_size_bits);
163 	if (!s->dq)
164 		goto err;
165 
166 	return TEE_SUCCESS;
167 err:
168 	crypto_acipher_free_rsa_keypair(s);
169 	return TEE_ERROR_OUT_OF_MEMORY;
170 }
171 
172 TEE_Result crypto_acipher_alloc_rsa_public_key(struct rsa_public_key *s,
173 					       size_t key_size_bits)
174 __weak __alias("sw_crypto_acipher_alloc_rsa_public_key");
175 
176 TEE_Result sw_crypto_acipher_alloc_rsa_public_key(struct rsa_public_key *s,
177 						  size_t key_size_bits)
178 {
179 	memset(s, 0, sizeof(*s));
180 	s->e = crypto_bignum_allocate(key_size_bits);
181 	if (!s->e)
182 		return TEE_ERROR_OUT_OF_MEMORY;
183 	s->n = crypto_bignum_allocate(key_size_bits);
184 	if (!s->n)
185 		goto err;
186 	return TEE_SUCCESS;
187 err:
188 	crypto_bignum_free(s->e);
189 	return TEE_ERROR_OUT_OF_MEMORY;
190 }
191 
192 void crypto_acipher_free_rsa_public_key(struct rsa_public_key *s)
193 __weak __alias("sw_crypto_acipher_free_rsa_public_key");
194 
195 void sw_crypto_acipher_free_rsa_public_key(struct rsa_public_key *s)
196 {
197 	if (!s)
198 		return;
199 	crypto_bignum_free(s->n);
200 	crypto_bignum_free(s->e);
201 }
202 
203 void crypto_acipher_free_rsa_keypair(struct rsa_keypair *s)
204 __weak __alias("sw_crypto_acipher_free_rsa_keypair");
205 
206 void sw_crypto_acipher_free_rsa_keypair(struct rsa_keypair *s)
207 {
208 	if (!s)
209 		return;
210 	crypto_bignum_free(s->e);
211 	crypto_bignum_free(s->d);
212 	crypto_bignum_free(s->n);
213 	crypto_bignum_free(s->p);
214 	crypto_bignum_free(s->q);
215 	crypto_bignum_free(s->qp);
216 	crypto_bignum_free(s->dp);
217 	crypto_bignum_free(s->dq);
218 }
219 
220 TEE_Result crypto_acipher_gen_rsa_key(struct rsa_keypair *key,
221 				      size_t key_size)
222 __weak __alias("sw_crypto_acipher_gen_rsa_key");
223 
224 TEE_Result sw_crypto_acipher_gen_rsa_key(struct rsa_keypair *key,
225 					 size_t key_size)
226 {
227 	TEE_Result res = TEE_SUCCESS;
228 	mbedtls_rsa_context rsa;
229 	mbedtls_ctr_drbg_context rngctx;
230 	int lmd_res = 0;
231 	uint32_t e = 0;
232 
233 	mbedtls_ctr_drbg_init(&rngctx);
234 	if (mbedtls_ctr_drbg_seed(&rngctx, mbd_rand, NULL, NULL, 0))
235 		return TEE_ERROR_BAD_STATE;
236 
237 	memset(&rsa, 0, sizeof(rsa));
238 	mbedtls_rsa_init(&rsa, 0, 0);
239 
240 	/* get the public exponent */
241 	mbedtls_mpi_write_binary((mbedtls_mpi *)key->e,
242 				 (unsigned char *)&e, sizeof(uint32_t));
243 
244 	e = TEE_U32_FROM_BIG_ENDIAN(e);
245 	lmd_res = mbedtls_rsa_gen_key(&rsa, mbedtls_ctr_drbg_random, &rngctx,
246 				      key_size, (int)e);
247 	mbedtls_ctr_drbg_free(&rngctx);
248 	if (lmd_res != 0) {
249 		res = get_tee_result(lmd_res);
250 	} else if ((size_t)mbedtls_mpi_bitlen(&rsa.N) != key_size) {
251 		res = TEE_ERROR_BAD_PARAMETERS;
252 	} else {
253 		/* Copy the key */
254 		crypto_bignum_copy(key->e, (void *)&rsa.E);
255 		crypto_bignum_copy(key->d, (void *)&rsa.D);
256 		crypto_bignum_copy(key->n, (void *)&rsa.N);
257 		crypto_bignum_copy(key->p, (void *)&rsa.P);
258 
259 		crypto_bignum_copy(key->q, (void *)&rsa.Q);
260 		crypto_bignum_copy(key->qp, (void *)&rsa.QP);
261 		crypto_bignum_copy(key->dp, (void *)&rsa.DP);
262 		crypto_bignum_copy(key->dq, (void *)&rsa.DQ);
263 
264 		res = TEE_SUCCESS;
265 	}
266 
267 	mbedtls_rsa_free(&rsa);
268 
269 	return res;
270 }
271 
272 TEE_Result crypto_acipher_rsanopad_encrypt(struct rsa_public_key *key,
273 					   const uint8_t *src,
274 					   size_t src_len, uint8_t *dst,
275 					   size_t *dst_len)
276 __weak __alias("sw_crypto_acipher_rsanopad_encrypt");
277 
278 TEE_Result sw_crypto_acipher_rsanopad_encrypt(struct rsa_public_key *key,
279 					      const uint8_t *src,
280 					      size_t src_len, uint8_t *dst,
281 					      size_t *dst_len)
282 {
283 	TEE_Result res = TEE_SUCCESS;
284 	mbedtls_rsa_context rsa;
285 	int lmd_res = 0;
286 	uint8_t *buf = NULL;
287 	unsigned long blen = 0;
288 	unsigned long offset = 0;
289 
290 	memset(&rsa, 0, sizeof(rsa));
291 	mbedtls_rsa_init(&rsa, 0, 0);
292 
293 	rsa.E = *(mbedtls_mpi *)key->e;
294 	rsa.N = *(mbedtls_mpi *)key->n;
295 
296 	rsa.len = crypto_bignum_num_bytes((void *)&rsa.N);
297 
298 	blen = CFG_CORE_BIGNUM_MAX_BITS / 8;
299 	buf = malloc(blen);
300 	if (!buf) {
301 		res = TEE_ERROR_OUT_OF_MEMORY;
302 		goto out;
303 	}
304 
305 	memset(buf, 0, blen);
306 	memcpy(buf + rsa.len - src_len, src, src_len);
307 
308 	lmd_res = mbedtls_rsa_public(&rsa, buf, buf);
309 	if (lmd_res != 0) {
310 		FMSG("mbedtls_rsa_public() returned 0x%x", -lmd_res);
311 		res = get_tee_result(lmd_res);
312 		goto out;
313 	}
314 
315 	/* Remove the zero-padding (leave one zero if buff is all zeroes) */
316 	offset = 0;
317 	while ((offset < rsa.len - 1) && (buf[offset] == 0))
318 		offset++;
319 
320 	if (*dst_len < rsa.len - offset) {
321 		*dst_len = rsa.len - offset;
322 		res = TEE_ERROR_SHORT_BUFFER;
323 		goto out;
324 	}
325 	*dst_len = rsa.len - offset;
326 	memcpy(dst, buf + offset, *dst_len);
327 out:
328 	free(buf);
329 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
330 	mbedtls_mpi_init(&rsa.E);
331 	mbedtls_mpi_init(&rsa.N);
332 	mbedtls_rsa_free(&rsa);
333 
334 	return res;
335 }
336 
337 TEE_Result crypto_acipher_rsanopad_decrypt(struct rsa_keypair *key,
338 					   const uint8_t *src,
339 					   size_t src_len, uint8_t *dst,
340 					   size_t *dst_len)
341 __weak __alias("sw_crypto_acipher_rsanopad_decrypt");
342 
343 TEE_Result sw_crypto_acipher_rsanopad_decrypt(struct rsa_keypair *key,
344 					      const uint8_t *src,
345 					      size_t src_len, uint8_t *dst,
346 					      size_t *dst_len)
347 {
348 	TEE_Result res = TEE_SUCCESS;
349 	mbedtls_rsa_context rsa;
350 	int lmd_res = 0;
351 	uint8_t *buf = NULL;
352 	unsigned long blen = 0;
353 	unsigned long offset = 0;
354 
355 	memset(&rsa, 0, sizeof(rsa));
356 	rsa_init_from_key_pair(&rsa, key);
357 
358 	blen = CFG_CORE_BIGNUM_MAX_BITS / 8;
359 	buf = malloc(blen);
360 	if (!buf) {
361 		res = TEE_ERROR_OUT_OF_MEMORY;
362 		goto out;
363 	}
364 
365 	memset(buf, 0, blen);
366 	memcpy(buf + rsa.len - src_len, src, src_len);
367 
368 	lmd_res = mbedtls_rsa_private(&rsa, NULL, NULL, buf, buf);
369 	if (lmd_res != 0) {
370 		FMSG("mbedtls_rsa_private() returned 0x%x", -lmd_res);
371 		res = get_tee_result(lmd_res);
372 		goto out;
373 	}
374 
375 	/* Remove the zero-padding (leave one zero if buff is all zeroes) */
376 	offset = 0;
377 	while ((offset < rsa.len - 1) && (buf[offset] == 0))
378 		offset++;
379 
380 	if (*dst_len < rsa.len - offset) {
381 		*dst_len = rsa.len - offset;
382 		res = TEE_ERROR_SHORT_BUFFER;
383 		goto out;
384 	}
385 	*dst_len = rsa.len - offset;
386 	memcpy(dst, (char *)buf + offset, *dst_len);
387 out:
388 	if (buf)
389 		free(buf);
390 	mbd_rsa_free(&rsa);
391 	return res;
392 }
393 
394 TEE_Result crypto_acipher_rsaes_decrypt(uint32_t algo,
395 					struct rsa_keypair *key,
396 					const uint8_t *label __unused,
397 					size_t label_len __unused,
398 					const uint8_t *src, size_t src_len,
399 					uint8_t *dst, size_t *dst_len)
400 __weak __alias("sw_crypto_acipher_rsaes_decrypt");
401 
402 TEE_Result sw_crypto_acipher_rsaes_decrypt(uint32_t algo,
403 					   struct rsa_keypair *key,
404 					   const uint8_t *label __unused,
405 					   size_t label_len __unused,
406 					   const uint8_t *src, size_t src_len,
407 					   uint8_t *dst, size_t *dst_len)
408 {
409 	TEE_Result res = TEE_SUCCESS;
410 	int lmd_res = 0;
411 	int lmd_padding = 0;
412 	size_t blen = 0;
413 	size_t mod_size = 0;
414 	void *buf = NULL;
415 	mbedtls_rsa_context rsa;
416 	const mbedtls_pk_info_t *pk_info = NULL;
417 	uint32_t md_algo = MBEDTLS_MD_NONE;
418 
419 	memset(&rsa, 0, sizeof(rsa));
420 	rsa_init_from_key_pair(&rsa, key);
421 
422 	/*
423 	 * Use a temporary buffer since we don't know exactly how large
424 	 * the required size of the out buffer without doing a partial
425 	 * decrypt. We know the upper bound though.
426 	 */
427 	if (algo == TEE_ALG_RSAES_PKCS1_V1_5) {
428 		mod_size = crypto_bignum_num_bytes(key->n);
429 		blen = mod_size - 11;
430 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
431 	} else {
432 		/* Decoded message is always shorter than encrypted message */
433 		blen = src_len;
434 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
435 	}
436 
437 	buf = malloc(blen);
438 	if (!buf) {
439 		res = TEE_ERROR_OUT_OF_MEMORY;
440 		goto out;
441 	}
442 
443 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
444 	if (!pk_info) {
445 		res = TEE_ERROR_NOT_SUPPORTED;
446 		goto out;
447 	}
448 
449 	/*
450 	 * TEE_ALG_RSAES_PKCS1_V1_5 is invalid in hash. But its hash algo will
451 	 * not be used in rsa, so skip it here.
452 	 */
453 	if (algo != TEE_ALG_RSAES_PKCS1_V1_5) {
454 		md_algo = tee_algo_to_mbedtls_hash_algo(algo);
455 		if (md_algo == MBEDTLS_MD_NONE) {
456 			res = TEE_ERROR_NOT_SUPPORTED;
457 			goto out;
458 		}
459 	}
460 
461 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
462 
463 	if (lmd_padding == MBEDTLS_RSA_PKCS_V15)
464 		lmd_res = pk_info->decrypt_func(&rsa, src, src_len, buf, &blen,
465 						blen, NULL, NULL);
466 	else
467 		lmd_res = pk_info->decrypt_func(&rsa, src, src_len, buf, &blen,
468 						blen, mbd_rand, NULL);
469 	if (lmd_res != 0) {
470 		FMSG("decrypt_func() returned 0x%x", -lmd_res);
471 		res = get_tee_result(lmd_res);
472 		goto out;
473 	}
474 
475 	if (*dst_len < blen) {
476 		*dst_len = blen;
477 		res = TEE_ERROR_SHORT_BUFFER;
478 		goto out;
479 	}
480 
481 	res = TEE_SUCCESS;
482 	*dst_len = blen;
483 	memcpy(dst, buf, blen);
484 out:
485 	if (buf)
486 		free(buf);
487 	mbd_rsa_free(&rsa);
488 	return res;
489 }
490 
491 TEE_Result crypto_acipher_rsaes_encrypt(uint32_t algo,
492 					struct rsa_public_key *key,
493 					const uint8_t *label __unused,
494 					size_t label_len __unused,
495 					const uint8_t *src, size_t src_len,
496 					uint8_t *dst, size_t *dst_len)
497 __weak __alias("sw_crypto_acipher_rsaes_encrypt");
498 
499 TEE_Result sw_crypto_acipher_rsaes_encrypt(uint32_t algo,
500 					   struct rsa_public_key *key,
501 					   const uint8_t *label __unused,
502 					   size_t label_len __unused,
503 					   const uint8_t *src, size_t src_len,
504 					   uint8_t *dst, size_t *dst_len)
505 {
506 	TEE_Result res = TEE_SUCCESS;
507 	int lmd_res = 0;
508 	int lmd_padding = 0;
509 	size_t mod_size = 0;
510 	mbedtls_rsa_context rsa;
511 	const mbedtls_pk_info_t *pk_info = NULL;
512 	uint32_t md_algo = MBEDTLS_MD_NONE;
513 
514 	memset(&rsa, 0, sizeof(rsa));
515 	mbedtls_rsa_init(&rsa, 0, 0);
516 
517 	rsa.E = *(mbedtls_mpi *)key->e;
518 	rsa.N = *(mbedtls_mpi *)key->n;
519 
520 	mod_size = crypto_bignum_num_bytes(key->n);
521 	if (*dst_len < mod_size) {
522 		*dst_len = mod_size;
523 		res = TEE_ERROR_SHORT_BUFFER;
524 		goto out;
525 	}
526 	*dst_len = mod_size;
527 	rsa.len = mod_size;
528 
529 	if (algo == TEE_ALG_RSAES_PKCS1_V1_5)
530 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
531 	else
532 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
533 
534 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
535 	if (!pk_info) {
536 		res = TEE_ERROR_NOT_SUPPORTED;
537 		goto out;
538 	}
539 
540 	/*
541 	 * TEE_ALG_RSAES_PKCS1_V1_5 is invalid in hash. But its hash algo will
542 	 * not be used in rsa, so skip it here.
543 	 */
544 	if (algo != TEE_ALG_RSAES_PKCS1_V1_5) {
545 		md_algo = tee_algo_to_mbedtls_hash_algo(algo);
546 		if (md_algo == MBEDTLS_MD_NONE) {
547 			res = TEE_ERROR_NOT_SUPPORTED;
548 			goto out;
549 		}
550 	}
551 
552 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
553 
554 	lmd_res = pk_info->encrypt_func(&rsa, src, src_len, dst, dst_len,
555 					*dst_len, mbd_rand, NULL);
556 	if (lmd_res != 0) {
557 		FMSG("encrypt_func() returned 0x%x", -lmd_res);
558 		res = get_tee_result(lmd_res);
559 		goto out;
560 	}
561 	res = TEE_SUCCESS;
562 out:
563 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
564 	mbedtls_mpi_init(&rsa.E);
565 	mbedtls_mpi_init(&rsa.N);
566 	mbedtls_rsa_free(&rsa);
567 	return res;
568 }
569 
570 TEE_Result crypto_acipher_rsassa_sign(uint32_t algo, struct rsa_keypair *key,
571 				      int salt_len __unused,
572 				      const uint8_t *msg, size_t msg_len,
573 				      uint8_t *sig, size_t *sig_len)
574 __weak __alias("sw_crypto_acipher_rsassa_sign");
575 
576 TEE_Result sw_crypto_acipher_rsassa_sign(uint32_t algo, struct rsa_keypair *key,
577 					 int salt_len __unused,
578 					 const uint8_t *msg, size_t msg_len,
579 					 uint8_t *sig, size_t *sig_len)
580 {
581 	TEE_Result res = TEE_SUCCESS;
582 	int lmd_res = 0;
583 	int lmd_padding = 0;
584 	size_t mod_size = 0;
585 	size_t hash_size = 0;
586 	mbedtls_rsa_context rsa;
587 	const mbedtls_pk_info_t *pk_info = NULL;
588 	uint32_t md_algo = 0;
589 
590 	memset(&rsa, 0, sizeof(rsa));
591 	rsa_init_from_key_pair(&rsa, key);
592 
593 	switch (algo) {
594 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
595 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
596 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
597 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
598 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
599 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
600 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
601 		break;
602 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_MD5:
603 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
604 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
605 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
606 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
607 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
608 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
609 		break;
610 	default:
611 		res = TEE_ERROR_BAD_PARAMETERS;
612 		goto err;
613 	}
614 
615 	res = tee_alg_get_digest_size(TEE_DIGEST_HASH_TO_ALGO(algo),
616 				      &hash_size);
617 	if (res != TEE_SUCCESS)
618 		goto err;
619 
620 	if (msg_len != hash_size) {
621 		res = TEE_ERROR_BAD_PARAMETERS;
622 		goto err;
623 	}
624 
625 	mod_size = crypto_bignum_num_bytes(key->n);
626 	if (*sig_len < mod_size) {
627 		*sig_len = mod_size;
628 		res = TEE_ERROR_SHORT_BUFFER;
629 		goto err;
630 	}
631 	rsa.len = mod_size;
632 
633 	md_algo = tee_algo_to_mbedtls_hash_algo(algo);
634 	if (md_algo == MBEDTLS_MD_NONE) {
635 		res = TEE_ERROR_NOT_SUPPORTED;
636 		goto err;
637 	}
638 
639 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
640 	if (!pk_info) {
641 		res = TEE_ERROR_NOT_SUPPORTED;
642 		goto err;
643 	}
644 
645 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
646 
647 	if (lmd_padding == MBEDTLS_RSA_PKCS_V15)
648 		lmd_res = pk_info->sign_func(&rsa, md_algo, msg, msg_len, sig,
649 					     sig_len, NULL, NULL);
650 	else
651 		lmd_res = pk_info->sign_func(&rsa, md_algo, msg, msg_len, sig,
652 					     sig_len, mbd_rand, NULL);
653 	if (lmd_res != 0) {
654 		FMSG("sign_func failed, returned 0x%x", -lmd_res);
655 		res = get_tee_result(lmd_res);
656 		goto err;
657 	}
658 	res = TEE_SUCCESS;
659 err:
660 	mbd_rsa_free(&rsa);
661 	return res;
662 }
663 
664 TEE_Result crypto_acipher_rsassa_verify(uint32_t algo,
665 					struct rsa_public_key *key,
666 					int salt_len __unused,
667 					const uint8_t *msg,
668 					size_t msg_len, const uint8_t *sig,
669 					size_t sig_len)
670 __weak __alias("sw_crypto_acipher_rsassa_verify");
671 
672 TEE_Result sw_crypto_acipher_rsassa_verify(uint32_t algo,
673 					   struct rsa_public_key *key,
674 					   int salt_len __unused,
675 					   const uint8_t *msg,
676 					   size_t msg_len, const uint8_t *sig,
677 					   size_t sig_len)
678 {
679 	TEE_Result res = TEE_SUCCESS;
680 	int lmd_res = 0;
681 	int lmd_padding = 0;
682 	size_t hash_size = 0;
683 	size_t bigint_size = 0;
684 	mbedtls_rsa_context rsa;
685 	const mbedtls_pk_info_t *pk_info = NULL;
686 	uint32_t md_algo = 0;
687 	struct ftmn ftmn = { };
688 	unsigned long arg_hash = 0;
689 
690 	/*
691 	 * The caller expects to call crypto_acipher_rsassa_verify(),
692 	 * update the hash as needed.
693 	 */
694 	FTMN_CALLEE_SWAP_HASH(FTMN_FUNC_HASH("crypto_acipher_rsassa_verify"));
695 
696 	memset(&rsa, 0, sizeof(rsa));
697 	mbedtls_rsa_init(&rsa, 0, 0);
698 
699 	rsa.E = *(mbedtls_mpi *)key->e;
700 	rsa.N = *(mbedtls_mpi *)key->n;
701 
702 	res = tee_alg_get_digest_size(TEE_DIGEST_HASH_TO_ALGO(algo),
703 				      &hash_size);
704 	if (res != TEE_SUCCESS)
705 		goto err;
706 
707 	if (msg_len != hash_size) {
708 		res = TEE_ERROR_BAD_PARAMETERS;
709 		goto err;
710 	}
711 
712 	bigint_size = crypto_bignum_num_bytes(key->n);
713 	if (sig_len < bigint_size) {
714 		res = TEE_ERROR_SIGNATURE_INVALID;
715 		goto err;
716 	}
717 
718 	rsa.len = bigint_size;
719 
720 	switch (algo) {
721 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
722 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
723 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
724 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
725 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
726 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
727 		arg_hash = FTMN_FUNC_HASH("mbedtls_rsa_rsassa_pkcs1_v15_verify");
728 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
729 		break;
730 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_MD5:
731 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
732 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
733 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
734 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
735 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
736 		arg_hash = FTMN_FUNC_HASH("mbedtls_rsa_rsassa_pss_verify_ext");
737 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
738 		break;
739 	default:
740 		res = TEE_ERROR_BAD_PARAMETERS;
741 		goto err;
742 	}
743 
744 	md_algo = tee_algo_to_mbedtls_hash_algo(algo);
745 	if (md_algo == MBEDTLS_MD_NONE) {
746 		res = TEE_ERROR_NOT_SUPPORTED;
747 		goto err;
748 	}
749 
750 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
751 	if (!pk_info) {
752 		res = TEE_ERROR_NOT_SUPPORTED;
753 		goto err;
754 	}
755 
756 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
757 
758 	FTMN_PUSH_LINKED_CALL(&ftmn, arg_hash);
759 	lmd_res = pk_info->verify_func(&rsa, md_algo, msg, msg_len,
760 				       sig, sig_len);
761 	if (!lmd_res)
762 		FTMN_SET_CHECK_RES_FROM_CALL(&ftmn, FTMN_INCR0, lmd_res);
763 	FTMN_POP_LINKED_CALL(&ftmn);
764 	if (lmd_res != 0) {
765 		FMSG("verify_func failed, returned 0x%x", -lmd_res);
766 		res = TEE_ERROR_SIGNATURE_INVALID;
767 		goto err;
768 	}
769 	res = TEE_SUCCESS;
770 	goto out;
771 
772 err:
773 	FTMN_SET_CHECK_RES_NOT_ZERO(&ftmn, FTMN_INCR0, res);
774 out:
775 	FTMN_CALLEE_DONE_CHECK(&ftmn, FTMN_INCR0, FTMN_STEP_COUNT(1), res);
776 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
777 	mbedtls_mpi_init(&rsa.E);
778 	mbedtls_mpi_init(&rsa.N);
779 	mbedtls_rsa_free(&rsa);
780 	return res;
781 }
782