xref: /optee_os/lib/libmbedtls/core/rsa.c (revision 23ef3871cb5814f45010171373add4d339285616)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (C) 2018, ARM Limited
4  * Copyright (C) 2019, Linaro Limited
5  */
6 
7 #include <assert.h>
8 #include <crypto/crypto.h>
9 #include <crypto/crypto_impl.h>
10 #include <mbedtls/ctr_drbg.h>
11 #include <mbedtls/entropy.h>
12 #include <mbedtls/pk.h>
13 #include <mbedtls/pk_internal.h>
14 #include <stdlib.h>
15 #include <string.h>
16 #include <tee/tee_cryp_utl.h>
17 #include <utee_defines.h>
18 
19 #include "mbed_helpers.h"
20 
21 static TEE_Result get_tee_result(int lmd_res)
22 {
23 	switch (lmd_res) {
24 	case 0:
25 		return TEE_SUCCESS;
26 	case MBEDTLS_ERR_RSA_PRIVATE_FAILED +
27 		MBEDTLS_ERR_MPI_BAD_INPUT_DATA:
28 	case MBEDTLS_ERR_RSA_BAD_INPUT_DATA:
29 	case MBEDTLS_ERR_RSA_INVALID_PADDING:
30 	case MBEDTLS_ERR_PK_TYPE_MISMATCH:
31 		return TEE_ERROR_BAD_PARAMETERS;
32 	case MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE:
33 		return TEE_ERROR_SHORT_BUFFER;
34 	default:
35 		return TEE_ERROR_BAD_STATE;
36 	}
37 }
38 
39 static uint32_t tee_algo_to_mbedtls_hash_algo(uint32_t algo)
40 {
41 	switch (algo) {
42 #if defined(CFG_CRYPTO_SHA1)
43 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
44 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
45 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA1:
46 	case TEE_ALG_SHA1:
47 	case TEE_ALG_DSA_SHA1:
48 	case TEE_ALG_HMAC_SHA1:
49 		return MBEDTLS_MD_SHA1;
50 #endif
51 #if defined(CFG_CRYPTO_MD5)
52 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
53 	case TEE_ALG_MD5:
54 	case TEE_ALG_HMAC_MD5:
55 		return MBEDTLS_MD_MD5;
56 #endif
57 #if defined(CFG_CRYPTO_SHA224)
58 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
59 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
60 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA224:
61 	case TEE_ALG_SHA224:
62 	case TEE_ALG_DSA_SHA224:
63 	case TEE_ALG_HMAC_SHA224:
64 		return MBEDTLS_MD_SHA224;
65 #endif
66 #if defined(CFG_CRYPTO_SHA256)
67 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
68 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
69 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA256:
70 	case TEE_ALG_SHA256:
71 	case TEE_ALG_DSA_SHA256:
72 	case TEE_ALG_HMAC_SHA256:
73 		return MBEDTLS_MD_SHA256;
74 #endif
75 #if defined(CFG_CRYPTO_SHA384)
76 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
77 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
78 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA384:
79 	case TEE_ALG_SHA384:
80 	case TEE_ALG_HMAC_SHA384:
81 		return MBEDTLS_MD_SHA384;
82 #endif
83 #if defined(CFG_CRYPTO_SHA512)
84 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
85 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
86 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA512:
87 	case TEE_ALG_SHA512:
88 	case TEE_ALG_HMAC_SHA512:
89 		return MBEDTLS_MD_SHA512;
90 #endif
91 	default:
92 		return MBEDTLS_MD_NONE;
93 	}
94 }
95 
96 static void rsa_init_from_key_pair(mbedtls_rsa_context *rsa,
97 				struct rsa_keypair *key)
98 {
99 	mbedtls_rsa_init(rsa, 0, 0);
100 
101 	rsa->E = *(mbedtls_mpi *)key->e;
102 	rsa->N = *(mbedtls_mpi *)key->n;
103 	rsa->D = *(mbedtls_mpi *)key->d;
104 	if (key->p && crypto_bignum_num_bytes(key->p)) {
105 		rsa->P = *(mbedtls_mpi *)key->p;
106 		rsa->Q = *(mbedtls_mpi *)key->q;
107 		rsa->QP = *(mbedtls_mpi *)key->qp;
108 		rsa->DP = *(mbedtls_mpi *)key->dp;
109 		rsa->DQ = *(mbedtls_mpi *)key->dq;
110 	}
111 	rsa->len = mbedtls_mpi_size(&rsa->N);
112 }
113 
114 static void mbd_rsa_free(mbedtls_rsa_context *rsa)
115 {
116 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
117 	mbedtls_mpi_init(&rsa->E);
118 	mbedtls_mpi_init(&rsa->N);
119 	mbedtls_mpi_init(&rsa->D);
120 	if (mbedtls_mpi_size(&rsa->P)) {
121 		mbedtls_mpi_init(&rsa->P);
122 		mbedtls_mpi_init(&rsa->Q);
123 		mbedtls_mpi_init(&rsa->QP);
124 		mbedtls_mpi_init(&rsa->DP);
125 		mbedtls_mpi_init(&rsa->DQ);
126 	}
127 	mbedtls_rsa_free(rsa);
128 }
129 
130 TEE_Result crypto_acipher_alloc_rsa_keypair(struct rsa_keypair *s,
131 					    size_t key_size_bits)
132 __weak __alias("sw_crypto_acipher_alloc_rsa_keypair");
133 
134 TEE_Result sw_crypto_acipher_alloc_rsa_keypair(struct rsa_keypair *s,
135 					       size_t key_size_bits)
136 {
137 	memset(s, 0, sizeof(*s));
138 	s->e = crypto_bignum_allocate(key_size_bits);
139 	if (!s->e)
140 		goto err;
141 	s->d = crypto_bignum_allocate(key_size_bits);
142 	if (!s->d)
143 		goto err;
144 	s->n = crypto_bignum_allocate(key_size_bits);
145 	if (!s->n)
146 		goto err;
147 	s->p = crypto_bignum_allocate(key_size_bits);
148 	if (!s->p)
149 		goto err;
150 	s->q = crypto_bignum_allocate(key_size_bits);
151 	if (!s->q)
152 		goto err;
153 	s->qp = crypto_bignum_allocate(key_size_bits);
154 	if (!s->qp)
155 		goto err;
156 	s->dp = crypto_bignum_allocate(key_size_bits);
157 	if (!s->dp)
158 		goto err;
159 	s->dq = crypto_bignum_allocate(key_size_bits);
160 	if (!s->dq)
161 		goto err;
162 
163 	return TEE_SUCCESS;
164 err:
165 	crypto_acipher_free_rsa_keypair(s);
166 	return TEE_ERROR_OUT_OF_MEMORY;
167 }
168 
169 TEE_Result crypto_acipher_alloc_rsa_public_key(struct rsa_public_key *s,
170 					       size_t key_size_bits)
171 __weak __alias("sw_crypto_acipher_alloc_rsa_public_key");
172 
173 TEE_Result sw_crypto_acipher_alloc_rsa_public_key(struct rsa_public_key *s,
174 						  size_t key_size_bits)
175 {
176 	memset(s, 0, sizeof(*s));
177 	s->e = crypto_bignum_allocate(key_size_bits);
178 	if (!s->e)
179 		return TEE_ERROR_OUT_OF_MEMORY;
180 	s->n = crypto_bignum_allocate(key_size_bits);
181 	if (!s->n)
182 		goto err;
183 	return TEE_SUCCESS;
184 err:
185 	crypto_bignum_free(s->e);
186 	return TEE_ERROR_OUT_OF_MEMORY;
187 }
188 
189 void crypto_acipher_free_rsa_public_key(struct rsa_public_key *s)
190 __weak __alias("sw_crypto_acipher_free_rsa_public_key");
191 
192 void sw_crypto_acipher_free_rsa_public_key(struct rsa_public_key *s)
193 {
194 	if (!s)
195 		return;
196 	crypto_bignum_free(s->n);
197 	crypto_bignum_free(s->e);
198 }
199 
200 void crypto_acipher_free_rsa_keypair(struct rsa_keypair *s)
201 __weak __alias("sw_crypto_acipher_free_rsa_keypair");
202 
203 void sw_crypto_acipher_free_rsa_keypair(struct rsa_keypair *s)
204 {
205 	if (!s)
206 		return;
207 	crypto_bignum_free(s->e);
208 	crypto_bignum_free(s->d);
209 	crypto_bignum_free(s->n);
210 	crypto_bignum_free(s->p);
211 	crypto_bignum_free(s->q);
212 	crypto_bignum_free(s->qp);
213 	crypto_bignum_free(s->dp);
214 	crypto_bignum_free(s->dq);
215 }
216 
217 TEE_Result crypto_acipher_gen_rsa_key(struct rsa_keypair *key,
218 				      size_t key_size)
219 __weak __alias("sw_crypto_acipher_gen_rsa_key");
220 
221 TEE_Result sw_crypto_acipher_gen_rsa_key(struct rsa_keypair *key,
222 					 size_t key_size)
223 {
224 	TEE_Result res = TEE_SUCCESS;
225 	mbedtls_rsa_context rsa;
226 	mbedtls_ctr_drbg_context rngctx;
227 	int lmd_res = 0;
228 	uint32_t e = 0;
229 
230 	mbedtls_ctr_drbg_init(&rngctx);
231 	if (mbedtls_ctr_drbg_seed(&rngctx, mbd_rand, NULL, NULL, 0))
232 		return TEE_ERROR_BAD_STATE;
233 
234 	memset(&rsa, 0, sizeof(rsa));
235 	mbedtls_rsa_init(&rsa, 0, 0);
236 
237 	/* get the public exponent */
238 	mbedtls_mpi_write_binary((mbedtls_mpi *)key->e,
239 				 (unsigned char *)&e, sizeof(uint32_t));
240 
241 	e = TEE_U32_FROM_BIG_ENDIAN(e);
242 	lmd_res = mbedtls_rsa_gen_key(&rsa, mbedtls_ctr_drbg_random, &rngctx,
243 				      key_size, (int)e);
244 	mbedtls_ctr_drbg_free(&rngctx);
245 	if (lmd_res != 0) {
246 		res = get_tee_result(lmd_res);
247 	} else if ((size_t)mbedtls_mpi_bitlen(&rsa.N) != key_size) {
248 		res = TEE_ERROR_BAD_PARAMETERS;
249 	} else {
250 		/* Copy the key */
251 		crypto_bignum_copy(key->e, (void *)&rsa.E);
252 		crypto_bignum_copy(key->d, (void *)&rsa.D);
253 		crypto_bignum_copy(key->n, (void *)&rsa.N);
254 		crypto_bignum_copy(key->p, (void *)&rsa.P);
255 
256 		crypto_bignum_copy(key->q, (void *)&rsa.Q);
257 		crypto_bignum_copy(key->qp, (void *)&rsa.QP);
258 		crypto_bignum_copy(key->dp, (void *)&rsa.DP);
259 		crypto_bignum_copy(key->dq, (void *)&rsa.DQ);
260 
261 		res = TEE_SUCCESS;
262 	}
263 
264 	mbedtls_rsa_free(&rsa);
265 
266 	return res;
267 }
268 
269 TEE_Result crypto_acipher_rsanopad_encrypt(struct rsa_public_key *key,
270 					   const uint8_t *src,
271 					   size_t src_len, uint8_t *dst,
272 					   size_t *dst_len)
273 __weak __alias("sw_crypto_acipher_rsanopad_encrypt");
274 
275 TEE_Result sw_crypto_acipher_rsanopad_encrypt(struct rsa_public_key *key,
276 					      const uint8_t *src,
277 					      size_t src_len, uint8_t *dst,
278 					      size_t *dst_len)
279 {
280 	TEE_Result res = TEE_SUCCESS;
281 	mbedtls_rsa_context rsa;
282 	int lmd_res = 0;
283 	uint8_t *buf = NULL;
284 	unsigned long blen = 0;
285 	unsigned long offset = 0;
286 
287 	memset(&rsa, 0, sizeof(rsa));
288 	mbedtls_rsa_init(&rsa, 0, 0);
289 
290 	rsa.E = *(mbedtls_mpi *)key->e;
291 	rsa.N = *(mbedtls_mpi *)key->n;
292 
293 	rsa.len = crypto_bignum_num_bytes((void *)&rsa.N);
294 
295 	blen = CFG_CORE_BIGNUM_MAX_BITS / 8;
296 	buf = malloc(blen);
297 	if (!buf) {
298 		res = TEE_ERROR_OUT_OF_MEMORY;
299 		goto out;
300 	}
301 
302 	memset(buf, 0, blen);
303 	memcpy(buf + rsa.len - src_len, src, src_len);
304 
305 	lmd_res = mbedtls_rsa_public(&rsa, buf, buf);
306 	if (lmd_res != 0) {
307 		FMSG("mbedtls_rsa_public() returned 0x%x", -lmd_res);
308 		res = get_tee_result(lmd_res);
309 		goto out;
310 	}
311 
312 	/* Remove the zero-padding (leave one zero if buff is all zeroes) */
313 	offset = 0;
314 	while ((offset < rsa.len - 1) && (buf[offset] == 0))
315 		offset++;
316 
317 	if (*dst_len < rsa.len - offset) {
318 		*dst_len = rsa.len - offset;
319 		res = TEE_ERROR_SHORT_BUFFER;
320 		goto out;
321 	}
322 	*dst_len = rsa.len - offset;
323 	memcpy(dst, buf + offset, *dst_len);
324 out:
325 	free(buf);
326 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
327 	mbedtls_mpi_init(&rsa.E);
328 	mbedtls_mpi_init(&rsa.N);
329 	mbedtls_rsa_free(&rsa);
330 
331 	return res;
332 }
333 
334 TEE_Result crypto_acipher_rsanopad_decrypt(struct rsa_keypair *key,
335 					   const uint8_t *src,
336 					   size_t src_len, uint8_t *dst,
337 					   size_t *dst_len)
338 __weak __alias("sw_crypto_acipher_rsanopad_decrypt");
339 
340 TEE_Result sw_crypto_acipher_rsanopad_decrypt(struct rsa_keypair *key,
341 					      const uint8_t *src,
342 					      size_t src_len, uint8_t *dst,
343 					      size_t *dst_len)
344 {
345 	TEE_Result res = TEE_SUCCESS;
346 	mbedtls_rsa_context rsa;
347 	int lmd_res = 0;
348 	uint8_t *buf = NULL;
349 	unsigned long blen = 0;
350 	unsigned long offset = 0;
351 
352 	memset(&rsa, 0, sizeof(rsa));
353 	rsa_init_from_key_pair(&rsa, key);
354 
355 	blen = CFG_CORE_BIGNUM_MAX_BITS / 8;
356 	buf = malloc(blen);
357 	if (!buf) {
358 		res = TEE_ERROR_OUT_OF_MEMORY;
359 		goto out;
360 	}
361 
362 	memset(buf, 0, blen);
363 	memcpy(buf + rsa.len - src_len, src, src_len);
364 
365 	lmd_res = mbedtls_rsa_private(&rsa, NULL, NULL, buf, buf);
366 	if (lmd_res != 0) {
367 		FMSG("mbedtls_rsa_private() returned 0x%x", -lmd_res);
368 		res = get_tee_result(lmd_res);
369 		goto out;
370 	}
371 
372 	/* Remove the zero-padding (leave one zero if buff is all zeroes) */
373 	offset = 0;
374 	while ((offset < rsa.len - 1) && (buf[offset] == 0))
375 		offset++;
376 
377 	if (*dst_len < rsa.len - offset) {
378 		*dst_len = rsa.len - offset;
379 		res = TEE_ERROR_SHORT_BUFFER;
380 		goto out;
381 	}
382 	*dst_len = rsa.len - offset;
383 	memcpy(dst, (char *)buf + offset, *dst_len);
384 out:
385 	if (buf)
386 		free(buf);
387 	mbd_rsa_free(&rsa);
388 	return res;
389 }
390 
391 TEE_Result crypto_acipher_rsaes_decrypt(uint32_t algo,
392 					struct rsa_keypair *key,
393 					const uint8_t *label __unused,
394 					size_t label_len __unused,
395 					const uint8_t *src, size_t src_len,
396 					uint8_t *dst, size_t *dst_len)
397 __weak __alias("sw_crypto_acipher_rsaes_decrypt");
398 
399 TEE_Result sw_crypto_acipher_rsaes_decrypt(uint32_t algo,
400 					   struct rsa_keypair *key,
401 					   const uint8_t *label __unused,
402 					   size_t label_len __unused,
403 					   const uint8_t *src, size_t src_len,
404 					   uint8_t *dst, size_t *dst_len)
405 {
406 	TEE_Result res = TEE_SUCCESS;
407 	int lmd_res = 0;
408 	int lmd_padding = 0;
409 	size_t blen = 0;
410 	size_t mod_size = 0;
411 	void *buf = NULL;
412 	mbedtls_rsa_context rsa;
413 	const mbedtls_pk_info_t *pk_info = NULL;
414 	uint32_t md_algo = MBEDTLS_MD_NONE;
415 
416 	memset(&rsa, 0, sizeof(rsa));
417 	rsa_init_from_key_pair(&rsa, key);
418 
419 	/*
420 	 * Use a temporary buffer since we don't know exactly how large
421 	 * the required size of the out buffer without doing a partial
422 	 * decrypt. We know the upper bound though.
423 	 */
424 	if (algo == TEE_ALG_RSAES_PKCS1_V1_5) {
425 		mod_size = crypto_bignum_num_bytes(key->n);
426 		blen = mod_size - 11;
427 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
428 	} else {
429 		/* Decoded message is always shorter than encrypted message */
430 		blen = src_len;
431 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
432 	}
433 
434 	buf = malloc(blen);
435 	if (!buf) {
436 		res = TEE_ERROR_OUT_OF_MEMORY;
437 		goto out;
438 	}
439 
440 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
441 	if (!pk_info) {
442 		res = TEE_ERROR_NOT_SUPPORTED;
443 		goto out;
444 	}
445 
446 	/*
447 	 * TEE_ALG_RSAES_PKCS1_V1_5 is invalid in hash. But its hash algo will
448 	 * not be used in rsa, so skip it here.
449 	 */
450 	if (algo != TEE_ALG_RSAES_PKCS1_V1_5) {
451 		md_algo = tee_algo_to_mbedtls_hash_algo(algo);
452 		if (md_algo == MBEDTLS_MD_NONE) {
453 			res = TEE_ERROR_NOT_SUPPORTED;
454 			goto out;
455 		}
456 	}
457 
458 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
459 
460 	if (lmd_padding == MBEDTLS_RSA_PKCS_V15)
461 		lmd_res = pk_info->decrypt_func(&rsa, src, src_len, buf, &blen,
462 						blen, NULL, NULL);
463 	else
464 		lmd_res = pk_info->decrypt_func(&rsa, src, src_len, buf, &blen,
465 						blen, mbd_rand, NULL);
466 	if (lmd_res != 0) {
467 		FMSG("decrypt_func() returned 0x%x", -lmd_res);
468 		res = get_tee_result(lmd_res);
469 		goto out;
470 	}
471 
472 	if (*dst_len < blen) {
473 		*dst_len = blen;
474 		res = TEE_ERROR_SHORT_BUFFER;
475 		goto out;
476 	}
477 
478 	res = TEE_SUCCESS;
479 	*dst_len = blen;
480 	memcpy(dst, buf, blen);
481 out:
482 	if (buf)
483 		free(buf);
484 	mbd_rsa_free(&rsa);
485 	return res;
486 }
487 
488 TEE_Result crypto_acipher_rsaes_encrypt(uint32_t algo,
489 					struct rsa_public_key *key,
490 					const uint8_t *label __unused,
491 					size_t label_len __unused,
492 					const uint8_t *src, size_t src_len,
493 					uint8_t *dst, size_t *dst_len)
494 __weak __alias("sw_crypto_acipher_rsaes_encrypt");
495 
496 TEE_Result sw_crypto_acipher_rsaes_encrypt(uint32_t algo,
497 					   struct rsa_public_key *key,
498 					   const uint8_t *label __unused,
499 					   size_t label_len __unused,
500 					   const uint8_t *src, size_t src_len,
501 					   uint8_t *dst, size_t *dst_len)
502 {
503 	TEE_Result res = TEE_SUCCESS;
504 	int lmd_res = 0;
505 	int lmd_padding = 0;
506 	size_t mod_size = 0;
507 	mbedtls_rsa_context rsa;
508 	const mbedtls_pk_info_t *pk_info = NULL;
509 	uint32_t md_algo = MBEDTLS_MD_NONE;
510 
511 	memset(&rsa, 0, sizeof(rsa));
512 	mbedtls_rsa_init(&rsa, 0, 0);
513 
514 	rsa.E = *(mbedtls_mpi *)key->e;
515 	rsa.N = *(mbedtls_mpi *)key->n;
516 
517 	mod_size = crypto_bignum_num_bytes(key->n);
518 	if (*dst_len < mod_size) {
519 		*dst_len = mod_size;
520 		res = TEE_ERROR_SHORT_BUFFER;
521 		goto out;
522 	}
523 	*dst_len = mod_size;
524 	rsa.len = mod_size;
525 
526 	if (algo == TEE_ALG_RSAES_PKCS1_V1_5)
527 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
528 	else
529 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
530 
531 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
532 	if (!pk_info) {
533 		res = TEE_ERROR_NOT_SUPPORTED;
534 		goto out;
535 	}
536 
537 	/*
538 	 * TEE_ALG_RSAES_PKCS1_V1_5 is invalid in hash. But its hash algo will
539 	 * not be used in rsa, so skip it here.
540 	 */
541 	if (algo != TEE_ALG_RSAES_PKCS1_V1_5) {
542 		md_algo = tee_algo_to_mbedtls_hash_algo(algo);
543 		if (md_algo == MBEDTLS_MD_NONE) {
544 			res = TEE_ERROR_NOT_SUPPORTED;
545 			goto out;
546 		}
547 	}
548 
549 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
550 
551 	lmd_res = pk_info->encrypt_func(&rsa, src, src_len, dst, dst_len,
552 					*dst_len, mbd_rand, NULL);
553 	if (lmd_res != 0) {
554 		FMSG("encrypt_func() returned 0x%x", -lmd_res);
555 		res = get_tee_result(lmd_res);
556 		goto out;
557 	}
558 	res = TEE_SUCCESS;
559 out:
560 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
561 	mbedtls_mpi_init(&rsa.E);
562 	mbedtls_mpi_init(&rsa.N);
563 	mbedtls_rsa_free(&rsa);
564 	return res;
565 }
566 
567 TEE_Result crypto_acipher_rsassa_sign(uint32_t algo, struct rsa_keypair *key,
568 				      int salt_len __unused,
569 				      const uint8_t *msg, size_t msg_len,
570 				      uint8_t *sig, size_t *sig_len)
571 __weak __alias("sw_crypto_acipher_rsassa_sign");
572 
573 TEE_Result sw_crypto_acipher_rsassa_sign(uint32_t algo, struct rsa_keypair *key,
574 					 int salt_len __unused,
575 					 const uint8_t *msg, size_t msg_len,
576 					 uint8_t *sig, size_t *sig_len)
577 {
578 	TEE_Result res = TEE_SUCCESS;
579 	int lmd_res = 0;
580 	int lmd_padding = 0;
581 	size_t mod_size = 0;
582 	size_t hash_size = 0;
583 	mbedtls_rsa_context rsa;
584 	const mbedtls_pk_info_t *pk_info = NULL;
585 	uint32_t md_algo = 0;
586 
587 	memset(&rsa, 0, sizeof(rsa));
588 	rsa_init_from_key_pair(&rsa, key);
589 
590 	switch (algo) {
591 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
592 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
593 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
594 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
595 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
596 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
597 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
598 		break;
599 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
600 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
601 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
602 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
603 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
604 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
605 		break;
606 	default:
607 		res = TEE_ERROR_BAD_PARAMETERS;
608 		goto err;
609 	}
610 
611 	res = tee_alg_get_digest_size(TEE_DIGEST_HASH_TO_ALGO(algo),
612 				      &hash_size);
613 	if (res != TEE_SUCCESS)
614 		goto err;
615 
616 	if (msg_len != hash_size) {
617 		res = TEE_ERROR_BAD_PARAMETERS;
618 		goto err;
619 	}
620 
621 	mod_size = crypto_bignum_num_bytes(key->n);
622 	if (*sig_len < mod_size) {
623 		*sig_len = mod_size;
624 		res = TEE_ERROR_SHORT_BUFFER;
625 		goto err;
626 	}
627 	rsa.len = mod_size;
628 
629 	md_algo = tee_algo_to_mbedtls_hash_algo(algo);
630 	if (md_algo == MBEDTLS_MD_NONE) {
631 		res = TEE_ERROR_NOT_SUPPORTED;
632 		goto err;
633 	}
634 
635 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
636 	if (!pk_info) {
637 		res = TEE_ERROR_NOT_SUPPORTED;
638 		goto err;
639 	}
640 
641 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
642 
643 	if (lmd_padding == MBEDTLS_RSA_PKCS_V15)
644 		lmd_res = pk_info->sign_func(&rsa, md_algo, msg, msg_len, sig,
645 					     sig_len, NULL, NULL);
646 	else
647 		lmd_res = pk_info->sign_func(&rsa, md_algo, msg, msg_len, sig,
648 					     sig_len, mbd_rand, NULL);
649 	if (lmd_res != 0) {
650 		FMSG("sign_func failed, returned 0x%x", -lmd_res);
651 		res = get_tee_result(lmd_res);
652 		goto err;
653 	}
654 	res = TEE_SUCCESS;
655 err:
656 	mbd_rsa_free(&rsa);
657 	return res;
658 }
659 
660 TEE_Result crypto_acipher_rsassa_verify(uint32_t algo,
661 					struct rsa_public_key *key,
662 					int salt_len __unused,
663 					const uint8_t *msg,
664 					size_t msg_len, const uint8_t *sig,
665 					size_t sig_len)
666 __weak __alias("sw_crypto_acipher_rsassa_verify");
667 
668 TEE_Result sw_crypto_acipher_rsassa_verify(uint32_t algo,
669 					   struct rsa_public_key *key,
670 					   int salt_len __unused,
671 					   const uint8_t *msg,
672 					   size_t msg_len, const uint8_t *sig,
673 					   size_t sig_len)
674 {
675 	TEE_Result res = TEE_SUCCESS;
676 	int lmd_res = 0;
677 	int lmd_padding = 0;
678 	size_t hash_size = 0;
679 	size_t bigint_size = 0;
680 	mbedtls_rsa_context rsa;
681 	const mbedtls_pk_info_t *pk_info = NULL;
682 	uint32_t md_algo = 0;
683 
684 	memset(&rsa, 0, sizeof(rsa));
685 	mbedtls_rsa_init(&rsa, 0, 0);
686 
687 	rsa.E = *(mbedtls_mpi *)key->e;
688 	rsa.N = *(mbedtls_mpi *)key->n;
689 
690 	res = tee_alg_get_digest_size(TEE_DIGEST_HASH_TO_ALGO(algo),
691 				      &hash_size);
692 	if (res != TEE_SUCCESS)
693 		goto err;
694 
695 	if (msg_len != hash_size) {
696 		res = TEE_ERROR_BAD_PARAMETERS;
697 		goto err;
698 	}
699 
700 	bigint_size = crypto_bignum_num_bytes(key->n);
701 	if (sig_len < bigint_size) {
702 		res = TEE_ERROR_SIGNATURE_INVALID;
703 		goto err;
704 	}
705 
706 	rsa.len = bigint_size;
707 
708 	switch (algo) {
709 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
710 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
711 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
712 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
713 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
714 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
715 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
716 		break;
717 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
718 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
719 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
720 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
721 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
722 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
723 		break;
724 	default:
725 		res = TEE_ERROR_BAD_PARAMETERS;
726 		goto err;
727 	}
728 
729 	md_algo = tee_algo_to_mbedtls_hash_algo(algo);
730 	if (md_algo == MBEDTLS_MD_NONE) {
731 		res = TEE_ERROR_NOT_SUPPORTED;
732 		goto err;
733 	}
734 
735 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
736 	if (!pk_info) {
737 		res = TEE_ERROR_NOT_SUPPORTED;
738 		goto err;
739 	}
740 
741 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
742 
743 	lmd_res = pk_info->verify_func(&rsa, md_algo, msg, msg_len,
744 				       sig, sig_len);
745 	if (lmd_res != 0) {
746 		FMSG("verify_func failed, returned 0x%x", -lmd_res);
747 		res = TEE_ERROR_SIGNATURE_INVALID;
748 		goto err;
749 	}
750 	res = TEE_SUCCESS;
751 err:
752 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
753 	mbedtls_mpi_init(&rsa.E);
754 	mbedtls_mpi_init(&rsa.N);
755 	mbedtls_rsa_free(&rsa);
756 	return res;
757 }
758