xref: /OK3568_Linux_fs/kernel/drivers/net/wireguard/noise.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0
2*4882a593Smuzhiyun /*
3*4882a593Smuzhiyun  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4*4882a593Smuzhiyun  */
5*4882a593Smuzhiyun 
6*4882a593Smuzhiyun #include "noise.h"
7*4882a593Smuzhiyun #include "device.h"
8*4882a593Smuzhiyun #include "peer.h"
9*4882a593Smuzhiyun #include "messages.h"
10*4882a593Smuzhiyun #include "queueing.h"
11*4882a593Smuzhiyun #include "peerlookup.h"
12*4882a593Smuzhiyun 
13*4882a593Smuzhiyun #include <linux/rcupdate.h>
14*4882a593Smuzhiyun #include <linux/slab.h>
15*4882a593Smuzhiyun #include <linux/bitmap.h>
16*4882a593Smuzhiyun #include <linux/scatterlist.h>
17*4882a593Smuzhiyun #include <linux/highmem.h>
18*4882a593Smuzhiyun #include <crypto/algapi.h>
19*4882a593Smuzhiyun 
20*4882a593Smuzhiyun /* This implements Noise_IKpsk2:
21*4882a593Smuzhiyun  *
22*4882a593Smuzhiyun  * <- s
23*4882a593Smuzhiyun  * ******
24*4882a593Smuzhiyun  * -> e, es, s, ss, {t}
25*4882a593Smuzhiyun  * <- e, ee, se, psk, {}
26*4882a593Smuzhiyun  */
27*4882a593Smuzhiyun 
28*4882a593Smuzhiyun static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
29*4882a593Smuzhiyun static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
30*4882a593Smuzhiyun static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
31*4882a593Smuzhiyun static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
32*4882a593Smuzhiyun static atomic64_t keypair_counter = ATOMIC64_INIT(0);
33*4882a593Smuzhiyun 
wg_noise_init(void)34*4882a593Smuzhiyun void __init wg_noise_init(void)
35*4882a593Smuzhiyun {
36*4882a593Smuzhiyun 	struct blake2s_state blake;
37*4882a593Smuzhiyun 
38*4882a593Smuzhiyun 	blake2s(handshake_init_chaining_key, handshake_name, NULL,
39*4882a593Smuzhiyun 		NOISE_HASH_LEN, sizeof(handshake_name), 0);
40*4882a593Smuzhiyun 	blake2s_init(&blake, NOISE_HASH_LEN);
41*4882a593Smuzhiyun 	blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
42*4882a593Smuzhiyun 	blake2s_update(&blake, identifier_name, sizeof(identifier_name));
43*4882a593Smuzhiyun 	blake2s_final(&blake, handshake_init_hash);
44*4882a593Smuzhiyun }
45*4882a593Smuzhiyun 
46*4882a593Smuzhiyun /* Must hold peer->handshake.static_identity->lock */
wg_noise_precompute_static_static(struct wg_peer * peer)47*4882a593Smuzhiyun void wg_noise_precompute_static_static(struct wg_peer *peer)
48*4882a593Smuzhiyun {
49*4882a593Smuzhiyun 	down_write(&peer->handshake.lock);
50*4882a593Smuzhiyun 	if (!peer->handshake.static_identity->has_identity ||
51*4882a593Smuzhiyun 	    !curve25519(peer->handshake.precomputed_static_static,
52*4882a593Smuzhiyun 			peer->handshake.static_identity->static_private,
53*4882a593Smuzhiyun 			peer->handshake.remote_static))
54*4882a593Smuzhiyun 		memset(peer->handshake.precomputed_static_static, 0,
55*4882a593Smuzhiyun 		       NOISE_PUBLIC_KEY_LEN);
56*4882a593Smuzhiyun 	up_write(&peer->handshake.lock);
57*4882a593Smuzhiyun }
58*4882a593Smuzhiyun 
wg_noise_handshake_init(struct noise_handshake * handshake,struct noise_static_identity * static_identity,const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],struct wg_peer * peer)59*4882a593Smuzhiyun void wg_noise_handshake_init(struct noise_handshake *handshake,
60*4882a593Smuzhiyun 			     struct noise_static_identity *static_identity,
61*4882a593Smuzhiyun 			     const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
62*4882a593Smuzhiyun 			     const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
63*4882a593Smuzhiyun 			     struct wg_peer *peer)
64*4882a593Smuzhiyun {
65*4882a593Smuzhiyun 	memset(handshake, 0, sizeof(*handshake));
66*4882a593Smuzhiyun 	init_rwsem(&handshake->lock);
67*4882a593Smuzhiyun 	handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
68*4882a593Smuzhiyun 	handshake->entry.peer = peer;
69*4882a593Smuzhiyun 	memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
70*4882a593Smuzhiyun 	if (peer_preshared_key)
71*4882a593Smuzhiyun 		memcpy(handshake->preshared_key, peer_preshared_key,
72*4882a593Smuzhiyun 		       NOISE_SYMMETRIC_KEY_LEN);
73*4882a593Smuzhiyun 	handshake->static_identity = static_identity;
74*4882a593Smuzhiyun 	handshake->state = HANDSHAKE_ZEROED;
75*4882a593Smuzhiyun 	wg_noise_precompute_static_static(peer);
76*4882a593Smuzhiyun }
77*4882a593Smuzhiyun 
handshake_zero(struct noise_handshake * handshake)78*4882a593Smuzhiyun static void handshake_zero(struct noise_handshake *handshake)
79*4882a593Smuzhiyun {
80*4882a593Smuzhiyun 	memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
81*4882a593Smuzhiyun 	memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
82*4882a593Smuzhiyun 	memset(&handshake->hash, 0, NOISE_HASH_LEN);
83*4882a593Smuzhiyun 	memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
84*4882a593Smuzhiyun 	handshake->remote_index = 0;
85*4882a593Smuzhiyun 	handshake->state = HANDSHAKE_ZEROED;
86*4882a593Smuzhiyun }
87*4882a593Smuzhiyun 
wg_noise_handshake_clear(struct noise_handshake * handshake)88*4882a593Smuzhiyun void wg_noise_handshake_clear(struct noise_handshake *handshake)
89*4882a593Smuzhiyun {
90*4882a593Smuzhiyun 	down_write(&handshake->lock);
91*4882a593Smuzhiyun 	wg_index_hashtable_remove(
92*4882a593Smuzhiyun 			handshake->entry.peer->device->index_hashtable,
93*4882a593Smuzhiyun 			&handshake->entry);
94*4882a593Smuzhiyun 	handshake_zero(handshake);
95*4882a593Smuzhiyun 	up_write(&handshake->lock);
96*4882a593Smuzhiyun }
97*4882a593Smuzhiyun 
keypair_create(struct wg_peer * peer)98*4882a593Smuzhiyun static struct noise_keypair *keypair_create(struct wg_peer *peer)
99*4882a593Smuzhiyun {
100*4882a593Smuzhiyun 	struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
101*4882a593Smuzhiyun 
102*4882a593Smuzhiyun 	if (unlikely(!keypair))
103*4882a593Smuzhiyun 		return NULL;
104*4882a593Smuzhiyun 	spin_lock_init(&keypair->receiving_counter.lock);
105*4882a593Smuzhiyun 	keypair->internal_id = atomic64_inc_return(&keypair_counter);
106*4882a593Smuzhiyun 	keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
107*4882a593Smuzhiyun 	keypair->entry.peer = peer;
108*4882a593Smuzhiyun 	kref_init(&keypair->refcount);
109*4882a593Smuzhiyun 	return keypair;
110*4882a593Smuzhiyun }
111*4882a593Smuzhiyun 
keypair_free_rcu(struct rcu_head * rcu)112*4882a593Smuzhiyun static void keypair_free_rcu(struct rcu_head *rcu)
113*4882a593Smuzhiyun {
114*4882a593Smuzhiyun 	kfree_sensitive(container_of(rcu, struct noise_keypair, rcu));
115*4882a593Smuzhiyun }
116*4882a593Smuzhiyun 
keypair_free_kref(struct kref * kref)117*4882a593Smuzhiyun static void keypair_free_kref(struct kref *kref)
118*4882a593Smuzhiyun {
119*4882a593Smuzhiyun 	struct noise_keypair *keypair =
120*4882a593Smuzhiyun 		container_of(kref, struct noise_keypair, refcount);
121*4882a593Smuzhiyun 
122*4882a593Smuzhiyun 	net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
123*4882a593Smuzhiyun 			    keypair->entry.peer->device->dev->name,
124*4882a593Smuzhiyun 			    keypair->internal_id,
125*4882a593Smuzhiyun 			    keypair->entry.peer->internal_id);
126*4882a593Smuzhiyun 	wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
127*4882a593Smuzhiyun 				  &keypair->entry);
128*4882a593Smuzhiyun 	call_rcu(&keypair->rcu, keypair_free_rcu);
129*4882a593Smuzhiyun }
130*4882a593Smuzhiyun 
wg_noise_keypair_put(struct noise_keypair * keypair,bool unreference_now)131*4882a593Smuzhiyun void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
132*4882a593Smuzhiyun {
133*4882a593Smuzhiyun 	if (unlikely(!keypair))
134*4882a593Smuzhiyun 		return;
135*4882a593Smuzhiyun 	if (unlikely(unreference_now))
136*4882a593Smuzhiyun 		wg_index_hashtable_remove(
137*4882a593Smuzhiyun 			keypair->entry.peer->device->index_hashtable,
138*4882a593Smuzhiyun 			&keypair->entry);
139*4882a593Smuzhiyun 	kref_put(&keypair->refcount, keypair_free_kref);
140*4882a593Smuzhiyun }
141*4882a593Smuzhiyun 
wg_noise_keypair_get(struct noise_keypair * keypair)142*4882a593Smuzhiyun struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
143*4882a593Smuzhiyun {
144*4882a593Smuzhiyun 	RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
145*4882a593Smuzhiyun 		"Taking noise keypair reference without holding the RCU BH read lock");
146*4882a593Smuzhiyun 	if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
147*4882a593Smuzhiyun 		return NULL;
148*4882a593Smuzhiyun 	return keypair;
149*4882a593Smuzhiyun }
150*4882a593Smuzhiyun 
wg_noise_keypairs_clear(struct noise_keypairs * keypairs)151*4882a593Smuzhiyun void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
152*4882a593Smuzhiyun {
153*4882a593Smuzhiyun 	struct noise_keypair *old;
154*4882a593Smuzhiyun 
155*4882a593Smuzhiyun 	spin_lock_bh(&keypairs->keypair_update_lock);
156*4882a593Smuzhiyun 
157*4882a593Smuzhiyun 	/* We zero the next_keypair before zeroing the others, so that
158*4882a593Smuzhiyun 	 * wg_noise_received_with_keypair returns early before subsequent ones
159*4882a593Smuzhiyun 	 * are zeroed.
160*4882a593Smuzhiyun 	 */
161*4882a593Smuzhiyun 	old = rcu_dereference_protected(keypairs->next_keypair,
162*4882a593Smuzhiyun 		lockdep_is_held(&keypairs->keypair_update_lock));
163*4882a593Smuzhiyun 	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
164*4882a593Smuzhiyun 	wg_noise_keypair_put(old, true);
165*4882a593Smuzhiyun 
166*4882a593Smuzhiyun 	old = rcu_dereference_protected(keypairs->previous_keypair,
167*4882a593Smuzhiyun 		lockdep_is_held(&keypairs->keypair_update_lock));
168*4882a593Smuzhiyun 	RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
169*4882a593Smuzhiyun 	wg_noise_keypair_put(old, true);
170*4882a593Smuzhiyun 
171*4882a593Smuzhiyun 	old = rcu_dereference_protected(keypairs->current_keypair,
172*4882a593Smuzhiyun 		lockdep_is_held(&keypairs->keypair_update_lock));
173*4882a593Smuzhiyun 	RCU_INIT_POINTER(keypairs->current_keypair, NULL);
174*4882a593Smuzhiyun 	wg_noise_keypair_put(old, true);
175*4882a593Smuzhiyun 
176*4882a593Smuzhiyun 	spin_unlock_bh(&keypairs->keypair_update_lock);
177*4882a593Smuzhiyun }
178*4882a593Smuzhiyun 
wg_noise_expire_current_peer_keypairs(struct wg_peer * peer)179*4882a593Smuzhiyun void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
180*4882a593Smuzhiyun {
181*4882a593Smuzhiyun 	struct noise_keypair *keypair;
182*4882a593Smuzhiyun 
183*4882a593Smuzhiyun 	wg_noise_handshake_clear(&peer->handshake);
184*4882a593Smuzhiyun 	wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
185*4882a593Smuzhiyun 
186*4882a593Smuzhiyun 	spin_lock_bh(&peer->keypairs.keypair_update_lock);
187*4882a593Smuzhiyun 	keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
188*4882a593Smuzhiyun 			lockdep_is_held(&peer->keypairs.keypair_update_lock));
189*4882a593Smuzhiyun 	if (keypair)
190*4882a593Smuzhiyun 		keypair->sending.is_valid = false;
191*4882a593Smuzhiyun 	keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
192*4882a593Smuzhiyun 			lockdep_is_held(&peer->keypairs.keypair_update_lock));
193*4882a593Smuzhiyun 	if (keypair)
194*4882a593Smuzhiyun 		keypair->sending.is_valid = false;
195*4882a593Smuzhiyun 	spin_unlock_bh(&peer->keypairs.keypair_update_lock);
196*4882a593Smuzhiyun }
197*4882a593Smuzhiyun 
add_new_keypair(struct noise_keypairs * keypairs,struct noise_keypair * new_keypair)198*4882a593Smuzhiyun static void add_new_keypair(struct noise_keypairs *keypairs,
199*4882a593Smuzhiyun 			    struct noise_keypair *new_keypair)
200*4882a593Smuzhiyun {
201*4882a593Smuzhiyun 	struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
202*4882a593Smuzhiyun 
203*4882a593Smuzhiyun 	spin_lock_bh(&keypairs->keypair_update_lock);
204*4882a593Smuzhiyun 	previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
205*4882a593Smuzhiyun 		lockdep_is_held(&keypairs->keypair_update_lock));
206*4882a593Smuzhiyun 	next_keypair = rcu_dereference_protected(keypairs->next_keypair,
207*4882a593Smuzhiyun 		lockdep_is_held(&keypairs->keypair_update_lock));
208*4882a593Smuzhiyun 	current_keypair = rcu_dereference_protected(keypairs->current_keypair,
209*4882a593Smuzhiyun 		lockdep_is_held(&keypairs->keypair_update_lock));
210*4882a593Smuzhiyun 	if (new_keypair->i_am_the_initiator) {
211*4882a593Smuzhiyun 		/* If we're the initiator, it means we've sent a handshake, and
212*4882a593Smuzhiyun 		 * received a confirmation response, which means this new
213*4882a593Smuzhiyun 		 * keypair can now be used.
214*4882a593Smuzhiyun 		 */
215*4882a593Smuzhiyun 		if (next_keypair) {
216*4882a593Smuzhiyun 			/* If there already was a next keypair pending, we
217*4882a593Smuzhiyun 			 * demote it to be the previous keypair, and free the
218*4882a593Smuzhiyun 			 * existing current. Note that this means KCI can result
219*4882a593Smuzhiyun 			 * in this transition. It would perhaps be more sound to
220*4882a593Smuzhiyun 			 * always just get rid of the unused next keypair
221*4882a593Smuzhiyun 			 * instead of putting it in the previous slot, but this
222*4882a593Smuzhiyun 			 * might be a bit less robust. Something to think about
223*4882a593Smuzhiyun 			 * for the future.
224*4882a593Smuzhiyun 			 */
225*4882a593Smuzhiyun 			RCU_INIT_POINTER(keypairs->next_keypair, NULL);
226*4882a593Smuzhiyun 			rcu_assign_pointer(keypairs->previous_keypair,
227*4882a593Smuzhiyun 					   next_keypair);
228*4882a593Smuzhiyun 			wg_noise_keypair_put(current_keypair, true);
229*4882a593Smuzhiyun 		} else /* If there wasn't an existing next keypair, we replace
230*4882a593Smuzhiyun 			* the previous with the current one.
231*4882a593Smuzhiyun 			*/
232*4882a593Smuzhiyun 			rcu_assign_pointer(keypairs->previous_keypair,
233*4882a593Smuzhiyun 					   current_keypair);
234*4882a593Smuzhiyun 		/* At this point we can get rid of the old previous keypair, and
235*4882a593Smuzhiyun 		 * set up the new keypair.
236*4882a593Smuzhiyun 		 */
237*4882a593Smuzhiyun 		wg_noise_keypair_put(previous_keypair, true);
238*4882a593Smuzhiyun 		rcu_assign_pointer(keypairs->current_keypair, new_keypair);
239*4882a593Smuzhiyun 	} else {
240*4882a593Smuzhiyun 		/* If we're the responder, it means we can't use the new keypair
241*4882a593Smuzhiyun 		 * until we receive confirmation via the first data packet, so
242*4882a593Smuzhiyun 		 * we get rid of the existing previous one, the possibly
243*4882a593Smuzhiyun 		 * existing next one, and slide in the new next one.
244*4882a593Smuzhiyun 		 */
245*4882a593Smuzhiyun 		rcu_assign_pointer(keypairs->next_keypair, new_keypair);
246*4882a593Smuzhiyun 		wg_noise_keypair_put(next_keypair, true);
247*4882a593Smuzhiyun 		RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
248*4882a593Smuzhiyun 		wg_noise_keypair_put(previous_keypair, true);
249*4882a593Smuzhiyun 	}
250*4882a593Smuzhiyun 	spin_unlock_bh(&keypairs->keypair_update_lock);
251*4882a593Smuzhiyun }
252*4882a593Smuzhiyun 
wg_noise_received_with_keypair(struct noise_keypairs * keypairs,struct noise_keypair * received_keypair)253*4882a593Smuzhiyun bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
254*4882a593Smuzhiyun 				    struct noise_keypair *received_keypair)
255*4882a593Smuzhiyun {
256*4882a593Smuzhiyun 	struct noise_keypair *old_keypair;
257*4882a593Smuzhiyun 	bool key_is_new;
258*4882a593Smuzhiyun 
259*4882a593Smuzhiyun 	/* We first check without taking the spinlock. */
260*4882a593Smuzhiyun 	key_is_new = received_keypair ==
261*4882a593Smuzhiyun 		     rcu_access_pointer(keypairs->next_keypair);
262*4882a593Smuzhiyun 	if (likely(!key_is_new))
263*4882a593Smuzhiyun 		return false;
264*4882a593Smuzhiyun 
265*4882a593Smuzhiyun 	spin_lock_bh(&keypairs->keypair_update_lock);
266*4882a593Smuzhiyun 	/* After locking, we double check that things didn't change from
267*4882a593Smuzhiyun 	 * beneath us.
268*4882a593Smuzhiyun 	 */
269*4882a593Smuzhiyun 	if (unlikely(received_keypair !=
270*4882a593Smuzhiyun 		    rcu_dereference_protected(keypairs->next_keypair,
271*4882a593Smuzhiyun 			    lockdep_is_held(&keypairs->keypair_update_lock)))) {
272*4882a593Smuzhiyun 		spin_unlock_bh(&keypairs->keypair_update_lock);
273*4882a593Smuzhiyun 		return false;
274*4882a593Smuzhiyun 	}
275*4882a593Smuzhiyun 
276*4882a593Smuzhiyun 	/* When we've finally received the confirmation, we slide the next
277*4882a593Smuzhiyun 	 * into the current, the current into the previous, and get rid of
278*4882a593Smuzhiyun 	 * the old previous.
279*4882a593Smuzhiyun 	 */
280*4882a593Smuzhiyun 	old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
281*4882a593Smuzhiyun 		lockdep_is_held(&keypairs->keypair_update_lock));
282*4882a593Smuzhiyun 	rcu_assign_pointer(keypairs->previous_keypair,
283*4882a593Smuzhiyun 		rcu_dereference_protected(keypairs->current_keypair,
284*4882a593Smuzhiyun 			lockdep_is_held(&keypairs->keypair_update_lock)));
285*4882a593Smuzhiyun 	wg_noise_keypair_put(old_keypair, true);
286*4882a593Smuzhiyun 	rcu_assign_pointer(keypairs->current_keypair, received_keypair);
287*4882a593Smuzhiyun 	RCU_INIT_POINTER(keypairs->next_keypair, NULL);
288*4882a593Smuzhiyun 
289*4882a593Smuzhiyun 	spin_unlock_bh(&keypairs->keypair_update_lock);
290*4882a593Smuzhiyun 	return true;
291*4882a593Smuzhiyun }
292*4882a593Smuzhiyun 
293*4882a593Smuzhiyun /* Must hold static_identity->lock */
wg_noise_set_static_identity_private_key(struct noise_static_identity * static_identity,const u8 private_key[NOISE_PUBLIC_KEY_LEN])294*4882a593Smuzhiyun void wg_noise_set_static_identity_private_key(
295*4882a593Smuzhiyun 	struct noise_static_identity *static_identity,
296*4882a593Smuzhiyun 	const u8 private_key[NOISE_PUBLIC_KEY_LEN])
297*4882a593Smuzhiyun {
298*4882a593Smuzhiyun 	memcpy(static_identity->static_private, private_key,
299*4882a593Smuzhiyun 	       NOISE_PUBLIC_KEY_LEN);
300*4882a593Smuzhiyun 	curve25519_clamp_secret(static_identity->static_private);
301*4882a593Smuzhiyun 	static_identity->has_identity = curve25519_generate_public(
302*4882a593Smuzhiyun 		static_identity->static_public, private_key);
303*4882a593Smuzhiyun }
304*4882a593Smuzhiyun 
hmac(u8 * out,const u8 * in,const u8 * key,const size_t inlen,const size_t keylen)305*4882a593Smuzhiyun static void hmac(u8 *out, const u8 *in, const u8 *key, const size_t inlen, const size_t keylen)
306*4882a593Smuzhiyun {
307*4882a593Smuzhiyun 	struct blake2s_state state;
308*4882a593Smuzhiyun 	u8 x_key[BLAKE2S_BLOCK_SIZE] __aligned(__alignof__(u32)) = { 0 };
309*4882a593Smuzhiyun 	u8 i_hash[BLAKE2S_HASH_SIZE] __aligned(__alignof__(u32));
310*4882a593Smuzhiyun 	int i;
311*4882a593Smuzhiyun 
312*4882a593Smuzhiyun 	if (keylen > BLAKE2S_BLOCK_SIZE) {
313*4882a593Smuzhiyun 		blake2s_init(&state, BLAKE2S_HASH_SIZE);
314*4882a593Smuzhiyun 		blake2s_update(&state, key, keylen);
315*4882a593Smuzhiyun 		blake2s_final(&state, x_key);
316*4882a593Smuzhiyun 	} else
317*4882a593Smuzhiyun 		memcpy(x_key, key, keylen);
318*4882a593Smuzhiyun 
319*4882a593Smuzhiyun 	for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
320*4882a593Smuzhiyun 		x_key[i] ^= 0x36;
321*4882a593Smuzhiyun 
322*4882a593Smuzhiyun 	blake2s_init(&state, BLAKE2S_HASH_SIZE);
323*4882a593Smuzhiyun 	blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
324*4882a593Smuzhiyun 	blake2s_update(&state, in, inlen);
325*4882a593Smuzhiyun 	blake2s_final(&state, i_hash);
326*4882a593Smuzhiyun 
327*4882a593Smuzhiyun 	for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
328*4882a593Smuzhiyun 		x_key[i] ^= 0x5c ^ 0x36;
329*4882a593Smuzhiyun 
330*4882a593Smuzhiyun 	blake2s_init(&state, BLAKE2S_HASH_SIZE);
331*4882a593Smuzhiyun 	blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
332*4882a593Smuzhiyun 	blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE);
333*4882a593Smuzhiyun 	blake2s_final(&state, i_hash);
334*4882a593Smuzhiyun 
335*4882a593Smuzhiyun 	memcpy(out, i_hash, BLAKE2S_HASH_SIZE);
336*4882a593Smuzhiyun 	memzero_explicit(x_key, BLAKE2S_BLOCK_SIZE);
337*4882a593Smuzhiyun 	memzero_explicit(i_hash, BLAKE2S_HASH_SIZE);
338*4882a593Smuzhiyun }
339*4882a593Smuzhiyun 
340*4882a593Smuzhiyun /* This is Hugo Krawczyk's HKDF:
341*4882a593Smuzhiyun  *  - https://eprint.iacr.org/2010/264.pdf
342*4882a593Smuzhiyun  *  - https://tools.ietf.org/html/rfc5869
343*4882a593Smuzhiyun  */
kdf(u8 * first_dst,u8 * second_dst,u8 * third_dst,const u8 * data,size_t first_len,size_t second_len,size_t third_len,size_t data_len,const u8 chaining_key[NOISE_HASH_LEN])344*4882a593Smuzhiyun static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
345*4882a593Smuzhiyun 		size_t first_len, size_t second_len, size_t third_len,
346*4882a593Smuzhiyun 		size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
347*4882a593Smuzhiyun {
348*4882a593Smuzhiyun 	u8 output[BLAKE2S_HASH_SIZE + 1];
349*4882a593Smuzhiyun 	u8 secret[BLAKE2S_HASH_SIZE];
350*4882a593Smuzhiyun 
351*4882a593Smuzhiyun 	WARN_ON(IS_ENABLED(DEBUG) &&
352*4882a593Smuzhiyun 		(first_len > BLAKE2S_HASH_SIZE ||
353*4882a593Smuzhiyun 		 second_len > BLAKE2S_HASH_SIZE ||
354*4882a593Smuzhiyun 		 third_len > BLAKE2S_HASH_SIZE ||
355*4882a593Smuzhiyun 		 ((second_len || second_dst || third_len || third_dst) &&
356*4882a593Smuzhiyun 		  (!first_len || !first_dst)) ||
357*4882a593Smuzhiyun 		 ((third_len || third_dst) && (!second_len || !second_dst))));
358*4882a593Smuzhiyun 
359*4882a593Smuzhiyun 	/* Extract entropy from data into secret */
360*4882a593Smuzhiyun 	hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
361*4882a593Smuzhiyun 
362*4882a593Smuzhiyun 	if (!first_dst || !first_len)
363*4882a593Smuzhiyun 		goto out;
364*4882a593Smuzhiyun 
365*4882a593Smuzhiyun 	/* Expand first key: key = secret, data = 0x1 */
366*4882a593Smuzhiyun 	output[0] = 1;
367*4882a593Smuzhiyun 	hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
368*4882a593Smuzhiyun 	memcpy(first_dst, output, first_len);
369*4882a593Smuzhiyun 
370*4882a593Smuzhiyun 	if (!second_dst || !second_len)
371*4882a593Smuzhiyun 		goto out;
372*4882a593Smuzhiyun 
373*4882a593Smuzhiyun 	/* Expand second key: key = secret, data = first-key || 0x2 */
374*4882a593Smuzhiyun 	output[BLAKE2S_HASH_SIZE] = 2;
375*4882a593Smuzhiyun 	hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
376*4882a593Smuzhiyun 	memcpy(second_dst, output, second_len);
377*4882a593Smuzhiyun 
378*4882a593Smuzhiyun 	if (!third_dst || !third_len)
379*4882a593Smuzhiyun 		goto out;
380*4882a593Smuzhiyun 
381*4882a593Smuzhiyun 	/* Expand third key: key = secret, data = second-key || 0x3 */
382*4882a593Smuzhiyun 	output[BLAKE2S_HASH_SIZE] = 3;
383*4882a593Smuzhiyun 	hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
384*4882a593Smuzhiyun 	memcpy(third_dst, output, third_len);
385*4882a593Smuzhiyun 
386*4882a593Smuzhiyun out:
387*4882a593Smuzhiyun 	/* Clear sensitive data from stack */
388*4882a593Smuzhiyun 	memzero_explicit(secret, BLAKE2S_HASH_SIZE);
389*4882a593Smuzhiyun 	memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
390*4882a593Smuzhiyun }
391*4882a593Smuzhiyun 
derive_keys(struct noise_symmetric_key * first_dst,struct noise_symmetric_key * second_dst,const u8 chaining_key[NOISE_HASH_LEN])392*4882a593Smuzhiyun static void derive_keys(struct noise_symmetric_key *first_dst,
393*4882a593Smuzhiyun 			struct noise_symmetric_key *second_dst,
394*4882a593Smuzhiyun 			const u8 chaining_key[NOISE_HASH_LEN])
395*4882a593Smuzhiyun {
396*4882a593Smuzhiyun 	u64 birthdate = ktime_get_coarse_boottime_ns();
397*4882a593Smuzhiyun 	kdf(first_dst->key, second_dst->key, NULL, NULL,
398*4882a593Smuzhiyun 	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
399*4882a593Smuzhiyun 	    chaining_key);
400*4882a593Smuzhiyun 	first_dst->birthdate = second_dst->birthdate = birthdate;
401*4882a593Smuzhiyun 	first_dst->is_valid = second_dst->is_valid = true;
402*4882a593Smuzhiyun }
403*4882a593Smuzhiyun 
mix_dh(u8 chaining_key[NOISE_HASH_LEN],u8 key[NOISE_SYMMETRIC_KEY_LEN],const u8 private[NOISE_PUBLIC_KEY_LEN],const u8 public[NOISE_PUBLIC_KEY_LEN])404*4882a593Smuzhiyun static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
405*4882a593Smuzhiyun 				u8 key[NOISE_SYMMETRIC_KEY_LEN],
406*4882a593Smuzhiyun 				const u8 private[NOISE_PUBLIC_KEY_LEN],
407*4882a593Smuzhiyun 				const u8 public[NOISE_PUBLIC_KEY_LEN])
408*4882a593Smuzhiyun {
409*4882a593Smuzhiyun 	u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
410*4882a593Smuzhiyun 
411*4882a593Smuzhiyun 	if (unlikely(!curve25519(dh_calculation, private, public)))
412*4882a593Smuzhiyun 		return false;
413*4882a593Smuzhiyun 	kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
414*4882a593Smuzhiyun 	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
415*4882a593Smuzhiyun 	memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
416*4882a593Smuzhiyun 	return true;
417*4882a593Smuzhiyun }
418*4882a593Smuzhiyun 
mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],u8 key[NOISE_SYMMETRIC_KEY_LEN],const u8 precomputed[NOISE_PUBLIC_KEY_LEN])419*4882a593Smuzhiyun static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
420*4882a593Smuzhiyun 					    u8 key[NOISE_SYMMETRIC_KEY_LEN],
421*4882a593Smuzhiyun 					    const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
422*4882a593Smuzhiyun {
423*4882a593Smuzhiyun 	static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
424*4882a593Smuzhiyun 	if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
425*4882a593Smuzhiyun 		return false;
426*4882a593Smuzhiyun 	kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
427*4882a593Smuzhiyun 	    NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
428*4882a593Smuzhiyun 	    chaining_key);
429*4882a593Smuzhiyun 	return true;
430*4882a593Smuzhiyun }
431*4882a593Smuzhiyun 
mix_hash(u8 hash[NOISE_HASH_LEN],const u8 * src,size_t src_len)432*4882a593Smuzhiyun static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
433*4882a593Smuzhiyun {
434*4882a593Smuzhiyun 	struct blake2s_state blake;
435*4882a593Smuzhiyun 
436*4882a593Smuzhiyun 	blake2s_init(&blake, NOISE_HASH_LEN);
437*4882a593Smuzhiyun 	blake2s_update(&blake, hash, NOISE_HASH_LEN);
438*4882a593Smuzhiyun 	blake2s_update(&blake, src, src_len);
439*4882a593Smuzhiyun 	blake2s_final(&blake, hash);
440*4882a593Smuzhiyun }
441*4882a593Smuzhiyun 
mix_psk(u8 chaining_key[NOISE_HASH_LEN],u8 hash[NOISE_HASH_LEN],u8 key[NOISE_SYMMETRIC_KEY_LEN],const u8 psk[NOISE_SYMMETRIC_KEY_LEN])442*4882a593Smuzhiyun static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
443*4882a593Smuzhiyun 		    u8 key[NOISE_SYMMETRIC_KEY_LEN],
444*4882a593Smuzhiyun 		    const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
445*4882a593Smuzhiyun {
446*4882a593Smuzhiyun 	u8 temp_hash[NOISE_HASH_LEN];
447*4882a593Smuzhiyun 
448*4882a593Smuzhiyun 	kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
449*4882a593Smuzhiyun 	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
450*4882a593Smuzhiyun 	mix_hash(hash, temp_hash, NOISE_HASH_LEN);
451*4882a593Smuzhiyun 	memzero_explicit(temp_hash, NOISE_HASH_LEN);
452*4882a593Smuzhiyun }
453*4882a593Smuzhiyun 
handshake_init(u8 chaining_key[NOISE_HASH_LEN],u8 hash[NOISE_HASH_LEN],const u8 remote_static[NOISE_PUBLIC_KEY_LEN])454*4882a593Smuzhiyun static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
455*4882a593Smuzhiyun 			   u8 hash[NOISE_HASH_LEN],
456*4882a593Smuzhiyun 			   const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
457*4882a593Smuzhiyun {
458*4882a593Smuzhiyun 	memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
459*4882a593Smuzhiyun 	memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
460*4882a593Smuzhiyun 	mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
461*4882a593Smuzhiyun }
462*4882a593Smuzhiyun 
message_encrypt(u8 * dst_ciphertext,const u8 * src_plaintext,size_t src_len,u8 key[NOISE_SYMMETRIC_KEY_LEN],u8 hash[NOISE_HASH_LEN])463*4882a593Smuzhiyun static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
464*4882a593Smuzhiyun 			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
465*4882a593Smuzhiyun 			    u8 hash[NOISE_HASH_LEN])
466*4882a593Smuzhiyun {
467*4882a593Smuzhiyun 	chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
468*4882a593Smuzhiyun 				 NOISE_HASH_LEN,
469*4882a593Smuzhiyun 				 0 /* Always zero for Noise_IK */, key);
470*4882a593Smuzhiyun 	mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
471*4882a593Smuzhiyun }
472*4882a593Smuzhiyun 
message_decrypt(u8 * dst_plaintext,const u8 * src_ciphertext,size_t src_len,u8 key[NOISE_SYMMETRIC_KEY_LEN],u8 hash[NOISE_HASH_LEN])473*4882a593Smuzhiyun static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
474*4882a593Smuzhiyun 			    size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
475*4882a593Smuzhiyun 			    u8 hash[NOISE_HASH_LEN])
476*4882a593Smuzhiyun {
477*4882a593Smuzhiyun 	if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
478*4882a593Smuzhiyun 				      hash, NOISE_HASH_LEN,
479*4882a593Smuzhiyun 				      0 /* Always zero for Noise_IK */, key))
480*4882a593Smuzhiyun 		return false;
481*4882a593Smuzhiyun 	mix_hash(hash, src_ciphertext, src_len);
482*4882a593Smuzhiyun 	return true;
483*4882a593Smuzhiyun }
484*4882a593Smuzhiyun 
message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],u8 chaining_key[NOISE_HASH_LEN],u8 hash[NOISE_HASH_LEN])485*4882a593Smuzhiyun static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
486*4882a593Smuzhiyun 			      const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
487*4882a593Smuzhiyun 			      u8 chaining_key[NOISE_HASH_LEN],
488*4882a593Smuzhiyun 			      u8 hash[NOISE_HASH_LEN])
489*4882a593Smuzhiyun {
490*4882a593Smuzhiyun 	if (ephemeral_dst != ephemeral_src)
491*4882a593Smuzhiyun 		memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
492*4882a593Smuzhiyun 	mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
493*4882a593Smuzhiyun 	kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
494*4882a593Smuzhiyun 	    NOISE_PUBLIC_KEY_LEN, chaining_key);
495*4882a593Smuzhiyun }
496*4882a593Smuzhiyun 
tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])497*4882a593Smuzhiyun static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
498*4882a593Smuzhiyun {
499*4882a593Smuzhiyun 	struct timespec64 now;
500*4882a593Smuzhiyun 
501*4882a593Smuzhiyun 	ktime_get_real_ts64(&now);
502*4882a593Smuzhiyun 
503*4882a593Smuzhiyun 	/* In order to prevent some sort of infoleak from precise timers, we
504*4882a593Smuzhiyun 	 * round down the nanoseconds part to the closest rounded-down power of
505*4882a593Smuzhiyun 	 * two to the maximum initiations per second allowed anyway by the
506*4882a593Smuzhiyun 	 * implementation.
507*4882a593Smuzhiyun 	 */
508*4882a593Smuzhiyun 	now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
509*4882a593Smuzhiyun 		rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
510*4882a593Smuzhiyun 
511*4882a593Smuzhiyun 	/* https://cr.yp.to/libtai/tai64.html */
512*4882a593Smuzhiyun 	*(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
513*4882a593Smuzhiyun 	*(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
514*4882a593Smuzhiyun }
515*4882a593Smuzhiyun 
516*4882a593Smuzhiyun bool
wg_noise_handshake_create_initiation(struct message_handshake_initiation * dst,struct noise_handshake * handshake)517*4882a593Smuzhiyun wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
518*4882a593Smuzhiyun 				     struct noise_handshake *handshake)
519*4882a593Smuzhiyun {
520*4882a593Smuzhiyun 	u8 timestamp[NOISE_TIMESTAMP_LEN];
521*4882a593Smuzhiyun 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
522*4882a593Smuzhiyun 	bool ret = false;
523*4882a593Smuzhiyun 
524*4882a593Smuzhiyun 	/* We need to wait for crng _before_ taking any locks, since
525*4882a593Smuzhiyun 	 * curve25519_generate_secret uses get_random_bytes_wait.
526*4882a593Smuzhiyun 	 */
527*4882a593Smuzhiyun 	wait_for_random_bytes();
528*4882a593Smuzhiyun 
529*4882a593Smuzhiyun 	down_read(&handshake->static_identity->lock);
530*4882a593Smuzhiyun 	down_write(&handshake->lock);
531*4882a593Smuzhiyun 
532*4882a593Smuzhiyun 	if (unlikely(!handshake->static_identity->has_identity))
533*4882a593Smuzhiyun 		goto out;
534*4882a593Smuzhiyun 
535*4882a593Smuzhiyun 	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
536*4882a593Smuzhiyun 
537*4882a593Smuzhiyun 	handshake_init(handshake->chaining_key, handshake->hash,
538*4882a593Smuzhiyun 		       handshake->remote_static);
539*4882a593Smuzhiyun 
540*4882a593Smuzhiyun 	/* e */
541*4882a593Smuzhiyun 	curve25519_generate_secret(handshake->ephemeral_private);
542*4882a593Smuzhiyun 	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
543*4882a593Smuzhiyun 					handshake->ephemeral_private))
544*4882a593Smuzhiyun 		goto out;
545*4882a593Smuzhiyun 	message_ephemeral(dst->unencrypted_ephemeral,
546*4882a593Smuzhiyun 			  dst->unencrypted_ephemeral, handshake->chaining_key,
547*4882a593Smuzhiyun 			  handshake->hash);
548*4882a593Smuzhiyun 
549*4882a593Smuzhiyun 	/* es */
550*4882a593Smuzhiyun 	if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
551*4882a593Smuzhiyun 		    handshake->remote_static))
552*4882a593Smuzhiyun 		goto out;
553*4882a593Smuzhiyun 
554*4882a593Smuzhiyun 	/* s */
555*4882a593Smuzhiyun 	message_encrypt(dst->encrypted_static,
556*4882a593Smuzhiyun 			handshake->static_identity->static_public,
557*4882a593Smuzhiyun 			NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
558*4882a593Smuzhiyun 
559*4882a593Smuzhiyun 	/* ss */
560*4882a593Smuzhiyun 	if (!mix_precomputed_dh(handshake->chaining_key, key,
561*4882a593Smuzhiyun 				handshake->precomputed_static_static))
562*4882a593Smuzhiyun 		goto out;
563*4882a593Smuzhiyun 
564*4882a593Smuzhiyun 	/* {t} */
565*4882a593Smuzhiyun 	tai64n_now(timestamp);
566*4882a593Smuzhiyun 	message_encrypt(dst->encrypted_timestamp, timestamp,
567*4882a593Smuzhiyun 			NOISE_TIMESTAMP_LEN, key, handshake->hash);
568*4882a593Smuzhiyun 
569*4882a593Smuzhiyun 	dst->sender_index = wg_index_hashtable_insert(
570*4882a593Smuzhiyun 		handshake->entry.peer->device->index_hashtable,
571*4882a593Smuzhiyun 		&handshake->entry);
572*4882a593Smuzhiyun 
573*4882a593Smuzhiyun 	handshake->state = HANDSHAKE_CREATED_INITIATION;
574*4882a593Smuzhiyun 	ret = true;
575*4882a593Smuzhiyun 
576*4882a593Smuzhiyun out:
577*4882a593Smuzhiyun 	up_write(&handshake->lock);
578*4882a593Smuzhiyun 	up_read(&handshake->static_identity->lock);
579*4882a593Smuzhiyun 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
580*4882a593Smuzhiyun 	return ret;
581*4882a593Smuzhiyun }
582*4882a593Smuzhiyun 
583*4882a593Smuzhiyun struct wg_peer *
wg_noise_handshake_consume_initiation(struct message_handshake_initiation * src,struct wg_device * wg)584*4882a593Smuzhiyun wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
585*4882a593Smuzhiyun 				      struct wg_device *wg)
586*4882a593Smuzhiyun {
587*4882a593Smuzhiyun 	struct wg_peer *peer = NULL, *ret_peer = NULL;
588*4882a593Smuzhiyun 	struct noise_handshake *handshake;
589*4882a593Smuzhiyun 	bool replay_attack, flood_attack;
590*4882a593Smuzhiyun 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
591*4882a593Smuzhiyun 	u8 chaining_key[NOISE_HASH_LEN];
592*4882a593Smuzhiyun 	u8 hash[NOISE_HASH_LEN];
593*4882a593Smuzhiyun 	u8 s[NOISE_PUBLIC_KEY_LEN];
594*4882a593Smuzhiyun 	u8 e[NOISE_PUBLIC_KEY_LEN];
595*4882a593Smuzhiyun 	u8 t[NOISE_TIMESTAMP_LEN];
596*4882a593Smuzhiyun 	u64 initiation_consumption;
597*4882a593Smuzhiyun 
598*4882a593Smuzhiyun 	down_read(&wg->static_identity.lock);
599*4882a593Smuzhiyun 	if (unlikely(!wg->static_identity.has_identity))
600*4882a593Smuzhiyun 		goto out;
601*4882a593Smuzhiyun 
602*4882a593Smuzhiyun 	handshake_init(chaining_key, hash, wg->static_identity.static_public);
603*4882a593Smuzhiyun 
604*4882a593Smuzhiyun 	/* e */
605*4882a593Smuzhiyun 	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
606*4882a593Smuzhiyun 
607*4882a593Smuzhiyun 	/* es */
608*4882a593Smuzhiyun 	if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
609*4882a593Smuzhiyun 		goto out;
610*4882a593Smuzhiyun 
611*4882a593Smuzhiyun 	/* s */
612*4882a593Smuzhiyun 	if (!message_decrypt(s, src->encrypted_static,
613*4882a593Smuzhiyun 			     sizeof(src->encrypted_static), key, hash))
614*4882a593Smuzhiyun 		goto out;
615*4882a593Smuzhiyun 
616*4882a593Smuzhiyun 	/* Lookup which peer we're actually talking to */
617*4882a593Smuzhiyun 	peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
618*4882a593Smuzhiyun 	if (!peer)
619*4882a593Smuzhiyun 		goto out;
620*4882a593Smuzhiyun 	handshake = &peer->handshake;
621*4882a593Smuzhiyun 
622*4882a593Smuzhiyun 	/* ss */
623*4882a593Smuzhiyun 	if (!mix_precomputed_dh(chaining_key, key,
624*4882a593Smuzhiyun 				handshake->precomputed_static_static))
625*4882a593Smuzhiyun 	    goto out;
626*4882a593Smuzhiyun 
627*4882a593Smuzhiyun 	/* {t} */
628*4882a593Smuzhiyun 	if (!message_decrypt(t, src->encrypted_timestamp,
629*4882a593Smuzhiyun 			     sizeof(src->encrypted_timestamp), key, hash))
630*4882a593Smuzhiyun 		goto out;
631*4882a593Smuzhiyun 
632*4882a593Smuzhiyun 	down_read(&handshake->lock);
633*4882a593Smuzhiyun 	replay_attack = memcmp(t, handshake->latest_timestamp,
634*4882a593Smuzhiyun 			       NOISE_TIMESTAMP_LEN) <= 0;
635*4882a593Smuzhiyun 	flood_attack = (s64)handshake->last_initiation_consumption +
636*4882a593Smuzhiyun 			       NSEC_PER_SEC / INITIATIONS_PER_SECOND >
637*4882a593Smuzhiyun 		       (s64)ktime_get_coarse_boottime_ns();
638*4882a593Smuzhiyun 	up_read(&handshake->lock);
639*4882a593Smuzhiyun 	if (replay_attack || flood_attack)
640*4882a593Smuzhiyun 		goto out;
641*4882a593Smuzhiyun 
642*4882a593Smuzhiyun 	/* Success! Copy everything to peer */
643*4882a593Smuzhiyun 	down_write(&handshake->lock);
644*4882a593Smuzhiyun 	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
645*4882a593Smuzhiyun 	if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
646*4882a593Smuzhiyun 		memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
647*4882a593Smuzhiyun 	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
648*4882a593Smuzhiyun 	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
649*4882a593Smuzhiyun 	handshake->remote_index = src->sender_index;
650*4882a593Smuzhiyun 	initiation_consumption = ktime_get_coarse_boottime_ns();
651*4882a593Smuzhiyun 	if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
652*4882a593Smuzhiyun 		handshake->last_initiation_consumption = initiation_consumption;
653*4882a593Smuzhiyun 	handshake->state = HANDSHAKE_CONSUMED_INITIATION;
654*4882a593Smuzhiyun 	up_write(&handshake->lock);
655*4882a593Smuzhiyun 	ret_peer = peer;
656*4882a593Smuzhiyun 
657*4882a593Smuzhiyun out:
658*4882a593Smuzhiyun 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
659*4882a593Smuzhiyun 	memzero_explicit(hash, NOISE_HASH_LEN);
660*4882a593Smuzhiyun 	memzero_explicit(chaining_key, NOISE_HASH_LEN);
661*4882a593Smuzhiyun 	up_read(&wg->static_identity.lock);
662*4882a593Smuzhiyun 	if (!ret_peer)
663*4882a593Smuzhiyun 		wg_peer_put(peer);
664*4882a593Smuzhiyun 	return ret_peer;
665*4882a593Smuzhiyun }
666*4882a593Smuzhiyun 
wg_noise_handshake_create_response(struct message_handshake_response * dst,struct noise_handshake * handshake)667*4882a593Smuzhiyun bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
668*4882a593Smuzhiyun 					struct noise_handshake *handshake)
669*4882a593Smuzhiyun {
670*4882a593Smuzhiyun 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
671*4882a593Smuzhiyun 	bool ret = false;
672*4882a593Smuzhiyun 
673*4882a593Smuzhiyun 	/* We need to wait for crng _before_ taking any locks, since
674*4882a593Smuzhiyun 	 * curve25519_generate_secret uses get_random_bytes_wait.
675*4882a593Smuzhiyun 	 */
676*4882a593Smuzhiyun 	wait_for_random_bytes();
677*4882a593Smuzhiyun 
678*4882a593Smuzhiyun 	down_read(&handshake->static_identity->lock);
679*4882a593Smuzhiyun 	down_write(&handshake->lock);
680*4882a593Smuzhiyun 
681*4882a593Smuzhiyun 	if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
682*4882a593Smuzhiyun 		goto out;
683*4882a593Smuzhiyun 
684*4882a593Smuzhiyun 	dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
685*4882a593Smuzhiyun 	dst->receiver_index = handshake->remote_index;
686*4882a593Smuzhiyun 
687*4882a593Smuzhiyun 	/* e */
688*4882a593Smuzhiyun 	curve25519_generate_secret(handshake->ephemeral_private);
689*4882a593Smuzhiyun 	if (!curve25519_generate_public(dst->unencrypted_ephemeral,
690*4882a593Smuzhiyun 					handshake->ephemeral_private))
691*4882a593Smuzhiyun 		goto out;
692*4882a593Smuzhiyun 	message_ephemeral(dst->unencrypted_ephemeral,
693*4882a593Smuzhiyun 			  dst->unencrypted_ephemeral, handshake->chaining_key,
694*4882a593Smuzhiyun 			  handshake->hash);
695*4882a593Smuzhiyun 
696*4882a593Smuzhiyun 	/* ee */
697*4882a593Smuzhiyun 	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
698*4882a593Smuzhiyun 		    handshake->remote_ephemeral))
699*4882a593Smuzhiyun 		goto out;
700*4882a593Smuzhiyun 
701*4882a593Smuzhiyun 	/* se */
702*4882a593Smuzhiyun 	if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
703*4882a593Smuzhiyun 		    handshake->remote_static))
704*4882a593Smuzhiyun 		goto out;
705*4882a593Smuzhiyun 
706*4882a593Smuzhiyun 	/* psk */
707*4882a593Smuzhiyun 	mix_psk(handshake->chaining_key, handshake->hash, key,
708*4882a593Smuzhiyun 		handshake->preshared_key);
709*4882a593Smuzhiyun 
710*4882a593Smuzhiyun 	/* {} */
711*4882a593Smuzhiyun 	message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
712*4882a593Smuzhiyun 
713*4882a593Smuzhiyun 	dst->sender_index = wg_index_hashtable_insert(
714*4882a593Smuzhiyun 		handshake->entry.peer->device->index_hashtable,
715*4882a593Smuzhiyun 		&handshake->entry);
716*4882a593Smuzhiyun 
717*4882a593Smuzhiyun 	handshake->state = HANDSHAKE_CREATED_RESPONSE;
718*4882a593Smuzhiyun 	ret = true;
719*4882a593Smuzhiyun 
720*4882a593Smuzhiyun out:
721*4882a593Smuzhiyun 	up_write(&handshake->lock);
722*4882a593Smuzhiyun 	up_read(&handshake->static_identity->lock);
723*4882a593Smuzhiyun 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
724*4882a593Smuzhiyun 	return ret;
725*4882a593Smuzhiyun }
726*4882a593Smuzhiyun 
727*4882a593Smuzhiyun struct wg_peer *
wg_noise_handshake_consume_response(struct message_handshake_response * src,struct wg_device * wg)728*4882a593Smuzhiyun wg_noise_handshake_consume_response(struct message_handshake_response *src,
729*4882a593Smuzhiyun 				    struct wg_device *wg)
730*4882a593Smuzhiyun {
731*4882a593Smuzhiyun 	enum noise_handshake_state state = HANDSHAKE_ZEROED;
732*4882a593Smuzhiyun 	struct wg_peer *peer = NULL, *ret_peer = NULL;
733*4882a593Smuzhiyun 	struct noise_handshake *handshake;
734*4882a593Smuzhiyun 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
735*4882a593Smuzhiyun 	u8 hash[NOISE_HASH_LEN];
736*4882a593Smuzhiyun 	u8 chaining_key[NOISE_HASH_LEN];
737*4882a593Smuzhiyun 	u8 e[NOISE_PUBLIC_KEY_LEN];
738*4882a593Smuzhiyun 	u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
739*4882a593Smuzhiyun 	u8 static_private[NOISE_PUBLIC_KEY_LEN];
740*4882a593Smuzhiyun 	u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
741*4882a593Smuzhiyun 
742*4882a593Smuzhiyun 	down_read(&wg->static_identity.lock);
743*4882a593Smuzhiyun 
744*4882a593Smuzhiyun 	if (unlikely(!wg->static_identity.has_identity))
745*4882a593Smuzhiyun 		goto out;
746*4882a593Smuzhiyun 
747*4882a593Smuzhiyun 	handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
748*4882a593Smuzhiyun 		wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
749*4882a593Smuzhiyun 		src->receiver_index, &peer);
750*4882a593Smuzhiyun 	if (unlikely(!handshake))
751*4882a593Smuzhiyun 		goto out;
752*4882a593Smuzhiyun 
753*4882a593Smuzhiyun 	down_read(&handshake->lock);
754*4882a593Smuzhiyun 	state = handshake->state;
755*4882a593Smuzhiyun 	memcpy(hash, handshake->hash, NOISE_HASH_LEN);
756*4882a593Smuzhiyun 	memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
757*4882a593Smuzhiyun 	memcpy(ephemeral_private, handshake->ephemeral_private,
758*4882a593Smuzhiyun 	       NOISE_PUBLIC_KEY_LEN);
759*4882a593Smuzhiyun 	memcpy(preshared_key, handshake->preshared_key,
760*4882a593Smuzhiyun 	       NOISE_SYMMETRIC_KEY_LEN);
761*4882a593Smuzhiyun 	up_read(&handshake->lock);
762*4882a593Smuzhiyun 
763*4882a593Smuzhiyun 	if (state != HANDSHAKE_CREATED_INITIATION)
764*4882a593Smuzhiyun 		goto fail;
765*4882a593Smuzhiyun 
766*4882a593Smuzhiyun 	/* e */
767*4882a593Smuzhiyun 	message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
768*4882a593Smuzhiyun 
769*4882a593Smuzhiyun 	/* ee */
770*4882a593Smuzhiyun 	if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
771*4882a593Smuzhiyun 		goto fail;
772*4882a593Smuzhiyun 
773*4882a593Smuzhiyun 	/* se */
774*4882a593Smuzhiyun 	if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
775*4882a593Smuzhiyun 		goto fail;
776*4882a593Smuzhiyun 
777*4882a593Smuzhiyun 	/* psk */
778*4882a593Smuzhiyun 	mix_psk(chaining_key, hash, key, preshared_key);
779*4882a593Smuzhiyun 
780*4882a593Smuzhiyun 	/* {} */
781*4882a593Smuzhiyun 	if (!message_decrypt(NULL, src->encrypted_nothing,
782*4882a593Smuzhiyun 			     sizeof(src->encrypted_nothing), key, hash))
783*4882a593Smuzhiyun 		goto fail;
784*4882a593Smuzhiyun 
785*4882a593Smuzhiyun 	/* Success! Copy everything to peer */
786*4882a593Smuzhiyun 	down_write(&handshake->lock);
787*4882a593Smuzhiyun 	/* It's important to check that the state is still the same, while we
788*4882a593Smuzhiyun 	 * have an exclusive lock.
789*4882a593Smuzhiyun 	 */
790*4882a593Smuzhiyun 	if (handshake->state != state) {
791*4882a593Smuzhiyun 		up_write(&handshake->lock);
792*4882a593Smuzhiyun 		goto fail;
793*4882a593Smuzhiyun 	}
794*4882a593Smuzhiyun 	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
795*4882a593Smuzhiyun 	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
796*4882a593Smuzhiyun 	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
797*4882a593Smuzhiyun 	handshake->remote_index = src->sender_index;
798*4882a593Smuzhiyun 	handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
799*4882a593Smuzhiyun 	up_write(&handshake->lock);
800*4882a593Smuzhiyun 	ret_peer = peer;
801*4882a593Smuzhiyun 	goto out;
802*4882a593Smuzhiyun 
803*4882a593Smuzhiyun fail:
804*4882a593Smuzhiyun 	wg_peer_put(peer);
805*4882a593Smuzhiyun out:
806*4882a593Smuzhiyun 	memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
807*4882a593Smuzhiyun 	memzero_explicit(hash, NOISE_HASH_LEN);
808*4882a593Smuzhiyun 	memzero_explicit(chaining_key, NOISE_HASH_LEN);
809*4882a593Smuzhiyun 	memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
810*4882a593Smuzhiyun 	memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
811*4882a593Smuzhiyun 	memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
812*4882a593Smuzhiyun 	up_read(&wg->static_identity.lock);
813*4882a593Smuzhiyun 	return ret_peer;
814*4882a593Smuzhiyun }
815*4882a593Smuzhiyun 
wg_noise_handshake_begin_session(struct noise_handshake * handshake,struct noise_keypairs * keypairs)816*4882a593Smuzhiyun bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
817*4882a593Smuzhiyun 				      struct noise_keypairs *keypairs)
818*4882a593Smuzhiyun {
819*4882a593Smuzhiyun 	struct noise_keypair *new_keypair;
820*4882a593Smuzhiyun 	bool ret = false;
821*4882a593Smuzhiyun 
822*4882a593Smuzhiyun 	down_write(&handshake->lock);
823*4882a593Smuzhiyun 	if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
824*4882a593Smuzhiyun 	    handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
825*4882a593Smuzhiyun 		goto out;
826*4882a593Smuzhiyun 
827*4882a593Smuzhiyun 	new_keypair = keypair_create(handshake->entry.peer);
828*4882a593Smuzhiyun 	if (!new_keypair)
829*4882a593Smuzhiyun 		goto out;
830*4882a593Smuzhiyun 	new_keypair->i_am_the_initiator = handshake->state ==
831*4882a593Smuzhiyun 					  HANDSHAKE_CONSUMED_RESPONSE;
832*4882a593Smuzhiyun 	new_keypair->remote_index = handshake->remote_index;
833*4882a593Smuzhiyun 
834*4882a593Smuzhiyun 	if (new_keypair->i_am_the_initiator)
835*4882a593Smuzhiyun 		derive_keys(&new_keypair->sending, &new_keypair->receiving,
836*4882a593Smuzhiyun 			    handshake->chaining_key);
837*4882a593Smuzhiyun 	else
838*4882a593Smuzhiyun 		derive_keys(&new_keypair->receiving, &new_keypair->sending,
839*4882a593Smuzhiyun 			    handshake->chaining_key);
840*4882a593Smuzhiyun 
841*4882a593Smuzhiyun 	handshake_zero(handshake);
842*4882a593Smuzhiyun 	rcu_read_lock_bh();
843*4882a593Smuzhiyun 	if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
844*4882a593Smuzhiyun 					   handshake)->is_dead))) {
845*4882a593Smuzhiyun 		add_new_keypair(keypairs, new_keypair);
846*4882a593Smuzhiyun 		net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
847*4882a593Smuzhiyun 				    handshake->entry.peer->device->dev->name,
848*4882a593Smuzhiyun 				    new_keypair->internal_id,
849*4882a593Smuzhiyun 				    handshake->entry.peer->internal_id);
850*4882a593Smuzhiyun 		ret = wg_index_hashtable_replace(
851*4882a593Smuzhiyun 			handshake->entry.peer->device->index_hashtable,
852*4882a593Smuzhiyun 			&handshake->entry, &new_keypair->entry);
853*4882a593Smuzhiyun 	} else {
854*4882a593Smuzhiyun 		kfree_sensitive(new_keypair);
855*4882a593Smuzhiyun 	}
856*4882a593Smuzhiyun 	rcu_read_unlock_bh();
857*4882a593Smuzhiyun 
858*4882a593Smuzhiyun out:
859*4882a593Smuzhiyun 	up_write(&handshake->lock);
860*4882a593Smuzhiyun 	return ret;
861*4882a593Smuzhiyun }
862