1 /* 2 * Copyright (c) 2013, Google Inc. 3 * 4 * SPDX-License-Identifier: GPL-2.0+ 5 */ 6 7 #include <common.h> 8 #include <fdtdec.h> 9 #include <rsa.h> 10 #include <sha1.h> 11 #include <sha256.h> 12 #include <asm/byteorder.h> 13 #include <asm/errno.h> 14 #include <asm/unaligned.h> 15 16 #define UINT64_MULT32(v, multby) (((uint64_t)(v)) * ((uint32_t)(multby))) 17 18 /** 19 * subtract_modulus() - subtract modulus from the given value 20 * 21 * @key: Key containing modulus to subtract 22 * @num: Number to subtract modulus from, as little endian word array 23 */ 24 static void subtract_modulus(const struct rsa_public_key *key, uint32_t num[]) 25 { 26 int64_t acc = 0; 27 uint i; 28 29 for (i = 0; i < key->len; i++) { 30 acc += (uint64_t)num[i] - key->modulus[i]; 31 num[i] = (uint32_t)acc; 32 acc >>= 32; 33 } 34 } 35 36 /** 37 * greater_equal_modulus() - check if a value is >= modulus 38 * 39 * @key: Key containing modulus to check 40 * @num: Number to check against modulus, as little endian word array 41 * @return 0 if num < modulus, 1 if num >= modulus 42 */ 43 static int greater_equal_modulus(const struct rsa_public_key *key, 44 uint32_t num[]) 45 { 46 uint32_t i; 47 48 for (i = key->len - 1; i >= 0; i--) { 49 if (num[i] < key->modulus[i]) 50 return 0; 51 if (num[i] > key->modulus[i]) 52 return 1; 53 } 54 55 return 1; /* equal */ 56 } 57 58 /** 59 * montgomery_mul_add_step() - Perform montgomery multiply-add step 60 * 61 * Operation: montgomery result[] += a * b[] / n0inv % modulus 62 * 63 * @key: RSA key 64 * @result: Place to put result, as little endian word array 65 * @a: Multiplier 66 * @b: Multiplicand, as little endian word array 67 */ 68 static void montgomery_mul_add_step(const struct rsa_public_key *key, 69 uint32_t result[], const uint32_t a, const uint32_t b[]) 70 { 71 uint64_t acc_a, acc_b; 72 uint32_t d0; 73 uint i; 74 75 acc_a = (uint64_t)a * b[0] + result[0]; 76 d0 = (uint32_t)acc_a * key->n0inv; 77 acc_b = (uint64_t)d0 * key->modulus[0] + (uint32_t)acc_a; 78 for (i = 1; i < key->len; i++) { 79 acc_a = (acc_a >> 32) + (uint64_t)a * b[i] + result[i]; 80 acc_b = (acc_b >> 32) + (uint64_t)d0 * key->modulus[i] + 81 (uint32_t)acc_a; 82 result[i - 1] = (uint32_t)acc_b; 83 } 84 85 acc_a = (acc_a >> 32) + (acc_b >> 32); 86 87 result[i - 1] = (uint32_t)acc_a; 88 89 if (acc_a >> 32) 90 subtract_modulus(key, result); 91 } 92 93 /** 94 * montgomery_mul() - Perform montgomery mutitply 95 * 96 * Operation: montgomery result[] = a[] * b[] / n0inv % modulus 97 * 98 * @key: RSA key 99 * @result: Place to put result, as little endian word array 100 * @a: Multiplier, as little endian word array 101 * @b: Multiplicand, as little endian word array 102 */ 103 static void montgomery_mul(const struct rsa_public_key *key, 104 uint32_t result[], uint32_t a[], const uint32_t b[]) 105 { 106 uint i; 107 108 for (i = 0; i < key->len; ++i) 109 result[i] = 0; 110 for (i = 0; i < key->len; ++i) 111 montgomery_mul_add_step(key, result, a[i], b); 112 } 113 114 /** 115 * pow_mod() - in-place public exponentiation 116 * 117 * @key: RSA key 118 * @inout: Big-endian word array containing value and result 119 */ 120 static int pow_mod(const struct rsa_public_key *key, uint32_t *inout) 121 { 122 uint32_t *result, *ptr; 123 uint i; 124 125 /* Sanity check for stack size - key->len is in 32-bit words */ 126 if (key->len > RSA_MAX_KEY_BITS / 32) { 127 debug("RSA key words %u exceeds maximum %d\n", key->len, 128 RSA_MAX_KEY_BITS / 32); 129 return -EINVAL; 130 } 131 132 uint32_t val[key->len], acc[key->len], tmp[key->len]; 133 result = tmp; /* Re-use location. */ 134 135 /* Convert from big endian byte array to little endian word array. */ 136 for (i = 0, ptr = inout + key->len - 1; i < key->len; i++, ptr--) 137 val[i] = get_unaligned_be32(ptr); 138 139 montgomery_mul(key, acc, val, key->rr); /* axx = a * RR / R mod M */ 140 for (i = 0; i < 16; i += 2) { 141 montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod M */ 142 montgomery_mul(key, acc, tmp, tmp); /* acc = tmp^2 / R mod M */ 143 } 144 montgomery_mul(key, result, acc, val); /* result = XX * a / R mod M */ 145 146 /* Make sure result < mod; result is at most 1x mod too large. */ 147 if (greater_equal_modulus(key, result)) 148 subtract_modulus(key, result); 149 150 /* Convert to bigendian byte array */ 151 for (i = key->len - 1, ptr = inout; (int)i >= 0; i--, ptr++) 152 put_unaligned_be32(result[i], ptr); 153 154 return 0; 155 } 156 157 static int rsa_verify_key(const struct rsa_public_key *key, const uint8_t *sig, 158 const uint32_t sig_len, const uint8_t *hash, 159 struct checksum_algo *algo) 160 { 161 const uint8_t *padding; 162 int pad_len; 163 int ret; 164 165 if (!key || !sig || !hash || !algo) 166 return -EIO; 167 168 if (sig_len != (key->len * sizeof(uint32_t))) { 169 debug("Signature is of incorrect length %d\n", sig_len); 170 return -EINVAL; 171 } 172 173 debug("Checksum algorithm: %s", algo->name); 174 175 /* Sanity check for stack size */ 176 if (sig_len > RSA_MAX_SIG_BITS / 8) { 177 debug("Signature length %u exceeds maximum %d\n", sig_len, 178 RSA_MAX_SIG_BITS / 8); 179 return -EINVAL; 180 } 181 182 uint32_t buf[sig_len / sizeof(uint32_t)]; 183 184 memcpy(buf, sig, sig_len); 185 186 ret = pow_mod(key, buf); 187 if (ret) 188 return ret; 189 190 padding = algo->rsa_padding; 191 pad_len = algo->pad_len - algo->checksum_len; 192 193 /* Check pkcs1.5 padding bytes. */ 194 if (memcmp(buf, padding, pad_len)) { 195 debug("In RSAVerify(): Padding check failed!\n"); 196 return -EINVAL; 197 } 198 199 /* Check hash. */ 200 if (memcmp((uint8_t *)buf + pad_len, hash, sig_len - pad_len)) { 201 debug("In RSAVerify(): Hash check failed!\n"); 202 return -EACCES; 203 } 204 205 return 0; 206 } 207 208 static void rsa_convert_big_endian(uint32_t *dst, const uint32_t *src, int len) 209 { 210 int i; 211 212 for (i = 0; i < len; i++) 213 dst[i] = fdt32_to_cpu(src[len - 1 - i]); 214 } 215 216 static int rsa_verify_with_keynode(struct image_sign_info *info, 217 const void *hash, uint8_t *sig, uint sig_len, int node) 218 { 219 const void *blob = info->fdt_blob; 220 struct rsa_public_key key; 221 const void *modulus, *rr; 222 int ret; 223 224 if (node < 0) { 225 debug("%s: Skipping invalid node", __func__); 226 return -EBADF; 227 } 228 if (!fdt_getprop(blob, node, "rsa,n0-inverse", NULL)) { 229 debug("%s: Missing rsa,n0-inverse", __func__); 230 return -EFAULT; 231 } 232 key.len = fdtdec_get_int(blob, node, "rsa,num-bits", 0); 233 key.n0inv = fdtdec_get_int(blob, node, "rsa,n0-inverse", 0); 234 modulus = fdt_getprop(blob, node, "rsa,modulus", NULL); 235 rr = fdt_getprop(blob, node, "rsa,r-squared", NULL); 236 if (!key.len || !modulus || !rr) { 237 debug("%s: Missing RSA key info", __func__); 238 return -EFAULT; 239 } 240 241 /* Sanity check for stack size */ 242 if (key.len > RSA_MAX_KEY_BITS || key.len < RSA_MIN_KEY_BITS) { 243 debug("RSA key bits %u outside allowed range %d..%d\n", 244 key.len, RSA_MIN_KEY_BITS, RSA_MAX_KEY_BITS); 245 return -EFAULT; 246 } 247 key.len /= sizeof(uint32_t) * 8; 248 uint32_t key1[key.len], key2[key.len]; 249 250 key.modulus = key1; 251 key.rr = key2; 252 rsa_convert_big_endian(key.modulus, modulus, key.len); 253 rsa_convert_big_endian(key.rr, rr, key.len); 254 if (!key.modulus || !key.rr) { 255 debug("%s: Out of memory", __func__); 256 return -ENOMEM; 257 } 258 259 debug("key length %d\n", key.len); 260 ret = rsa_verify_key(&key, sig, sig_len, hash, info->algo->checksum); 261 if (ret) { 262 printf("%s: RSA failed to verify: %d\n", __func__, ret); 263 return ret; 264 } 265 266 return 0; 267 } 268 269 int rsa_verify(struct image_sign_info *info, 270 const struct image_region region[], int region_count, 271 uint8_t *sig, uint sig_len) 272 { 273 const void *blob = info->fdt_blob; 274 /* Reserve memory for maximum checksum-length */ 275 uint8_t hash[info->algo->checksum->pad_len]; 276 int ndepth, noffset; 277 int sig_node, node; 278 char name[100]; 279 int ret; 280 281 /* 282 * Verify that the checksum-length does not exceed the 283 * rsa-signature-length 284 */ 285 if (info->algo->checksum->checksum_len > 286 info->algo->checksum->pad_len) { 287 debug("%s: invlaid checksum-algorithm %s for %s\n", 288 __func__, info->algo->checksum->name, info->algo->name); 289 return -EINVAL; 290 } 291 292 sig_node = fdt_subnode_offset(blob, 0, FIT_SIG_NODENAME); 293 if (sig_node < 0) { 294 debug("%s: No signature node found\n", __func__); 295 return -ENOENT; 296 } 297 298 /* Calculate checksum with checksum-algorithm */ 299 info->algo->checksum->calculate(region, region_count, hash); 300 301 /* See if we must use a particular key */ 302 if (info->required_keynode != -1) { 303 ret = rsa_verify_with_keynode(info, hash, sig, sig_len, 304 info->required_keynode); 305 if (!ret) 306 return ret; 307 } 308 309 /* Look for a key that matches our hint */ 310 snprintf(name, sizeof(name), "key-%s", info->keyname); 311 node = fdt_subnode_offset(blob, sig_node, name); 312 ret = rsa_verify_with_keynode(info, hash, sig, sig_len, node); 313 if (!ret) 314 return ret; 315 316 /* No luck, so try each of the keys in turn */ 317 for (ndepth = 0, noffset = fdt_next_node(info->fit, sig_node, &ndepth); 318 (noffset >= 0) && (ndepth > 0); 319 noffset = fdt_next_node(info->fit, noffset, &ndepth)) { 320 if (ndepth == 1 && noffset != node) { 321 ret = rsa_verify_with_keynode(info, hash, sig, sig_len, 322 noffset); 323 if (!ret) 324 break; 325 } 326 } 327 328 return ret; 329 } 330