1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3 * Copyright (c) 2022-2024, HiSilicon Technologies Co., Ltd.
4 * Kunpeng hardware accelerator hpre rsa algorithm implementation.
5 */
6
7 #include <crypto/crypto_impl.h>
8 #include <drvcrypt.h>
9 #include <drvcrypt_acipher.h>
10 #include <drvcrypt_math.h>
11 #include <initcall.h>
12 #include <malloc.h>
13 #include <mbedtls/rsa.h>
14 #include <rng_support.h>
15 #include <stdlib_ext.h>
16 #include <string.h>
17 #include <string_ext.h>
18 #include <tee/tee_cryp_utl.h>
19 #include <trace.h>
20
21 #include "hpre_main.h"
22 #include "hpre_rsa.h"
23
hpre_rsa_fill_addr_params(struct hpre_rsa_msg * msg,struct hpre_sqe * sqe)24 static enum hisi_drv_status hpre_rsa_fill_addr_params(struct hpre_rsa_msg *msg,
25 struct hpre_sqe *sqe)
26 {
27 switch (msg->alg_type) {
28 case HPRE_ALG_NC_NCRT:
29 case HPRE_ALG_NC_CRT:
30 if (msg->is_private) {
31 /* DECRYPT */
32 sqe->key = msg->prikey_dma;
33 sqe->in = msg->in_dma;
34 sqe->out = msg->out_dma;
35 } else {
36 /* ENCRYPT */
37 sqe->key = msg->pubkey_dma;
38 sqe->in = msg->in_dma;
39 sqe->out = msg->out_dma;
40 }
41 return HISI_QM_DRVCRYPT_NO_ERR;
42 default:
43 EMSG("Invalid alg_type[%"PRIu32"]", msg->alg_type);
44 return HISI_QM_DRVCRYPT_IN_EPARA;
45 }
46 }
47
hpre_rsa_fill_sqe(void * bd,void * info)48 static enum hisi_drv_status hpre_rsa_fill_sqe(void *bd, void *info)
49 {
50 struct hpre_rsa_msg *msg = info;
51 struct hpre_sqe *sqe = bd;
52
53 sqe->w0 = msg->alg_type | SHIFT_U32(0x1, HPRE_DONE_SHIFT);
54 sqe->task_len1 = TASK_LENGTH(msg->key_bytes);
55
56 return hpre_rsa_fill_addr_params(msg, sqe);
57 }
58
hpre_rsa_parse_sqe(void * bd,void * info __unused)59 static enum hisi_drv_status hpre_rsa_parse_sqe(void *bd, void *info __unused)
60 {
61 struct hpre_sqe *sqe = bd;
62 uint16_t err = 0;
63 uint16_t done = 0;
64
65 err = HPRE_TASK_ETYPE(sqe->w0);
66 done = HPRE_TASK_DONE(sqe->w0);
67 if (done != HPRE_HW_TASK_DONE || err) {
68 EMSG("HPRE do rsa fail! done=0x%"PRIX16", etype=0x%"PRIX16,
69 done, err);
70 return HISI_QM_DRVCRYPT_IN_EPARA;
71 }
72
73 return HISI_QM_DRVCRYPT_NO_ERR;
74 }
75
hpre_rsa_do_task(void * msg)76 static TEE_Result hpre_rsa_do_task(void *msg)
77 {
78 struct hisi_qp *rsa_qp = NULL;
79 TEE_Result res = TEE_SUCCESS;
80 enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR;
81
82 rsa_qp = hpre_create_qp(HISI_QM_CHANNEL_TYPE0);
83 if (!rsa_qp) {
84 EMSG("Fail to create rsa qp");
85 return TEE_ERROR_BUSY;
86 }
87
88 rsa_qp->fill_sqe = hpre_rsa_fill_sqe;
89 rsa_qp->parse_sqe = hpre_rsa_parse_sqe;
90 ret = hisi_qp_send(rsa_qp, msg);
91 if (ret) {
92 EMSG("Fail to send task, ret=%x"PRIx32, ret);
93 res = TEE_ERROR_BAD_STATE;
94 goto done_proc;
95 }
96
97 ret = hisi_qp_recv_sync(rsa_qp, msg);
98 if (ret) {
99 EMSG("Recv task error, ret=%x"PRIx32, ret);
100 res = TEE_ERROR_BAD_STATE;
101 goto done_proc;
102 }
103
104 done_proc:
105 hisi_qm_release_qp(rsa_qp);
106
107 return res;
108 }
109
mgf_process(size_t digest_size,uint8_t * seed,size_t seed_len,uint8_t * mask,size_t mask_len,struct drvcrypt_rsa_ed * rsa_data)110 static TEE_Result mgf_process(size_t digest_size, uint8_t *seed,
111 size_t seed_len, uint8_t *mask, size_t mask_len,
112 struct drvcrypt_rsa_ed *rsa_data)
113 {
114 struct drvcrypt_rsa_mgf mgf = { };
115
116 if (!rsa_data->mgf) {
117 EMSG("mgf function is NULL");
118 return TEE_ERROR_BAD_PARAMETERS;
119 }
120
121 mgf.hash_algo = rsa_data->hash_algo;
122 mgf.digest_size = digest_size;
123 mgf.seed.data = seed;
124 mgf.seed.length = seed_len;
125 mgf.mask.data = mask;
126 mgf.mask.length = mask_len;
127
128 return rsa_data->mgf(&mgf);
129 }
130
xor_process(uint8_t * a,uint8_t * b,uint8_t * out,size_t len)131 static TEE_Result xor_process(uint8_t *a, uint8_t *b, uint8_t *out, size_t len)
132 {
133 struct drvcrypt_mod_op xor_mod = { };
134
135 xor_mod.n.length = len;
136 xor_mod.a.data = a;
137 xor_mod.a.length = len;
138 xor_mod.b.data = b;
139 xor_mod.b.length = len;
140 xor_mod.result.data = out;
141 xor_mod.result.length = len;
142
143 return drvcrypt_xor_mod_n(&xor_mod);
144 }
145
hpre_rsa_get_hw_kbytes(size_t key_bits)146 static size_t hpre_rsa_get_hw_kbytes(size_t key_bits)
147 {
148 size_t size = 0;
149
150 if (key_bits <= 1024)
151 size = BITS_TO_BYTES(1024);
152 else if (key_bits <= 2048)
153 size = BITS_TO_BYTES(2048);
154 else if (key_bits <= 3072)
155 size = BITS_TO_BYTES(3072);
156 else if (key_bits <= 4096)
157 size = BITS_TO_BYTES(4096);
158 else
159 EMSG("Invalid key_bits[%zu]", key_bits);
160
161 return size;
162 }
163
hpre_rsa_params_free(struct hpre_rsa_msg * msg)164 static void hpre_rsa_params_free(struct hpre_rsa_msg *msg)
165 {
166 switch (msg->alg_type) {
167 case HPRE_ALG_NC_NCRT:
168 if (msg->is_private)
169 free_wipe(msg->prikey);
170 else
171 free(msg->pubkey);
172 break;
173 case HPRE_ALG_NC_CRT:
174 if (msg->is_private)
175 free_wipe(msg->prikey);
176 break;
177 default:
178 EMSG("Invalid alg_type[%"PRIu32"]", msg->alg_type);
179 break;
180 }
181 }
182
hpre_rsa_encrypt_alloc(struct hpre_rsa_msg * msg)183 static enum hisi_drv_status hpre_rsa_encrypt_alloc(struct hpre_rsa_msg *msg)
184 {
185 uint32_t size = HPRE_RSA_NCRT_TOTAL_BUF_SIZE(msg->key_bytes);
186 uint8_t *data = NULL;
187
188 data = calloc(1, size);
189 if (!data) {
190 EMSG("Fail to alloc rsa ncrt buf");
191 return HISI_QM_DRVCRYPT_ENOMEM;
192 }
193
194 msg->pubkey = data;
195 msg->pubkey_dma = virt_to_phys(msg->pubkey);
196
197 msg->in = data + (msg->key_bytes * 2);
198 msg->in_dma = msg->pubkey_dma + (msg->key_bytes * 2);
199
200 msg->out = msg->in + msg->key_bytes;
201 msg->out_dma = msg->in_dma + msg->key_bytes;
202
203 return HISI_QM_DRVCRYPT_NO_ERR;
204 }
205
206 static enum hisi_drv_status
hpre_rsa_encrypt_bn2bin(struct hpre_rsa_msg * msg,struct drvcrypt_rsa_ed * rsa_data)207 hpre_rsa_encrypt_bn2bin(struct hpre_rsa_msg *msg,
208 struct drvcrypt_rsa_ed *rsa_data)
209 {
210 struct rsa_public_key *key = rsa_data->key.key;
211 uint32_t e_len = 0;
212 uint32_t n_len = 0;
213 enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR;
214 uint8_t *n = NULL;
215
216 n = msg->pubkey + msg->key_bytes;
217
218 crypto_bignum_bn2bin(key->e, msg->pubkey);
219 crypto_bignum_bn2bin(key->n, n);
220 e_len = crypto_bignum_num_bytes(key->e);
221 n_len = crypto_bignum_num_bytes(key->n);
222
223 ret = hpre_bin_from_crypto_bin(msg->pubkey, msg->pubkey,
224 msg->key_bytes, e_len);
225 if (ret) {
226 EMSG("Fail to transfer rsa ncrt e from crypto_bin to hpre_bin");
227 return ret;
228 }
229
230 ret = hpre_bin_from_crypto_bin(n, n, msg->key_bytes, n_len);
231 if (ret) {
232 EMSG("Fail to transfer rsa ncrt n from crypto_bin to hpre_bin");
233 return ret;
234 }
235
236 ret = hpre_bin_from_crypto_bin(msg->in, rsa_data->message.data,
237 msg->key_bytes,
238 rsa_data->message.length);
239 if (ret)
240 EMSG("Fail to transfer rsa plaintext from crypto_bin to hpre_bin");
241
242 return ret;
243 }
244
hpre_rsa_encrypt_init(struct hpre_rsa_msg * msg,struct drvcrypt_rsa_ed * rsa_data)245 static TEE_Result hpre_rsa_encrypt_init(struct hpre_rsa_msg *msg,
246 struct drvcrypt_rsa_ed *rsa_data)
247 {
248 size_t n_bytes = rsa_data->key.n_size;
249 enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR;
250
251 msg->alg_type = HPRE_ALG_NC_NCRT;
252 msg->is_private = rsa_data->key.isprivate;
253 msg->key_bytes = hpre_rsa_get_hw_kbytes(BYTES_TO_BITS(n_bytes));
254 if (!msg->key_bytes)
255 return TEE_ERROR_BAD_PARAMETERS;
256
257 ret = hpre_rsa_encrypt_alloc(msg);
258 if (ret)
259 return TEE_ERROR_OUT_OF_MEMORY;
260
261 ret = hpre_rsa_encrypt_bn2bin(msg, rsa_data);
262 if (ret) {
263 hpre_rsa_params_free(msg);
264 return TEE_ERROR_BAD_STATE;
265 }
266
267 return TEE_SUCCESS;
268 }
269
rsa_nopad_encrypt(struct drvcrypt_rsa_ed * rsa_data)270 static TEE_Result rsa_nopad_encrypt(struct drvcrypt_rsa_ed *rsa_data)
271 {
272 size_t n_bytes = rsa_data->key.n_size;
273 struct hpre_rsa_msg msg = { };
274 TEE_Result ret = TEE_SUCCESS;
275
276 if (rsa_data->message.length > n_bytes) {
277 EMSG("Invalid msg length[%zu]", rsa_data->message.length);
278 return TEE_ERROR_BAD_PARAMETERS;
279 }
280
281 ret = hpre_rsa_encrypt_init(&msg, rsa_data);
282 if (ret) {
283 EMSG("Fail to init rsa msg");
284 return ret;
285 }
286
287 ret = hpre_rsa_do_task(&msg);
288 if (ret)
289 goto encrypt_deinit;
290
291 /* Ciphertext can have valid zero data in NOPAD MODE */
292 memcpy(rsa_data->cipher.data, msg.out + msg.key_bytes - n_bytes,
293 n_bytes);
294 rsa_data->cipher.length = n_bytes;
295
296 encrypt_deinit:
297 hpre_rsa_params_free(&msg);
298
299 return ret;
300 }
301
pkcs_v1_5_fill_ps(uint8_t * ps,size_t ps_len)302 static TEE_Result pkcs_v1_5_fill_ps(uint8_t *ps, size_t ps_len)
303 {
304 size_t i = 0;
305
306 if (hw_get_random_bytes(ps, ps_len)) {
307 EMSG("Fail to get ps data");
308 return TEE_ERROR_NO_DATA;
309 }
310
311 for (i = 0; i < ps_len; i++) {
312 if (ps[i] == 0)
313 ps[i] = PKCS_V1_5_PS_DATA;
314 }
315
316 return TEE_SUCCESS;
317 }
318
rsaes_pkcs_v1_5_encode(struct drvcrypt_rsa_ed * rsa_data,uint8_t * out)319 static TEE_Result rsaes_pkcs_v1_5_encode(struct drvcrypt_rsa_ed *rsa_data,
320 uint8_t *out)
321 {
322 size_t msg_len = rsa_data->message.length;
323 size_t out_len = rsa_data->cipher.length;
324 size_t n_bytes = rsa_data->key.n_size;
325 uint8_t *ps = out + PKCS_V1_5_PS_POS;
326 TEE_Result ret = TEE_SUCCESS;
327 size_t ps_len = 0;
328
329 /* PKCS_V1.5 format 0x00 || 0x02 || PS non-zero || 0x00 || M */
330 if ((msg_len + PKCS_V1_5_MSG_MIN_LEN) > n_bytes || out_len < n_bytes) {
331 EMSG("Invalid pkcs_v1.5 encode parameter");
332 return TEE_ERROR_BAD_PARAMETERS;
333 }
334
335 ps_len = n_bytes - PKCS_V1_5_FIXED_LEN - msg_len;
336 ret = pkcs_v1_5_fill_ps(ps, ps_len);
337 if (ret)
338 return ret;
339
340 out[0] = 0;
341 out[1] = ENCRYPT_PAD;
342 out[PKCS_V1_5_FIXED_LEN + ps_len - 1] = 0;
343 memcpy(out + PKCS_V1_5_FIXED_LEN + ps_len, rsa_data->message.data,
344 msg_len);
345
346 return TEE_SUCCESS;
347 }
348
rsa_pkcs_encrypt(struct drvcrypt_rsa_ed * rsa_data)349 static TEE_Result rsa_pkcs_encrypt(struct drvcrypt_rsa_ed *rsa_data)
350 {
351 uint32_t n_bytes = rsa_data->key.n_size;
352 struct drvcrypt_rsa_ed rsa_enc_info = *rsa_data;
353 TEE_Result ret = TEE_SUCCESS;
354
355 /* Alloc pkcs_v1.5 encode message data buf */
356 rsa_enc_info.message.data = malloc(n_bytes);
357 if (!rsa_enc_info.message.data) {
358 EMSG("Fail to alloc message data buf");
359 return TEE_ERROR_OUT_OF_MEMORY;
360 }
361
362 rsa_enc_info.message.length = n_bytes;
363 ret = rsaes_pkcs_v1_5_encode(rsa_data, rsa_enc_info.message.data);
364 if (ret) {
365 EMSG("Fail to get pkcs_v1.5 encode message data");
366 goto free_data;
367 }
368
369 ret = rsa_nopad_encrypt(&rsa_enc_info);
370 if (ret)
371 goto free_data;
372
373 memcpy(rsa_data->cipher.data, rsa_enc_info.cipher.data,
374 rsa_enc_info.cipher.length);
375 rsa_data->cipher.length = rsa_enc_info.cipher.length;
376
377 free_data:
378 free(rsa_enc_info.message.data);
379
380 return ret;
381 }
382
rsa_oaep_fill_db(struct drvcrypt_rsa_ed * rsa_data,uint8_t * db)383 static TEE_Result rsa_oaep_fill_db(struct drvcrypt_rsa_ed *rsa_data,
384 uint8_t *db)
385 {
386 size_t lhash_len = rsa_data->digest_size;
387 size_t n_bytes = rsa_data->key.n_size;
388 size_t db_len = n_bytes - lhash_len - 1;
389 size_t ps_len = 0;
390 TEE_Result ret = TEE_SUCCESS;
391
392 /* oaep db format lhash || ps zero || 01 || M */
393 ret = tee_hash_createdigest(rsa_data->hash_algo, rsa_data->label.data,
394 rsa_data->label.length, db, lhash_len);
395 if (ret) {
396 EMSG("Fail to get label hash");
397 return ret;
398 }
399
400 ps_len = db_len - lhash_len - rsa_data->message.length - 1;
401 db[lhash_len + ps_len] = 1;
402 memcpy(db + lhash_len + ps_len + 1, rsa_data->message.data,
403 rsa_data->message.length);
404
405 return TEE_SUCCESS;
406 }
407
rsa_oaep_fill_maskdb(struct drvcrypt_rsa_ed * rsa_data,uint8_t * seed,uint8_t * db,uint8_t * mask_db)408 static TEE_Result rsa_oaep_fill_maskdb(struct drvcrypt_rsa_ed *rsa_data,
409 uint8_t *seed, uint8_t *db,
410 uint8_t *mask_db)
411 {
412 size_t lhash_len = rsa_data->digest_size;
413 size_t n_bytes = rsa_data->key.n_size;
414 size_t db_len = n_bytes - lhash_len - 1;
415 uint8_t seed_mgf[OAEP_MAX_DB_LEN] = { };
416 TEE_Result ret = TEE_SUCCESS;
417
418 ret = mgf_process(lhash_len, seed, lhash_len, seed_mgf, db_len,
419 rsa_data);
420 if (ret) {
421 EMSG("Fail to get seed_mgf");
422 return ret;
423 }
424
425 return xor_process(db, seed_mgf, mask_db, db_len);
426 }
427
rsa_oaep_fill_maskseed(struct drvcrypt_rsa_ed * rsa_data,uint8_t * seed,uint8_t * em)428 static TEE_Result rsa_oaep_fill_maskseed(struct drvcrypt_rsa_ed *rsa_data,
429 uint8_t *seed, uint8_t *em)
430 {
431 uint8_t mask_db_mgf[OAEP_MAX_HASH_LEN] = { 0 };
432 size_t lhash_len = rsa_data->digest_size;
433 size_t n_bytes = rsa_data->key.n_size;
434 size_t db_len = n_bytes - lhash_len - 1;
435 uint8_t *mask_db = em + lhash_len + 1;
436 uint8_t *mask_seed = em + 1;
437 TEE_Result ret = TEE_SUCCESS;
438
439 ret = mgf_process(lhash_len, mask_db, db_len, mask_db_mgf, lhash_len,
440 rsa_data);
441 if (ret) {
442 EMSG("Fail to get mask_db_mgf");
443 return ret;
444 }
445
446 return xor_process(seed, mask_db_mgf, mask_seed, lhash_len);
447 }
448
rsa_oaep_encode(struct drvcrypt_rsa_ed * rsa_data,uint8_t * em)449 static TEE_Result rsa_oaep_encode(struct drvcrypt_rsa_ed *rsa_data,
450 uint8_t *em)
451 {
452 size_t lhash_len = rsa_data->digest_size;
453 uint8_t db[OAEP_MAX_DB_LEN] = { };
454 uint8_t seed[OAEP_MAX_HASH_LEN] = { };
455 TEE_Result ret = TEE_SUCCESS;
456
457 /* oaep format 00 || maskedseed || maskeddb */
458 em[0] = 0;
459
460 ret = rsa_oaep_fill_db(rsa_data, db);
461 if (ret)
462 return ret;
463
464 ret = hw_get_random_bytes(seed, lhash_len);
465 if (ret)
466 return ret;
467
468 ret = rsa_oaep_fill_maskdb(rsa_data, seed, db, em + lhash_len + 1);
469 if (ret)
470 return ret;
471
472 return rsa_oaep_fill_maskseed(rsa_data, seed, em);
473 }
474
rsa_oaep_encrypt(struct drvcrypt_rsa_ed * rsa_data)475 static TEE_Result rsa_oaep_encrypt(struct drvcrypt_rsa_ed *rsa_data)
476 {
477 size_t n_bytes = rsa_data->key.n_size;
478 struct drvcrypt_rsa_ed rsa_enc_info = *rsa_data;
479 TEE_Result ret = TEE_SUCCESS;
480
481 /* Alloc oaep encode message data buf */
482 rsa_enc_info.message.data = malloc(n_bytes);
483 if (!rsa_enc_info.message.data) {
484 EMSG("Fail to alloc message data buf");
485 return TEE_ERROR_OUT_OF_MEMORY;
486 }
487
488 rsa_enc_info.message.length = n_bytes;
489 ret = rsa_oaep_encode(rsa_data, rsa_enc_info.message.data);
490 if (ret) {
491 EMSG("Fail to get oaep encode message data");
492 goto free_data;
493 }
494
495 ret = rsa_nopad_encrypt(&rsa_enc_info);
496 if (ret)
497 goto free_data;
498
499 memcpy(rsa_data->cipher.data, rsa_enc_info.cipher.data,
500 rsa_enc_info.cipher.length);
501 rsa_data->cipher.length = rsa_enc_info.cipher.length;
502
503 free_data:
504 free(rsa_enc_info.message.data);
505
506 return ret;
507 }
508
hpre_rsa_encrypt(struct drvcrypt_rsa_ed * rsa_data)509 static TEE_Result hpre_rsa_encrypt(struct drvcrypt_rsa_ed *rsa_data)
510 {
511 if (!rsa_data) {
512 EMSG("Invalid rsa encrypt input parameter");
513 return TEE_ERROR_BAD_PARAMETERS;
514 }
515
516 switch (rsa_data->rsa_id) {
517 case DRVCRYPT_RSA_NOPAD:
518 case DRVCRYPT_RSASSA_PKCS_V1_5:
519 case DRVCRYPT_RSASSA_PSS:
520 return rsa_nopad_encrypt(rsa_data);
521 case DRVCRYPT_RSA_PKCS_V1_5:
522 return rsa_pkcs_encrypt(rsa_data);
523 case DRVCRYPT_RSA_OAEP:
524 return rsa_oaep_encrypt(rsa_data);
525 default:
526 EMSG("Invalid rsa id");
527 return TEE_ERROR_BAD_PARAMETERS;
528 }
529 }
530
hpre_rsa_crt_decrypt_alloc(struct hpre_rsa_msg * msg)531 static enum hisi_drv_status hpre_rsa_crt_decrypt_alloc(struct hpre_rsa_msg *msg)
532 {
533 uint32_t size = HPRE_RSA_CRT_TOTAL_BUF_SIZE(msg->key_bytes);
534 uint8_t *data = NULL;
535
536 data = calloc(1, size);
537 if (!data) {
538 EMSG("Fail to alloc rsa crt total buf");
539 return HISI_QM_DRVCRYPT_ENOMEM;
540 }
541
542 msg->prikey = data;
543 msg->prikey_dma = virt_to_phys(msg->prikey);
544
545 msg->in = data + (msg->key_bytes * 2) + (msg->key_bytes >> 1);
546 msg->in_dma = msg->prikey_dma + (msg->key_bytes * 2) +
547 (msg->key_bytes >> 1);
548
549 msg->out = msg->in + msg->key_bytes;
550 msg->out_dma = msg->in_dma + msg->key_bytes;
551
552 return HISI_QM_DRVCRYPT_NO_ERR;
553 }
554
555 static enum hisi_drv_status
hpre_rsa_ncrt_decrypt_alloc(struct hpre_rsa_msg * msg)556 hpre_rsa_ncrt_decrypt_alloc(struct hpre_rsa_msg *msg)
557 {
558 uint32_t size = HPRE_RSA_NCRT_TOTAL_BUF_SIZE(msg->key_bytes);
559 uint8_t *data = NULL;
560
561 data = calloc(1, size);
562 if (!data) {
563 EMSG("Fail to alloc rsa ncrt buf");
564 return HISI_QM_DRVCRYPT_ENOMEM;
565 }
566
567 msg->prikey = data;
568 msg->prikey_dma = virt_to_phys(msg->prikey);
569
570 msg->in = data + (msg->key_bytes * 2);
571 msg->in_dma = msg->prikey_dma + (msg->key_bytes * 2);
572
573 msg->out = msg->in + msg->key_bytes;
574 msg->out_dma = msg->in_dma + msg->key_bytes;
575
576 return HISI_QM_DRVCRYPT_NO_ERR;
577 }
578
579 static enum hisi_drv_status
hpre_rsa_crt_decrypt_bn2bin(struct hpre_rsa_msg * msg,struct drvcrypt_rsa_ed * rsa_data)580 hpre_rsa_crt_decrypt_bn2bin(struct hpre_rsa_msg *msg,
581 struct drvcrypt_rsa_ed *rsa_data)
582 {
583 struct rsa_keypair *key = rsa_data->key.key;
584 uint32_t p_bytes = msg->key_bytes >> 1;
585 uint32_t dq_len = crypto_bignum_num_bytes(key->dq);
586 uint32_t dp_len = crypto_bignum_num_bytes(key->dp);
587 uint32_t q_len = crypto_bignum_num_bytes(key->q);
588 uint32_t p_len = crypto_bignum_num_bytes(key->p);
589 uint32_t qp_len = crypto_bignum_num_bytes(key->qp);
590 uint8_t *dq = msg->prikey;
591 uint8_t *dp = msg->prikey + p_bytes;
592 uint8_t *q = dp + p_bytes;
593 uint8_t *p = q + p_bytes;
594 uint8_t *qp = p + p_bytes;
595 enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR;
596
597 crypto_bignum_bn2bin(key->dq, dq);
598 crypto_bignum_bn2bin(key->dp, dp);
599 crypto_bignum_bn2bin(key->q, q);
600 crypto_bignum_bn2bin(key->p, p);
601 crypto_bignum_bn2bin(key->qp, qp);
602
603 ret = hpre_bin_from_crypto_bin(dq, dq, p_bytes, dq_len);
604 if (ret) {
605 EMSG("Fail to transfer rsa crt dq from crypto_bin to hpre_bin");
606 return ret;
607 }
608
609 ret = hpre_bin_from_crypto_bin(dp, dp, p_bytes, dp_len);
610 if (ret) {
611 EMSG("Fail to transfer rsa crt dp from crypto_bin to hpre_bin");
612 return ret;
613 }
614
615 ret = hpre_bin_from_crypto_bin(q, q, p_bytes, q_len);
616 if (ret) {
617 EMSG("Fail to transfer rsa crt q from crypto_bin to hpre_bin");
618 return ret;
619 }
620
621 ret = hpre_bin_from_crypto_bin(p, p, p_bytes, p_len);
622 if (ret) {
623 EMSG("Fail to transfer rsa crt p from crypto_bin to hpre_bin");
624 return ret;
625 }
626
627 ret = hpre_bin_from_crypto_bin(qp, qp, p_bytes, qp_len);
628 if (ret) {
629 EMSG("Fail to transfer rsa crt qinv from crypto_bin to hpre_bin");
630 return ret;
631 }
632
633 ret = hpre_bin_from_crypto_bin(msg->in, rsa_data->cipher.data,
634 msg->key_bytes, rsa_data->cipher.length);
635 if (ret)
636 EMSG("Fail to transfer rsa ciphertext from crypto_bin to hpre_bin");
637
638 return ret;
639 }
640
641 static enum hisi_drv_status
hpre_rsa_ncrt_decrypt_bn2bin(struct hpre_rsa_msg * msg,struct drvcrypt_rsa_ed * rsa_data)642 hpre_rsa_ncrt_decrypt_bn2bin(struct hpre_rsa_msg *msg,
643 struct drvcrypt_rsa_ed *rsa_data)
644 {
645 struct rsa_keypair *key = rsa_data->key.key;
646 uint32_t d_len = 0;
647 uint32_t n_len = 0;
648 enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR;
649 uint8_t *n = NULL;
650
651 n = msg->prikey + msg->key_bytes;
652
653 crypto_bignum_bn2bin(key->d, msg->prikey);
654 crypto_bignum_bn2bin(key->n, n);
655 d_len = crypto_bignum_num_bytes(key->d);
656 n_len = crypto_bignum_num_bytes(key->n);
657
658 ret = hpre_bin_from_crypto_bin(msg->prikey, msg->prikey, msg->key_bytes,
659 d_len);
660 if (ret) {
661 EMSG("Fail to transfer rsa ncrt d from crypto_bin to hpre_bin");
662 return ret;
663 }
664
665 ret = hpre_bin_from_crypto_bin(n, n, msg->key_bytes, n_len);
666 if (ret) {
667 EMSG("Fail to transfer rsa ncrt n from crypto_bin to hpre_bin");
668 return ret;
669 }
670
671 ret = hpre_bin_from_crypto_bin(msg->in, rsa_data->cipher.data,
672 msg->key_bytes, rsa_data->cipher.length);
673 if (ret)
674 EMSG("Fail to transfer rsa ciphertext from crypto_bin to hpre_bin");
675
676 return ret;
677 }
678
hpre_rsa_is_crt_mod(struct rsa_keypair * key)679 static bool hpre_rsa_is_crt_mod(struct rsa_keypair *key)
680 {
681 if (key->p && crypto_bignum_num_bits(key->p) &&
682 key->q && crypto_bignum_num_bits(key->q) &&
683 key->dp && crypto_bignum_num_bits(key->dp) &&
684 key->dq && crypto_bignum_num_bits(key->dq) &&
685 key->qp && crypto_bignum_num_bits(key->qp))
686 return true;
687
688 return false;
689 }
690
hpre_rsa_decrypt_init(struct hpre_rsa_msg * msg,struct drvcrypt_rsa_ed * rsa_data)691 static TEE_Result hpre_rsa_decrypt_init(struct hpre_rsa_msg *msg,
692 struct drvcrypt_rsa_ed *rsa_data)
693 {
694 struct rsa_keypair *key = rsa_data->key.key;
695 size_t n_bytes = rsa_data->key.n_size;
696 bool is_crt = false;
697 enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR;
698
699 msg->is_private = rsa_data->key.isprivate;
700 msg->key_bytes = hpre_rsa_get_hw_kbytes(BYTES_TO_BITS(n_bytes));
701 if (!msg->key_bytes)
702 return TEE_ERROR_BAD_PARAMETERS;
703
704 is_crt = hpre_rsa_is_crt_mod(key);
705 if (is_crt) {
706 msg->alg_type = HPRE_ALG_NC_CRT;
707 ret = hpre_rsa_crt_decrypt_alloc(msg);
708 if (ret)
709 return TEE_ERROR_OUT_OF_MEMORY;
710
711 ret = hpre_rsa_crt_decrypt_bn2bin(msg, rsa_data);
712 if (ret) {
713 hpre_rsa_params_free(msg);
714 return TEE_ERROR_BAD_STATE;
715 }
716 } else {
717 msg->alg_type = HPRE_ALG_NC_NCRT;
718 ret = hpre_rsa_ncrt_decrypt_alloc(msg);
719 if (ret)
720 return TEE_ERROR_OUT_OF_MEMORY;
721
722 ret = hpre_rsa_ncrt_decrypt_bn2bin(msg, rsa_data);
723 if (ret) {
724 hpre_rsa_params_free(msg);
725 return TEE_ERROR_BAD_STATE;
726 }
727 }
728
729 return TEE_SUCCESS;
730 }
731
rsa_nopad_decrypt(struct drvcrypt_rsa_ed * rsa_data)732 static TEE_Result rsa_nopad_decrypt(struct drvcrypt_rsa_ed *rsa_data)
733 {
734 size_t n_bytes = rsa_data->key.n_size;
735 struct hpre_rsa_msg msg = { };
736 uint32_t offset = 0;
737 TEE_Result ret = TEE_SUCCESS;
738 uint8_t *pos = NULL;
739
740 if (rsa_data->cipher.length > n_bytes) {
741 EMSG("Invalid cipher length[%zu]", rsa_data->cipher.length);
742 return TEE_ERROR_BAD_PARAMETERS;
743 }
744
745 ret = hpre_rsa_decrypt_init(&msg, rsa_data);
746 if (ret) {
747 EMSG("Fail to init rsa msg");
748 return ret;
749 }
750
751 ret = hpre_rsa_do_task(&msg);
752 if (ret)
753 goto decrypt_deinit;
754
755 pos = msg.out + msg.key_bytes - n_bytes;
756 if (rsa_data->rsa_id == DRVCRYPT_RSA_NOPAD) {
757 /* Plaintext can not have valid zero data in NOPAD MODE */
758 while ((offset < n_bytes - 1) && (pos[offset] == 0))
759 offset++;
760 }
761
762 rsa_data->message.length = n_bytes - offset;
763 memcpy(rsa_data->message.data, pos + offset, rsa_data->message.length);
764
765 decrypt_deinit:
766 hpre_rsa_params_free(&msg);
767
768 return ret;
769 }
770
rsaes_pkcs_v1_5_decode(struct drvcrypt_rsa_ed * rsa_data,uint8_t * out,size_t * out_len)771 static TEE_Result rsaes_pkcs_v1_5_decode(struct drvcrypt_rsa_ed *rsa_data,
772 uint8_t *out, size_t *out_len)
773 {
774 size_t em_len = rsa_data->message.length;
775 uint8_t *em = rsa_data->message.data;
776 size_t ps_len = 0;
777 size_t i = 0;
778
779 /* PKCS_V1.5 EM format 0x00 || 0x02 || PS non-zero || 0x00 || M */
780 if (em_len < PKCS_V1_5_MSG_MIN_LEN || em[0] != 0 ||
781 em[1] != ENCRYPT_PAD) {
782 EMSG("Invalid pkcs_v1.5 decode parameter");
783 return TEE_ERROR_BAD_PARAMETERS;
784 }
785
786 for (i = PKCS_V1_5_PS_POS; i < em_len; i++) {
787 if (em[i] == 0)
788 break;
789 }
790
791 if (i >= em_len) {
792 EMSG("Fail to find zero pos");
793 return TEE_ERROR_BAD_PARAMETERS;
794 }
795
796 ps_len = i - PKCS_V1_5_PS_POS;
797 if (em_len - ps_len - PKCS_V1_5_FIXED_LEN > *out_len ||
798 ps_len < PKCS_V1_5_PS_MIN_LEN) {
799 EMSG("Invalid pkcs_v1.5 decode ps_len");
800 return TEE_ERROR_BAD_PARAMETERS;
801 }
802
803 *out_len = em_len - ps_len - PKCS_V1_5_FIXED_LEN;
804 memcpy(out, em + ps_len + PKCS_V1_5_FIXED_LEN, *out_len);
805
806 return TEE_SUCCESS;
807 }
808
rsa_pkcs_decrypt(struct drvcrypt_rsa_ed * rsa_data)809 static TEE_Result rsa_pkcs_decrypt(struct drvcrypt_rsa_ed *rsa_data)
810 {
811 uint32_t n_bytes = rsa_data->key.n_size;
812 struct drvcrypt_rsa_ed rsa_dec_info = *rsa_data;
813 TEE_Result ret = TEE_SUCCESS;
814
815 /* Alloc pkcs_v1.5 encode message data buf */
816 rsa_dec_info.message.data = malloc(n_bytes);
817 if (!rsa_dec_info.message.data) {
818 EMSG("Fail to alloc message data buf");
819 return TEE_ERROR_OUT_OF_MEMORY;
820 }
821
822 rsa_dec_info.message.length = n_bytes;
823 ret = rsa_nopad_decrypt(&rsa_dec_info);
824 if (ret)
825 goto free_data;
826
827 ret = rsaes_pkcs_v1_5_decode(&rsa_dec_info, rsa_data->message.data,
828 &rsa_data->message.length);
829 if (ret)
830 EMSG("Fail to get pkcs_v1.5 decode message data");
831
832 free_data:
833 free(rsa_dec_info.message.data);
834
835 return ret;
836 }
837
rsa_oaep_get_seed(struct drvcrypt_rsa_ed * rsa_data,uint8_t * mask_db,uint8_t * seed)838 static TEE_Result rsa_oaep_get_seed(struct drvcrypt_rsa_ed *rsa_data,
839 uint8_t *mask_db, uint8_t *seed)
840 {
841 size_t db_len = rsa_data->key.n_size - rsa_data->digest_size - 1;
842 uint8_t mask_db_mgf[OAEP_MAX_HASH_LEN] = { };
843 size_t lhash_len = rsa_data->digest_size;
844 uint8_t *mask_seed = NULL;
845 TEE_Result ret = TEE_SUCCESS;
846
847 mask_seed = rsa_data->message.data + 1;
848
849 ret = mgf_process(lhash_len, mask_db, db_len, mask_db_mgf, lhash_len,
850 rsa_data);
851 if (ret) {
852 EMSG("Fail to get mask_db mgf result");
853 return ret;
854 }
855
856 return xor_process(mask_seed, mask_db_mgf, seed, lhash_len);
857 }
858
rsa_oaep_get_db(struct drvcrypt_rsa_ed * rsa_data,uint8_t * mask_db,uint8_t * seed,uint8_t * db)859 static TEE_Result rsa_oaep_get_db(struct drvcrypt_rsa_ed *rsa_data,
860 uint8_t *mask_db, uint8_t *seed, uint8_t *db)
861 {
862 size_t db_len = rsa_data->key.n_size - rsa_data->digest_size - 1;
863 size_t lhash_len = rsa_data->digest_size;
864 uint8_t seed_mgf[OAEP_MAX_DB_LEN] = { };
865 TEE_Result ret = TEE_SUCCESS;
866
867 ret = mgf_process(lhash_len, seed, lhash_len, seed_mgf, db_len,
868 rsa_data);
869 if (ret) {
870 EMSG("Fail to get seed mgf result");
871 return ret;
872 }
873
874 return xor_process(mask_db, seed_mgf, db, db_len);
875 }
876
rsa_oaep_get_msg(struct drvcrypt_rsa_ed * rsa_data,uint8_t * db,uint8_t * out,size_t * out_len)877 static TEE_Result rsa_oaep_get_msg(struct drvcrypt_rsa_ed *rsa_data,
878 uint8_t *db, uint8_t *out, size_t *out_len)
879 {
880 size_t db_len = rsa_data->key.n_size - rsa_data->digest_size - 1;
881 size_t lhash_len = rsa_data->digest_size;
882 uint8_t hash[OAEP_MAX_HASH_LEN] = { };
883 size_t msg_len = 0;
884 size_t lp_len = 0;
885 TEE_Result ret = TEE_SUCCESS;
886
887 /* oaep db format lhash || ps zero || 01 || M */
888 ret = tee_hash_createdigest(rsa_data->hash_algo, rsa_data->label.data,
889 rsa_data->label.length, hash, lhash_len);
890 if (ret) {
891 EMSG("Fail to get label hash");
892 return ret;
893 }
894
895 if (memcmp(hash, db, lhash_len)) {
896 EMSG("Hash is not equal");
897 return TEE_ERROR_BAD_PARAMETERS;
898 }
899
900 for (lp_len = lhash_len; lp_len < db_len; lp_len++) {
901 if (db[lp_len] != 0)
902 break;
903 }
904
905 if (lp_len == db_len) {
906 EMSG("Fail to find fixed 01 in db");
907 return TEE_ERROR_BAD_PARAMETERS;
908 }
909
910 msg_len = db_len - lp_len - 1;
911 if (msg_len > rsa_data->message.length) {
912 DMSG("Message space is not enough");
913 *out_len = msg_len;
914 return TEE_ERROR_SHORT_BUFFER;
915 }
916
917 *out_len = msg_len;
918 memcpy(out, db + lp_len + 1, msg_len);
919
920 return TEE_SUCCESS;
921 }
922
rsa_oaep_decode(struct drvcrypt_rsa_ed * rsa_data,uint8_t * out,size_t * out_len)923 static TEE_Result rsa_oaep_decode(struct drvcrypt_rsa_ed *rsa_data,
924 uint8_t *out, size_t *out_len)
925 {
926 size_t lhash_len = rsa_data->digest_size;
927 uint8_t seed[OAEP_MAX_HASH_LEN] = { };
928 uint8_t db[OAEP_MAX_DB_LEN] = { };
929 uint8_t *mask_db = NULL;
930 TEE_Result ret = TEE_SUCCESS;
931
932 /* oaep format 00 || maskedseed || maskeddb */
933 mask_db = rsa_data->message.data + lhash_len + 1;
934 ret = rsa_oaep_get_seed(rsa_data, mask_db, seed);
935 if (ret)
936 return ret;
937
938 ret = rsa_oaep_get_db(rsa_data, mask_db, seed, db);
939 if (ret)
940 return ret;
941
942 return rsa_oaep_get_msg(rsa_data, db, out, out_len);
943 }
944
rsa_oaep_decrypt(struct drvcrypt_rsa_ed * rsa_data)945 static TEE_Result rsa_oaep_decrypt(struct drvcrypt_rsa_ed *rsa_data)
946 {
947 size_t n_bytes = rsa_data->key.n_size;
948 struct drvcrypt_rsa_ed rsa_dec_info = *rsa_data;
949 TEE_Result ret = TEE_SUCCESS;
950
951 /* Alloc oaep encode message data buf */
952 rsa_dec_info.message.data = malloc(n_bytes);
953 if (!rsa_dec_info.message.data) {
954 EMSG("Fail to alloc message data buf");
955 return TEE_ERROR_OUT_OF_MEMORY;
956 }
957
958 rsa_dec_info.message.length = n_bytes;
959 ret = rsa_nopad_decrypt(&rsa_dec_info);
960 if (ret)
961 goto free_data;
962
963 ret = rsa_oaep_decode(&rsa_dec_info, rsa_data->message.data,
964 &rsa_data->message.length);
965 if (ret)
966 EMSG("Fail to get oaep decode message data");
967
968 free_data:
969 free(rsa_dec_info.message.data);
970
971 return ret;
972 }
973
hpre_rsa_decrypt(struct drvcrypt_rsa_ed * rsa_data)974 static TEE_Result hpre_rsa_decrypt(struct drvcrypt_rsa_ed *rsa_data)
975 {
976 if (!rsa_data) {
977 EMSG("Invalid rsa decrypt input parameter");
978 return TEE_ERROR_BAD_PARAMETERS;
979 }
980
981 switch (rsa_data->rsa_id) {
982 case DRVCRYPT_RSA_NOPAD:
983 case DRVCRYPT_RSASSA_PKCS_V1_5:
984 case DRVCRYPT_RSASSA_PSS:
985 return rsa_nopad_decrypt(rsa_data);
986 case DRVCRYPT_RSA_PKCS_V1_5:
987 return rsa_pkcs_decrypt(rsa_data);
988 case DRVCRYPT_RSA_OAEP:
989 return rsa_oaep_decrypt(rsa_data);
990 default:
991 EMSG("Invalid rsa id");
992 return TEE_ERROR_NOT_SUPPORTED;
993 }
994 }
995
996 static const struct drvcrypt_rsa driver_rsa = {
997 .alloc_keypair = sw_crypto_acipher_alloc_rsa_keypair,
998 .alloc_publickey = sw_crypto_acipher_alloc_rsa_public_key,
999 .free_publickey = sw_crypto_acipher_free_rsa_public_key,
1000 .free_keypair = sw_crypto_acipher_free_rsa_keypair,
1001 .gen_keypair = sw_crypto_acipher_gen_rsa_key,
1002 .encrypt = hpre_rsa_encrypt,
1003 .decrypt = hpre_rsa_decrypt,
1004 .optional = {
1005 /*
1006 * If ssa_sign or verify is NULL, the framework will fill
1007 * data format directly by soft calculation. Then call api
1008 * encrypt or decrypt.
1009 */
1010 .ssa_sign = NULL,
1011 .ssa_verify = NULL,
1012 },
1013 };
1014
hpre_rsa_init(void)1015 static TEE_Result hpre_rsa_init(void)
1016 {
1017 TEE_Result ret = drvcrypt_register_rsa(&driver_rsa);
1018
1019 if (ret != TEE_SUCCESS)
1020 EMSG("hpre rsa register to crypto fail");
1021
1022 return ret;
1023 }
1024
1025 driver_init(hpre_rsa_init);
1026