xref: /optee_os/core/crypto/sm3.c (revision 5f7f88c6b9d618d1e068166bbf2b07757350791d)
1 // SPDX-License-Identifier: BSD-2-Clause
2 /*
3  * Copyright (c) 2019 Huawei Technologies Co., Ltd
4  */
5 /*
6  * SM3 Hash algorithm
7  * thanks to Xyssl
8  * author:goldboar
9  * email:goldboar@163.com
10  * 2011-10-26
11  */
12 
13 #include <compiler.h>
14 #include <crypto/crypto_accel.h>
15 #include <string_ext.h>
16 #include <string.h>
17 
18 #include "sm3.h"
19 
20 #define SM3_BLOCK_SIZE	64
21 
22 #define GET_UINT32_BE(n, b, i)				\
23 	do {						\
24 		(n) = ((uint32_t)(b)[(i)] << 24)     |	\
25 		      ((uint32_t)(b)[(i) + 1] << 16) |	\
26 		      ((uint32_t)(b)[(i) + 2] <<  8) |	\
27 		      ((uint32_t)(b)[(i) + 3]);		\
28 	} while (0)
29 
30 #define PUT_UINT32_BE(n, b, i)				\
31 	do {						\
32 		(b)[(i)] = (uint8_t)((n) >> 24);	\
33 		(b)[(i) + 1] = (uint8_t)((n) >> 16);	\
34 		(b)[(i) + 2] = (uint8_t)((n) >>  8);	\
35 		(b)[(i) + 3] = (uint8_t)((n));		\
36 	} while (0)
37 
38 void sm3_init(struct sm3_context *ctx)
39 {
40 	ctx->total[0] = 0;
41 	ctx->total[1] = 0;
42 
43 	ctx->state[0] = 0x7380166F;
44 	ctx->state[1] = 0x4914B2B9;
45 	ctx->state[2] = 0x172442D7;
46 	ctx->state[3] = 0xDA8A0600;
47 	ctx->state[4] = 0xA96F30BC;
48 	ctx->state[5] = 0x163138AA;
49 	ctx->state[6] = 0xE38DEE4D;
50 	ctx->state[7] = 0xB0FB0E4E;
51 }
52 
53 #define SHL(x, n)	((x) << (n))
54 
55 static uint32_t rotl(uint32_t val, int shift)
56 {
57 	shift &= 0x1F;
58 
59 	if (shift == 0)
60 		return val;
61 
62 	return SHL(val, shift) | (val >> (32 - shift));
63 }
64 
65 #define ROTL(x, n)	rotl((x), (n))
66 
67 static void __maybe_unused sm3_process(struct sm3_context *ctx,
68 				       const uint8_t data[64])
69 {
70 	uint32_t SS1, SS2, TT1, TT2, W[68], W1[64];
71 	uint32_t A, B, C, D, E, F, G, H;
72 	uint32_t T[64];
73 	uint32_t Temp1, Temp2, Temp3, Temp4, Temp5;
74 	int j;
75 
76 	for (j = 0; j < 16; j++)
77 		T[j] = 0x79CC4519;
78 	for (j = 16; j < 64; j++)
79 		T[j] = 0x7A879D8A;
80 
81 	GET_UINT32_BE(W[0], data,  0);
82 	GET_UINT32_BE(W[1], data,  4);
83 	GET_UINT32_BE(W[2], data,  8);
84 	GET_UINT32_BE(W[3], data, 12);
85 	GET_UINT32_BE(W[4], data, 16);
86 	GET_UINT32_BE(W[5], data, 20);
87 	GET_UINT32_BE(W[6], data, 24);
88 	GET_UINT32_BE(W[7], data, 28);
89 	GET_UINT32_BE(W[8], data, 32);
90 	GET_UINT32_BE(W[9], data, 36);
91 	GET_UINT32_BE(W[10], data, 40);
92 	GET_UINT32_BE(W[11], data, 44);
93 	GET_UINT32_BE(W[12], data, 48);
94 	GET_UINT32_BE(W[13], data, 52);
95 	GET_UINT32_BE(W[14], data, 56);
96 	GET_UINT32_BE(W[15], data, 60);
97 
98 #define FF0(x, y, z)	((x) ^ (y) ^ (z))
99 #define FF1(x, y, z)	(((x) & (y)) | ((x) & (z)) | ((y) & (z)))
100 
101 #define GG0(x, y, z)	((x) ^ (y) ^ (z))
102 #define GG1(x, y, z)	(((x) & (y)) | ((~(x)) & (z)))
103 
104 #define P0(x)	((x) ^ ROTL((x), 9) ^ ROTL((x), 17))
105 #define P1(x)	((x) ^ ROTL((x), 15) ^ ROTL((x), 23))
106 
107 	for (j = 16; j < 68; j++) {
108 		/*
109 		 * W[j] = P1( W[j-16] ^ W[j-9] ^ ROTL(W[j-3],15)) ^
110 		 *        ROTL(W[j - 13],7 ) ^ W[j-6];
111 		 */
112 
113 		Temp1 = W[j - 16] ^ W[j - 9];
114 		Temp2 = ROTL(W[j - 3], 15);
115 		Temp3 = Temp1 ^ Temp2;
116 		Temp4 = P1(Temp3);
117 		Temp5 =  ROTL(W[j - 13], 7) ^ W[j - 6];
118 		W[j] = Temp4 ^ Temp5;
119 	}
120 
121 	for (j =  0; j < 64; j++)
122 		W1[j] = W[j] ^ W[j + 4];
123 
124 	A = ctx->state[0];
125 	B = ctx->state[1];
126 	C = ctx->state[2];
127 	D = ctx->state[3];
128 	E = ctx->state[4];
129 	F = ctx->state[5];
130 	G = ctx->state[6];
131 	H = ctx->state[7];
132 
133 	for (j = 0; j < 16; j++) {
134 		SS1 = ROTL(ROTL(A, 12) + E + ROTL(T[j], j), 7);
135 		SS2 = SS1 ^ ROTL(A, 12);
136 		TT1 = FF0(A, B, C) + D + SS2 + W1[j];
137 		TT2 = GG0(E, F, G) + H + SS1 + W[j];
138 		D = C;
139 		C = ROTL(B, 9);
140 		B = A;
141 		A = TT1;
142 		H = G;
143 		G = ROTL(F, 19);
144 		F = E;
145 		E = P0(TT2);
146 	}
147 
148 	for (j = 16; j < 64; j++) {
149 		SS1 = ROTL(ROTL(A, 12) + E + ROTL(T[j], j), 7);
150 		SS2 = SS1 ^ ROTL(A, 12);
151 		TT1 = FF1(A, B, C) + D + SS2 + W1[j];
152 		TT2 = GG1(E, F, G) + H + SS1 + W[j];
153 		D = C;
154 		C = ROTL(B, 9);
155 		B = A;
156 		A = TT1;
157 		H = G;
158 		G = ROTL(F, 19);
159 		F = E;
160 		E = P0(TT2);
161 	}
162 
163 	ctx->state[0] ^= A;
164 	ctx->state[1] ^= B;
165 	ctx->state[2] ^= C;
166 	ctx->state[3] ^= D;
167 	ctx->state[4] ^= E;
168 	ctx->state[5] ^= F;
169 	ctx->state[6] ^= G;
170 	ctx->state[7] ^= H;
171 }
172 
173 static void sm3_process_blocks(struct sm3_context *ctx, const uint8_t *input,
174 			       unsigned int block_count)
175 {
176 #ifdef CFG_CRYPTO_SM3_ARM_CE
177 	if (block_count)
178 		crypto_accel_sm3_compress(ctx->state, input, block_count);
179 #else
180 	unsigned int n = 0;
181 
182 	for (n = 0; n < block_count; n++)
183 		sm3_process(ctx, input + n * SM3_BLOCK_SIZE);
184 #endif
185 }
186 
187 void sm3_update(struct sm3_context *ctx, const uint8_t *input, size_t ilen)
188 {
189 	unsigned int block_count = 0;
190 	size_t fill = 0;
191 	size_t left = 0;
192 
193 	if (!ilen)
194 		return;
195 
196 	left = ctx->total[0] & 0x3F;
197 	fill = 64 - left;
198 
199 	ctx->total[0] += ilen;
200 
201 	if (ctx->total[0] < ilen)
202 		ctx->total[1]++;
203 
204 	if (left && ilen >= fill) {
205 		memcpy(ctx->buffer + left, input, fill);
206 		sm3_process_blocks(ctx, ctx->buffer, 1);
207 		input += fill;
208 		ilen -= fill;
209 		left = 0;
210 	}
211 
212 	block_count = ilen / SM3_BLOCK_SIZE;
213 	sm3_process_blocks(ctx, input, block_count);
214 	ilen -= block_count * SM3_BLOCK_SIZE;
215 	input += block_count * SM3_BLOCK_SIZE;
216 
217 	if (ilen > 0)
218 		memcpy(ctx->buffer + left, input, ilen);
219 }
220 
221 static const uint8_t sm3_padding[64] = {
222 	0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
223 	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
224 	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
225 	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
226 };
227 
228 void sm3_final(struct sm3_context *ctx, uint8_t output[32])
229 {
230 	uint32_t last, padn;
231 	uint32_t high, low;
232 	uint8_t msglen[8];
233 
234 	high = (ctx->total[0] >> 29) | (ctx->total[1] <<  3);
235 	low  = ctx->total[0] << 3;
236 
237 	PUT_UINT32_BE(high, msglen, 0);
238 	PUT_UINT32_BE(low,  msglen, 4);
239 
240 	last = ctx->total[0] & 0x3F;
241 	padn = (last < 56) ? (56 - last) : (120 - last);
242 
243 	sm3_update(ctx, sm3_padding, padn);
244 	sm3_update(ctx, msglen, 8);
245 
246 	PUT_UINT32_BE(ctx->state[0], output,  0);
247 	PUT_UINT32_BE(ctx->state[1], output,  4);
248 	PUT_UINT32_BE(ctx->state[2], output,  8);
249 	PUT_UINT32_BE(ctx->state[3], output, 12);
250 	PUT_UINT32_BE(ctx->state[4], output, 16);
251 	PUT_UINT32_BE(ctx->state[5], output, 20);
252 	PUT_UINT32_BE(ctx->state[6], output, 24);
253 	PUT_UINT32_BE(ctx->state[7], output, 28);
254 }
255 
256 void sm3(const uint8_t *input, size_t ilen, uint8_t output[32])
257 {
258 	struct sm3_context ctx = { };
259 
260 	sm3_init(&ctx);
261 	sm3_update(&ctx, input, ilen);
262 	sm3_final(&ctx, output);
263 
264 	memzero_explicit(&ctx, sizeof(ctx));
265 }
266 
267 void sm3_hmac_init(struct sm3_context *ctx, const uint8_t *key, size_t keylen)
268 {
269 	size_t i;
270 	uint8_t sum[32];
271 
272 	if (keylen > 64) {
273 		sm3(key, keylen, sum);
274 		keylen = 32;
275 		key = sum;
276 	}
277 
278 	memset(ctx->ipad, 0x36, 64);
279 	memset(ctx->opad, 0x5C, 64);
280 
281 	for (i = 0; i < keylen; i++) {
282 		ctx->ipad[i] ^= key[i];
283 		ctx->opad[i] ^= key[i];
284 	}
285 
286 	sm3_init(ctx);
287 	sm3_update(ctx, ctx->ipad, 64);
288 
289 	memzero_explicit(sum, sizeof(sum));
290 }
291 
292 void sm3_hmac_update(struct sm3_context *ctx, const uint8_t *input, size_t ilen)
293 {
294 	sm3_update(ctx, input, ilen);
295 }
296 
297 void sm3_hmac_final(struct sm3_context *ctx, uint8_t output[32])
298 {
299 	uint8_t tmpbuf[32];
300 
301 	sm3_final(ctx, tmpbuf);
302 	sm3_init(ctx);
303 	sm3_update(ctx, ctx->opad, 64);
304 	sm3_update(ctx, tmpbuf, 32);
305 	sm3_final(ctx, output);
306 
307 	memzero_explicit(tmpbuf, sizeof(tmpbuf));
308 }
309 
310 void sm3_hmac(const uint8_t *key, size_t keylen, const uint8_t *input,
311 	      size_t ilen, uint8_t output[32])
312 {
313 	struct sm3_context ctx;
314 
315 	sm3_hmac_init(&ctx, key, keylen);
316 	sm3_hmac_update(&ctx, input, ilen);
317 	sm3_hmac_final(&ctx, output);
318 
319 	memzero_explicit(&ctx, sizeof(ctx));
320 }
321