xref: /optee_os/core/drivers/crypto/hisilicon/hpre_rsa.c (revision 4f75eab013a2bfcefc24d8d877300795e5e87568)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (c) 2022-2024, HiSilicon Technologies Co., Ltd.
4  * Kunpeng hardware accelerator hpre rsa algorithm implementation.
5  */
6 
7 #include <crypto/crypto_impl.h>
8 #include <drvcrypt.h>
9 #include <drvcrypt_acipher.h>
10 #include <drvcrypt_math.h>
11 #include <initcall.h>
12 #include <malloc.h>
13 #include <mbedtls/rsa.h>
14 #include <rng_support.h>
15 #include <stdlib_ext.h>
16 #include <string.h>
17 #include <string_ext.h>
18 #include <tee/tee_cryp_utl.h>
19 #include <trace.h>
20 
21 #include "hpre_main.h"
22 #include "hpre_rsa.h"
23 
hpre_rsa_fill_addr_params(struct hpre_rsa_msg * msg,struct hpre_sqe * sqe)24 static enum hisi_drv_status hpre_rsa_fill_addr_params(struct hpre_rsa_msg *msg,
25 						      struct hpre_sqe *sqe)
26 {
27 	switch (msg->alg_type) {
28 	case HPRE_ALG_NC_NCRT:
29 	case HPRE_ALG_NC_CRT:
30 		if (msg->is_private) {
31 			/* DECRYPT */
32 			sqe->key = msg->prikey_dma;
33 			sqe->in = msg->in_dma;
34 			sqe->out = msg->out_dma;
35 		} else {
36 			/* ENCRYPT */
37 			sqe->key = msg->pubkey_dma;
38 			sqe->in = msg->in_dma;
39 			sqe->out = msg->out_dma;
40 		}
41 		return HISI_QM_DRVCRYPT_NO_ERR;
42 	default:
43 		EMSG("Invalid alg_type[%"PRIu32"]", msg->alg_type);
44 		return HISI_QM_DRVCRYPT_IN_EPARA;
45 	}
46 }
47 
hpre_rsa_fill_sqe(void * bd,void * info)48 static enum hisi_drv_status hpre_rsa_fill_sqe(void *bd, void *info)
49 {
50 	struct hpre_rsa_msg *msg = info;
51 	struct hpre_sqe *sqe = bd;
52 
53 	sqe->w0 = msg->alg_type | SHIFT_U32(0x1, HPRE_DONE_SHIFT);
54 	sqe->task_len1 = TASK_LENGTH(msg->key_bytes);
55 
56 	return hpre_rsa_fill_addr_params(msg, sqe);
57 }
58 
hpre_rsa_parse_sqe(void * bd,void * info __unused)59 static enum hisi_drv_status hpre_rsa_parse_sqe(void *bd, void *info __unused)
60 {
61 	struct hpre_sqe *sqe = bd;
62 	uint16_t err = 0;
63 	uint16_t done = 0;
64 
65 	err = HPRE_TASK_ETYPE(sqe->w0);
66 	done = HPRE_TASK_DONE(sqe->w0);
67 	if (done != HPRE_HW_TASK_DONE || err) {
68 		EMSG("HPRE do rsa fail! done=0x%"PRIX16", etype=0x%"PRIX16,
69 		     done, err);
70 		return HISI_QM_DRVCRYPT_IN_EPARA;
71 	}
72 
73 	return HISI_QM_DRVCRYPT_NO_ERR;
74 }
75 
hpre_rsa_do_task(void * msg)76 static TEE_Result hpre_rsa_do_task(void *msg)
77 {
78 	struct hisi_qp *rsa_qp = NULL;
79 	TEE_Result res = TEE_SUCCESS;
80 	enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR;
81 
82 	rsa_qp = hpre_create_qp(HISI_QM_CHANNEL_TYPE0);
83 	if (!rsa_qp) {
84 		EMSG("Fail to create rsa qp");
85 		return TEE_ERROR_BUSY;
86 	}
87 
88 	rsa_qp->fill_sqe = hpre_rsa_fill_sqe;
89 	rsa_qp->parse_sqe = hpre_rsa_parse_sqe;
90 	ret = hisi_qp_send(rsa_qp, msg);
91 	if (ret) {
92 		EMSG("Fail to send task, ret=%x"PRIx32, ret);
93 		res = TEE_ERROR_BAD_STATE;
94 		goto done_proc;
95 	}
96 
97 	ret = hisi_qp_recv_sync(rsa_qp, msg);
98 	if (ret) {
99 		EMSG("Recv task error, ret=%x"PRIx32, ret);
100 		res = TEE_ERROR_BAD_STATE;
101 		goto done_proc;
102 	}
103 
104 done_proc:
105 	hisi_qm_release_qp(rsa_qp);
106 
107 	return res;
108 }
109 
mgf_process(size_t digest_size,uint8_t * seed,size_t seed_len,uint8_t * mask,size_t mask_len,struct drvcrypt_rsa_ed * rsa_data)110 static TEE_Result mgf_process(size_t digest_size, uint8_t *seed,
111 			      size_t seed_len, uint8_t *mask, size_t mask_len,
112 			      struct drvcrypt_rsa_ed *rsa_data)
113 {
114 	struct drvcrypt_rsa_mgf mgf = { };
115 
116 	if (!rsa_data->mgf) {
117 		EMSG("mgf function is NULL");
118 		return TEE_ERROR_BAD_PARAMETERS;
119 	}
120 
121 	mgf.hash_algo = rsa_data->hash_algo;
122 	mgf.digest_size = digest_size;
123 	mgf.seed.data = seed;
124 	mgf.seed.length = seed_len;
125 	mgf.mask.data = mask;
126 	mgf.mask.length = mask_len;
127 
128 	return rsa_data->mgf(&mgf);
129 }
130 
xor_process(uint8_t * a,uint8_t * b,uint8_t * out,size_t len)131 static TEE_Result xor_process(uint8_t *a, uint8_t *b, uint8_t *out, size_t len)
132 {
133 	struct drvcrypt_mod_op xor_mod = { };
134 
135 	xor_mod.n.length = len;
136 	xor_mod.a.data = a;
137 	xor_mod.a.length = len;
138 	xor_mod.b.data = b;
139 	xor_mod.b.length = len;
140 	xor_mod.result.data = out;
141 	xor_mod.result.length = len;
142 
143 	return drvcrypt_xor_mod_n(&xor_mod);
144 }
145 
hpre_rsa_get_hw_kbytes(size_t key_bits)146 static size_t hpre_rsa_get_hw_kbytes(size_t key_bits)
147 {
148 	size_t size = 0;
149 
150 	if (key_bits <= 1024)
151 		size = BITS_TO_BYTES(1024);
152 	else if (key_bits <= 2048)
153 		size = BITS_TO_BYTES(2048);
154 	else if (key_bits <= 3072)
155 		size = BITS_TO_BYTES(3072);
156 	else if (key_bits <= 4096)
157 		size = BITS_TO_BYTES(4096);
158 	else
159 		EMSG("Invalid key_bits[%zu]", key_bits);
160 
161 	return size;
162 }
163 
hpre_rsa_params_free(struct hpre_rsa_msg * msg)164 static void hpre_rsa_params_free(struct hpre_rsa_msg *msg)
165 {
166 	switch (msg->alg_type) {
167 	case HPRE_ALG_NC_NCRT:
168 		if (msg->is_private)
169 			free_wipe(msg->prikey);
170 		else
171 			free(msg->pubkey);
172 		break;
173 	case HPRE_ALG_NC_CRT:
174 		if (msg->is_private)
175 			free_wipe(msg->prikey);
176 		break;
177 	default:
178 		EMSG("Invalid alg_type[%"PRIu32"]", msg->alg_type);
179 		break;
180 	}
181 }
182 
hpre_rsa_encrypt_alloc(struct hpre_rsa_msg * msg)183 static enum hisi_drv_status hpre_rsa_encrypt_alloc(struct hpre_rsa_msg *msg)
184 {
185 	uint32_t size = HPRE_RSA_NCRT_TOTAL_BUF_SIZE(msg->key_bytes);
186 	uint8_t *data = NULL;
187 
188 	data = calloc(1, size);
189 	if (!data) {
190 		EMSG("Fail to alloc rsa ncrt buf");
191 		return HISI_QM_DRVCRYPT_ENOMEM;
192 	}
193 
194 	msg->pubkey = data;
195 	msg->pubkey_dma = virt_to_phys(msg->pubkey);
196 
197 	msg->in = data + (msg->key_bytes * 2);
198 	msg->in_dma = msg->pubkey_dma + (msg->key_bytes * 2);
199 
200 	msg->out = msg->in + msg->key_bytes;
201 	msg->out_dma = msg->in_dma + msg->key_bytes;
202 
203 	return HISI_QM_DRVCRYPT_NO_ERR;
204 }
205 
206 static enum hisi_drv_status
hpre_rsa_encrypt_bn2bin(struct hpre_rsa_msg * msg,struct drvcrypt_rsa_ed * rsa_data)207 hpre_rsa_encrypt_bn2bin(struct hpre_rsa_msg *msg,
208 			struct drvcrypt_rsa_ed *rsa_data)
209 {
210 	struct rsa_public_key *key = rsa_data->key.key;
211 	uint32_t e_len = 0;
212 	uint32_t n_len = 0;
213 	enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR;
214 	uint8_t *n = NULL;
215 
216 	n = msg->pubkey + msg->key_bytes;
217 
218 	crypto_bignum_bn2bin(key->e, msg->pubkey);
219 	crypto_bignum_bn2bin(key->n, n);
220 	e_len = crypto_bignum_num_bytes(key->e);
221 	n_len = crypto_bignum_num_bytes(key->n);
222 
223 	ret = hpre_bin_from_crypto_bin(msg->pubkey, msg->pubkey,
224 				       msg->key_bytes, e_len);
225 	if (ret) {
226 		EMSG("Fail to transfer rsa ncrt e from crypto_bin to hpre_bin");
227 		return ret;
228 	}
229 
230 	ret = hpre_bin_from_crypto_bin(n, n, msg->key_bytes, n_len);
231 	if (ret) {
232 		EMSG("Fail to transfer rsa ncrt n from crypto_bin to hpre_bin");
233 		return ret;
234 	}
235 
236 	ret = hpre_bin_from_crypto_bin(msg->in, rsa_data->message.data,
237 				       msg->key_bytes,
238 				       rsa_data->message.length);
239 	if (ret)
240 		EMSG("Fail to transfer rsa plaintext from crypto_bin to hpre_bin");
241 
242 	return ret;
243 }
244 
hpre_rsa_encrypt_init(struct hpre_rsa_msg * msg,struct drvcrypt_rsa_ed * rsa_data)245 static TEE_Result hpre_rsa_encrypt_init(struct hpre_rsa_msg *msg,
246 					struct drvcrypt_rsa_ed *rsa_data)
247 {
248 	size_t n_bytes = rsa_data->key.n_size;
249 	enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR;
250 
251 	msg->alg_type = HPRE_ALG_NC_NCRT;
252 	msg->is_private = rsa_data->key.isprivate;
253 	msg->key_bytes = hpre_rsa_get_hw_kbytes(BYTES_TO_BITS(n_bytes));
254 	if (!msg->key_bytes)
255 		return TEE_ERROR_BAD_PARAMETERS;
256 
257 	ret = hpre_rsa_encrypt_alloc(msg);
258 	if (ret)
259 		return TEE_ERROR_OUT_OF_MEMORY;
260 
261 	ret = hpre_rsa_encrypt_bn2bin(msg, rsa_data);
262 	if (ret) {
263 		hpre_rsa_params_free(msg);
264 		return TEE_ERROR_BAD_STATE;
265 	}
266 
267 	return TEE_SUCCESS;
268 }
269 
rsa_nopad_encrypt(struct drvcrypt_rsa_ed * rsa_data)270 static TEE_Result rsa_nopad_encrypt(struct drvcrypt_rsa_ed *rsa_data)
271 {
272 	size_t n_bytes = rsa_data->key.n_size;
273 	struct hpre_rsa_msg msg = { };
274 	TEE_Result ret = TEE_SUCCESS;
275 
276 	if (rsa_data->message.length > n_bytes) {
277 		EMSG("Invalid msg length[%zu]", rsa_data->message.length);
278 		return TEE_ERROR_BAD_PARAMETERS;
279 	}
280 
281 	ret = hpre_rsa_encrypt_init(&msg, rsa_data);
282 	if (ret) {
283 		EMSG("Fail to init rsa msg");
284 		return ret;
285 	}
286 
287 	ret = hpre_rsa_do_task(&msg);
288 	if (ret)
289 		goto encrypt_deinit;
290 
291 	/* Ciphertext can have valid zero data in NOPAD MODE */
292 	memcpy(rsa_data->cipher.data, msg.out + msg.key_bytes - n_bytes,
293 	       n_bytes);
294 	rsa_data->cipher.length = n_bytes;
295 
296 encrypt_deinit:
297 	hpre_rsa_params_free(&msg);
298 
299 	return ret;
300 }
301 
pkcs_v1_5_fill_ps(uint8_t * ps,size_t ps_len)302 static TEE_Result pkcs_v1_5_fill_ps(uint8_t *ps, size_t ps_len)
303 {
304 	size_t i = 0;
305 
306 	if (hw_get_random_bytes(ps, ps_len)) {
307 		EMSG("Fail to get ps data");
308 		return TEE_ERROR_NO_DATA;
309 	}
310 
311 	for (i = 0; i < ps_len; i++) {
312 		if (ps[i] == 0)
313 			ps[i] = PKCS_V1_5_PS_DATA;
314 	}
315 
316 	return TEE_SUCCESS;
317 }
318 
rsaes_pkcs_v1_5_encode(struct drvcrypt_rsa_ed * rsa_data,uint8_t * out)319 static TEE_Result rsaes_pkcs_v1_5_encode(struct drvcrypt_rsa_ed *rsa_data,
320 					 uint8_t *out)
321 {
322 	size_t msg_len = rsa_data->message.length;
323 	size_t out_len = rsa_data->cipher.length;
324 	size_t n_bytes = rsa_data->key.n_size;
325 	uint8_t *ps = out + PKCS_V1_5_PS_POS;
326 	TEE_Result ret = TEE_SUCCESS;
327 	size_t ps_len = 0;
328 
329 	/* PKCS_V1.5 format 0x00 || 0x02 || PS non-zero || 0x00 || M */
330 	if ((msg_len + PKCS_V1_5_MSG_MIN_LEN) > n_bytes || out_len < n_bytes) {
331 		EMSG("Invalid pkcs_v1.5 encode parameter");
332 		return TEE_ERROR_BAD_PARAMETERS;
333 	}
334 
335 	ps_len = n_bytes - PKCS_V1_5_FIXED_LEN - msg_len;
336 	ret = pkcs_v1_5_fill_ps(ps, ps_len);
337 	if (ret)
338 		return ret;
339 
340 	out[0] = 0;
341 	out[1] = ENCRYPT_PAD;
342 	out[PKCS_V1_5_FIXED_LEN + ps_len - 1] = 0;
343 	memcpy(out + PKCS_V1_5_FIXED_LEN + ps_len, rsa_data->message.data,
344 	       msg_len);
345 
346 	return TEE_SUCCESS;
347 }
348 
rsa_pkcs_encrypt(struct drvcrypt_rsa_ed * rsa_data)349 static TEE_Result rsa_pkcs_encrypt(struct drvcrypt_rsa_ed *rsa_data)
350 {
351 	uint32_t n_bytes = rsa_data->key.n_size;
352 	struct drvcrypt_rsa_ed rsa_enc_info = *rsa_data;
353 	TEE_Result ret = TEE_SUCCESS;
354 
355 	/* Alloc pkcs_v1.5 encode message data buf */
356 	rsa_enc_info.message.data = malloc(n_bytes);
357 	if (!rsa_enc_info.message.data) {
358 		EMSG("Fail to alloc message data buf");
359 		return TEE_ERROR_OUT_OF_MEMORY;
360 	}
361 
362 	rsa_enc_info.message.length = n_bytes;
363 	ret = rsaes_pkcs_v1_5_encode(rsa_data, rsa_enc_info.message.data);
364 	if (ret) {
365 		EMSG("Fail to get pkcs_v1.5 encode message data");
366 		goto free_data;
367 	}
368 
369 	ret = rsa_nopad_encrypt(&rsa_enc_info);
370 	if (ret)
371 		goto free_data;
372 
373 	memcpy(rsa_data->cipher.data, rsa_enc_info.cipher.data,
374 	       rsa_enc_info.cipher.length);
375 	rsa_data->cipher.length = rsa_enc_info.cipher.length;
376 
377 free_data:
378 	free(rsa_enc_info.message.data);
379 
380 	return ret;
381 }
382 
rsa_oaep_fill_db(struct drvcrypt_rsa_ed * rsa_data,uint8_t * db)383 static TEE_Result rsa_oaep_fill_db(struct drvcrypt_rsa_ed *rsa_data,
384 				   uint8_t *db)
385 {
386 	size_t lhash_len = rsa_data->digest_size;
387 	size_t n_bytes = rsa_data->key.n_size;
388 	size_t db_len = n_bytes - lhash_len - 1;
389 	size_t ps_len = 0;
390 	TEE_Result ret = TEE_SUCCESS;
391 
392 	/* oaep db format lhash || ps zero || 01 || M */
393 	ret = tee_hash_createdigest(rsa_data->hash_algo, rsa_data->label.data,
394 				    rsa_data->label.length, db, lhash_len);
395 	if (ret) {
396 		EMSG("Fail to get label hash");
397 		return ret;
398 	}
399 
400 	ps_len = db_len - lhash_len - rsa_data->message.length - 1;
401 	db[lhash_len + ps_len] = 1;
402 	memcpy(db + lhash_len + ps_len + 1, rsa_data->message.data,
403 	       rsa_data->message.length);
404 
405 	return TEE_SUCCESS;
406 }
407 
rsa_oaep_fill_maskdb(struct drvcrypt_rsa_ed * rsa_data,uint8_t * seed,uint8_t * db,uint8_t * mask_db)408 static TEE_Result rsa_oaep_fill_maskdb(struct drvcrypt_rsa_ed *rsa_data,
409 				       uint8_t *seed, uint8_t *db,
410 				       uint8_t *mask_db)
411 {
412 	size_t lhash_len = rsa_data->digest_size;
413 	size_t n_bytes = rsa_data->key.n_size;
414 	size_t db_len = n_bytes - lhash_len - 1;
415 	uint8_t seed_mgf[OAEP_MAX_DB_LEN] = { };
416 	TEE_Result ret = TEE_SUCCESS;
417 
418 	ret = mgf_process(lhash_len, seed, lhash_len, seed_mgf, db_len,
419 			  rsa_data);
420 	if (ret) {
421 		EMSG("Fail to get seed_mgf");
422 		return ret;
423 	}
424 
425 	return xor_process(db, seed_mgf, mask_db, db_len);
426 }
427 
rsa_oaep_fill_maskseed(struct drvcrypt_rsa_ed * rsa_data,uint8_t * seed,uint8_t * em)428 static TEE_Result rsa_oaep_fill_maskseed(struct drvcrypt_rsa_ed *rsa_data,
429 					 uint8_t *seed, uint8_t *em)
430 {
431 	uint8_t mask_db_mgf[OAEP_MAX_HASH_LEN] = { 0 };
432 	size_t lhash_len = rsa_data->digest_size;
433 	size_t n_bytes = rsa_data->key.n_size;
434 	size_t db_len = n_bytes - lhash_len - 1;
435 	uint8_t *mask_db = em + lhash_len + 1;
436 	uint8_t *mask_seed = em + 1;
437 	TEE_Result ret = TEE_SUCCESS;
438 
439 	ret = mgf_process(lhash_len, mask_db, db_len, mask_db_mgf, lhash_len,
440 			  rsa_data);
441 	if (ret) {
442 		EMSG("Fail to get mask_db_mgf");
443 		return ret;
444 	}
445 
446 	return xor_process(seed, mask_db_mgf, mask_seed, lhash_len);
447 }
448 
rsa_oaep_encode(struct drvcrypt_rsa_ed * rsa_data,uint8_t * em)449 static TEE_Result rsa_oaep_encode(struct drvcrypt_rsa_ed *rsa_data,
450 				  uint8_t *em)
451 {
452 	size_t lhash_len = rsa_data->digest_size;
453 	uint8_t db[OAEP_MAX_DB_LEN] = { };
454 	uint8_t seed[OAEP_MAX_HASH_LEN] = { };
455 	TEE_Result ret = TEE_SUCCESS;
456 
457 	/* oaep format 00 || maskedseed || maskeddb */
458 	em[0] = 0;
459 
460 	ret = rsa_oaep_fill_db(rsa_data, db);
461 	if (ret)
462 		return ret;
463 
464 	ret = hw_get_random_bytes(seed, lhash_len);
465 	if (ret)
466 		return ret;
467 
468 	ret = rsa_oaep_fill_maskdb(rsa_data, seed, db, em + lhash_len + 1);
469 	if (ret)
470 		return ret;
471 
472 	return rsa_oaep_fill_maskseed(rsa_data, seed, em);
473 }
474 
rsa_oaep_encrypt(struct drvcrypt_rsa_ed * rsa_data)475 static TEE_Result rsa_oaep_encrypt(struct drvcrypt_rsa_ed *rsa_data)
476 {
477 	size_t n_bytes = rsa_data->key.n_size;
478 	struct drvcrypt_rsa_ed rsa_enc_info = *rsa_data;
479 	TEE_Result ret = TEE_SUCCESS;
480 
481 	/* Alloc oaep encode message data buf */
482 	rsa_enc_info.message.data = malloc(n_bytes);
483 	if (!rsa_enc_info.message.data) {
484 		EMSG("Fail to alloc message data buf");
485 		return TEE_ERROR_OUT_OF_MEMORY;
486 	}
487 
488 	rsa_enc_info.message.length = n_bytes;
489 	ret = rsa_oaep_encode(rsa_data, rsa_enc_info.message.data);
490 	if (ret) {
491 		EMSG("Fail to get oaep encode message data");
492 		goto free_data;
493 	}
494 
495 	ret = rsa_nopad_encrypt(&rsa_enc_info);
496 	if (ret)
497 		goto free_data;
498 
499 	memcpy(rsa_data->cipher.data, rsa_enc_info.cipher.data,
500 	       rsa_enc_info.cipher.length);
501 	rsa_data->cipher.length = rsa_enc_info.cipher.length;
502 
503 free_data:
504 	free(rsa_enc_info.message.data);
505 
506 	return ret;
507 }
508 
hpre_rsa_encrypt(struct drvcrypt_rsa_ed * rsa_data)509 static TEE_Result hpre_rsa_encrypt(struct drvcrypt_rsa_ed *rsa_data)
510 {
511 	if (!rsa_data) {
512 		EMSG("Invalid rsa encrypt input parameter");
513 		return TEE_ERROR_BAD_PARAMETERS;
514 	}
515 
516 	switch (rsa_data->rsa_id) {
517 	case DRVCRYPT_RSA_NOPAD:
518 	case DRVCRYPT_RSASSA_PKCS_V1_5:
519 	case DRVCRYPT_RSASSA_PSS:
520 		return rsa_nopad_encrypt(rsa_data);
521 	case DRVCRYPT_RSA_PKCS_V1_5:
522 		return rsa_pkcs_encrypt(rsa_data);
523 	case DRVCRYPT_RSA_OAEP:
524 		return rsa_oaep_encrypt(rsa_data);
525 	default:
526 		EMSG("Invalid rsa id");
527 		return TEE_ERROR_BAD_PARAMETERS;
528 	}
529 }
530 
hpre_rsa_crt_decrypt_alloc(struct hpre_rsa_msg * msg)531 static enum hisi_drv_status hpre_rsa_crt_decrypt_alloc(struct hpre_rsa_msg *msg)
532 {
533 	uint32_t size = HPRE_RSA_CRT_TOTAL_BUF_SIZE(msg->key_bytes);
534 	uint8_t *data = NULL;
535 
536 	data = calloc(1, size);
537 	if (!data) {
538 		EMSG("Fail to alloc rsa crt total buf");
539 		return HISI_QM_DRVCRYPT_ENOMEM;
540 	}
541 
542 	msg->prikey = data;
543 	msg->prikey_dma = virt_to_phys(msg->prikey);
544 
545 	msg->in = data + (msg->key_bytes * 2) + (msg->key_bytes >> 1);
546 	msg->in_dma = msg->prikey_dma + (msg->key_bytes * 2) +
547 		      (msg->key_bytes >> 1);
548 
549 	msg->out = msg->in + msg->key_bytes;
550 	msg->out_dma = msg->in_dma + msg->key_bytes;
551 
552 	return HISI_QM_DRVCRYPT_NO_ERR;
553 }
554 
555 static enum hisi_drv_status
hpre_rsa_ncrt_decrypt_alloc(struct hpre_rsa_msg * msg)556 hpre_rsa_ncrt_decrypt_alloc(struct hpre_rsa_msg *msg)
557 {
558 	uint32_t size = HPRE_RSA_NCRT_TOTAL_BUF_SIZE(msg->key_bytes);
559 	uint8_t *data = NULL;
560 
561 	data = calloc(1, size);
562 	if (!data) {
563 		EMSG("Fail to alloc rsa ncrt buf");
564 		return HISI_QM_DRVCRYPT_ENOMEM;
565 	}
566 
567 	msg->prikey = data;
568 	msg->prikey_dma = virt_to_phys(msg->prikey);
569 
570 	msg->in = data + (msg->key_bytes * 2);
571 	msg->in_dma = msg->prikey_dma + (msg->key_bytes * 2);
572 
573 	msg->out = msg->in + msg->key_bytes;
574 	msg->out_dma = msg->in_dma + msg->key_bytes;
575 
576 	return HISI_QM_DRVCRYPT_NO_ERR;
577 }
578 
579 static enum hisi_drv_status
hpre_rsa_crt_decrypt_bn2bin(struct hpre_rsa_msg * msg,struct drvcrypt_rsa_ed * rsa_data)580 hpre_rsa_crt_decrypt_bn2bin(struct hpre_rsa_msg *msg,
581 			    struct drvcrypt_rsa_ed *rsa_data)
582 {
583 	struct rsa_keypair *key = rsa_data->key.key;
584 	uint32_t p_bytes = msg->key_bytes >> 1;
585 	uint32_t dq_len = crypto_bignum_num_bytes(key->dq);
586 	uint32_t dp_len = crypto_bignum_num_bytes(key->dp);
587 	uint32_t q_len = crypto_bignum_num_bytes(key->q);
588 	uint32_t p_len = crypto_bignum_num_bytes(key->p);
589 	uint32_t qp_len = crypto_bignum_num_bytes(key->qp);
590 	uint8_t *dq = msg->prikey;
591 	uint8_t *dp = msg->prikey + p_bytes;
592 	uint8_t *q = dp + p_bytes;
593 	uint8_t *p = q + p_bytes;
594 	uint8_t *qp = p + p_bytes;
595 	enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR;
596 
597 	crypto_bignum_bn2bin(key->dq, dq);
598 	crypto_bignum_bn2bin(key->dp, dp);
599 	crypto_bignum_bn2bin(key->q, q);
600 	crypto_bignum_bn2bin(key->p, p);
601 	crypto_bignum_bn2bin(key->qp, qp);
602 
603 	ret = hpre_bin_from_crypto_bin(dq, dq, p_bytes, dq_len);
604 	if (ret) {
605 		EMSG("Fail to transfer rsa crt dq from crypto_bin to hpre_bin");
606 		return ret;
607 	}
608 
609 	ret = hpre_bin_from_crypto_bin(dp, dp, p_bytes, dp_len);
610 	if (ret) {
611 		EMSG("Fail to transfer rsa crt dp from crypto_bin to hpre_bin");
612 		return ret;
613 	}
614 
615 	ret = hpre_bin_from_crypto_bin(q, q, p_bytes, q_len);
616 	if (ret) {
617 		EMSG("Fail to transfer rsa crt q from crypto_bin to hpre_bin");
618 		return ret;
619 	}
620 
621 	ret = hpre_bin_from_crypto_bin(p, p, p_bytes, p_len);
622 	if (ret) {
623 		EMSG("Fail to transfer rsa crt p from crypto_bin to hpre_bin");
624 		return ret;
625 	}
626 
627 	ret = hpre_bin_from_crypto_bin(qp, qp, p_bytes, qp_len);
628 	if (ret) {
629 		EMSG("Fail to transfer rsa crt qinv from crypto_bin to hpre_bin");
630 		return ret;
631 	}
632 
633 	ret = hpre_bin_from_crypto_bin(msg->in, rsa_data->cipher.data,
634 				       msg->key_bytes, rsa_data->cipher.length);
635 	if (ret)
636 		EMSG("Fail to transfer rsa ciphertext from crypto_bin to hpre_bin");
637 
638 	return ret;
639 }
640 
641 static enum hisi_drv_status
hpre_rsa_ncrt_decrypt_bn2bin(struct hpre_rsa_msg * msg,struct drvcrypt_rsa_ed * rsa_data)642 hpre_rsa_ncrt_decrypt_bn2bin(struct hpre_rsa_msg *msg,
643 			     struct drvcrypt_rsa_ed *rsa_data)
644 {
645 	struct rsa_keypair *key = rsa_data->key.key;
646 	uint32_t d_len = 0;
647 	uint32_t n_len = 0;
648 	enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR;
649 	uint8_t *n = NULL;
650 
651 	n = msg->prikey + msg->key_bytes;
652 
653 	crypto_bignum_bn2bin(key->d, msg->prikey);
654 	crypto_bignum_bn2bin(key->n, n);
655 	d_len = crypto_bignum_num_bytes(key->d);
656 	n_len = crypto_bignum_num_bytes(key->n);
657 
658 	ret = hpre_bin_from_crypto_bin(msg->prikey, msg->prikey, msg->key_bytes,
659 				       d_len);
660 	if (ret) {
661 		EMSG("Fail to transfer rsa ncrt d from crypto_bin to hpre_bin");
662 		return ret;
663 	}
664 
665 	ret = hpre_bin_from_crypto_bin(n, n, msg->key_bytes, n_len);
666 	if (ret) {
667 		EMSG("Fail to transfer rsa ncrt n from crypto_bin to hpre_bin");
668 		return ret;
669 	}
670 
671 	ret = hpre_bin_from_crypto_bin(msg->in, rsa_data->cipher.data,
672 				       msg->key_bytes, rsa_data->cipher.length);
673 	if (ret)
674 		EMSG("Fail to transfer rsa ciphertext from crypto_bin to hpre_bin");
675 
676 	return ret;
677 }
678 
hpre_rsa_is_crt_mod(struct rsa_keypair * key)679 static bool hpre_rsa_is_crt_mod(struct rsa_keypair *key)
680 {
681 	if (key->p && crypto_bignum_num_bits(key->p) &&
682 	    key->q && crypto_bignum_num_bits(key->q) &&
683 	    key->dp && crypto_bignum_num_bits(key->dp) &&
684 	    key->dq && crypto_bignum_num_bits(key->dq) &&
685 	    key->qp && crypto_bignum_num_bits(key->qp))
686 		return true;
687 
688 	return false;
689 }
690 
hpre_rsa_decrypt_init(struct hpre_rsa_msg * msg,struct drvcrypt_rsa_ed * rsa_data)691 static TEE_Result hpre_rsa_decrypt_init(struct hpre_rsa_msg *msg,
692 					struct drvcrypt_rsa_ed *rsa_data)
693 {
694 	struct rsa_keypair *key = rsa_data->key.key;
695 	size_t n_bytes = rsa_data->key.n_size;
696 	bool is_crt = false;
697 	enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR;
698 
699 	msg->is_private = rsa_data->key.isprivate;
700 	msg->key_bytes = hpre_rsa_get_hw_kbytes(BYTES_TO_BITS(n_bytes));
701 	if (!msg->key_bytes)
702 		return TEE_ERROR_BAD_PARAMETERS;
703 
704 	is_crt = hpre_rsa_is_crt_mod(key);
705 	if (is_crt) {
706 		msg->alg_type = HPRE_ALG_NC_CRT;
707 		ret = hpre_rsa_crt_decrypt_alloc(msg);
708 		if (ret)
709 			return TEE_ERROR_OUT_OF_MEMORY;
710 
711 		ret = hpre_rsa_crt_decrypt_bn2bin(msg, rsa_data);
712 		if (ret) {
713 			hpre_rsa_params_free(msg);
714 			return TEE_ERROR_BAD_STATE;
715 		}
716 	} else {
717 		msg->alg_type = HPRE_ALG_NC_NCRT;
718 		ret = hpre_rsa_ncrt_decrypt_alloc(msg);
719 		if (ret)
720 			return TEE_ERROR_OUT_OF_MEMORY;
721 
722 		ret = hpre_rsa_ncrt_decrypt_bn2bin(msg, rsa_data);
723 		if (ret) {
724 			hpre_rsa_params_free(msg);
725 			return TEE_ERROR_BAD_STATE;
726 		}
727 	}
728 
729 	return TEE_SUCCESS;
730 }
731 
rsa_nopad_decrypt(struct drvcrypt_rsa_ed * rsa_data)732 static TEE_Result rsa_nopad_decrypt(struct drvcrypt_rsa_ed *rsa_data)
733 {
734 	size_t n_bytes = rsa_data->key.n_size;
735 	struct hpre_rsa_msg msg = { };
736 	uint32_t offset = 0;
737 	TEE_Result ret = TEE_SUCCESS;
738 	uint8_t *pos = NULL;
739 
740 	if (rsa_data->cipher.length > n_bytes) {
741 		EMSG("Invalid cipher length[%zu]", rsa_data->cipher.length);
742 		return TEE_ERROR_BAD_PARAMETERS;
743 	}
744 
745 	ret = hpre_rsa_decrypt_init(&msg, rsa_data);
746 	if (ret) {
747 		EMSG("Fail to init rsa msg");
748 		return ret;
749 	}
750 
751 	ret = hpre_rsa_do_task(&msg);
752 	if (ret)
753 		goto decrypt_deinit;
754 
755 	pos = msg.out + msg.key_bytes - n_bytes;
756 	if (rsa_data->rsa_id == DRVCRYPT_RSA_NOPAD) {
757 		/* Plaintext can not have valid zero data in NOPAD MODE */
758 		while ((offset < n_bytes - 1) && (pos[offset] == 0))
759 			offset++;
760 	}
761 
762 	rsa_data->message.length = n_bytes - offset;
763 	memcpy(rsa_data->message.data, pos + offset, rsa_data->message.length);
764 
765 decrypt_deinit:
766 	hpre_rsa_params_free(&msg);
767 
768 	return ret;
769 }
770 
rsaes_pkcs_v1_5_decode(struct drvcrypt_rsa_ed * rsa_data,uint8_t * out,size_t * out_len)771 static TEE_Result rsaes_pkcs_v1_5_decode(struct drvcrypt_rsa_ed *rsa_data,
772 					 uint8_t *out, size_t *out_len)
773 {
774 	size_t em_len = rsa_data->message.length;
775 	uint8_t *em = rsa_data->message.data;
776 	size_t ps_len = 0;
777 	size_t i = 0;
778 
779 	/* PKCS_V1.5 EM format 0x00 || 0x02 || PS non-zero || 0x00 || M */
780 	if (em_len < PKCS_V1_5_MSG_MIN_LEN || em[0] != 0 ||
781 	    em[1] != ENCRYPT_PAD) {
782 		EMSG("Invalid pkcs_v1.5 decode parameter");
783 		return TEE_ERROR_BAD_PARAMETERS;
784 	}
785 
786 	for (i = PKCS_V1_5_PS_POS; i < em_len; i++) {
787 		if (em[i] == 0)
788 			break;
789 	}
790 
791 	if (i >= em_len) {
792 		EMSG("Fail to find zero pos");
793 		return TEE_ERROR_BAD_PARAMETERS;
794 	}
795 
796 	ps_len = i - PKCS_V1_5_PS_POS;
797 	if (em_len - ps_len - PKCS_V1_5_FIXED_LEN > *out_len ||
798 	    ps_len < PKCS_V1_5_PS_MIN_LEN) {
799 		EMSG("Invalid pkcs_v1.5 decode ps_len");
800 		return TEE_ERROR_BAD_PARAMETERS;
801 	}
802 
803 	*out_len = em_len - ps_len - PKCS_V1_5_FIXED_LEN;
804 	memcpy(out, em + ps_len + PKCS_V1_5_FIXED_LEN, *out_len);
805 
806 	return TEE_SUCCESS;
807 }
808 
rsa_pkcs_decrypt(struct drvcrypt_rsa_ed * rsa_data)809 static TEE_Result rsa_pkcs_decrypt(struct drvcrypt_rsa_ed *rsa_data)
810 {
811 	uint32_t n_bytes = rsa_data->key.n_size;
812 	struct drvcrypt_rsa_ed rsa_dec_info = *rsa_data;
813 	TEE_Result ret = TEE_SUCCESS;
814 
815 	/* Alloc pkcs_v1.5 encode message data buf */
816 	rsa_dec_info.message.data = malloc(n_bytes);
817 	if (!rsa_dec_info.message.data) {
818 		EMSG("Fail to alloc message data buf");
819 		return TEE_ERROR_OUT_OF_MEMORY;
820 	}
821 
822 	rsa_dec_info.message.length = n_bytes;
823 	ret = rsa_nopad_decrypt(&rsa_dec_info);
824 	if (ret)
825 		goto free_data;
826 
827 	ret = rsaes_pkcs_v1_5_decode(&rsa_dec_info, rsa_data->message.data,
828 				     &rsa_data->message.length);
829 	if (ret)
830 		EMSG("Fail to get pkcs_v1.5 decode message data");
831 
832 free_data:
833 	free(rsa_dec_info.message.data);
834 
835 	return ret;
836 }
837 
rsa_oaep_get_seed(struct drvcrypt_rsa_ed * rsa_data,uint8_t * mask_db,uint8_t * seed)838 static TEE_Result rsa_oaep_get_seed(struct drvcrypt_rsa_ed *rsa_data,
839 				    uint8_t *mask_db, uint8_t *seed)
840 {
841 	size_t db_len = rsa_data->key.n_size - rsa_data->digest_size - 1;
842 	uint8_t mask_db_mgf[OAEP_MAX_HASH_LEN] = { };
843 	size_t lhash_len = rsa_data->digest_size;
844 	uint8_t *mask_seed = NULL;
845 	TEE_Result ret = TEE_SUCCESS;
846 
847 	mask_seed = rsa_data->message.data + 1;
848 
849 	ret = mgf_process(lhash_len, mask_db, db_len, mask_db_mgf, lhash_len,
850 			  rsa_data);
851 	if (ret) {
852 		EMSG("Fail to get mask_db mgf result");
853 		return ret;
854 	}
855 
856 	return xor_process(mask_seed, mask_db_mgf, seed, lhash_len);
857 }
858 
rsa_oaep_get_db(struct drvcrypt_rsa_ed * rsa_data,uint8_t * mask_db,uint8_t * seed,uint8_t * db)859 static TEE_Result rsa_oaep_get_db(struct drvcrypt_rsa_ed *rsa_data,
860 				  uint8_t *mask_db, uint8_t *seed, uint8_t *db)
861 {
862 	size_t db_len = rsa_data->key.n_size - rsa_data->digest_size - 1;
863 	size_t lhash_len = rsa_data->digest_size;
864 	uint8_t seed_mgf[OAEP_MAX_DB_LEN] = { };
865 	TEE_Result ret = TEE_SUCCESS;
866 
867 	ret = mgf_process(lhash_len, seed, lhash_len, seed_mgf, db_len,
868 			  rsa_data);
869 	if (ret) {
870 		EMSG("Fail to get seed mgf result");
871 		return ret;
872 	}
873 
874 	return xor_process(mask_db, seed_mgf, db, db_len);
875 }
876 
rsa_oaep_get_msg(struct drvcrypt_rsa_ed * rsa_data,uint8_t * db,uint8_t * out,size_t * out_len)877 static TEE_Result rsa_oaep_get_msg(struct drvcrypt_rsa_ed *rsa_data,
878 				   uint8_t *db, uint8_t *out, size_t *out_len)
879 {
880 	size_t db_len = rsa_data->key.n_size - rsa_data->digest_size - 1;
881 	size_t lhash_len = rsa_data->digest_size;
882 	uint8_t hash[OAEP_MAX_HASH_LEN] = { };
883 	size_t msg_len = 0;
884 	size_t lp_len = 0;
885 	TEE_Result ret = TEE_SUCCESS;
886 
887 	/* oaep db format lhash || ps zero || 01 || M */
888 	ret = tee_hash_createdigest(rsa_data->hash_algo, rsa_data->label.data,
889 				    rsa_data->label.length, hash, lhash_len);
890 	if (ret) {
891 		EMSG("Fail to get label hash");
892 		return ret;
893 	}
894 
895 	if (memcmp(hash, db, lhash_len)) {
896 		EMSG("Hash is not equal");
897 		return TEE_ERROR_BAD_PARAMETERS;
898 	}
899 
900 	for (lp_len = lhash_len; lp_len < db_len; lp_len++) {
901 		if (db[lp_len] != 0)
902 			break;
903 	}
904 
905 	if (lp_len == db_len) {
906 		EMSG("Fail to find fixed 01 in db");
907 		return TEE_ERROR_BAD_PARAMETERS;
908 	}
909 
910 	msg_len = db_len - lp_len - 1;
911 	if (msg_len > rsa_data->message.length) {
912 		DMSG("Message space is not enough");
913 		*out_len = msg_len;
914 		return TEE_ERROR_SHORT_BUFFER;
915 	}
916 
917 	*out_len = msg_len;
918 	memcpy(out, db + lp_len + 1, msg_len);
919 
920 	return TEE_SUCCESS;
921 }
922 
rsa_oaep_decode(struct drvcrypt_rsa_ed * rsa_data,uint8_t * out,size_t * out_len)923 static TEE_Result rsa_oaep_decode(struct drvcrypt_rsa_ed *rsa_data,
924 				  uint8_t *out, size_t *out_len)
925 {
926 	size_t lhash_len = rsa_data->digest_size;
927 	uint8_t seed[OAEP_MAX_HASH_LEN] = { };
928 	uint8_t db[OAEP_MAX_DB_LEN] = { };
929 	uint8_t *mask_db = NULL;
930 	TEE_Result ret = TEE_SUCCESS;
931 
932 	/* oaep format 00 || maskedseed || maskeddb */
933 	mask_db = rsa_data->message.data + lhash_len + 1;
934 	ret = rsa_oaep_get_seed(rsa_data, mask_db, seed);
935 	if (ret)
936 		return ret;
937 
938 	ret = rsa_oaep_get_db(rsa_data, mask_db, seed, db);
939 	if (ret)
940 		return ret;
941 
942 	return rsa_oaep_get_msg(rsa_data, db, out, out_len);
943 }
944 
rsa_oaep_decrypt(struct drvcrypt_rsa_ed * rsa_data)945 static TEE_Result rsa_oaep_decrypt(struct drvcrypt_rsa_ed *rsa_data)
946 {
947 	size_t n_bytes = rsa_data->key.n_size;
948 	struct drvcrypt_rsa_ed rsa_dec_info = *rsa_data;
949 	TEE_Result ret = TEE_SUCCESS;
950 
951 	/* Alloc oaep encode message data buf */
952 	rsa_dec_info.message.data = malloc(n_bytes);
953 	if (!rsa_dec_info.message.data) {
954 		EMSG("Fail to alloc message data buf");
955 		return TEE_ERROR_OUT_OF_MEMORY;
956 	}
957 
958 	rsa_dec_info.message.length = n_bytes;
959 	ret = rsa_nopad_decrypt(&rsa_dec_info);
960 	if (ret)
961 		goto free_data;
962 
963 	ret = rsa_oaep_decode(&rsa_dec_info, rsa_data->message.data,
964 			      &rsa_data->message.length);
965 	if (ret)
966 		EMSG("Fail to get oaep decode message data");
967 
968 free_data:
969 	free(rsa_dec_info.message.data);
970 
971 	return ret;
972 }
973 
hpre_rsa_decrypt(struct drvcrypt_rsa_ed * rsa_data)974 static TEE_Result hpre_rsa_decrypt(struct drvcrypt_rsa_ed *rsa_data)
975 {
976 	if (!rsa_data) {
977 		EMSG("Invalid rsa decrypt input parameter");
978 		return TEE_ERROR_BAD_PARAMETERS;
979 	}
980 
981 	switch (rsa_data->rsa_id) {
982 	case DRVCRYPT_RSA_NOPAD:
983 	case DRVCRYPT_RSASSA_PKCS_V1_5:
984 	case DRVCRYPT_RSASSA_PSS:
985 		return rsa_nopad_decrypt(rsa_data);
986 	case DRVCRYPT_RSA_PKCS_V1_5:
987 		return rsa_pkcs_decrypt(rsa_data);
988 	case DRVCRYPT_RSA_OAEP:
989 		return rsa_oaep_decrypt(rsa_data);
990 	default:
991 		EMSG("Invalid rsa id");
992 		return TEE_ERROR_NOT_SUPPORTED;
993 	}
994 }
995 
996 static const struct drvcrypt_rsa driver_rsa = {
997 	.alloc_keypair = sw_crypto_acipher_alloc_rsa_keypair,
998 	.alloc_publickey = sw_crypto_acipher_alloc_rsa_public_key,
999 	.free_publickey = sw_crypto_acipher_free_rsa_public_key,
1000 	.free_keypair = sw_crypto_acipher_free_rsa_keypair,
1001 	.gen_keypair = sw_crypto_acipher_gen_rsa_key,
1002 	.encrypt = hpre_rsa_encrypt,
1003 	.decrypt = hpre_rsa_decrypt,
1004 	.optional = {
1005 		/*
1006 		 * If ssa_sign or verify is NULL, the framework will fill
1007 		 * data format directly by soft calculation. Then call api
1008 		 * encrypt or decrypt.
1009 		 */
1010 		.ssa_sign = NULL,
1011 		.ssa_verify = NULL,
1012 	},
1013 };
1014 
hpre_rsa_init(void)1015 static TEE_Result hpre_rsa_init(void)
1016 {
1017 	TEE_Result ret = drvcrypt_register_rsa(&driver_rsa);
1018 
1019 	if (ret != TEE_SUCCESS)
1020 		EMSG("hpre rsa register to crypto fail");
1021 
1022 	return ret;
1023 }
1024 
1025 driver_init(hpre_rsa_init);
1026