xref: /optee_os/core/drivers/crypto/crypto_api/acipher/rsassa.c (revision 04e46975d8f02e25209af552aaea4acb4d70c7f9)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright 2018-2020 NXP
4  *
5  * RSA Signature Software common implementation.
6  * Functions preparing and/or verifying the signature
7  * encoded string.
8  *
9  * PKCS #1 v2.1: RSA Cryptography Standard
10  * https://www.ietf.org/rfc/rfc3447.txt
11  */
12 #include <crypto/crypto.h>
13 #include <drvcrypt.h>
14 #include <drvcrypt_asn1_oid.h>
15 #include <drvcrypt_math.h>
16 #include <malloc.h>
17 #include <string.h>
18 #include <tee_api_defines_extensions.h>
19 #include <tee/tee_cryp_utl.h>
20 #include <utee_defines.h>
21 #include <util.h>
22 
23 #include "local.h"
24 
25 /*
26  * PKCS#1 V1.5 - Encode the message in Distinguished Encoding Rules
27  * (DER) format.
28  * Refer to EMSA-PKCS1-v1_5 chapter of the PKCS#1 v2.1
29  *
30  * @ssa_data  RSA data to encode
31  * @EM        [out] Encoded Message
32  */
emsa_pkcs1_v1_5_encode(struct drvcrypt_rsa_ssa * ssa_data,struct drvcrypt_buf * EM)33 static TEE_Result emsa_pkcs1_v1_5_encode(struct drvcrypt_rsa_ssa *ssa_data,
34 					 struct drvcrypt_buf *EM)
35 {
36 	const struct drvcrypt_oid *hash_oid = NULL;
37 	size_t ps_size = 0;
38 	uint8_t *buf = NULL;
39 
40 	hash_oid = drvcrypt_get_alg_hash_oid(ssa_data->hash_algo);
41 	if (!hash_oid)
42 		return TEE_ERROR_NOT_SUPPORTED;
43 
44 	/*
45 	 * Calculate the PS size
46 	 *  EM Size (modulus size) - 3 bytes - DigestInfo DER format size
47 	 */
48 	ps_size = ssa_data->key.n_size - 3;
49 	ps_size -= ssa_data->digest_size;
50 	ps_size -= 10 + hash_oid->asn1_length;
51 
52 	CRYPTO_TRACE("PS size = %zu (n %zu)", ps_size, ssa_data->key.n_size);
53 
54 	/*
55 	 * EM = 0x00 || 0x01 || PS || 0x00 || T
56 	 *
57 	 * where T represent the message DigestInfo in DER:
58 	 *    DigestInfo ::= SEQUENCE {
59 	 *                digestAlgorithm AlgorithmIdentifier,
60 	 *                digest OCTET STRING
61 	 *                }
62 	 *
63 	 * T  Length = digest length + oid length
64 	 * EM Length = T Length + 11 + PS Length
65 	 */
66 	buf = EM->data;
67 
68 	/* Set the EM first byte to 0x00 */
69 	*buf++ = 0x00;
70 
71 	/* Set the EM second byte to 0x01 */
72 	*buf++ = 0x01;
73 
74 	/* Fill PS with 0xFF */
75 	memset(buf, UINT8_MAX, ps_size);
76 	buf += ps_size;
77 
78 	/* Set the Byte after PS to 0x00 */
79 	*buf++ = 0x00;
80 
81 	/*
82 	 * Create the DigestInfo DER Sequence
83 	 *
84 	 *  DigestInfo ::= SEQUENCE {
85 	 *                digestAlgorithm AlgorithmIdentifier,
86 	 *                digest OCTET STRING
87 	 *                }
88 	 *
89 	 */
90 	/* SEQUENCE { */
91 	*buf++ = DRVCRYPT_ASN1_SEQUENCE | DRVCRYPT_ASN1_CONSTRUCTED;
92 	*buf++ = 0x08 + hash_oid->asn1_length + ssa_data->digest_size;
93 
94 	/* digestAlgorithm AlgorithmIdentifier */
95 	*buf++ = DRVCRYPT_ASN1_SEQUENCE | DRVCRYPT_ASN1_CONSTRUCTED;
96 	*buf++ = 0x04 + hash_oid->asn1_length;
97 	*buf++ = DRVCRYPT_ASN1_OID;
98 	*buf++ = hash_oid->asn1_length;
99 
100 	/* digest OCTET STRING */
101 	memcpy(buf, hash_oid->asn1, hash_oid->asn1_length);
102 	buf += hash_oid->asn1_length;
103 	*buf++ = DRVCRYPT_ASN1_NULL;
104 	*buf++ = 0x00;
105 	*buf++ = DRVCRYPT_ASN1_OCTET_STRING;
106 	*buf++ = ssa_data->digest_size;
107 	/* } */
108 
109 	memcpy(buf, ssa_data->message.data, ssa_data->digest_size);
110 
111 	CRYPTO_DUMPBUF("Encoded Message", EM->data, (size_t)EM->length);
112 
113 	return TEE_SUCCESS;
114 }
115 
116 /*
117  * PKCS#1 V1.5 - Encode the message in Distinguished Encoding Rules
118  * (DER) format.
119  * Refer to EMSA-PKCS1-v1_5 chapter of the PKCS#1 v2.1
120  *
121  * @ssa_data  RSA data to encode
122  * @EM        [out] Encoded Message
123  */
124 static TEE_Result
emsa_pkcs1_v1_5_encode_noasn1(struct drvcrypt_rsa_ssa * ssa_data,struct drvcrypt_buf * EM)125 emsa_pkcs1_v1_5_encode_noasn1(struct drvcrypt_rsa_ssa *ssa_data,
126 			      struct drvcrypt_buf *EM)
127 {
128 	size_t ps_size = 0;
129 	uint8_t *buf = NULL;
130 
131 	/*
132 	 * Calculate the PS size
133 	 *  EM Size (modulus size) - 3 bytes - Message Length
134 	 */
135 	ps_size = ssa_data->key.n_size - 3;
136 
137 	if (ps_size < ssa_data->message.length)
138 		return TEE_ERROR_BAD_PARAMETERS;
139 
140 	ps_size -= ssa_data->message.length;
141 
142 	CRYPTO_TRACE("PS size = %zu (n %zu)", ps_size, ssa_data->key.n_size);
143 
144 	/*
145 	 * EM = 0x00 || 0x01 || PS || 0x00 || T
146 	 *
147 	 * T  Length = message length
148 	 * EM Length = T Length + PS Length
149 	 */
150 	buf = EM->data;
151 
152 	/* Set the EM first byte to 0x00 */
153 	*buf++ = 0x00;
154 
155 	/* Set the EM second byte to 0x01 */
156 	*buf++ = 0x01;
157 
158 	/* Fill PS with 0xFF */
159 	memset(buf, UINT8_MAX, ps_size);
160 	buf += ps_size;
161 
162 	/* Set the Byte after PS to 0x00 */
163 	*buf++ = 0x00;
164 
165 	memcpy(buf, ssa_data->message.data, ssa_data->message.length);
166 
167 	CRYPTO_DUMPBUF("Encoded Message", EM->data, EM->length);
168 
169 	return TEE_SUCCESS;
170 }
171 
172 /*
173  * PKCS#1 V1.5 - Signature of RSA message and encodes the signature.
174  * Refer to RSASSA-PKCS1-v1_5 chapter of the PKCS#1 v2.1
175  *
176  * @ssa_data   [in/out] RSA data to sign / Signature
177  */
rsassa_pkcs1_v1_5_sign(struct drvcrypt_rsa_ssa * ssa_data)178 static TEE_Result rsassa_pkcs1_v1_5_sign(struct drvcrypt_rsa_ssa *ssa_data)
179 {
180 	TEE_Result ret = TEE_ERROR_BAD_PARAMETERS;
181 	struct drvcrypt_buf EM = { };
182 	struct drvcrypt_rsa_ed rsa_data = { };
183 	struct drvcrypt_rsa *rsa = NULL;
184 
185 	EM.length = ssa_data->key.n_size;
186 	EM.data = malloc(EM.length);
187 	if (!EM.data)
188 		return TEE_ERROR_OUT_OF_MEMORY;
189 
190 	/* Encode the Message */
191 	if (ssa_data->algo != TEE_ALG_RSASSA_PKCS1_V1_5)
192 		ret = emsa_pkcs1_v1_5_encode(ssa_data, &EM);
193 	else
194 		ret = emsa_pkcs1_v1_5_encode_noasn1(ssa_data, &EM);
195 
196 	if (ret != TEE_SUCCESS)
197 		goto out;
198 
199 	/*
200 	 * RSA Encrypt/Decrypt are doing the same operation except
201 	 * that decrypt takes a RSA Private key in parameter
202 	 */
203 	rsa_data.key.key = ssa_data->key.key;
204 	rsa_data.key.isprivate = true;
205 	rsa_data.key.n_size = ssa_data->key.n_size;
206 
207 	rsa = drvcrypt_get_ops(CRYPTO_RSA);
208 	if (!rsa) {
209 		ret = TEE_ERROR_NOT_IMPLEMENTED;
210 		goto out;
211 	}
212 
213 	/* Prepare the decryption data parameters */
214 	rsa_data.rsa_id = DRVCRYPT_RSASSA_PKCS_V1_5;
215 	rsa_data.message.data = ssa_data->signature.data;
216 	rsa_data.message.length = ssa_data->signature.length;
217 	rsa_data.cipher.data = EM.data;
218 	rsa_data.cipher.length = EM.length;
219 	rsa_data.hash_algo = ssa_data->hash_algo;
220 	rsa_data.algo = ssa_data->algo;
221 
222 	ret = rsa->decrypt(&rsa_data);
223 
224 	/* Set the message decrypted size */
225 	ssa_data->signature.length = rsa_data.message.length;
226 
227 out:
228 	free(EM.data);
229 
230 	return ret;
231 }
232 
233 /*
234  * PKCS#1 V1.5 - Verification of the RSA message's signature.
235  * Refer to RSASSA-PKCS1-v1_5 chapter of the PKCS#1 v2.1
236  *
237  * @ssa_data   [int/out] RSA data signed and encoded signature
238  */
rsassa_pkcs1_v1_5_verify(struct drvcrypt_rsa_ssa * ssa_data)239 static TEE_Result rsassa_pkcs1_v1_5_verify(struct drvcrypt_rsa_ssa *ssa_data)
240 {
241 	TEE_Result ret = TEE_ERROR_BAD_PARAMETERS;
242 	struct drvcrypt_buf EM = { };
243 	struct drvcrypt_buf EM_gen = { };
244 	struct drvcrypt_rsa_ed rsa_data = { };
245 	struct drvcrypt_rsa *rsa = NULL;
246 
247 	EM.length = ssa_data->key.n_size;
248 	EM.data = malloc(EM.length);
249 
250 	if (!EM.data) {
251 		ret = TEE_ERROR_OUT_OF_MEMORY;
252 		goto end_verify;
253 	}
254 
255 	EM_gen.length = ssa_data->key.n_size;
256 	EM_gen.data = malloc(EM.length);
257 
258 	if (!EM_gen.data) {
259 		ret = TEE_ERROR_OUT_OF_MEMORY;
260 		goto end_verify;
261 	}
262 
263 	/*
264 	 * RSA Encrypt/Decrypt are doing the same operation except
265 	 * that the encrypt takes a RSA Public key in parameter
266 	 */
267 	rsa_data.key.key = ssa_data->key.key;
268 	rsa_data.key.isprivate = false;
269 	rsa_data.key.n_size = ssa_data->key.n_size;
270 
271 	rsa = drvcrypt_get_ops(CRYPTO_RSA);
272 	if (rsa) {
273 		/* Prepare the encryption data parameters */
274 		rsa_data.rsa_id = DRVCRYPT_RSASSA_PKCS_V1_5;
275 		rsa_data.message.data = ssa_data->signature.data;
276 		rsa_data.message.length = ssa_data->signature.length;
277 		rsa_data.cipher.data = EM.data;
278 		rsa_data.cipher.length = EM.length;
279 		rsa_data.hash_algo = ssa_data->hash_algo;
280 		ret = rsa->encrypt(&rsa_data);
281 
282 		/* Set the cipher size */
283 		EM.length = rsa_data.cipher.length;
284 	} else {
285 		ret = TEE_ERROR_NOT_IMPLEMENTED;
286 	}
287 
288 	if (ret != TEE_SUCCESS)
289 		goto end_verify;
290 
291 	/* Encode the Message */
292 	if (ssa_data->algo != TEE_ALG_RSASSA_PKCS1_V1_5)
293 		ret = emsa_pkcs1_v1_5_encode(ssa_data, &EM_gen);
294 	else
295 		ret = emsa_pkcs1_v1_5_encode_noasn1(ssa_data, &EM_gen);
296 
297 	if (ret != TEE_SUCCESS)
298 		goto end_verify;
299 
300 	/* Check if EM decrypted and EM re-generated are identical */
301 	ret = TEE_ERROR_SIGNATURE_INVALID;
302 	if (EM.length == EM_gen.length) {
303 		if (!memcmp(EM.data, EM_gen.data, EM.length))
304 			ret = TEE_SUCCESS;
305 	}
306 
307 end_verify:
308 	free(EM.data);
309 	free(EM_gen.data);
310 
311 	return ret;
312 }
313 
314 /*
315  * PSS - Encode the message using a Probabilistic Signature Scheme (PSS)
316  * Refer to EMSA-PSS (encoding) chapter of the PKCS#1 v2.1
317  *
318  * @ssa_data  RSA data to encode
319  * @emBits    EM size in bits
320  * @EM        [out] Encoded Message
321  */
emsa_pss_encode(struct drvcrypt_rsa_ssa * ssa_data,size_t emBits,struct drvcrypt_buf * EM)322 static TEE_Result emsa_pss_encode(struct drvcrypt_rsa_ssa *ssa_data,
323 				  size_t emBits, struct drvcrypt_buf *EM)
324 {
325 	TEE_Result ret = TEE_ERROR_GENERIC;
326 	struct drvcrypt_rsa_mgf mgf_data = { };
327 	struct drvcrypt_buf hash = { };
328 	struct drvcrypt_buf dbMask = { };
329 	struct drvcrypt_buf DB = { };
330 	size_t db_size = 0;
331 	size_t ps_size = 0;
332 	size_t msg_size = 0;
333 	uint8_t *buf = NULL;
334 	uint8_t *msg_db = NULL;
335 	uint8_t *salt = NULL;
336 	struct drvcrypt_mod_op mod_op = { };
337 
338 	/*
339 	 * Build EM = maskedDB || H || 0xbc
340 	 *
341 	 * where
342 	 *    maskedDB = DB xor dbMask
343 	 *       DB     = PS || 0x01 || salt
344 	 *       dbMask = MGF(H, emLen - hLen - 1)
345 	 *
346 	 *    H  = Hash(M')
347 	 *       M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt
348 	 *
349 	 * PS size = emLen - sLen - hLen - 2 (may be = 0)
350 	 */
351 
352 	/*
353 	 * Calculate the M' and DB size to allocate a temporary buffer
354 	 * used for both object
355 	 */
356 	ps_size = EM->length - ssa_data->digest_size - ssa_data->salt_len - 2;
357 	db_size = EM->length - ssa_data->digest_size - 1;
358 	msg_size = 8 + ssa_data->digest_size + ssa_data->salt_len;
359 
360 	CRYPTO_TRACE("PS Len = %zu, DB Len = %zu, M' Len = %zu", ps_size,
361 		     db_size, msg_size);
362 
363 	msg_db = malloc(MAX(db_size, msg_size));
364 	if (!msg_db)
365 		return TEE_ERROR_OUT_OF_MEMORY;
366 
367 	if (ssa_data->salt_len) {
368 		salt = malloc(ssa_data->salt_len);
369 
370 		if (!salt) {
371 			ret = TEE_ERROR_OUT_OF_MEMORY;
372 			goto end_pss_encode;
373 		}
374 	}
375 
376 	/*
377 	 * Step 4 and 5
378 	 * Generate the M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt
379 	 *
380 	 * where
381 	 *   mHash is the input message (already hash)
382 	 *   salt is a random number of salt_len (input data) can be empty
383 	 */
384 	buf = msg_db;
385 
386 	memset(buf, 0, 8);
387 	buf += 8;
388 
389 	memcpy(buf, ssa_data->message.data, ssa_data->message.length);
390 	buf += ssa_data->message.length;
391 
392 	/* Get salt random number if salt length not 0 */
393 	if (ssa_data->salt_len) {
394 		ret = crypto_rng_read(salt, ssa_data->salt_len);
395 		CRYPTO_TRACE("Get salt of %zu bytes (ret = 0x%08" PRIx32 ")",
396 			     ssa_data->salt_len, ret);
397 		if (ret != TEE_SUCCESS)
398 			goto end_pss_encode;
399 
400 		memcpy(buf, salt, ssa_data->salt_len);
401 	}
402 
403 	/*
404 	 * Step 6
405 	 * Hash the M' generated new message
406 	 * H = hash(M')
407 	 */
408 	hash.data = &EM->data[db_size];
409 	hash.length = ssa_data->digest_size;
410 
411 	ret = tee_hash_createdigest(ssa_data->hash_algo, msg_db, msg_size,
412 				    hash.data, hash.length);
413 
414 	CRYPTO_TRACE("H = hash(M') returned 0x%08" PRIx32, ret);
415 	if (ret != TEE_SUCCESS)
416 		goto end_pss_encode;
417 
418 	CRYPTO_DUMPBUF("H = hash(M')", hash.data, hash.length);
419 
420 	/*
421 	 * Step 7 and 8
422 	 *   DB = PS || 0x01 || salt
423 	 */
424 	buf = msg_db;
425 	if (ps_size)
426 		memset(buf, 0, ps_size);
427 	buf += ps_size;
428 	*buf++ = 0x01;
429 
430 	if (ssa_data->salt_len)
431 		memcpy(buf, salt, ssa_data->salt_len);
432 
433 	DB.data = msg_db;
434 	DB.length = db_size;
435 
436 	CRYPTO_DUMPBUF("DB", DB.data, DB.length);
437 
438 	/*
439 	 * Step 9
440 	 * Generate a Mask of the seed value
441 	 * dbMask = MGF(H, emLen - hLen - 1)
442 	 *
443 	 * Note: Will use the same buffer for the dbMask and maskedDB
444 	 *       maskedDB is in the EM output
445 	 */
446 	dbMask.data = EM->data;
447 	dbMask.length = db_size;
448 
449 	mgf_data.hash_algo = ssa_data->hash_algo;
450 	mgf_data.digest_size = ssa_data->digest_size;
451 	mgf_data.seed.data = hash.data;
452 	mgf_data.seed.length = hash.length;
453 	mgf_data.mask.data = dbMask.data;
454 	mgf_data.mask.length = dbMask.length;
455 	ret = ssa_data->mgf(&mgf_data);
456 
457 	CRYPTO_TRACE("dbMask = MGF(H, emLen - hLen - 1) returned 0x%08" PRIx32,
458 		     ret);
459 	if (ret != TEE_SUCCESS)
460 		goto end_pss_encode;
461 
462 	CRYPTO_DUMPBUF("dbMask", dbMask.data, dbMask.length);
463 
464 	/*
465 	 * Step 10
466 	 * maskedDB = DB xor dbMask
467 	 */
468 	mod_op.n.length = dbMask.length;
469 	mod_op.a.data = DB.data;
470 	mod_op.a.length = DB.length;
471 	mod_op.b.data = dbMask.data;
472 	mod_op.b.length = dbMask.length;
473 	mod_op.result.data = dbMask.data;
474 	mod_op.result.length = dbMask.length;
475 
476 	ret = drvcrypt_xor_mod_n(&mod_op);
477 	CRYPTO_TRACE("maskedDB = DB xor dbMask returned 0x%08" PRIx32, ret);
478 	if (ret != TEE_SUCCESS)
479 		goto end_pss_encode;
480 
481 	CRYPTO_DUMPBUF("maskedDB", dbMask.data, dbMask.length);
482 
483 	/*
484 	 * Step 11
485 	 * Set the leftmost 8emLen - emBits of the leftmost octet
486 	 * in maskedDB to 0'
487 	 */
488 	EM->data[0] &= (UINT8_MAX >> ((EM->length * 8) - emBits));
489 
490 	/*
491 	 * Step 12
492 	 * EM = maskedDB || H || 0xbc
493 	 */
494 	EM->data[EM->length - 1] = 0xbc;
495 
496 	CRYPTO_DUMPBUF("EM", EM->data, EM->length);
497 
498 	ret = TEE_SUCCESS;
499 end_pss_encode:
500 	free(msg_db);
501 	free(salt);
502 
503 	return ret;
504 }
505 
506 /*
507  * PSS - Verify the message using a Probabilistic Signature Scheme (PSS)
508  * Refer to EMSA-PSS (verification) chapter of the PKCS#1 v2.1
509  *
510  * @ssa_data  RSA data to encode
511  * @emBits    EM size in bits
512  * @EM        [out] Encoded Message
513  */
emsa_pss_verify(struct drvcrypt_rsa_ssa * ssa_data,size_t emBits,struct drvcrypt_buf * EM)514 static TEE_Result emsa_pss_verify(struct drvcrypt_rsa_ssa *ssa_data,
515 				  size_t emBits, struct drvcrypt_buf *EM)
516 {
517 	TEE_Result ret = TEE_ERROR_GENERIC;
518 	struct drvcrypt_rsa_mgf mgf_data = { };
519 	struct drvcrypt_buf hash = { };
520 	struct drvcrypt_buf hash_gen = { };
521 	size_t db_size = 0;
522 	size_t ps_size = 0;
523 	size_t msg_size = 0;
524 	uint8_t *msg_db = NULL;
525 	uint8_t *salt = NULL;
526 	uint8_t *buf = NULL;
527 	struct drvcrypt_mod_op mod_op = { };
528 
529 	/*
530 	 * EM = maskedDB || H || 0xbc
531 	 *
532 	 * where
533 	 *    maskedDB = DB xor dbMask
534 	 *       DB     = PS || 0x01 || salt
535 	 *       dbMask = MGF(H, emLen - hLen - 1)
536 	 *
537 	 *    H  = Hash(M')
538 	 *       M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt
539 	 *
540 	 * PS size = emLen - sLen - hLen - 2 (may be = 0)
541 	 */
542 
543 	/*
544 	 * Calculate the M' and DB size to allocate a temporary buffer
545 	 * used for both object
546 	 */
547 	ps_size = EM->length - ssa_data->digest_size - ssa_data->salt_len - 2;
548 	db_size = EM->length - ssa_data->digest_size - 1;
549 	msg_size = 8 + ssa_data->digest_size + ssa_data->salt_len;
550 
551 	CRYPTO_TRACE("PS Len = %zu, DB Len = %zu, M' Len = %zu", ps_size,
552 		     db_size, msg_size);
553 
554 	msg_db = malloc(MAX(db_size, msg_size));
555 	if (!msg_db)
556 		return TEE_ERROR_OUT_OF_MEMORY;
557 
558 	/*
559 	 * Step 4
560 	 * Check if rightmost octet of EM is 0xbc
561 	 */
562 	if (EM->data[EM->length - 1] != 0xbc) {
563 		CRYPTO_TRACE("rigthmost octet != 0xbc (0x%" PRIX8 ")",
564 			     EM->data[EM->length - 1]);
565 		ret = TEE_ERROR_SIGNATURE_INVALID;
566 		goto end_pss_verify;
567 	}
568 
569 	/*
570 	 * Step 6
571 	 * Check if the leftmost 8emLen - emBits of the leftmost octet
572 	 * in maskedDB are 0's
573 	 */
574 	if (EM->data[0] & ~(UINT8_MAX >> (EM->length * 8 - emBits))) {
575 		CRYPTO_TRACE("Error leftmost octet maskedDB not 0's");
576 		CRYPTO_TRACE("EM[0] = 0x%" PRIX8
577 			     " - EM Len = %zu, emBits = %zu",
578 			     EM->data[0], EM->length, emBits);
579 		ret = TEE_ERROR_SIGNATURE_INVALID;
580 		goto end_pss_verify;
581 	}
582 
583 	hash.data = &EM->data[db_size];
584 	hash.length = ssa_data->digest_size;
585 
586 	/*
587 	 * Step 7
588 	 * dbMask = MGF(H, emLen - hLen - 1)
589 	 *
590 	 * Note: Will use the same buffer for the dbMask and DB
591 	 */
592 	mgf_data.hash_algo = ssa_data->hash_algo;
593 	mgf_data.digest_size = ssa_data->digest_size;
594 	mgf_data.seed.data = hash.data;
595 	mgf_data.seed.length = hash.length;
596 	mgf_data.mask.data = msg_db;
597 	mgf_data.mask.length = db_size;
598 	ret = ssa_data->mgf(&mgf_data);
599 
600 	CRYPTO_TRACE("dbMask = MGF(H, emLen - hLen - 1) returned 0x%08" PRIx32,
601 		     ret);
602 	if (ret != TEE_SUCCESS)
603 		goto end_pss_verify;
604 
605 	CRYPTO_DUMPBUF("dbMask", msg_db, db_size);
606 
607 	/*
608 	 * Step 8
609 	 * DB = maskedDB xor dbMask
610 	 *
611 	 *
612 	 * Note: maskedDB is in the EM input
613 	 */
614 	mod_op.n.length = db_size;
615 	mod_op.a.data = EM->data;
616 	mod_op.a.length = db_size;
617 	mod_op.b.data = msg_db;
618 	mod_op.b.length = db_size;
619 	mod_op.result.data = msg_db;
620 	mod_op.result.length = db_size;
621 
622 	ret = drvcrypt_xor_mod_n(&mod_op);
623 	CRYPTO_TRACE("DB = maskedDB xor dbMask returned 0x%08" PRIx32, ret);
624 	if (ret != TEE_SUCCESS)
625 		goto end_pss_verify;
626 
627 	/*
628 	 * Step 9
629 	 * Set the leftmost 8emLen - emBits of the leftmost octet in
630 	 * DB to zero
631 	 */
632 	*msg_db &= UINT8_MAX >> (EM->length * 8 - emBits);
633 
634 	CRYPTO_DUMPBUF("DB", msg_db, db_size);
635 
636 	/*
637 	 * Step 10
638 	 * Expected to have
639 	 *       DB     = PS || 0x01 || salt
640 	 *
641 	 * PS must be 0
642 	 * PS size = emLen - sLen - hLen - 2 (may be = 0)
643 	 */
644 	buf = msg_db;
645 	while (buf < msg_db + ps_size) {
646 		if (*buf++ != 0) {
647 			ret = TEE_ERROR_SIGNATURE_INVALID;
648 			goto end_pss_verify;
649 		}
650 	}
651 
652 	if (*buf++ != 0x01) {
653 		ret = TEE_ERROR_SIGNATURE_INVALID;
654 		goto end_pss_verify;
655 	}
656 
657 	/*
658 	 * Step 11
659 	 * Get the salt value
660 	 */
661 	if (ssa_data->salt_len) {
662 		salt = malloc(ssa_data->salt_len);
663 		if (!salt) {
664 			ret = TEE_ERROR_OUT_OF_MEMORY;
665 			goto end_pss_verify;
666 		}
667 
668 		memcpy(salt, buf, ssa_data->salt_len);
669 	}
670 
671 	/*
672 	 * Step 12
673 	 * Generate the M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt
674 	 *
675 	 * where
676 	 *   mHash is the input message (already hash)
677 	 *   salt is a random number of salt_len (input data) can be empty
678 	 */
679 	buf = msg_db;
680 
681 	memset(buf, 0, 8);
682 	buf += 8;
683 
684 	memcpy(buf, ssa_data->message.data, ssa_data->message.length);
685 	buf += ssa_data->message.length;
686 
687 	if (ssa_data->salt_len)
688 		memcpy(buf, salt, ssa_data->salt_len);
689 
690 	/*
691 	 * Step 13
692 	 * Hash the M' generated new message
693 	 * H' = hash(M')
694 	 *
695 	 * Note: reuse the msg_db buffer as Hash result
696 	 */
697 	hash_gen.data = msg_db;
698 	hash_gen.length = ssa_data->digest_size;
699 
700 	ret = tee_hash_createdigest(ssa_data->hash_algo, msg_db, msg_size,
701 				    hash_gen.data, hash_gen.length);
702 
703 	CRYPTO_TRACE("H' = hash(M') returned 0x%08" PRIx32, ret);
704 	if (ret != TEE_SUCCESS)
705 		goto end_pss_verify;
706 
707 	CRYPTO_DUMPBUF("H' = hash(M')", hash_gen.data, hash_gen.length);
708 
709 	/*
710 	 * Step 14
711 	 * Compare H and H'
712 	 */
713 	ret = TEE_ERROR_SIGNATURE_INVALID;
714 	if (!memcmp(hash_gen.data, hash.data, hash_gen.length))
715 		ret = TEE_SUCCESS;
716 
717 end_pss_verify:
718 	free(msg_db);
719 	free(salt);
720 
721 	return ret;
722 }
723 
724 /*
725  * PSS - Signature of RSA message and encodes the signature.
726  * Refer to RSASSA-PSS chapter of the PKCS#1 v2.1
727  *
728  * @ssa_data   [in/out] RSA data to sign / Signature
729  */
rsassa_pss_sign(struct drvcrypt_rsa_ssa * ssa_data)730 static TEE_Result rsassa_pss_sign(struct drvcrypt_rsa_ssa *ssa_data)
731 {
732 	TEE_Result ret = TEE_ERROR_GENERIC;
733 	struct rsa_keypair *key = NULL;
734 	struct drvcrypt_buf EM = { };
735 	size_t modBits = 0;
736 	struct drvcrypt_rsa_ed rsa_data = { };
737 	struct drvcrypt_rsa *rsa = NULL;
738 
739 	key = ssa_data->key.key;
740 
741 	/* Get modulus length in bits */
742 	modBits = crypto_bignum_num_bits(key->n);
743 	if (modBits <= 0)
744 		return TEE_ERROR_BAD_PARAMETERS;
745 
746 	/*
747 	 * EM Length = (modBits - 1) / 8
748 	 * if (modBits - 1) is not divisible by 8, one more byte is needed
749 	 */
750 	modBits--;
751 	EM.length = ROUNDUP_DIV(modBits, 8);
752 
753 	if (EM.length < ssa_data->digest_size + ssa_data->salt_len + 2)
754 		return TEE_ERROR_BAD_PARAMETERS;
755 
756 	EM.data = malloc(EM.length);
757 	if (!EM.data)
758 		return TEE_ERROR_OUT_OF_MEMORY;
759 
760 	CRYPTO_TRACE("modBits = %zu, hence EM Length = %zu", modBits + 1,
761 		     EM.length);
762 
763 	/* Encode the Message */
764 	ret = emsa_pss_encode(ssa_data, modBits, &EM);
765 	CRYPTO_TRACE("EMSA PSS Encode returned 0x%08" PRIx32, ret);
766 
767 	/*
768 	 * RSA Encrypt/Decrypt are doing the same operation
769 	 * expect that the decrypt takes a RSA Private key in parameter
770 	 */
771 	if (ret == TEE_SUCCESS) {
772 		rsa_data.key.key = ssa_data->key.key;
773 		rsa_data.key.isprivate = true;
774 		rsa_data.key.n_size = ssa_data->key.n_size;
775 
776 		rsa = drvcrypt_get_ops(CRYPTO_RSA);
777 		if (rsa) {
778 			/* Prepare the decryption data parameters */
779 			rsa_data.rsa_id = DRVCRYPT_RSASSA_PSS;
780 			rsa_data.message.data = ssa_data->signature.data;
781 			rsa_data.message.length = ssa_data->signature.length;
782 			rsa_data.cipher.data = EM.data;
783 			rsa_data.cipher.length = EM.length;
784 			rsa_data.algo = ssa_data->algo;
785 
786 			ret = rsa->decrypt(&rsa_data);
787 
788 			/* Set the message decrypted size */
789 			ssa_data->signature.length = rsa_data.message.length;
790 		} else {
791 			ret = TEE_ERROR_NOT_IMPLEMENTED;
792 		}
793 	}
794 	free(EM.data);
795 
796 	return ret;
797 }
798 
799 /*
800  * PSS - Signature verification of RSA message.
801  * Refer to RSASSA-PSS chapter of the PKCS#1 v2.1
802  *
803  * @ssa_data   [in/out] RSA Signature vs. message to verify
804  */
rsassa_pss_verify(struct drvcrypt_rsa_ssa * ssa_data)805 static TEE_Result rsassa_pss_verify(struct drvcrypt_rsa_ssa *ssa_data)
806 {
807 	TEE_Result ret = TEE_ERROR_GENERIC;
808 	struct rsa_public_key *key = NULL;
809 	struct drvcrypt_buf EM = { };
810 	size_t modBits = 0;
811 	struct drvcrypt_rsa_ed rsa_data = { };
812 	struct drvcrypt_rsa *rsa = NULL;
813 
814 	key = ssa_data->key.key;
815 
816 	/* Get modulus length in bits */
817 	modBits = crypto_bignum_num_bits(key->n);
818 	if (modBits <= 0)
819 		return TEE_ERROR_BAD_PARAMETERS;
820 
821 	/*
822 	 * EM Length = (modBits - 1) / 8
823 	 * if (modBits - 1) is not divisible by 8, one more byte is needed
824 	 */
825 	modBits--;
826 	EM.length = ROUNDUP_DIV(modBits, 8);
827 
828 	if (EM.length < ssa_data->digest_size + ssa_data->salt_len + 2)
829 		return TEE_ERROR_BAD_PARAMETERS;
830 
831 	EM.data = malloc(EM.length);
832 	if (!EM.data)
833 		return TEE_ERROR_OUT_OF_MEMORY;
834 
835 	CRYPTO_TRACE("modBits = %zu, hence EM Length = %zu", modBits + 1,
836 		     EM.length);
837 
838 	/*
839 	 * RSA Encrypt/Decrypt are doing the same operation
840 	 * expect that the encrypt takes a RSA Public key in parameter
841 	 */
842 	rsa_data.key.key = ssa_data->key.key;
843 	rsa_data.key.isprivate = false;
844 	rsa_data.key.n_size = ssa_data->key.n_size;
845 
846 	rsa = drvcrypt_get_ops(CRYPTO_RSA);
847 	if (rsa) {
848 		/* Prepare the encryption data parameters */
849 		rsa_data.rsa_id = DRVCRYPT_RSASSA_PSS;
850 		rsa_data.message.data = ssa_data->signature.data;
851 		rsa_data.message.length = ssa_data->signature.length;
852 		rsa_data.cipher.data = EM.data;
853 		rsa_data.cipher.length = EM.length;
854 		rsa_data.algo = ssa_data->algo;
855 
856 		ret = rsa->encrypt(&rsa_data);
857 
858 		/* Set the cipher size */
859 		EM.length = rsa_data.cipher.length;
860 	} else {
861 		ret = TEE_ERROR_NOT_IMPLEMENTED;
862 		goto end_pss_verify;
863 	}
864 
865 	if (ret == TEE_SUCCESS) {
866 		/* Verify the Message */
867 		ret = emsa_pss_verify(ssa_data, modBits, &EM);
868 		CRYPTO_TRACE("EMSA PSS Verify returned 0x%08" PRIx32, ret);
869 	} else {
870 		CRYPTO_TRACE("RSA NO PAD returned 0x%08" PRIx32, ret);
871 		ret = TEE_ERROR_SIGNATURE_INVALID;
872 	}
873 
874 end_pss_verify:
875 	free(EM.data);
876 
877 	return ret;
878 }
879 
drvcrypt_rsassa_sign(struct drvcrypt_rsa_ssa * ssa_data)880 TEE_Result drvcrypt_rsassa_sign(struct drvcrypt_rsa_ssa *ssa_data)
881 {
882 	switch (ssa_data->algo) {
883 	case TEE_ALG_RSASSA_PKCS1_V1_5:
884 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
885 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
886 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
887 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
888 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
889 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
890 		return rsassa_pkcs1_v1_5_sign(ssa_data);
891 
892 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_MD5:
893 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
894 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
895 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
896 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
897 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
898 		return rsassa_pss_sign(ssa_data);
899 
900 	default:
901 		break;
902 	}
903 
904 	return TEE_ERROR_BAD_PARAMETERS;
905 }
906 
drvcrypt_rsassa_verify(struct drvcrypt_rsa_ssa * ssa_data)907 TEE_Result drvcrypt_rsassa_verify(struct drvcrypt_rsa_ssa *ssa_data)
908 {
909 	switch (ssa_data->algo) {
910 	case TEE_ALG_RSASSA_PKCS1_V1_5:
911 	case TEE_ALG_RSASSA_PKCS1_V1_5_MD5:
912 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA1:
913 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA224:
914 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA256:
915 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA384:
916 	case TEE_ALG_RSASSA_PKCS1_V1_5_SHA512:
917 		return rsassa_pkcs1_v1_5_verify(ssa_data);
918 
919 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_MD5:
920 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA1:
921 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA224:
922 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA256:
923 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA384:
924 	case TEE_ALG_RSASSA_PKCS1_PSS_MGF1_SHA512:
925 		return rsassa_pss_verify(ssa_data);
926 
927 	default:
928 		break;
929 	}
930 
931 	return TEE_ERROR_BAD_PARAMETERS;
932 }
933