xref: /optee_os/lib/libmbedtls/core/rsa.c (revision 731185b11620a6de1f824278ab3c166c7853ef66)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (C) 2018, ARM Limited
4  * Copyright (C) 2019, Linaro Limited
5  */
6 
7 #include <assert.h>
8 #include <crypto/crypto.h>
9 #include <crypto/crypto_impl.h>
10 #include <mbedtls/ctr_drbg.h>
11 #include <mbedtls/entropy.h>
12 #include <mbedtls/pk.h>
13 #include <stdlib.h>
14 #include <string.h>
15 #include <tee/tee_cryp_utl.h>
16 #include <utee_defines.h>
17 #include <fault_mitigation.h>
18 
19 #include "mbed_helpers.h"
20 #include "../mbedtls/library/pk_wrap.h"
21 #include "../mbedtls/library/rsa_alt_helpers.h"
22 
23 static TEE_Result get_tee_result(int lmd_res)
24 {
25 	switch (lmd_res) {
26 	case 0:
27 		return TEE_SUCCESS;
28 	case MBEDTLS_ERR_RSA_PRIVATE_FAILED +
29 		MBEDTLS_ERR_MPI_BAD_INPUT_DATA:
30 	case MBEDTLS_ERR_RSA_BAD_INPUT_DATA:
31 	case MBEDTLS_ERR_RSA_INVALID_PADDING:
32 	case MBEDTLS_ERR_PK_TYPE_MISMATCH:
33 		return TEE_ERROR_BAD_PARAMETERS;
34 	case MBEDTLS_ERR_RSA_OUTPUT_TOO_LARGE:
35 		return TEE_ERROR_SHORT_BUFFER;
36 	default:
37 		return TEE_ERROR_BAD_STATE;
38 	}
39 }
40 
41 static uint32_t tee_algo_to_mbedtls_hash_algo(uint32_t algo)
42 {
43 	switch (algo) {
44 #if defined(CFG_CRYPTO_SHA1)
45 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
46 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
47 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA1:
48 	case TEE_ALG_SHA1:
49 	case TEE_ALG_DSA_SHA1:
50 	case TEE_ALG_HMAC_SHA1:
51 		return MBEDTLS_MD_SHA1;
52 #endif
53 #if defined(CFG_CRYPTO_MD5)
54 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
55 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_MD5:
56 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_MD5:
57 	case TEE_ALG_MD5:
58 	case TEE_ALG_HMAC_MD5:
59 		return MBEDTLS_MD_MD5;
60 #endif
61 #if defined(CFG_CRYPTO_SHA224)
62 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
63 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
64 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA224:
65 	case TEE_ALG_SHA224:
66 	case TEE_ALG_DSA_SHA224:
67 	case TEE_ALG_HMAC_SHA224:
68 		return MBEDTLS_MD_SHA224;
69 #endif
70 #if defined(CFG_CRYPTO_SHA256)
71 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
72 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
73 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA256:
74 	case TEE_ALG_SHA256:
75 	case TEE_ALG_DSA_SHA256:
76 	case TEE_ALG_HMAC_SHA256:
77 		return MBEDTLS_MD_SHA256;
78 #endif
79 #if defined(CFG_CRYPTO_SHA384)
80 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
81 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
82 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA384:
83 	case TEE_ALG_SHA384:
84 	case TEE_ALG_HMAC_SHA384:
85 		return MBEDTLS_MD_SHA384;
86 #endif
87 #if defined(CFG_CRYPTO_SHA512)
88 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
89 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
90 	case TEE_ALG_RSAES_PKCS1_OAEP_MGF1_SHA512:
91 	case TEE_ALG_SHA512:
92 	case TEE_ALG_HMAC_SHA512:
93 		return MBEDTLS_MD_SHA512;
94 #endif
95 	default:
96 		return MBEDTLS_MD_NONE;
97 	}
98 }
99 
100 static TEE_Result rsa_init_and_complete_from_key_pair(mbedtls_rsa_context *rsa,
101 						      struct rsa_keypair *key)
102 {
103 	int lmd_res = 0;
104 
105 	mbedtls_rsa_init(rsa);
106 
107 	rsa->E = *(mbedtls_mpi *)key->e;
108 	rsa->N = *(mbedtls_mpi *)key->n;
109 	rsa->D = *(mbedtls_mpi *)key->d;
110 	rsa->len = mbedtls_mpi_size(&rsa->N);
111 
112 	if (key->p && crypto_bignum_num_bytes(key->p)) {
113 		rsa->P = *(mbedtls_mpi *)key->p;
114 		rsa->Q = *(mbedtls_mpi *)key->q;
115 		rsa->QP = *(mbedtls_mpi *)key->qp;
116 		rsa->DP = *(mbedtls_mpi *)key->dp;
117 		rsa->DQ = *(mbedtls_mpi *)key->dq;
118 	} else {
119 		mbedtls_mpi_init_mempool(&rsa->P);
120 		mbedtls_mpi_init_mempool(&rsa->Q);
121 		mbedtls_mpi_init_mempool(&rsa->QP);
122 		mbedtls_mpi_init_mempool(&rsa->DP);
123 		mbedtls_mpi_init_mempool(&rsa->DQ);
124 
125 		lmd_res = mbedtls_rsa_deduce_primes(&rsa->N, &rsa->E, &rsa->D,
126 						    &rsa->P, &rsa->Q);
127 		if (lmd_res) {
128 			DMSG("mbedtls_rsa_deduce_primes() returned 0x%x",
129 			     -lmd_res);
130 			goto err;
131 		}
132 
133 		lmd_res = mbedtls_rsa_deduce_crt(&rsa->P, &rsa->Q, &rsa->D,
134 						 &rsa->DP, &rsa->DQ, &rsa->QP);
135 		if (lmd_res) {
136 			DMSG("mbedtls_rsa_deduce_crt() returned 0x%x",
137 			     -lmd_res);
138 			goto err;
139 		}
140 	}
141 
142 	return TEE_SUCCESS;
143 err:
144 	mbedtls_mpi_free(&rsa->P);
145 	mbedtls_mpi_free(&rsa->Q);
146 	mbedtls_mpi_free(&rsa->QP);
147 	mbedtls_mpi_free(&rsa->DP);
148 	mbedtls_mpi_free(&rsa->DQ);
149 
150 	return get_tee_result(lmd_res);
151 }
152 
153 static void mbd_rsa_free(mbedtls_rsa_context *rsa, struct rsa_keypair *key)
154 {
155 	/*
156 	 * The mpi's in @rsa are initialized from @key, but the primes and
157 	 * CRT part are generated if @key doesn't have them. When freeing
158 	 * we should only free the generated mpi's, the ones copied are
159 	 * reset instead.
160 	 */
161 	mbedtls_mpi_init(&rsa->E);
162 	mbedtls_mpi_init(&rsa->N);
163 	mbedtls_mpi_init(&rsa->D);
164 	if (key->p && crypto_bignum_num_bytes(key->p)) {
165 		mbedtls_mpi_init(&rsa->P);
166 		mbedtls_mpi_init(&rsa->Q);
167 		mbedtls_mpi_init(&rsa->QP);
168 		mbedtls_mpi_init(&rsa->DP);
169 		mbedtls_mpi_init(&rsa->DQ);
170 	}
171 	mbedtls_rsa_free(rsa);
172 }
173 
174 TEE_Result crypto_acipher_alloc_rsa_keypair(struct rsa_keypair *s,
175 					    size_t key_size_bits)
176 __weak __alias("sw_crypto_acipher_alloc_rsa_keypair");
177 
178 TEE_Result sw_crypto_acipher_alloc_rsa_keypair(struct rsa_keypair *s,
179 					       size_t key_size_bits)
180 {
181 	memset(s, 0, sizeof(*s));
182 	s->e = crypto_bignum_allocate(key_size_bits);
183 	if (!s->e)
184 		goto err;
185 	s->d = crypto_bignum_allocate(key_size_bits);
186 	if (!s->d)
187 		goto err;
188 	s->n = crypto_bignum_allocate(key_size_bits);
189 	if (!s->n)
190 		goto err;
191 	s->p = crypto_bignum_allocate(key_size_bits);
192 	if (!s->p)
193 		goto err;
194 	s->q = crypto_bignum_allocate(key_size_bits);
195 	if (!s->q)
196 		goto err;
197 	s->qp = crypto_bignum_allocate(key_size_bits);
198 	if (!s->qp)
199 		goto err;
200 	s->dp = crypto_bignum_allocate(key_size_bits);
201 	if (!s->dp)
202 		goto err;
203 	s->dq = crypto_bignum_allocate(key_size_bits);
204 	if (!s->dq)
205 		goto err;
206 
207 	return TEE_SUCCESS;
208 err:
209 	crypto_acipher_free_rsa_keypair(s);
210 	return TEE_ERROR_OUT_OF_MEMORY;
211 }
212 
213 TEE_Result crypto_acipher_alloc_rsa_public_key(struct rsa_public_key *s,
214 					       size_t key_size_bits)
215 __weak __alias("sw_crypto_acipher_alloc_rsa_public_key");
216 
217 TEE_Result sw_crypto_acipher_alloc_rsa_public_key(struct rsa_public_key *s,
218 						  size_t key_size_bits)
219 {
220 	memset(s, 0, sizeof(*s));
221 	s->e = crypto_bignum_allocate(key_size_bits);
222 	if (!s->e)
223 		return TEE_ERROR_OUT_OF_MEMORY;
224 	s->n = crypto_bignum_allocate(key_size_bits);
225 	if (!s->n)
226 		goto err;
227 	return TEE_SUCCESS;
228 err:
229 	crypto_bignum_free(&s->e);
230 	return TEE_ERROR_OUT_OF_MEMORY;
231 }
232 
233 void crypto_acipher_free_rsa_public_key(struct rsa_public_key *s)
234 __weak __alias("sw_crypto_acipher_free_rsa_public_key");
235 
236 void sw_crypto_acipher_free_rsa_public_key(struct rsa_public_key *s)
237 {
238 	if (!s)
239 		return;
240 	crypto_bignum_free(&s->n);
241 	crypto_bignum_free(&s->e);
242 }
243 
244 void crypto_acipher_free_rsa_keypair(struct rsa_keypair *s)
245 __weak __alias("sw_crypto_acipher_free_rsa_keypair");
246 
247 void sw_crypto_acipher_free_rsa_keypair(struct rsa_keypair *s)
248 {
249 	if (!s)
250 		return;
251 	crypto_bignum_free(&s->e);
252 	crypto_bignum_free(&s->d);
253 	crypto_bignum_free(&s->n);
254 	crypto_bignum_free(&s->p);
255 	crypto_bignum_free(&s->q);
256 	crypto_bignum_free(&s->qp);
257 	crypto_bignum_free(&s->dp);
258 	crypto_bignum_free(&s->dq);
259 }
260 
261 TEE_Result crypto_acipher_gen_rsa_key(struct rsa_keypair *key,
262 				      size_t key_size)
263 __weak __alias("sw_crypto_acipher_gen_rsa_key");
264 
265 TEE_Result sw_crypto_acipher_gen_rsa_key(struct rsa_keypair *key,
266 					 size_t key_size)
267 {
268 	TEE_Result res = TEE_SUCCESS;
269 	mbedtls_rsa_context rsa;
270 	mbedtls_ctr_drbg_context rngctx;
271 	int lmd_res = 0;
272 	uint32_t e = 0;
273 
274 	mbedtls_ctr_drbg_init(&rngctx);
275 	if (mbedtls_ctr_drbg_seed(&rngctx, mbd_rand, NULL, NULL, 0))
276 		return TEE_ERROR_BAD_STATE;
277 
278 	memset(&rsa, 0, sizeof(rsa));
279 	mbedtls_rsa_init(&rsa);
280 
281 	/* get the public exponent */
282 	mbedtls_mpi_write_binary((mbedtls_mpi *)key->e,
283 				 (unsigned char *)&e, sizeof(uint32_t));
284 
285 	e = TEE_U32_FROM_BIG_ENDIAN(e);
286 	lmd_res = mbedtls_rsa_gen_key(&rsa, mbedtls_ctr_drbg_random, &rngctx,
287 				      key_size, (int)e);
288 	mbedtls_ctr_drbg_free(&rngctx);
289 	if (lmd_res != 0) {
290 		res = get_tee_result(lmd_res);
291 	} else if ((size_t)mbedtls_mpi_bitlen(&rsa.N) != key_size) {
292 		res = TEE_ERROR_BAD_PARAMETERS;
293 	} else {
294 		/* Copy the key */
295 		crypto_bignum_copy(key->e, (void *)&rsa.E);
296 		crypto_bignum_copy(key->d, (void *)&rsa.D);
297 		crypto_bignum_copy(key->n, (void *)&rsa.N);
298 		crypto_bignum_copy(key->p, (void *)&rsa.P);
299 
300 		crypto_bignum_copy(key->q, (void *)&rsa.Q);
301 		crypto_bignum_copy(key->qp, (void *)&rsa.QP);
302 		crypto_bignum_copy(key->dp, (void *)&rsa.DP);
303 		crypto_bignum_copy(key->dq, (void *)&rsa.DQ);
304 
305 		res = TEE_SUCCESS;
306 	}
307 
308 	mbedtls_rsa_free(&rsa);
309 
310 	return res;
311 }
312 
313 TEE_Result crypto_acipher_rsanopad_encrypt(struct rsa_public_key *key,
314 					   const uint8_t *src,
315 					   size_t src_len, uint8_t *dst,
316 					   size_t *dst_len)
317 __weak __alias("sw_crypto_acipher_rsanopad_encrypt");
318 
319 TEE_Result sw_crypto_acipher_rsanopad_encrypt(struct rsa_public_key *key,
320 					      const uint8_t *src,
321 					      size_t src_len, uint8_t *dst,
322 					      size_t *dst_len)
323 {
324 	TEE_Result res = TEE_SUCCESS;
325 	mbedtls_rsa_context rsa;
326 	int lmd_res = 0;
327 	uint8_t *buf = NULL;
328 	unsigned long blen = 0;
329 	unsigned long offset = 0;
330 
331 	memset(&rsa, 0, sizeof(rsa));
332 	mbedtls_rsa_init(&rsa);
333 
334 	rsa.E = *(mbedtls_mpi *)key->e;
335 	rsa.N = *(mbedtls_mpi *)key->n;
336 
337 	rsa.len = crypto_bignum_num_bytes((void *)&rsa.N);
338 
339 	blen = CFG_CORE_BIGNUM_MAX_BITS / 8;
340 	buf = malloc(blen);
341 	if (!buf) {
342 		res = TEE_ERROR_OUT_OF_MEMORY;
343 		goto out;
344 	}
345 
346 	memset(buf, 0, blen);
347 	memcpy(buf + rsa.len - src_len, src, src_len);
348 
349 	lmd_res = mbedtls_rsa_public(&rsa, buf, buf);
350 	if (lmd_res != 0) {
351 		FMSG("mbedtls_rsa_public() returned 0x%x", -lmd_res);
352 		res = get_tee_result(lmd_res);
353 		goto out;
354 	}
355 
356 	/* Remove the zero-padding (leave one zero if buff is all zeroes) */
357 	offset = 0;
358 	while ((offset < rsa.len - 1) && (buf[offset] == 0))
359 		offset++;
360 
361 	if (*dst_len < rsa.len - offset) {
362 		*dst_len = rsa.len - offset;
363 		res = TEE_ERROR_SHORT_BUFFER;
364 		goto out;
365 	}
366 	*dst_len = rsa.len - offset;
367 	memcpy(dst, buf + offset, *dst_len);
368 out:
369 	free(buf);
370 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
371 	mbedtls_mpi_init(&rsa.E);
372 	mbedtls_mpi_init(&rsa.N);
373 	mbedtls_rsa_free(&rsa);
374 
375 	return res;
376 }
377 
378 TEE_Result crypto_acipher_rsanopad_decrypt(struct rsa_keypair *key,
379 					   const uint8_t *src,
380 					   size_t src_len, uint8_t *dst,
381 					   size_t *dst_len)
382 __weak __alias("sw_crypto_acipher_rsanopad_decrypt");
383 
384 TEE_Result sw_crypto_acipher_rsanopad_decrypt(struct rsa_keypair *key,
385 					      const uint8_t *src,
386 					      size_t src_len, uint8_t *dst,
387 					      size_t *dst_len)
388 {
389 	TEE_Result res = TEE_SUCCESS;
390 	mbedtls_rsa_context rsa = { };
391 	int lmd_res = 0;
392 	uint8_t *buf = NULL;
393 	unsigned long blen = 0;
394 	unsigned long offset = 0;
395 
396 	res = rsa_init_and_complete_from_key_pair(&rsa, key);
397 	if (res)
398 		return res;
399 
400 	blen = CFG_CORE_BIGNUM_MAX_BITS / 8;
401 	buf = malloc(blen);
402 	if (!buf) {
403 		res = TEE_ERROR_OUT_OF_MEMORY;
404 		goto out;
405 	}
406 
407 	memset(buf, 0, blen);
408 	memcpy(buf + rsa.len - src_len, src, src_len);
409 
410 	lmd_res = mbedtls_rsa_private(&rsa, mbd_rand, NULL, buf, buf);
411 	if (lmd_res != 0) {
412 		FMSG("mbedtls_rsa_private() returned 0x%x", -lmd_res);
413 		res = get_tee_result(lmd_res);
414 		goto out;
415 	}
416 
417 	/* Remove the zero-padding (leave one zero if buff is all zeroes) */
418 	offset = 0;
419 	while ((offset < rsa.len - 1) && (buf[offset] == 0))
420 		offset++;
421 
422 	if (*dst_len < rsa.len - offset) {
423 		*dst_len = rsa.len - offset;
424 		res = TEE_ERROR_SHORT_BUFFER;
425 		goto out;
426 	}
427 	*dst_len = rsa.len - offset;
428 	memcpy(dst, (char *)buf + offset, *dst_len);
429 out:
430 	if (buf)
431 		free(buf);
432 	mbd_rsa_free(&rsa, key);
433 	return res;
434 }
435 
436 TEE_Result crypto_acipher_rsaes_decrypt(uint32_t algo,
437 					struct rsa_keypair *key,
438 					const uint8_t *label __unused,
439 					size_t label_len __unused,
440 					const uint8_t *src, size_t src_len,
441 					uint8_t *dst, size_t *dst_len)
442 __weak __alias("sw_crypto_acipher_rsaes_decrypt");
443 
444 TEE_Result sw_crypto_acipher_rsaes_decrypt(uint32_t algo,
445 					   struct rsa_keypair *key,
446 					   const uint8_t *label __unused,
447 					   size_t label_len __unused,
448 					   const uint8_t *src, size_t src_len,
449 					   uint8_t *dst, size_t *dst_len)
450 {
451 	TEE_Result res = TEE_SUCCESS;
452 	int lmd_res = 0;
453 	int lmd_padding = 0;
454 	size_t blen = 0;
455 	size_t mod_size = 0;
456 	void *buf = NULL;
457 	mbedtls_rsa_context rsa = { };
458 	const mbedtls_pk_info_t *pk_info = NULL;
459 	uint32_t md_algo = MBEDTLS_MD_NONE;
460 
461 	res = rsa_init_and_complete_from_key_pair(&rsa, key);
462 	if (res)
463 		return res;
464 
465 	/*
466 	 * Use a temporary buffer since we don't know exactly how large
467 	 * the required size of the out buffer without doing a partial
468 	 * decrypt. We know the upper bound though.
469 	 */
470 	if (algo == TEE_ALG_RSAES_PKCS1_V1_5) {
471 		mod_size = crypto_bignum_num_bytes(key->n);
472 		blen = mod_size - 11;
473 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
474 	} else {
475 		/* Decoded message is always shorter than encrypted message */
476 		blen = src_len;
477 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
478 	}
479 
480 	buf = malloc(blen);
481 	if (!buf) {
482 		res = TEE_ERROR_OUT_OF_MEMORY;
483 		goto out;
484 	}
485 
486 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
487 	if (!pk_info) {
488 		res = TEE_ERROR_NOT_SUPPORTED;
489 		goto out;
490 	}
491 
492 	/*
493 	 * TEE_ALG_RSAES_PKCS1_V1_5 is invalid in hash. But its hash algo will
494 	 * not be used in rsa, so skip it here.
495 	 */
496 	if (algo != TEE_ALG_RSAES_PKCS1_V1_5) {
497 		md_algo = tee_algo_to_mbedtls_hash_algo(algo);
498 		if (md_algo == MBEDTLS_MD_NONE) {
499 			res = TEE_ERROR_NOT_SUPPORTED;
500 			goto out;
501 		}
502 	}
503 
504 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
505 
506 	lmd_res = pk_info->decrypt_func(&rsa, src, src_len, buf, &blen,
507 					blen, mbd_rand, NULL);
508 	if (lmd_res != 0) {
509 		FMSG("decrypt_func() returned 0x%x", -lmd_res);
510 		res = get_tee_result(lmd_res);
511 		goto out;
512 	}
513 
514 	if (*dst_len < blen) {
515 		*dst_len = blen;
516 		res = TEE_ERROR_SHORT_BUFFER;
517 		goto out;
518 	}
519 
520 	res = TEE_SUCCESS;
521 	*dst_len = blen;
522 	memcpy(dst, buf, blen);
523 out:
524 	if (buf)
525 		free(buf);
526 	mbd_rsa_free(&rsa, key);
527 	return res;
528 }
529 
530 TEE_Result crypto_acipher_rsaes_encrypt(uint32_t algo,
531 					struct rsa_public_key *key,
532 					const uint8_t *label __unused,
533 					size_t label_len __unused,
534 					const uint8_t *src, size_t src_len,
535 					uint8_t *dst, size_t *dst_len)
536 __weak __alias("sw_crypto_acipher_rsaes_encrypt");
537 
538 TEE_Result sw_crypto_acipher_rsaes_encrypt(uint32_t algo,
539 					   struct rsa_public_key *key,
540 					   const uint8_t *label __unused,
541 					   size_t label_len __unused,
542 					   const uint8_t *src, size_t src_len,
543 					   uint8_t *dst, size_t *dst_len)
544 {
545 	TEE_Result res = TEE_SUCCESS;
546 	int lmd_res = 0;
547 	int lmd_padding = 0;
548 	size_t mod_size = 0;
549 	mbedtls_rsa_context rsa;
550 	const mbedtls_pk_info_t *pk_info = NULL;
551 	uint32_t md_algo = MBEDTLS_MD_NONE;
552 
553 	memset(&rsa, 0, sizeof(rsa));
554 	mbedtls_rsa_init(&rsa);
555 
556 	rsa.E = *(mbedtls_mpi *)key->e;
557 	rsa.N = *(mbedtls_mpi *)key->n;
558 
559 	mod_size = crypto_bignum_num_bytes(key->n);
560 	if (*dst_len < mod_size) {
561 		*dst_len = mod_size;
562 		res = TEE_ERROR_SHORT_BUFFER;
563 		goto out;
564 	}
565 	*dst_len = mod_size;
566 	rsa.len = mod_size;
567 
568 	if (algo == TEE_ALG_RSAES_PKCS1_V1_5)
569 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
570 	else
571 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
572 
573 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
574 	if (!pk_info) {
575 		res = TEE_ERROR_NOT_SUPPORTED;
576 		goto out;
577 	}
578 
579 	/*
580 	 * TEE_ALG_RSAES_PKCS1_V1_5 is invalid in hash. But its hash algo will
581 	 * not be used in rsa, so skip it here.
582 	 */
583 	if (algo != TEE_ALG_RSAES_PKCS1_V1_5) {
584 		md_algo = tee_algo_to_mbedtls_hash_algo(algo);
585 		if (md_algo == MBEDTLS_MD_NONE) {
586 			res = TEE_ERROR_NOT_SUPPORTED;
587 			goto out;
588 		}
589 	}
590 
591 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
592 
593 	lmd_res = pk_info->encrypt_func(&rsa, src, src_len, dst, dst_len,
594 					*dst_len, mbd_rand, NULL);
595 	if (lmd_res != 0) {
596 		FMSG("encrypt_func() returned 0x%x", -lmd_res);
597 		res = get_tee_result(lmd_res);
598 		goto out;
599 	}
600 	res = TEE_SUCCESS;
601 out:
602 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
603 	mbedtls_mpi_init(&rsa.E);
604 	mbedtls_mpi_init(&rsa.N);
605 	mbedtls_rsa_free(&rsa);
606 	return res;
607 }
608 
609 TEE_Result crypto_acipher_rsassa_sign(uint32_t algo, struct rsa_keypair *key,
610 				      int salt_len __unused,
611 				      const uint8_t *msg, size_t msg_len,
612 				      uint8_t *sig, size_t *sig_len)
613 __weak __alias("sw_crypto_acipher_rsassa_sign");
614 
615 TEE_Result sw_crypto_acipher_rsassa_sign(uint32_t algo, struct rsa_keypair *key,
616 					 int salt_len __unused,
617 					 const uint8_t *msg, size_t msg_len,
618 					 uint8_t *sig, size_t *sig_len)
619 {
620 	TEE_Result res = TEE_SUCCESS;
621 	int lmd_res = 0;
622 	int lmd_padding = 0;
623 	size_t mod_size = 0;
624 	size_t hash_size = 0;
625 	mbedtls_rsa_context rsa = { };
626 	const mbedtls_pk_info_t *pk_info = NULL;
627 	uint32_t md_algo = 0;
628 
629 	res = rsa_init_and_complete_from_key_pair(&rsa, key);
630 	if (res)
631 		return res;
632 
633 	switch (algo) {
634 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
635 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
636 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
637 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
638 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
639 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
640 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
641 		break;
642 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_MD5:
643 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
644 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
645 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
646 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
647 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
648 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
649 		break;
650 	default:
651 		res = TEE_ERROR_BAD_PARAMETERS;
652 		goto err;
653 	}
654 
655 	res = tee_alg_get_digest_size(TEE_DIGEST_HASH_TO_ALGO(algo),
656 				      &hash_size);
657 	if (res != TEE_SUCCESS)
658 		goto err;
659 
660 	if (msg_len != hash_size) {
661 		res = TEE_ERROR_BAD_PARAMETERS;
662 		goto err;
663 	}
664 
665 	mod_size = crypto_bignum_num_bytes(key->n);
666 	if (*sig_len < mod_size) {
667 		*sig_len = mod_size;
668 		res = TEE_ERROR_SHORT_BUFFER;
669 		goto err;
670 	}
671 	rsa.len = mod_size;
672 
673 	md_algo = tee_algo_to_mbedtls_hash_algo(algo);
674 	if (md_algo == MBEDTLS_MD_NONE) {
675 		res = TEE_ERROR_NOT_SUPPORTED;
676 		goto err;
677 	}
678 
679 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
680 	if (!pk_info) {
681 		res = TEE_ERROR_NOT_SUPPORTED;
682 		goto err;
683 	}
684 
685 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
686 
687 	lmd_res = pk_info->sign_func(&rsa, md_algo, msg, msg_len, sig,
688 				     *sig_len, sig_len, mbd_rand, NULL);
689 	if (lmd_res != 0) {
690 		FMSG("sign_func failed, returned 0x%x", -lmd_res);
691 		res = get_tee_result(lmd_res);
692 		goto err;
693 	}
694 	res = TEE_SUCCESS;
695 err:
696 	mbd_rsa_free(&rsa, key);
697 	return res;
698 }
699 
700 TEE_Result crypto_acipher_rsassa_verify(uint32_t algo,
701 					struct rsa_public_key *key,
702 					int salt_len __unused,
703 					const uint8_t *msg,
704 					size_t msg_len, const uint8_t *sig,
705 					size_t sig_len)
706 __weak __alias("sw_crypto_acipher_rsassa_verify");
707 
708 TEE_Result sw_crypto_acipher_rsassa_verify(uint32_t algo,
709 					   struct rsa_public_key *key,
710 					   int salt_len __unused,
711 					   const uint8_t *msg,
712 					   size_t msg_len, const uint8_t *sig,
713 					   size_t sig_len)
714 {
715 	TEE_Result res = TEE_SUCCESS;
716 	int lmd_res = 0;
717 	int lmd_padding = 0;
718 	size_t hash_size = 0;
719 	size_t bigint_size = 0;
720 	mbedtls_rsa_context rsa;
721 	const mbedtls_pk_info_t *pk_info = NULL;
722 	uint32_t md_algo = 0;
723 	struct ftmn ftmn = { };
724 	unsigned long arg_hash = 0;
725 
726 	/*
727 	 * The caller expects to call crypto_acipher_rsassa_verify(),
728 	 * update the hash as needed.
729 	 */
730 	FTMN_CALLEE_SWAP_HASH(FTMN_FUNC_HASH("crypto_acipher_rsassa_verify"));
731 
732 	memset(&rsa, 0, sizeof(rsa));
733 	mbedtls_rsa_init(&rsa);
734 
735 	rsa.E = *(mbedtls_mpi *)key->e;
736 	rsa.N = *(mbedtls_mpi *)key->n;
737 
738 	res = tee_alg_get_digest_size(TEE_DIGEST_HASH_TO_ALGO(algo),
739 				      &hash_size);
740 	if (res != TEE_SUCCESS)
741 		goto err;
742 
743 	if (msg_len != hash_size) {
744 		res = TEE_ERROR_BAD_PARAMETERS;
745 		goto err;
746 	}
747 
748 	bigint_size = crypto_bignum_num_bytes(key->n);
749 	if (sig_len < bigint_size) {
750 		res = TEE_ERROR_SIGNATURE_INVALID;
751 		goto err;
752 	}
753 
754 	rsa.len = bigint_size;
755 
756 	switch (algo) {
757 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
758 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
759 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
760 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
761 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
762 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
763 		arg_hash = FTMN_FUNC_HASH("mbedtls_rsa_rsassa_pkcs1_v15_verify");
764 		lmd_padding = MBEDTLS_RSA_PKCS_V15;
765 		break;
766 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_MD5:
767 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
768 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
769 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
770 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
771 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
772 		arg_hash = FTMN_FUNC_HASH("mbedtls_rsa_rsassa_pss_verify_ext");
773 		lmd_padding = MBEDTLS_RSA_PKCS_V21;
774 		break;
775 	default:
776 		res = TEE_ERROR_BAD_PARAMETERS;
777 		goto err;
778 	}
779 
780 	md_algo = tee_algo_to_mbedtls_hash_algo(algo);
781 	if (md_algo == MBEDTLS_MD_NONE) {
782 		res = TEE_ERROR_NOT_SUPPORTED;
783 		goto err;
784 	}
785 
786 	pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
787 	if (!pk_info) {
788 		res = TEE_ERROR_NOT_SUPPORTED;
789 		goto err;
790 	}
791 
792 	mbedtls_rsa_set_padding(&rsa, lmd_padding, md_algo);
793 
794 	FTMN_PUSH_LINKED_CALL(&ftmn, arg_hash);
795 	lmd_res = pk_info->verify_func(&rsa, md_algo, msg, msg_len,
796 				       sig, sig_len);
797 	if (!lmd_res)
798 		FTMN_SET_CHECK_RES_FROM_CALL(&ftmn, FTMN_INCR0, lmd_res);
799 	FTMN_POP_LINKED_CALL(&ftmn);
800 	if (lmd_res != 0) {
801 		FMSG("verify_func failed, returned 0x%x", -lmd_res);
802 		res = TEE_ERROR_SIGNATURE_INVALID;
803 		goto err;
804 	}
805 	res = TEE_SUCCESS;
806 	goto out;
807 
808 err:
809 	FTMN_SET_CHECK_RES_NOT_ZERO(&ftmn, FTMN_INCR0, res);
810 out:
811 	FTMN_CALLEE_DONE_CHECK(&ftmn, FTMN_INCR0, FTMN_STEP_COUNT(1), res);
812 	/* Reset mpi to skip freeing here, those mpis will be freed with key */
813 	mbedtls_mpi_init(&rsa.E);
814 	mbedtls_mpi_init(&rsa.N);
815 	mbedtls_rsa_free(&rsa);
816 	return res;
817 }
818