xref: /OK3568_Linux_fs/u-boot/lib/rsa/rsa-mod-exp.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun /*
2*4882a593Smuzhiyun  * Copyright (c) 2013, Google Inc.
3*4882a593Smuzhiyun  *
4*4882a593Smuzhiyun  * SPDX-License-Identifier:	GPL-2.0+
5*4882a593Smuzhiyun  */
6*4882a593Smuzhiyun 
7*4882a593Smuzhiyun #ifndef USE_HOSTCC
8*4882a593Smuzhiyun #include <common.h>
9*4882a593Smuzhiyun #include <fdtdec.h>
10*4882a593Smuzhiyun #include <asm/types.h>
11*4882a593Smuzhiyun #include <asm/byteorder.h>
12*4882a593Smuzhiyun #include <linux/errno.h>
13*4882a593Smuzhiyun #include <asm/types.h>
14*4882a593Smuzhiyun #include <asm/unaligned.h>
15*4882a593Smuzhiyun #else
16*4882a593Smuzhiyun #include "fdt_host.h"
17*4882a593Smuzhiyun #include "mkimage.h"
18*4882a593Smuzhiyun #include <fdt_support.h>
19*4882a593Smuzhiyun #endif
20*4882a593Smuzhiyun #include <u-boot/rsa.h>
21*4882a593Smuzhiyun #include <u-boot/rsa-mod-exp.h>
22*4882a593Smuzhiyun 
23*4882a593Smuzhiyun #define UINT64_MULT32(v, multby)  (((uint64_t)(v)) * ((uint32_t)(multby)))
24*4882a593Smuzhiyun 
25*4882a593Smuzhiyun #define get_unaligned_be32(a) fdt32_to_cpu(*(uint32_t *)a)
26*4882a593Smuzhiyun #define put_unaligned_be32(a, b) (*(uint32_t *)(b) = cpu_to_fdt32(a))
27*4882a593Smuzhiyun 
28*4882a593Smuzhiyun /* Default public exponent for backward compatibility */
29*4882a593Smuzhiyun #define RSA_DEFAULT_PUBEXP	65537
30*4882a593Smuzhiyun 
31*4882a593Smuzhiyun /**
32*4882a593Smuzhiyun  * subtract_modulus() - subtract modulus from the given value
33*4882a593Smuzhiyun  *
34*4882a593Smuzhiyun  * @key:	Key containing modulus to subtract
35*4882a593Smuzhiyun  * @num:	Number to subtract modulus from, as little endian word array
36*4882a593Smuzhiyun  */
subtract_modulus(const struct rsa_public_key * key,uint32_t num[])37*4882a593Smuzhiyun static void subtract_modulus(const struct rsa_public_key *key, uint32_t num[])
38*4882a593Smuzhiyun {
39*4882a593Smuzhiyun 	int64_t acc = 0;
40*4882a593Smuzhiyun 	uint i;
41*4882a593Smuzhiyun 
42*4882a593Smuzhiyun 	for (i = 0; i < key->len; i++) {
43*4882a593Smuzhiyun 		acc += (uint64_t)num[i] - key->modulus[i];
44*4882a593Smuzhiyun 		num[i] = (uint32_t)acc;
45*4882a593Smuzhiyun 		acc >>= 32;
46*4882a593Smuzhiyun 	}
47*4882a593Smuzhiyun }
48*4882a593Smuzhiyun 
49*4882a593Smuzhiyun /**
50*4882a593Smuzhiyun  * greater_equal_modulus() - check if a value is >= modulus
51*4882a593Smuzhiyun  *
52*4882a593Smuzhiyun  * @key:	Key containing modulus to check
53*4882a593Smuzhiyun  * @num:	Number to check against modulus, as little endian word array
54*4882a593Smuzhiyun  * @return 0 if num < modulus, 1 if num >= modulus
55*4882a593Smuzhiyun  */
greater_equal_modulus(const struct rsa_public_key * key,uint32_t num[])56*4882a593Smuzhiyun static int greater_equal_modulus(const struct rsa_public_key *key,
57*4882a593Smuzhiyun 				 uint32_t num[])
58*4882a593Smuzhiyun {
59*4882a593Smuzhiyun 	int i;
60*4882a593Smuzhiyun 
61*4882a593Smuzhiyun 	for (i = (int)key->len - 1; i >= 0; i--) {
62*4882a593Smuzhiyun 		if (num[i] < key->modulus[i])
63*4882a593Smuzhiyun 			return 0;
64*4882a593Smuzhiyun 		if (num[i] > key->modulus[i])
65*4882a593Smuzhiyun 			return 1;
66*4882a593Smuzhiyun 	}
67*4882a593Smuzhiyun 
68*4882a593Smuzhiyun 	return 1;  /* equal */
69*4882a593Smuzhiyun }
70*4882a593Smuzhiyun 
71*4882a593Smuzhiyun /**
72*4882a593Smuzhiyun  * montgomery_mul_add_step() - Perform montgomery multiply-add step
73*4882a593Smuzhiyun  *
74*4882a593Smuzhiyun  * Operation: montgomery result[] += a * b[] / n0inv % modulus
75*4882a593Smuzhiyun  *
76*4882a593Smuzhiyun  * @key:	RSA key
77*4882a593Smuzhiyun  * @result:	Place to put result, as little endian word array
78*4882a593Smuzhiyun  * @a:		Multiplier
79*4882a593Smuzhiyun  * @b:		Multiplicand, as little endian word array
80*4882a593Smuzhiyun  */
montgomery_mul_add_step(const struct rsa_public_key * key,uint32_t result[],const uint32_t a,const uint32_t b[])81*4882a593Smuzhiyun static void montgomery_mul_add_step(const struct rsa_public_key *key,
82*4882a593Smuzhiyun 		uint32_t result[], const uint32_t a, const uint32_t b[])
83*4882a593Smuzhiyun {
84*4882a593Smuzhiyun 	uint64_t acc_a, acc_b;
85*4882a593Smuzhiyun 	uint32_t d0;
86*4882a593Smuzhiyun 	uint i;
87*4882a593Smuzhiyun 
88*4882a593Smuzhiyun 	acc_a = (uint64_t)a * b[0] + result[0];
89*4882a593Smuzhiyun 	d0 = (uint32_t)acc_a * key->n0inv;
90*4882a593Smuzhiyun 	acc_b = (uint64_t)d0 * key->modulus[0] + (uint32_t)acc_a;
91*4882a593Smuzhiyun 	for (i = 1; i < key->len; i++) {
92*4882a593Smuzhiyun 		acc_a = (acc_a >> 32) + (uint64_t)a * b[i] + result[i];
93*4882a593Smuzhiyun 		acc_b = (acc_b >> 32) + (uint64_t)d0 * key->modulus[i] +
94*4882a593Smuzhiyun 				(uint32_t)acc_a;
95*4882a593Smuzhiyun 		result[i - 1] = (uint32_t)acc_b;
96*4882a593Smuzhiyun 	}
97*4882a593Smuzhiyun 
98*4882a593Smuzhiyun 	acc_a = (acc_a >> 32) + (acc_b >> 32);
99*4882a593Smuzhiyun 
100*4882a593Smuzhiyun 	result[i - 1] = (uint32_t)acc_a;
101*4882a593Smuzhiyun 
102*4882a593Smuzhiyun 	if (acc_a >> 32)
103*4882a593Smuzhiyun 		subtract_modulus(key, result);
104*4882a593Smuzhiyun }
105*4882a593Smuzhiyun 
106*4882a593Smuzhiyun /**
107*4882a593Smuzhiyun  * montgomery_mul() - Perform montgomery mutitply
108*4882a593Smuzhiyun  *
109*4882a593Smuzhiyun  * Operation: montgomery result[] = a[] * b[] / n0inv % modulus
110*4882a593Smuzhiyun  *
111*4882a593Smuzhiyun  * @key:	RSA key
112*4882a593Smuzhiyun  * @result:	Place to put result, as little endian word array
113*4882a593Smuzhiyun  * @a:		Multiplier, as little endian word array
114*4882a593Smuzhiyun  * @b:		Multiplicand, as little endian word array
115*4882a593Smuzhiyun  */
montgomery_mul(const struct rsa_public_key * key,uint32_t result[],uint32_t a[],const uint32_t b[])116*4882a593Smuzhiyun static void montgomery_mul(const struct rsa_public_key *key,
117*4882a593Smuzhiyun 		uint32_t result[], uint32_t a[], const uint32_t b[])
118*4882a593Smuzhiyun {
119*4882a593Smuzhiyun 	uint i;
120*4882a593Smuzhiyun 
121*4882a593Smuzhiyun 	for (i = 0; i < key->len; ++i)
122*4882a593Smuzhiyun 		result[i] = 0;
123*4882a593Smuzhiyun 	for (i = 0; i < key->len; ++i)
124*4882a593Smuzhiyun 		montgomery_mul_add_step(key, result, a[i], b);
125*4882a593Smuzhiyun }
126*4882a593Smuzhiyun 
127*4882a593Smuzhiyun /**
128*4882a593Smuzhiyun  * num_pub_exponent_bits() - Number of bits in the public exponent
129*4882a593Smuzhiyun  *
130*4882a593Smuzhiyun  * @key:	RSA key
131*4882a593Smuzhiyun  * @num_bits:	Storage for the number of public exponent bits
132*4882a593Smuzhiyun  */
num_public_exponent_bits(const struct rsa_public_key * key,int * num_bits)133*4882a593Smuzhiyun static int num_public_exponent_bits(const struct rsa_public_key *key,
134*4882a593Smuzhiyun 		int *num_bits)
135*4882a593Smuzhiyun {
136*4882a593Smuzhiyun 	uint64_t exponent;
137*4882a593Smuzhiyun 	int exponent_bits;
138*4882a593Smuzhiyun 	const uint max_bits = (sizeof(exponent) * 8);
139*4882a593Smuzhiyun 
140*4882a593Smuzhiyun 	exponent = key->exponent;
141*4882a593Smuzhiyun 	exponent_bits = 0;
142*4882a593Smuzhiyun 
143*4882a593Smuzhiyun 	if (!exponent) {
144*4882a593Smuzhiyun 		*num_bits = exponent_bits;
145*4882a593Smuzhiyun 		return 0;
146*4882a593Smuzhiyun 	}
147*4882a593Smuzhiyun 
148*4882a593Smuzhiyun 	for (exponent_bits = 1; exponent_bits < max_bits + 1; ++exponent_bits)
149*4882a593Smuzhiyun 		if (!(exponent >>= 1)) {
150*4882a593Smuzhiyun 			*num_bits = exponent_bits;
151*4882a593Smuzhiyun 			return 0;
152*4882a593Smuzhiyun 		}
153*4882a593Smuzhiyun 
154*4882a593Smuzhiyun 	return -EINVAL;
155*4882a593Smuzhiyun }
156*4882a593Smuzhiyun 
157*4882a593Smuzhiyun /**
158*4882a593Smuzhiyun  * is_public_exponent_bit_set() - Check if a bit in the public exponent is set
159*4882a593Smuzhiyun  *
160*4882a593Smuzhiyun  * @key:	RSA key
161*4882a593Smuzhiyun  * @pos:	The bit position to check
162*4882a593Smuzhiyun  */
is_public_exponent_bit_set(const struct rsa_public_key * key,int pos)163*4882a593Smuzhiyun static int is_public_exponent_bit_set(const struct rsa_public_key *key,
164*4882a593Smuzhiyun 		int pos)
165*4882a593Smuzhiyun {
166*4882a593Smuzhiyun 	return key->exponent & (1ULL << pos);
167*4882a593Smuzhiyun }
168*4882a593Smuzhiyun 
169*4882a593Smuzhiyun /**
170*4882a593Smuzhiyun  * pow_mod() - in-place public exponentiation
171*4882a593Smuzhiyun  *
172*4882a593Smuzhiyun  * @key:	RSA key
173*4882a593Smuzhiyun  * @inout:	Big-endian word array containing value and result
174*4882a593Smuzhiyun  */
pow_mod(const struct rsa_public_key * key,uint32_t * inout)175*4882a593Smuzhiyun static int pow_mod(const struct rsa_public_key *key, uint32_t *inout)
176*4882a593Smuzhiyun {
177*4882a593Smuzhiyun 	uint32_t *result, *ptr;
178*4882a593Smuzhiyun 	uint i;
179*4882a593Smuzhiyun 	int j, k;
180*4882a593Smuzhiyun 
181*4882a593Smuzhiyun 	/* Sanity check for stack size - key->len is in 32-bit words */
182*4882a593Smuzhiyun 	if (key->len > RSA_MAX_KEY_BITS / 32) {
183*4882a593Smuzhiyun 		debug("RSA key words %u exceeds maximum %d\n", key->len,
184*4882a593Smuzhiyun 		      RSA_MAX_KEY_BITS / 32);
185*4882a593Smuzhiyun 		return -EINVAL;
186*4882a593Smuzhiyun 	}
187*4882a593Smuzhiyun 
188*4882a593Smuzhiyun 	uint32_t val[key->len], acc[key->len], tmp[key->len];
189*4882a593Smuzhiyun 	uint32_t a_scaled[key->len];
190*4882a593Smuzhiyun 	result = tmp;  /* Re-use location. */
191*4882a593Smuzhiyun 
192*4882a593Smuzhiyun 	/* Convert from big endian byte array to little endian word array. */
193*4882a593Smuzhiyun 	for (i = 0, ptr = inout + key->len - 1; i < key->len; i++, ptr--)
194*4882a593Smuzhiyun 		val[i] = get_unaligned_be32(ptr);
195*4882a593Smuzhiyun 
196*4882a593Smuzhiyun 	if (0 != num_public_exponent_bits(key, &k))
197*4882a593Smuzhiyun 		return -EINVAL;
198*4882a593Smuzhiyun 
199*4882a593Smuzhiyun 	if (k < 2) {
200*4882a593Smuzhiyun 		debug("Public exponent is too short (%d bits, minimum 2)\n",
201*4882a593Smuzhiyun 		      k);
202*4882a593Smuzhiyun 		return -EINVAL;
203*4882a593Smuzhiyun 	}
204*4882a593Smuzhiyun 
205*4882a593Smuzhiyun 	if (!is_public_exponent_bit_set(key, 0)) {
206*4882a593Smuzhiyun 		debug("LSB of RSA public exponent must be set.\n");
207*4882a593Smuzhiyun 		return -EINVAL;
208*4882a593Smuzhiyun 	}
209*4882a593Smuzhiyun 
210*4882a593Smuzhiyun 	/* the bit at e[k-1] is 1 by definition, so start with: C := M */
211*4882a593Smuzhiyun 	montgomery_mul(key, acc, val, key->rr); /* acc = a * RR / R mod n */
212*4882a593Smuzhiyun 	/* retain scaled version for intermediate use */
213*4882a593Smuzhiyun 	memcpy(a_scaled, acc, key->len * sizeof(a_scaled[0]));
214*4882a593Smuzhiyun 
215*4882a593Smuzhiyun 	for (j = k - 2; j > 0; --j) {
216*4882a593Smuzhiyun 		montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod n */
217*4882a593Smuzhiyun 
218*4882a593Smuzhiyun 		if (is_public_exponent_bit_set(key, j)) {
219*4882a593Smuzhiyun 			/* acc = tmp * val / R mod n */
220*4882a593Smuzhiyun 			montgomery_mul(key, acc, tmp, a_scaled);
221*4882a593Smuzhiyun 		} else {
222*4882a593Smuzhiyun 			/* e[j] == 0, copy tmp back to acc for next operation */
223*4882a593Smuzhiyun 			memcpy(acc, tmp, key->len * sizeof(acc[0]));
224*4882a593Smuzhiyun 		}
225*4882a593Smuzhiyun 	}
226*4882a593Smuzhiyun 
227*4882a593Smuzhiyun 	/* the bit at e[0] is always 1 */
228*4882a593Smuzhiyun 	montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod n */
229*4882a593Smuzhiyun 	montgomery_mul(key, acc, tmp, val); /* acc = tmp * a / R mod M */
230*4882a593Smuzhiyun 	memcpy(result, acc, key->len * sizeof(result[0]));
231*4882a593Smuzhiyun 
232*4882a593Smuzhiyun 	/* Make sure result < mod; result is at most 1x mod too large. */
233*4882a593Smuzhiyun 	if (greater_equal_modulus(key, result))
234*4882a593Smuzhiyun 		subtract_modulus(key, result);
235*4882a593Smuzhiyun 
236*4882a593Smuzhiyun 	/* Convert to bigendian byte array */
237*4882a593Smuzhiyun 	for (i = key->len - 1, ptr = inout; (int)i >= 0; i--, ptr++)
238*4882a593Smuzhiyun 		put_unaligned_be32(result[i], ptr);
239*4882a593Smuzhiyun 	return 0;
240*4882a593Smuzhiyun }
241*4882a593Smuzhiyun 
rsa_convert_big_endian(uint32_t * dst,const uint32_t * src,int len)242*4882a593Smuzhiyun static void rsa_convert_big_endian(uint32_t *dst, const uint32_t *src, int len)
243*4882a593Smuzhiyun {
244*4882a593Smuzhiyun 	int i;
245*4882a593Smuzhiyun 
246*4882a593Smuzhiyun 	for (i = 0; i < len; i++)
247*4882a593Smuzhiyun 		dst[i] = fdt32_to_cpu(src[len - 1 - i]);
248*4882a593Smuzhiyun }
249*4882a593Smuzhiyun 
rsa_mod_exp_sw(const uint8_t * sig,uint32_t sig_len,struct key_prop * prop,uint8_t * out)250*4882a593Smuzhiyun int rsa_mod_exp_sw(const uint8_t *sig, uint32_t sig_len,
251*4882a593Smuzhiyun 		struct key_prop *prop, uint8_t *out)
252*4882a593Smuzhiyun {
253*4882a593Smuzhiyun #ifndef USE_HOSTCC
254*4882a593Smuzhiyun 	__cacheline_aligned uint64_t tmp;
255*4882a593Smuzhiyun #else
256*4882a593Smuzhiyun 	uint64_t tmp;
257*4882a593Smuzhiyun #endif
258*4882a593Smuzhiyun 	struct rsa_public_key key;
259*4882a593Smuzhiyun 	int ret;
260*4882a593Smuzhiyun 
261*4882a593Smuzhiyun 	if (!prop) {
262*4882a593Smuzhiyun 		debug("%s: Skipping invalid prop", __func__);
263*4882a593Smuzhiyun 		return -EBADF;
264*4882a593Smuzhiyun 	}
265*4882a593Smuzhiyun 	key.n0inv = prop->n0inv;
266*4882a593Smuzhiyun 	key.len = prop->num_bits;
267*4882a593Smuzhiyun 
268*4882a593Smuzhiyun 	if (!prop->public_exponent) {
269*4882a593Smuzhiyun 		key.exponent = RSA_DEFAULT_PUBEXP;
270*4882a593Smuzhiyun 	} else {
271*4882a593Smuzhiyun 		/*
272*4882a593Smuzhiyun 		 * it seems fdt64_to_cpu() input param address must be 8-bytes
273*4882a593Smuzhiyun 		 * align, otherwise it brings a data-abort. No root cause was
274*4882a593Smuzhiyun 		 * found.
275*4882a593Smuzhiyun 		 *
276*4882a593Smuzhiyun 		 * workaround it this a tmp value.
277*4882a593Smuzhiyun 		 */
278*4882a593Smuzhiyun 		memcpy((void *)&tmp, prop->public_exponent, sizeof(uint64_t));
279*4882a593Smuzhiyun 		key.exponent = fdt64_to_cpu(tmp);
280*4882a593Smuzhiyun 	}
281*4882a593Smuzhiyun 
282*4882a593Smuzhiyun 	if (!key.len || !prop->modulus || !prop->rr) {
283*4882a593Smuzhiyun 		debug("%s: Missing RSA key info", __func__);
284*4882a593Smuzhiyun 		return -EFAULT;
285*4882a593Smuzhiyun 	}
286*4882a593Smuzhiyun 
287*4882a593Smuzhiyun 	/* Sanity check for stack size */
288*4882a593Smuzhiyun 	if (key.len > RSA_MAX_KEY_BITS || key.len < RSA_MIN_KEY_BITS) {
289*4882a593Smuzhiyun 		debug("RSA key bits %u outside allowed range %d..%d\n",
290*4882a593Smuzhiyun 		      key.len, RSA_MIN_KEY_BITS, RSA_MAX_KEY_BITS);
291*4882a593Smuzhiyun 		return -EFAULT;
292*4882a593Smuzhiyun 	}
293*4882a593Smuzhiyun 	key.len /= sizeof(uint32_t) * 8;
294*4882a593Smuzhiyun 	uint32_t key1[key.len], key2[key.len];
295*4882a593Smuzhiyun 
296*4882a593Smuzhiyun 	key.modulus = key1;
297*4882a593Smuzhiyun 	key.rr = key2;
298*4882a593Smuzhiyun 	rsa_convert_big_endian(key.modulus, (uint32_t *)prop->modulus, key.len);
299*4882a593Smuzhiyun 	rsa_convert_big_endian(key.rr, (uint32_t *)prop->rr, key.len);
300*4882a593Smuzhiyun 	if (!key.modulus || !key.rr) {
301*4882a593Smuzhiyun 		debug("%s: Out of memory", __func__);
302*4882a593Smuzhiyun 		return -ENOMEM;
303*4882a593Smuzhiyun 	}
304*4882a593Smuzhiyun 
305*4882a593Smuzhiyun 	uint32_t buf[sig_len / sizeof(uint32_t)];
306*4882a593Smuzhiyun 
307*4882a593Smuzhiyun 	memcpy(buf, sig, sig_len);
308*4882a593Smuzhiyun 
309*4882a593Smuzhiyun 	ret = pow_mod(&key, buf);
310*4882a593Smuzhiyun 	if (ret)
311*4882a593Smuzhiyun 		return ret;
312*4882a593Smuzhiyun 
313*4882a593Smuzhiyun 	memcpy(out, buf, sig_len);
314*4882a593Smuzhiyun 
315*4882a593Smuzhiyun 	return 0;
316*4882a593Smuzhiyun }
317