1 // SPDX-License-Identifier: BSD-2-Clause 2 /* 3 * Copyright (c) 2022-2024, HiSilicon Technologies Co., Ltd. 4 * Kunpeng hardware accelerator hpre rsa algorithm implementation. 5 */ 6 7 #include <crypto/crypto_impl.h> 8 #include <drvcrypt.h> 9 #include <drvcrypt_acipher.h> 10 #include <drvcrypt_math.h> 11 #include <initcall.h> 12 #include <malloc.h> 13 #include <mbedtls/rsa.h> 14 #include <rng_support.h> 15 #include <stdlib_ext.h> 16 #include <string.h> 17 #include <string_ext.h> 18 #include <tee/tee_cryp_utl.h> 19 #include <trace.h> 20 21 #include "hpre_main.h" 22 #include "hpre_rsa.h" 23 24 static enum hisi_drv_status hpre_rsa_fill_addr_params(struct hpre_rsa_msg *msg, 25 struct hpre_sqe *sqe) 26 { 27 switch (msg->alg_type) { 28 case HPRE_ALG_NC_NCRT: 29 case HPRE_ALG_NC_CRT: 30 if (msg->is_private) { 31 /* DECRYPT */ 32 sqe->key = msg->prikey_dma; 33 sqe->in = msg->in_dma; 34 sqe->out = msg->out_dma; 35 } else { 36 /* ENCRYPT */ 37 sqe->key = msg->pubkey_dma; 38 sqe->in = msg->in_dma; 39 sqe->out = msg->out_dma; 40 } 41 return HISI_QM_DRVCRYPT_NO_ERR; 42 default: 43 EMSG("Invalid alg_type[%"PRIu32"]", msg->alg_type); 44 return HISI_QM_DRVCRYPT_IN_EPARA; 45 } 46 } 47 48 static enum hisi_drv_status hpre_rsa_fill_sqe(void *bd, void *info) 49 { 50 struct hpre_rsa_msg *msg = info; 51 struct hpre_sqe *sqe = bd; 52 53 sqe->w0 = msg->alg_type | SHIFT_U32(0x1, HPRE_DONE_SHIFT); 54 sqe->task_len1 = TASK_LENGTH(msg->key_bytes); 55 56 return hpre_rsa_fill_addr_params(msg, sqe); 57 } 58 59 static enum hisi_drv_status hpre_rsa_parse_sqe(void *bd, void *info __unused) 60 { 61 struct hpre_sqe *sqe = bd; 62 uint16_t err = 0; 63 uint16_t done = 0; 64 65 err = HPRE_TASK_ETYPE(sqe->w0); 66 done = HPRE_TASK_DONE(sqe->w0); 67 if (done != HPRE_HW_TASK_DONE || err) { 68 EMSG("HPRE do rsa fail! done=0x%"PRIX16", etype=0x%"PRIX16, 69 done, err); 70 return HISI_QM_DRVCRYPT_IN_EPARA; 71 } 72 73 return HISI_QM_DRVCRYPT_NO_ERR; 74 } 75 76 static TEE_Result hpre_rsa_do_task(void *msg) 77 { 78 struct hisi_qp *rsa_qp = NULL; 79 TEE_Result res = TEE_SUCCESS; 80 enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; 81 82 rsa_qp = hpre_create_qp(HISI_QM_CHANNEL_TYPE0); 83 if (!rsa_qp) { 84 EMSG("Fail to create rsa qp"); 85 return TEE_ERROR_BUSY; 86 } 87 88 rsa_qp->fill_sqe = hpre_rsa_fill_sqe; 89 rsa_qp->parse_sqe = hpre_rsa_parse_sqe; 90 ret = hisi_qp_send(rsa_qp, msg); 91 if (ret) { 92 EMSG("Fail to send task, ret=%x"PRIx32, ret); 93 res = TEE_ERROR_BAD_STATE; 94 goto done_proc; 95 } 96 97 ret = hisi_qp_recv_sync(rsa_qp, msg); 98 if (ret) { 99 EMSG("Recv task error, ret=%x"PRIx32, ret); 100 res = TEE_ERROR_BAD_STATE; 101 goto done_proc; 102 } 103 104 done_proc: 105 hisi_qm_release_qp(rsa_qp); 106 107 return res; 108 } 109 110 static TEE_Result mgf_process(size_t digest_size, uint8_t *seed, 111 size_t seed_len, uint8_t *mask, size_t mask_len, 112 struct drvcrypt_rsa_ed *rsa_data) 113 { 114 struct drvcrypt_rsa_mgf mgf = { }; 115 116 if (!rsa_data->mgf) { 117 EMSG("mgf function is NULL"); 118 return TEE_ERROR_BAD_PARAMETERS; 119 } 120 121 mgf.hash_algo = rsa_data->hash_algo; 122 mgf.digest_size = digest_size; 123 mgf.seed.data = seed; 124 mgf.seed.length = seed_len; 125 mgf.mask.data = mask; 126 mgf.mask.length = mask_len; 127 128 return rsa_data->mgf(&mgf); 129 } 130 131 static TEE_Result xor_process(uint8_t *a, uint8_t *b, uint8_t *out, size_t len) 132 { 133 struct drvcrypt_mod_op xor_mod = { }; 134 135 xor_mod.n.length = len; 136 xor_mod.a.data = a; 137 xor_mod.a.length = len; 138 xor_mod.b.data = b; 139 xor_mod.b.length = len; 140 xor_mod.result.data = out; 141 xor_mod.result.length = len; 142 143 return drvcrypt_xor_mod_n(&xor_mod); 144 } 145 146 static size_t hpre_rsa_get_hw_kbytes(size_t key_bits) 147 { 148 size_t size = 0; 149 150 if (key_bits <= 1024) 151 size = BITS_TO_BYTES(1024); 152 else if (key_bits <= 2048) 153 size = BITS_TO_BYTES(2048); 154 else if (key_bits <= 3072) 155 size = BITS_TO_BYTES(3072); 156 else if (key_bits <= 4096) 157 size = BITS_TO_BYTES(4096); 158 else 159 EMSG("Invalid key_bits[%zu]", key_bits); 160 161 return size; 162 } 163 164 static void hpre_rsa_params_free(struct hpre_rsa_msg *msg) 165 { 166 switch (msg->alg_type) { 167 case HPRE_ALG_NC_NCRT: 168 if (msg->is_private) 169 free_wipe(msg->prikey); 170 else 171 free(msg->pubkey); 172 break; 173 case HPRE_ALG_NC_CRT: 174 if (msg->is_private) 175 free_wipe(msg->prikey); 176 break; 177 default: 178 EMSG("Invalid alg_type[%"PRIu32"]", msg->alg_type); 179 break; 180 } 181 } 182 183 static enum hisi_drv_status hpre_rsa_encrypt_alloc(struct hpre_rsa_msg *msg) 184 { 185 uint32_t size = HPRE_RSA_NCRT_TOTAL_BUF_SIZE(msg->key_bytes); 186 uint8_t *data = NULL; 187 188 data = calloc(1, size); 189 if (!data) { 190 EMSG("Fail to alloc rsa ncrt buf"); 191 return HISI_QM_DRVCRYPT_ENOMEM; 192 } 193 194 msg->pubkey = data; 195 msg->pubkey_dma = virt_to_phys(msg->pubkey); 196 197 msg->in = data + (msg->key_bytes * 2); 198 msg->in_dma = msg->pubkey_dma + (msg->key_bytes * 2); 199 200 msg->out = msg->in + msg->key_bytes; 201 msg->out_dma = msg->in_dma + msg->key_bytes; 202 203 return HISI_QM_DRVCRYPT_NO_ERR; 204 } 205 206 static enum hisi_drv_status 207 hpre_rsa_encrypt_bn2bin(struct hpre_rsa_msg *msg, 208 struct drvcrypt_rsa_ed *rsa_data) 209 { 210 struct rsa_public_key *key = rsa_data->key.key; 211 uint32_t e_len = 0; 212 uint32_t n_len = 0; 213 enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; 214 uint8_t *n = NULL; 215 216 n = msg->pubkey + msg->key_bytes; 217 218 crypto_bignum_bn2bin(key->e, msg->pubkey); 219 crypto_bignum_bn2bin(key->n, n); 220 e_len = crypto_bignum_num_bytes(key->e); 221 n_len = crypto_bignum_num_bytes(key->n); 222 223 ret = hpre_bin_from_crypto_bin(msg->pubkey, msg->pubkey, 224 msg->key_bytes, e_len); 225 if (ret) { 226 EMSG("Fail to transfer rsa ncrt e from crypto_bin to hpre_bin"); 227 return ret; 228 } 229 230 ret = hpre_bin_from_crypto_bin(n, n, msg->key_bytes, n_len); 231 if (ret) { 232 EMSG("Fail to transfer rsa ncrt n from crypto_bin to hpre_bin"); 233 return ret; 234 } 235 236 ret = hpre_bin_from_crypto_bin(msg->in, rsa_data->message.data, 237 msg->key_bytes, 238 rsa_data->message.length); 239 if (ret) 240 EMSG("Fail to transfer rsa plaintext from crypto_bin to hpre_bin"); 241 242 return ret; 243 } 244 245 static TEE_Result hpre_rsa_encrypt_init(struct hpre_rsa_msg *msg, 246 struct drvcrypt_rsa_ed *rsa_data) 247 { 248 size_t n_bytes = rsa_data->key.n_size; 249 enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; 250 251 msg->alg_type = HPRE_ALG_NC_NCRT; 252 msg->is_private = rsa_data->key.isprivate; 253 msg->key_bytes = hpre_rsa_get_hw_kbytes(BYTES_TO_BITS(n_bytes)); 254 if (!msg->key_bytes) 255 return TEE_ERROR_BAD_PARAMETERS; 256 257 ret = hpre_rsa_encrypt_alloc(msg); 258 if (ret) 259 return TEE_ERROR_OUT_OF_MEMORY; 260 261 ret = hpre_rsa_encrypt_bn2bin(msg, rsa_data); 262 if (ret) { 263 hpre_rsa_params_free(msg); 264 return TEE_ERROR_BAD_STATE; 265 } 266 267 return TEE_SUCCESS; 268 } 269 270 static TEE_Result rsa_nopad_encrypt(struct drvcrypt_rsa_ed *rsa_data) 271 { 272 size_t n_bytes = rsa_data->key.n_size; 273 struct hpre_rsa_msg msg = { }; 274 TEE_Result ret = TEE_SUCCESS; 275 276 if (rsa_data->message.length > n_bytes) { 277 EMSG("Invalid msg length[%zu]", rsa_data->message.length); 278 return TEE_ERROR_BAD_PARAMETERS; 279 } 280 281 ret = hpre_rsa_encrypt_init(&msg, rsa_data); 282 if (ret) { 283 EMSG("Fail to init rsa msg"); 284 return ret; 285 } 286 287 ret = hpre_rsa_do_task(&msg); 288 if (ret) 289 goto encrypt_deinit; 290 291 /* Ciphertext can have valid zero data in NOPAD MODE */ 292 memcpy(rsa_data->cipher.data, msg.out + msg.key_bytes - n_bytes, 293 n_bytes); 294 rsa_data->cipher.length = n_bytes; 295 296 encrypt_deinit: 297 hpre_rsa_params_free(&msg); 298 299 return ret; 300 } 301 302 static TEE_Result pkcs_v1_5_fill_ps(uint8_t *ps, size_t ps_len) 303 { 304 size_t i = 0; 305 306 if (hw_get_random_bytes(ps, ps_len)) { 307 EMSG("Fail to get ps data"); 308 return TEE_ERROR_NO_DATA; 309 } 310 311 for (i = 0; i < ps_len; i++) { 312 if (ps[i] == 0) 313 ps[i] = PKCS_V1_5_PS_DATA; 314 } 315 316 return TEE_SUCCESS; 317 } 318 319 static TEE_Result rsaes_pkcs_v1_5_encode(struct drvcrypt_rsa_ed *rsa_data, 320 uint8_t *out) 321 { 322 size_t msg_len = rsa_data->message.length; 323 size_t out_len = rsa_data->cipher.length; 324 size_t n_bytes = rsa_data->key.n_size; 325 uint8_t *ps = out + PKCS_V1_5_PS_POS; 326 TEE_Result ret = TEE_SUCCESS; 327 size_t ps_len = 0; 328 329 /* PKCS_V1.5 format 0x00 || 0x02 || PS non-zero || 0x00 || M */ 330 if ((msg_len + PKCS_V1_5_MSG_MIN_LEN) > n_bytes || out_len < n_bytes) { 331 EMSG("Invalid pkcs_v1.5 encode parameter"); 332 return TEE_ERROR_BAD_PARAMETERS; 333 } 334 335 ps_len = n_bytes - PKCS_V1_5_FIXED_LEN - msg_len; 336 ret = pkcs_v1_5_fill_ps(ps, ps_len); 337 if (ret) 338 return ret; 339 340 out[0] = 0; 341 out[1] = ENCRYPT_PAD; 342 out[PKCS_V1_5_FIXED_LEN + ps_len - 1] = 0; 343 memcpy(out + PKCS_V1_5_FIXED_LEN + ps_len, rsa_data->message.data, 344 msg_len); 345 346 return TEE_SUCCESS; 347 } 348 349 static TEE_Result rsa_pkcs_encrypt(struct drvcrypt_rsa_ed *rsa_data) 350 { 351 uint32_t n_bytes = rsa_data->key.n_size; 352 struct drvcrypt_rsa_ed rsa_enc_info = *rsa_data; 353 TEE_Result ret = TEE_SUCCESS; 354 355 /* Alloc pkcs_v1.5 encode message data buf */ 356 rsa_enc_info.message.data = malloc(n_bytes); 357 if (!rsa_enc_info.message.data) { 358 EMSG("Fail to alloc message data buf"); 359 return TEE_ERROR_OUT_OF_MEMORY; 360 } 361 362 rsa_enc_info.message.length = n_bytes; 363 ret = rsaes_pkcs_v1_5_encode(rsa_data, rsa_enc_info.message.data); 364 if (ret) { 365 EMSG("Fail to get pkcs_v1.5 encode message data"); 366 goto free_data; 367 } 368 369 ret = rsa_nopad_encrypt(&rsa_enc_info); 370 if (ret) 371 goto free_data; 372 373 memcpy(rsa_data->cipher.data, rsa_enc_info.cipher.data, 374 rsa_enc_info.cipher.length); 375 rsa_data->cipher.length = rsa_enc_info.cipher.length; 376 377 free_data: 378 free(rsa_enc_info.message.data); 379 380 return ret; 381 } 382 383 static TEE_Result rsa_oaep_fill_db(struct drvcrypt_rsa_ed *rsa_data, 384 uint8_t *db) 385 { 386 size_t lhash_len = rsa_data->digest_size; 387 size_t n_bytes = rsa_data->key.n_size; 388 size_t db_len = n_bytes - lhash_len - 1; 389 size_t ps_len = 0; 390 TEE_Result ret = TEE_SUCCESS; 391 392 /* oaep db format lhash || ps zero || 01 || M */ 393 ret = tee_hash_createdigest(rsa_data->hash_algo, rsa_data->label.data, 394 rsa_data->label.length, db, lhash_len); 395 if (ret) { 396 EMSG("Fail to get label hash"); 397 return ret; 398 } 399 400 ps_len = db_len - lhash_len - rsa_data->message.length - 1; 401 db[lhash_len + ps_len] = 1; 402 memcpy(db + lhash_len + ps_len + 1, rsa_data->message.data, 403 rsa_data->message.length); 404 405 return TEE_SUCCESS; 406 } 407 408 static TEE_Result rsa_oaep_fill_maskdb(struct drvcrypt_rsa_ed *rsa_data, 409 uint8_t *seed, uint8_t *db, 410 uint8_t *mask_db) 411 { 412 size_t lhash_len = rsa_data->digest_size; 413 size_t n_bytes = rsa_data->key.n_size; 414 size_t db_len = n_bytes - lhash_len - 1; 415 uint8_t seed_mgf[OAEP_MAX_DB_LEN] = { }; 416 TEE_Result ret = TEE_SUCCESS; 417 418 ret = mgf_process(lhash_len, seed, lhash_len, seed_mgf, db_len, 419 rsa_data); 420 if (ret) { 421 EMSG("Fail to get seed_mgf"); 422 return ret; 423 } 424 425 return xor_process(db, seed_mgf, mask_db, db_len); 426 } 427 428 static TEE_Result rsa_oaep_fill_maskseed(struct drvcrypt_rsa_ed *rsa_data, 429 uint8_t *seed, uint8_t *em) 430 { 431 uint8_t mask_db_mgf[OAEP_MAX_HASH_LEN] = { 0 }; 432 size_t lhash_len = rsa_data->digest_size; 433 size_t n_bytes = rsa_data->key.n_size; 434 size_t db_len = n_bytes - lhash_len - 1; 435 uint8_t *mask_db = em + lhash_len + 1; 436 uint8_t *mask_seed = em + 1; 437 TEE_Result ret = TEE_SUCCESS; 438 439 ret = mgf_process(lhash_len, mask_db, db_len, mask_db_mgf, lhash_len, 440 rsa_data); 441 if (ret) { 442 EMSG("Fail to get mask_db_mgf"); 443 return ret; 444 } 445 446 return xor_process(seed, mask_db_mgf, mask_seed, lhash_len); 447 } 448 449 static TEE_Result rsa_oaep_encode(struct drvcrypt_rsa_ed *rsa_data, 450 uint8_t *em) 451 { 452 size_t lhash_len = rsa_data->digest_size; 453 uint8_t db[OAEP_MAX_DB_LEN] = { }; 454 uint8_t seed[OAEP_MAX_HASH_LEN] = { }; 455 TEE_Result ret = TEE_SUCCESS; 456 457 /* oaep format 00 || maskedseed || maskeddb */ 458 em[0] = 0; 459 460 ret = rsa_oaep_fill_db(rsa_data, db); 461 if (ret) 462 return ret; 463 464 ret = hw_get_random_bytes(seed, lhash_len); 465 if (ret) 466 return ret; 467 468 ret = rsa_oaep_fill_maskdb(rsa_data, seed, db, em + lhash_len + 1); 469 if (ret) 470 return ret; 471 472 return rsa_oaep_fill_maskseed(rsa_data, seed, em); 473 } 474 475 static TEE_Result rsa_oaep_encrypt(struct drvcrypt_rsa_ed *rsa_data) 476 { 477 size_t n_bytes = rsa_data->key.n_size; 478 struct drvcrypt_rsa_ed rsa_enc_info = *rsa_data; 479 TEE_Result ret = TEE_SUCCESS; 480 481 /* Alloc oaep encode message data buf */ 482 rsa_enc_info.message.data = malloc(n_bytes); 483 if (!rsa_enc_info.message.data) { 484 EMSG("Fail to alloc message data buf"); 485 return TEE_ERROR_OUT_OF_MEMORY; 486 } 487 488 rsa_enc_info.message.length = n_bytes; 489 ret = rsa_oaep_encode(rsa_data, rsa_enc_info.message.data); 490 if (ret) { 491 EMSG("Fail to get oaep encode message data"); 492 goto free_data; 493 } 494 495 ret = rsa_nopad_encrypt(&rsa_enc_info); 496 if (ret) 497 goto free_data; 498 499 memcpy(rsa_data->cipher.data, rsa_enc_info.cipher.data, 500 rsa_enc_info.cipher.length); 501 rsa_data->cipher.length = rsa_enc_info.cipher.length; 502 503 free_data: 504 free(rsa_enc_info.message.data); 505 506 return ret; 507 } 508 509 static TEE_Result hpre_rsa_encrypt(struct drvcrypt_rsa_ed *rsa_data) 510 { 511 if (!rsa_data) { 512 EMSG("Invalid rsa encrypt input parameter"); 513 return TEE_ERROR_BAD_PARAMETERS; 514 } 515 516 switch (rsa_data->rsa_id) { 517 case DRVCRYPT_RSA_NOPAD: 518 case DRVCRYPT_RSASSA_PKCS_V1_5: 519 case DRVCRYPT_RSASSA_PSS: 520 return rsa_nopad_encrypt(rsa_data); 521 case DRVCRYPT_RSA_PKCS_V1_5: 522 return rsa_pkcs_encrypt(rsa_data); 523 case DRVCRYPT_RSA_OAEP: 524 return rsa_oaep_encrypt(rsa_data); 525 default: 526 EMSG("Invalid rsa id"); 527 return TEE_ERROR_BAD_PARAMETERS; 528 } 529 } 530 531 static enum hisi_drv_status hpre_rsa_crt_decrypt_alloc(struct hpre_rsa_msg *msg) 532 { 533 uint32_t size = HPRE_RSA_CRT_TOTAL_BUF_SIZE(msg->key_bytes); 534 uint8_t *data = NULL; 535 536 data = calloc(1, size); 537 if (!data) { 538 EMSG("Fail to alloc rsa crt total buf"); 539 return HISI_QM_DRVCRYPT_ENOMEM; 540 } 541 542 msg->prikey = data; 543 msg->prikey_dma = virt_to_phys(msg->prikey); 544 545 msg->in = data + (msg->key_bytes * 2) + (msg->key_bytes >> 1); 546 msg->in_dma = msg->prikey_dma + (msg->key_bytes * 2) + 547 (msg->key_bytes >> 1); 548 549 msg->out = msg->in + msg->key_bytes; 550 msg->out_dma = msg->in_dma + msg->key_bytes; 551 552 return HISI_QM_DRVCRYPT_NO_ERR; 553 } 554 555 static enum hisi_drv_status 556 hpre_rsa_ncrt_decrypt_alloc(struct hpre_rsa_msg *msg) 557 { 558 uint32_t size = HPRE_RSA_NCRT_TOTAL_BUF_SIZE(msg->key_bytes); 559 uint8_t *data = NULL; 560 561 data = calloc(1, size); 562 if (!data) { 563 EMSG("Fail to alloc rsa ncrt buf"); 564 return HISI_QM_DRVCRYPT_ENOMEM; 565 } 566 567 msg->prikey = data; 568 msg->prikey_dma = virt_to_phys(msg->prikey); 569 570 msg->in = data + (msg->key_bytes * 2); 571 msg->in_dma = msg->prikey_dma + (msg->key_bytes * 2); 572 573 msg->out = msg->in + msg->key_bytes; 574 msg->out_dma = msg->in_dma + msg->key_bytes; 575 576 return HISI_QM_DRVCRYPT_NO_ERR; 577 } 578 579 static enum hisi_drv_status 580 hpre_rsa_crt_decrypt_bn2bin(struct hpre_rsa_msg *msg, 581 struct drvcrypt_rsa_ed *rsa_data) 582 { 583 struct rsa_keypair *key = rsa_data->key.key; 584 uint32_t p_bytes = msg->key_bytes >> 1; 585 uint32_t dq_len = crypto_bignum_num_bytes(key->dq); 586 uint32_t dp_len = crypto_bignum_num_bytes(key->dp); 587 uint32_t q_len = crypto_bignum_num_bytes(key->q); 588 uint32_t p_len = crypto_bignum_num_bytes(key->p); 589 uint32_t qp_len = crypto_bignum_num_bytes(key->qp); 590 uint8_t *dq = msg->prikey; 591 uint8_t *dp = msg->prikey + p_bytes; 592 uint8_t *q = dp + p_bytes; 593 uint8_t *p = q + p_bytes; 594 uint8_t *qp = p + p_bytes; 595 enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; 596 597 crypto_bignum_bn2bin(key->dq, dq); 598 crypto_bignum_bn2bin(key->dp, dp); 599 crypto_bignum_bn2bin(key->q, q); 600 crypto_bignum_bn2bin(key->p, p); 601 crypto_bignum_bn2bin(key->qp, qp); 602 603 ret = hpre_bin_from_crypto_bin(dq, dq, p_bytes, dq_len); 604 if (ret) { 605 EMSG("Fail to transfer rsa crt dq from crypto_bin to hpre_bin"); 606 return ret; 607 } 608 609 ret = hpre_bin_from_crypto_bin(dp, dp, p_bytes, dp_len); 610 if (ret) { 611 EMSG("Fail to transfer rsa crt dp from crypto_bin to hpre_bin"); 612 return ret; 613 } 614 615 ret = hpre_bin_from_crypto_bin(q, q, p_bytes, q_len); 616 if (ret) { 617 EMSG("Fail to transfer rsa crt q from crypto_bin to hpre_bin"); 618 return ret; 619 } 620 621 ret = hpre_bin_from_crypto_bin(p, p, p_bytes, p_len); 622 if (ret) { 623 EMSG("Fail to transfer rsa crt p from crypto_bin to hpre_bin"); 624 return ret; 625 } 626 627 ret = hpre_bin_from_crypto_bin(qp, qp, p_bytes, qp_len); 628 if (ret) { 629 EMSG("Fail to transfer rsa crt qinv from crypto_bin to hpre_bin"); 630 return ret; 631 } 632 633 ret = hpre_bin_from_crypto_bin(msg->in, rsa_data->cipher.data, 634 msg->key_bytes, rsa_data->cipher.length); 635 if (ret) 636 EMSG("Fail to transfer rsa ciphertext from crypto_bin to hpre_bin"); 637 638 return ret; 639 } 640 641 static enum hisi_drv_status 642 hpre_rsa_ncrt_decrypt_bn2bin(struct hpre_rsa_msg *msg, 643 struct drvcrypt_rsa_ed *rsa_data) 644 { 645 struct rsa_keypair *key = rsa_data->key.key; 646 uint32_t d_len = 0; 647 uint32_t n_len = 0; 648 enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; 649 uint8_t *n = NULL; 650 651 n = msg->prikey + msg->key_bytes; 652 653 crypto_bignum_bn2bin(key->d, msg->prikey); 654 crypto_bignum_bn2bin(key->n, n); 655 d_len = crypto_bignum_num_bytes(key->d); 656 n_len = crypto_bignum_num_bytes(key->n); 657 658 ret = hpre_bin_from_crypto_bin(msg->prikey, msg->prikey, msg->key_bytes, 659 d_len); 660 if (ret) { 661 EMSG("Fail to transfer rsa ncrt d from crypto_bin to hpre_bin"); 662 return ret; 663 } 664 665 ret = hpre_bin_from_crypto_bin(n, n, msg->key_bytes, n_len); 666 if (ret) { 667 EMSG("Fail to transfer rsa ncrt n from crypto_bin to hpre_bin"); 668 return ret; 669 } 670 671 ret = hpre_bin_from_crypto_bin(msg->in, rsa_data->cipher.data, 672 msg->key_bytes, rsa_data->cipher.length); 673 if (ret) 674 EMSG("Fail to transfer rsa ciphertext from crypto_bin to hpre_bin"); 675 676 return ret; 677 } 678 679 static bool hpre_rsa_is_crt_mod(struct rsa_keypair *key) 680 { 681 if (key->p && crypto_bignum_num_bits(key->p) && 682 key->q && crypto_bignum_num_bits(key->q) && 683 key->dp && crypto_bignum_num_bits(key->dp) && 684 key->dq && crypto_bignum_num_bits(key->dq) && 685 key->qp && crypto_bignum_num_bits(key->qp)) 686 return true; 687 688 return false; 689 } 690 691 static TEE_Result hpre_rsa_decrypt_init(struct hpre_rsa_msg *msg, 692 struct drvcrypt_rsa_ed *rsa_data) 693 { 694 struct rsa_keypair *key = rsa_data->key.key; 695 size_t n_bytes = rsa_data->key.n_size; 696 bool is_crt = false; 697 enum hisi_drv_status ret = HISI_QM_DRVCRYPT_NO_ERR; 698 699 msg->is_private = rsa_data->key.isprivate; 700 msg->key_bytes = hpre_rsa_get_hw_kbytes(BYTES_TO_BITS(n_bytes)); 701 if (!msg->key_bytes) 702 return TEE_ERROR_BAD_PARAMETERS; 703 704 is_crt = hpre_rsa_is_crt_mod(key); 705 if (is_crt) { 706 msg->alg_type = HPRE_ALG_NC_CRT; 707 ret = hpre_rsa_crt_decrypt_alloc(msg); 708 if (ret) 709 return TEE_ERROR_OUT_OF_MEMORY; 710 711 ret = hpre_rsa_crt_decrypt_bn2bin(msg, rsa_data); 712 if (ret) { 713 hpre_rsa_params_free(msg); 714 return TEE_ERROR_BAD_STATE; 715 } 716 } else { 717 msg->alg_type = HPRE_ALG_NC_NCRT; 718 ret = hpre_rsa_ncrt_decrypt_alloc(msg); 719 if (ret) 720 return TEE_ERROR_OUT_OF_MEMORY; 721 722 ret = hpre_rsa_ncrt_decrypt_bn2bin(msg, rsa_data); 723 if (ret) { 724 hpre_rsa_params_free(msg); 725 return TEE_ERROR_BAD_STATE; 726 } 727 } 728 729 return TEE_SUCCESS; 730 } 731 732 static TEE_Result rsa_nopad_decrypt(struct drvcrypt_rsa_ed *rsa_data) 733 { 734 size_t n_bytes = rsa_data->key.n_size; 735 struct hpre_rsa_msg msg = { }; 736 uint32_t offset = 0; 737 TEE_Result ret = TEE_SUCCESS; 738 uint8_t *pos = NULL; 739 740 if (rsa_data->cipher.length > n_bytes) { 741 EMSG("Invalid cipher length[%zu]", rsa_data->cipher.length); 742 return TEE_ERROR_BAD_PARAMETERS; 743 } 744 745 ret = hpre_rsa_decrypt_init(&msg, rsa_data); 746 if (ret) { 747 EMSG("Fail to init rsa msg"); 748 return ret; 749 } 750 751 ret = hpre_rsa_do_task(&msg); 752 if (ret) 753 goto decrypt_deinit; 754 755 pos = msg.out + msg.key_bytes - n_bytes; 756 if (rsa_data->rsa_id == DRVCRYPT_RSA_NOPAD) { 757 /* Plaintext can not have valid zero data in NOPAD MODE */ 758 while ((offset < n_bytes - 1) && (pos[offset] == 0)) 759 offset++; 760 } 761 762 rsa_data->message.length = n_bytes - offset; 763 memcpy(rsa_data->message.data, pos + offset, rsa_data->message.length); 764 765 decrypt_deinit: 766 hpre_rsa_params_free(&msg); 767 768 return ret; 769 } 770 771 static TEE_Result rsaes_pkcs_v1_5_decode(struct drvcrypt_rsa_ed *rsa_data, 772 uint8_t *out, size_t *out_len) 773 { 774 size_t em_len = rsa_data->message.length; 775 uint8_t *em = rsa_data->message.data; 776 size_t ps_len = 0; 777 size_t i = 0; 778 779 /* PKCS_V1.5 EM format 0x00 || 0x02 || PS non-zero || 0x00 || M */ 780 if (em_len < PKCS_V1_5_MSG_MIN_LEN || em[0] != 0 || 781 em[1] != ENCRYPT_PAD) { 782 EMSG("Invalid pkcs_v1.5 decode parameter"); 783 return TEE_ERROR_BAD_PARAMETERS; 784 } 785 786 for (i = PKCS_V1_5_PS_POS; i < em_len; i++) { 787 if (em[i] == 0) 788 break; 789 } 790 791 if (i >= em_len) { 792 EMSG("Fail to find zero pos"); 793 return TEE_ERROR_BAD_PARAMETERS; 794 } 795 796 ps_len = i - PKCS_V1_5_PS_POS; 797 if (em_len - ps_len - PKCS_V1_5_FIXED_LEN > *out_len || 798 ps_len < PKCS_V1_5_PS_MIN_LEN) { 799 EMSG("Invalid pkcs_v1.5 decode ps_len"); 800 return TEE_ERROR_BAD_PARAMETERS; 801 } 802 803 *out_len = em_len - ps_len - PKCS_V1_5_FIXED_LEN; 804 memcpy(out, em + ps_len + PKCS_V1_5_FIXED_LEN, *out_len); 805 806 return TEE_SUCCESS; 807 } 808 809 static TEE_Result rsa_pkcs_decrypt(struct drvcrypt_rsa_ed *rsa_data) 810 { 811 uint32_t n_bytes = rsa_data->key.n_size; 812 struct drvcrypt_rsa_ed rsa_dec_info = *rsa_data; 813 TEE_Result ret = TEE_SUCCESS; 814 815 /* Alloc pkcs_v1.5 encode message data buf */ 816 rsa_dec_info.message.data = malloc(n_bytes); 817 if (!rsa_dec_info.message.data) { 818 EMSG("Fail to alloc message data buf"); 819 return TEE_ERROR_OUT_OF_MEMORY; 820 } 821 822 rsa_dec_info.message.length = n_bytes; 823 ret = rsa_nopad_decrypt(&rsa_dec_info); 824 if (ret) 825 goto free_data; 826 827 ret = rsaes_pkcs_v1_5_decode(&rsa_dec_info, rsa_data->message.data, 828 &rsa_data->message.length); 829 if (ret) 830 EMSG("Fail to get pkcs_v1.5 decode message data"); 831 832 free_data: 833 free(rsa_dec_info.message.data); 834 835 return ret; 836 } 837 838 static TEE_Result rsa_oaep_get_seed(struct drvcrypt_rsa_ed *rsa_data, 839 uint8_t *mask_db, uint8_t *seed) 840 { 841 size_t db_len = rsa_data->key.n_size - rsa_data->digest_size - 1; 842 uint8_t mask_db_mgf[OAEP_MAX_HASH_LEN] = { }; 843 size_t lhash_len = rsa_data->digest_size; 844 uint8_t *mask_seed = NULL; 845 TEE_Result ret = TEE_SUCCESS; 846 847 mask_seed = rsa_data->message.data + 1; 848 849 ret = mgf_process(lhash_len, mask_db, db_len, mask_db_mgf, lhash_len, 850 rsa_data); 851 if (ret) { 852 EMSG("Fail to get mask_db mgf result"); 853 return ret; 854 } 855 856 return xor_process(mask_seed, mask_db_mgf, seed, lhash_len); 857 } 858 859 static TEE_Result rsa_oaep_get_db(struct drvcrypt_rsa_ed *rsa_data, 860 uint8_t *mask_db, uint8_t *seed, uint8_t *db) 861 { 862 size_t db_len = rsa_data->key.n_size - rsa_data->digest_size - 1; 863 size_t lhash_len = rsa_data->digest_size; 864 uint8_t seed_mgf[OAEP_MAX_DB_LEN] = { }; 865 TEE_Result ret = TEE_SUCCESS; 866 867 ret = mgf_process(lhash_len, seed, lhash_len, seed_mgf, db_len, 868 rsa_data); 869 if (ret) { 870 EMSG("Fail to get seed mgf result"); 871 return ret; 872 } 873 874 return xor_process(mask_db, seed_mgf, db, db_len); 875 } 876 877 static TEE_Result rsa_oaep_get_msg(struct drvcrypt_rsa_ed *rsa_data, 878 uint8_t *db, uint8_t *out, size_t *out_len) 879 { 880 size_t db_len = rsa_data->key.n_size - rsa_data->digest_size - 1; 881 size_t lhash_len = rsa_data->digest_size; 882 uint8_t hash[OAEP_MAX_HASH_LEN] = { }; 883 size_t msg_len = 0; 884 size_t lp_len = 0; 885 TEE_Result ret = TEE_SUCCESS; 886 887 /* oaep db format lhash || ps zero || 01 || M */ 888 ret = tee_hash_createdigest(rsa_data->hash_algo, rsa_data->label.data, 889 rsa_data->label.length, hash, lhash_len); 890 if (ret) { 891 EMSG("Fail to get label hash"); 892 return ret; 893 } 894 895 if (memcmp(hash, db, lhash_len)) { 896 EMSG("Hash is not equal"); 897 return TEE_ERROR_BAD_PARAMETERS; 898 } 899 900 for (lp_len = lhash_len; lp_len < db_len; lp_len++) { 901 if (db[lp_len] != 0) 902 break; 903 } 904 905 if (lp_len == db_len) { 906 EMSG("Fail to find fixed 01 in db"); 907 return TEE_ERROR_BAD_PARAMETERS; 908 } 909 910 msg_len = db_len - lp_len - 1; 911 if (msg_len > rsa_data->message.length) { 912 DMSG("Message space is not enough"); 913 *out_len = msg_len; 914 return TEE_ERROR_SHORT_BUFFER; 915 } 916 917 *out_len = msg_len; 918 memcpy(out, db + lp_len + 1, msg_len); 919 920 return TEE_SUCCESS; 921 } 922 923 static TEE_Result rsa_oaep_decode(struct drvcrypt_rsa_ed *rsa_data, 924 uint8_t *out, size_t *out_len) 925 { 926 size_t lhash_len = rsa_data->digest_size; 927 uint8_t seed[OAEP_MAX_HASH_LEN] = { }; 928 uint8_t db[OAEP_MAX_DB_LEN] = { }; 929 uint8_t *mask_db = NULL; 930 TEE_Result ret = TEE_SUCCESS; 931 932 /* oaep format 00 || maskedseed || maskeddb */ 933 mask_db = rsa_data->message.data + lhash_len + 1; 934 ret = rsa_oaep_get_seed(rsa_data, mask_db, seed); 935 if (ret) 936 return ret; 937 938 ret = rsa_oaep_get_db(rsa_data, mask_db, seed, db); 939 if (ret) 940 return ret; 941 942 return rsa_oaep_get_msg(rsa_data, db, out, out_len); 943 } 944 945 static TEE_Result rsa_oaep_decrypt(struct drvcrypt_rsa_ed *rsa_data) 946 { 947 size_t n_bytes = rsa_data->key.n_size; 948 struct drvcrypt_rsa_ed rsa_dec_info = *rsa_data; 949 TEE_Result ret = TEE_SUCCESS; 950 951 /* Alloc oaep encode message data buf */ 952 rsa_dec_info.message.data = malloc(n_bytes); 953 if (!rsa_dec_info.message.data) { 954 EMSG("Fail to alloc message data buf"); 955 return TEE_ERROR_OUT_OF_MEMORY; 956 } 957 958 rsa_dec_info.message.length = n_bytes; 959 ret = rsa_nopad_decrypt(&rsa_dec_info); 960 if (ret) 961 goto free_data; 962 963 ret = rsa_oaep_decode(&rsa_dec_info, rsa_data->message.data, 964 &rsa_data->message.length); 965 if (ret) 966 EMSG("Fail to get oaep decode message data"); 967 968 free_data: 969 free(rsa_dec_info.message.data); 970 971 return ret; 972 } 973 974 static TEE_Result hpre_rsa_decrypt(struct drvcrypt_rsa_ed *rsa_data) 975 { 976 if (!rsa_data) { 977 EMSG("Invalid rsa decrypt input parameter"); 978 return TEE_ERROR_BAD_PARAMETERS; 979 } 980 981 switch (rsa_data->rsa_id) { 982 case DRVCRYPT_RSA_NOPAD: 983 case DRVCRYPT_RSASSA_PKCS_V1_5: 984 case DRVCRYPT_RSASSA_PSS: 985 return rsa_nopad_decrypt(rsa_data); 986 case DRVCRYPT_RSA_PKCS_V1_5: 987 return rsa_pkcs_decrypt(rsa_data); 988 case DRVCRYPT_RSA_OAEP: 989 return rsa_oaep_decrypt(rsa_data); 990 default: 991 EMSG("Invalid rsa id"); 992 return TEE_ERROR_NOT_SUPPORTED; 993 } 994 } 995 996 static const struct drvcrypt_rsa driver_rsa = { 997 .alloc_keypair = sw_crypto_acipher_alloc_rsa_keypair, 998 .alloc_publickey = sw_crypto_acipher_alloc_rsa_public_key, 999 .free_publickey = sw_crypto_acipher_free_rsa_public_key, 1000 .free_keypair = sw_crypto_acipher_free_rsa_keypair, 1001 .gen_keypair = sw_crypto_acipher_gen_rsa_key, 1002 .encrypt = hpre_rsa_encrypt, 1003 .decrypt = hpre_rsa_decrypt, 1004 .optional = { 1005 /* 1006 * If ssa_sign or verify is NULL, the framework will fill 1007 * data format directly by soft calculation. Then call api 1008 * encrypt or decrypt. 1009 */ 1010 .ssa_sign = NULL, 1011 .ssa_verify = NULL, 1012 }, 1013 }; 1014 1015 static TEE_Result hpre_rsa_init(void) 1016 { 1017 TEE_Result ret = drvcrypt_register_rsa(&driver_rsa); 1018 1019 if (ret != TEE_SUCCESS) 1020 EMSG("hpre rsa register to crypto fail"); 1021 1022 return ret; 1023 } 1024 1025 driver_init(hpre_rsa_init); 1026