xref: /OK3568_Linux_fs/kernel/drivers/net/wireguard/netlink.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0
2*4882a593Smuzhiyun /*
3*4882a593Smuzhiyun  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4*4882a593Smuzhiyun  */
5*4882a593Smuzhiyun 
6*4882a593Smuzhiyun #include "netlink.h"
7*4882a593Smuzhiyun #include "device.h"
8*4882a593Smuzhiyun #include "peer.h"
9*4882a593Smuzhiyun #include "socket.h"
10*4882a593Smuzhiyun #include "queueing.h"
11*4882a593Smuzhiyun #include "messages.h"
12*4882a593Smuzhiyun 
13*4882a593Smuzhiyun #include <uapi/linux/wireguard.h>
14*4882a593Smuzhiyun 
15*4882a593Smuzhiyun #include <linux/if.h>
16*4882a593Smuzhiyun #include <net/genetlink.h>
17*4882a593Smuzhiyun #include <net/sock.h>
18*4882a593Smuzhiyun #include <crypto/algapi.h>
19*4882a593Smuzhiyun 
20*4882a593Smuzhiyun static struct genl_family genl_family;
21*4882a593Smuzhiyun 
22*4882a593Smuzhiyun static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = {
23*4882a593Smuzhiyun 	[WGDEVICE_A_IFINDEX]		= { .type = NLA_U32 },
24*4882a593Smuzhiyun 	[WGDEVICE_A_IFNAME]		= { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 },
25*4882a593Smuzhiyun 	[WGDEVICE_A_PRIVATE_KEY]	= NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
26*4882a593Smuzhiyun 	[WGDEVICE_A_PUBLIC_KEY]		= NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
27*4882a593Smuzhiyun 	[WGDEVICE_A_FLAGS]		= { .type = NLA_U32 },
28*4882a593Smuzhiyun 	[WGDEVICE_A_LISTEN_PORT]	= { .type = NLA_U16 },
29*4882a593Smuzhiyun 	[WGDEVICE_A_FWMARK]		= { .type = NLA_U32 },
30*4882a593Smuzhiyun 	[WGDEVICE_A_PEERS]		= { .type = NLA_NESTED }
31*4882a593Smuzhiyun };
32*4882a593Smuzhiyun 
33*4882a593Smuzhiyun static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = {
34*4882a593Smuzhiyun 	[WGPEER_A_PUBLIC_KEY]				= NLA_POLICY_EXACT_LEN(NOISE_PUBLIC_KEY_LEN),
35*4882a593Smuzhiyun 	[WGPEER_A_PRESHARED_KEY]			= NLA_POLICY_EXACT_LEN(NOISE_SYMMETRIC_KEY_LEN),
36*4882a593Smuzhiyun 	[WGPEER_A_FLAGS]				= { .type = NLA_U32 },
37*4882a593Smuzhiyun 	[WGPEER_A_ENDPOINT]				= NLA_POLICY_MIN_LEN(sizeof(struct sockaddr)),
38*4882a593Smuzhiyun 	[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]	= { .type = NLA_U16 },
39*4882a593Smuzhiyun 	[WGPEER_A_LAST_HANDSHAKE_TIME]			= NLA_POLICY_EXACT_LEN(sizeof(struct __kernel_timespec)),
40*4882a593Smuzhiyun 	[WGPEER_A_RX_BYTES]				= { .type = NLA_U64 },
41*4882a593Smuzhiyun 	[WGPEER_A_TX_BYTES]				= { .type = NLA_U64 },
42*4882a593Smuzhiyun 	[WGPEER_A_ALLOWEDIPS]				= { .type = NLA_NESTED },
43*4882a593Smuzhiyun 	[WGPEER_A_PROTOCOL_VERSION]			= { .type = NLA_U32 }
44*4882a593Smuzhiyun };
45*4882a593Smuzhiyun 
46*4882a593Smuzhiyun static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = {
47*4882a593Smuzhiyun 	[WGALLOWEDIP_A_FAMILY]		= { .type = NLA_U16 },
48*4882a593Smuzhiyun 	[WGALLOWEDIP_A_IPADDR]		= NLA_POLICY_MIN_LEN(sizeof(struct in_addr)),
49*4882a593Smuzhiyun 	[WGALLOWEDIP_A_CIDR_MASK]	= { .type = NLA_U8 }
50*4882a593Smuzhiyun };
51*4882a593Smuzhiyun 
lookup_interface(struct nlattr ** attrs,struct sk_buff * skb)52*4882a593Smuzhiyun static struct wg_device *lookup_interface(struct nlattr **attrs,
53*4882a593Smuzhiyun 					  struct sk_buff *skb)
54*4882a593Smuzhiyun {
55*4882a593Smuzhiyun 	struct net_device *dev = NULL;
56*4882a593Smuzhiyun 
57*4882a593Smuzhiyun 	if (!attrs[WGDEVICE_A_IFINDEX] == !attrs[WGDEVICE_A_IFNAME])
58*4882a593Smuzhiyun 		return ERR_PTR(-EBADR);
59*4882a593Smuzhiyun 	if (attrs[WGDEVICE_A_IFINDEX])
60*4882a593Smuzhiyun 		dev = dev_get_by_index(sock_net(skb->sk),
61*4882a593Smuzhiyun 				       nla_get_u32(attrs[WGDEVICE_A_IFINDEX]));
62*4882a593Smuzhiyun 	else if (attrs[WGDEVICE_A_IFNAME])
63*4882a593Smuzhiyun 		dev = dev_get_by_name(sock_net(skb->sk),
64*4882a593Smuzhiyun 				      nla_data(attrs[WGDEVICE_A_IFNAME]));
65*4882a593Smuzhiyun 	if (!dev)
66*4882a593Smuzhiyun 		return ERR_PTR(-ENODEV);
67*4882a593Smuzhiyun 	if (!dev->rtnl_link_ops || !dev->rtnl_link_ops->kind ||
68*4882a593Smuzhiyun 	    strcmp(dev->rtnl_link_ops->kind, KBUILD_MODNAME)) {
69*4882a593Smuzhiyun 		dev_put(dev);
70*4882a593Smuzhiyun 		return ERR_PTR(-EOPNOTSUPP);
71*4882a593Smuzhiyun 	}
72*4882a593Smuzhiyun 	return netdev_priv(dev);
73*4882a593Smuzhiyun }
74*4882a593Smuzhiyun 
get_allowedips(struct sk_buff * skb,const u8 * ip,u8 cidr,int family)75*4882a593Smuzhiyun static int get_allowedips(struct sk_buff *skb, const u8 *ip, u8 cidr,
76*4882a593Smuzhiyun 			  int family)
77*4882a593Smuzhiyun {
78*4882a593Smuzhiyun 	struct nlattr *allowedip_nest;
79*4882a593Smuzhiyun 
80*4882a593Smuzhiyun 	allowedip_nest = nla_nest_start(skb, 0);
81*4882a593Smuzhiyun 	if (!allowedip_nest)
82*4882a593Smuzhiyun 		return -EMSGSIZE;
83*4882a593Smuzhiyun 
84*4882a593Smuzhiyun 	if (nla_put_u8(skb, WGALLOWEDIP_A_CIDR_MASK, cidr) ||
85*4882a593Smuzhiyun 	    nla_put_u16(skb, WGALLOWEDIP_A_FAMILY, family) ||
86*4882a593Smuzhiyun 	    nla_put(skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ?
87*4882a593Smuzhiyun 		    sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) {
88*4882a593Smuzhiyun 		nla_nest_cancel(skb, allowedip_nest);
89*4882a593Smuzhiyun 		return -EMSGSIZE;
90*4882a593Smuzhiyun 	}
91*4882a593Smuzhiyun 
92*4882a593Smuzhiyun 	nla_nest_end(skb, allowedip_nest);
93*4882a593Smuzhiyun 	return 0;
94*4882a593Smuzhiyun }
95*4882a593Smuzhiyun 
96*4882a593Smuzhiyun struct dump_ctx {
97*4882a593Smuzhiyun 	struct wg_device *wg;
98*4882a593Smuzhiyun 	struct wg_peer *next_peer;
99*4882a593Smuzhiyun 	u64 allowedips_seq;
100*4882a593Smuzhiyun 	struct allowedips_node *next_allowedip;
101*4882a593Smuzhiyun };
102*4882a593Smuzhiyun 
103*4882a593Smuzhiyun #define DUMP_CTX(cb) ((struct dump_ctx *)(cb)->args)
104*4882a593Smuzhiyun 
105*4882a593Smuzhiyun static int
get_peer(struct wg_peer * peer,struct sk_buff * skb,struct dump_ctx * ctx)106*4882a593Smuzhiyun get_peer(struct wg_peer *peer, struct sk_buff *skb, struct dump_ctx *ctx)
107*4882a593Smuzhiyun {
108*4882a593Smuzhiyun 
109*4882a593Smuzhiyun 	struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, 0);
110*4882a593Smuzhiyun 	struct allowedips_node *allowedips_node = ctx->next_allowedip;
111*4882a593Smuzhiyun 	bool fail;
112*4882a593Smuzhiyun 
113*4882a593Smuzhiyun 	if (!peer_nest)
114*4882a593Smuzhiyun 		return -EMSGSIZE;
115*4882a593Smuzhiyun 
116*4882a593Smuzhiyun 	down_read(&peer->handshake.lock);
117*4882a593Smuzhiyun 	fail = nla_put(skb, WGPEER_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN,
118*4882a593Smuzhiyun 		       peer->handshake.remote_static);
119*4882a593Smuzhiyun 	up_read(&peer->handshake.lock);
120*4882a593Smuzhiyun 	if (fail)
121*4882a593Smuzhiyun 		goto err;
122*4882a593Smuzhiyun 
123*4882a593Smuzhiyun 	if (!allowedips_node) {
124*4882a593Smuzhiyun 		const struct __kernel_timespec last_handshake = {
125*4882a593Smuzhiyun 			.tv_sec = peer->walltime_last_handshake.tv_sec,
126*4882a593Smuzhiyun 			.tv_nsec = peer->walltime_last_handshake.tv_nsec
127*4882a593Smuzhiyun 		};
128*4882a593Smuzhiyun 
129*4882a593Smuzhiyun 		down_read(&peer->handshake.lock);
130*4882a593Smuzhiyun 		fail = nla_put(skb, WGPEER_A_PRESHARED_KEY,
131*4882a593Smuzhiyun 			       NOISE_SYMMETRIC_KEY_LEN,
132*4882a593Smuzhiyun 			       peer->handshake.preshared_key);
133*4882a593Smuzhiyun 		up_read(&peer->handshake.lock);
134*4882a593Smuzhiyun 		if (fail)
135*4882a593Smuzhiyun 			goto err;
136*4882a593Smuzhiyun 
137*4882a593Smuzhiyun 		if (nla_put(skb, WGPEER_A_LAST_HANDSHAKE_TIME,
138*4882a593Smuzhiyun 			    sizeof(last_handshake), &last_handshake) ||
139*4882a593Smuzhiyun 		    nla_put_u16(skb, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL,
140*4882a593Smuzhiyun 				peer->persistent_keepalive_interval) ||
141*4882a593Smuzhiyun 		    nla_put_u64_64bit(skb, WGPEER_A_TX_BYTES, peer->tx_bytes,
142*4882a593Smuzhiyun 				      WGPEER_A_UNSPEC) ||
143*4882a593Smuzhiyun 		    nla_put_u64_64bit(skb, WGPEER_A_RX_BYTES, peer->rx_bytes,
144*4882a593Smuzhiyun 				      WGPEER_A_UNSPEC) ||
145*4882a593Smuzhiyun 		    nla_put_u32(skb, WGPEER_A_PROTOCOL_VERSION, 1))
146*4882a593Smuzhiyun 			goto err;
147*4882a593Smuzhiyun 
148*4882a593Smuzhiyun 		read_lock_bh(&peer->endpoint_lock);
149*4882a593Smuzhiyun 		if (peer->endpoint.addr.sa_family == AF_INET)
150*4882a593Smuzhiyun 			fail = nla_put(skb, WGPEER_A_ENDPOINT,
151*4882a593Smuzhiyun 				       sizeof(peer->endpoint.addr4),
152*4882a593Smuzhiyun 				       &peer->endpoint.addr4);
153*4882a593Smuzhiyun 		else if (peer->endpoint.addr.sa_family == AF_INET6)
154*4882a593Smuzhiyun 			fail = nla_put(skb, WGPEER_A_ENDPOINT,
155*4882a593Smuzhiyun 				       sizeof(peer->endpoint.addr6),
156*4882a593Smuzhiyun 				       &peer->endpoint.addr6);
157*4882a593Smuzhiyun 		read_unlock_bh(&peer->endpoint_lock);
158*4882a593Smuzhiyun 		if (fail)
159*4882a593Smuzhiyun 			goto err;
160*4882a593Smuzhiyun 		allowedips_node =
161*4882a593Smuzhiyun 			list_first_entry_or_null(&peer->allowedips_list,
162*4882a593Smuzhiyun 					struct allowedips_node, peer_list);
163*4882a593Smuzhiyun 	}
164*4882a593Smuzhiyun 	if (!allowedips_node)
165*4882a593Smuzhiyun 		goto no_allowedips;
166*4882a593Smuzhiyun 	if (!ctx->allowedips_seq)
167*4882a593Smuzhiyun 		ctx->allowedips_seq = peer->device->peer_allowedips.seq;
168*4882a593Smuzhiyun 	else if (ctx->allowedips_seq != peer->device->peer_allowedips.seq)
169*4882a593Smuzhiyun 		goto no_allowedips;
170*4882a593Smuzhiyun 
171*4882a593Smuzhiyun 	allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS);
172*4882a593Smuzhiyun 	if (!allowedips_nest)
173*4882a593Smuzhiyun 		goto err;
174*4882a593Smuzhiyun 
175*4882a593Smuzhiyun 	list_for_each_entry_from(allowedips_node, &peer->allowedips_list,
176*4882a593Smuzhiyun 				 peer_list) {
177*4882a593Smuzhiyun 		u8 cidr, ip[16] __aligned(__alignof(u64));
178*4882a593Smuzhiyun 		int family;
179*4882a593Smuzhiyun 
180*4882a593Smuzhiyun 		family = wg_allowedips_read_node(allowedips_node, ip, &cidr);
181*4882a593Smuzhiyun 		if (get_allowedips(skb, ip, cidr, family)) {
182*4882a593Smuzhiyun 			nla_nest_end(skb, allowedips_nest);
183*4882a593Smuzhiyun 			nla_nest_end(skb, peer_nest);
184*4882a593Smuzhiyun 			ctx->next_allowedip = allowedips_node;
185*4882a593Smuzhiyun 			return -EMSGSIZE;
186*4882a593Smuzhiyun 		}
187*4882a593Smuzhiyun 	}
188*4882a593Smuzhiyun 	nla_nest_end(skb, allowedips_nest);
189*4882a593Smuzhiyun no_allowedips:
190*4882a593Smuzhiyun 	nla_nest_end(skb, peer_nest);
191*4882a593Smuzhiyun 	ctx->next_allowedip = NULL;
192*4882a593Smuzhiyun 	ctx->allowedips_seq = 0;
193*4882a593Smuzhiyun 	return 0;
194*4882a593Smuzhiyun err:
195*4882a593Smuzhiyun 	nla_nest_cancel(skb, peer_nest);
196*4882a593Smuzhiyun 	return -EMSGSIZE;
197*4882a593Smuzhiyun }
198*4882a593Smuzhiyun 
wg_get_device_start(struct netlink_callback * cb)199*4882a593Smuzhiyun static int wg_get_device_start(struct netlink_callback *cb)
200*4882a593Smuzhiyun {
201*4882a593Smuzhiyun 	struct wg_device *wg;
202*4882a593Smuzhiyun 
203*4882a593Smuzhiyun 	wg = lookup_interface(genl_dumpit_info(cb)->attrs, cb->skb);
204*4882a593Smuzhiyun 	if (IS_ERR(wg))
205*4882a593Smuzhiyun 		return PTR_ERR(wg);
206*4882a593Smuzhiyun 	DUMP_CTX(cb)->wg = wg;
207*4882a593Smuzhiyun 	return 0;
208*4882a593Smuzhiyun }
209*4882a593Smuzhiyun 
wg_get_device_dump(struct sk_buff * skb,struct netlink_callback * cb)210*4882a593Smuzhiyun static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
211*4882a593Smuzhiyun {
212*4882a593Smuzhiyun 	struct wg_peer *peer, *next_peer_cursor;
213*4882a593Smuzhiyun 	struct dump_ctx *ctx = DUMP_CTX(cb);
214*4882a593Smuzhiyun 	struct wg_device *wg = ctx->wg;
215*4882a593Smuzhiyun 	struct nlattr *peers_nest;
216*4882a593Smuzhiyun 	int ret = -EMSGSIZE;
217*4882a593Smuzhiyun 	bool done = true;
218*4882a593Smuzhiyun 	void *hdr;
219*4882a593Smuzhiyun 
220*4882a593Smuzhiyun 	rtnl_lock();
221*4882a593Smuzhiyun 	mutex_lock(&wg->device_update_lock);
222*4882a593Smuzhiyun 	cb->seq = wg->device_update_gen;
223*4882a593Smuzhiyun 	next_peer_cursor = ctx->next_peer;
224*4882a593Smuzhiyun 
225*4882a593Smuzhiyun 	hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
226*4882a593Smuzhiyun 			  &genl_family, NLM_F_MULTI, WG_CMD_GET_DEVICE);
227*4882a593Smuzhiyun 	if (!hdr)
228*4882a593Smuzhiyun 		goto out;
229*4882a593Smuzhiyun 	genl_dump_check_consistent(cb, hdr);
230*4882a593Smuzhiyun 
231*4882a593Smuzhiyun 	if (!ctx->next_peer) {
232*4882a593Smuzhiyun 		if (nla_put_u16(skb, WGDEVICE_A_LISTEN_PORT,
233*4882a593Smuzhiyun 				wg->incoming_port) ||
234*4882a593Smuzhiyun 		    nla_put_u32(skb, WGDEVICE_A_FWMARK, wg->fwmark) ||
235*4882a593Smuzhiyun 		    nla_put_u32(skb, WGDEVICE_A_IFINDEX, wg->dev->ifindex) ||
236*4882a593Smuzhiyun 		    nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name))
237*4882a593Smuzhiyun 			goto out;
238*4882a593Smuzhiyun 
239*4882a593Smuzhiyun 		down_read(&wg->static_identity.lock);
240*4882a593Smuzhiyun 		if (wg->static_identity.has_identity) {
241*4882a593Smuzhiyun 			if (nla_put(skb, WGDEVICE_A_PRIVATE_KEY,
242*4882a593Smuzhiyun 				    NOISE_PUBLIC_KEY_LEN,
243*4882a593Smuzhiyun 				    wg->static_identity.static_private) ||
244*4882a593Smuzhiyun 			    nla_put(skb, WGDEVICE_A_PUBLIC_KEY,
245*4882a593Smuzhiyun 				    NOISE_PUBLIC_KEY_LEN,
246*4882a593Smuzhiyun 				    wg->static_identity.static_public)) {
247*4882a593Smuzhiyun 				up_read(&wg->static_identity.lock);
248*4882a593Smuzhiyun 				goto out;
249*4882a593Smuzhiyun 			}
250*4882a593Smuzhiyun 		}
251*4882a593Smuzhiyun 		up_read(&wg->static_identity.lock);
252*4882a593Smuzhiyun 	}
253*4882a593Smuzhiyun 
254*4882a593Smuzhiyun 	peers_nest = nla_nest_start(skb, WGDEVICE_A_PEERS);
255*4882a593Smuzhiyun 	if (!peers_nest)
256*4882a593Smuzhiyun 		goto out;
257*4882a593Smuzhiyun 	ret = 0;
258*4882a593Smuzhiyun 	/* If the last cursor was removed via list_del_init in peer_remove, then
259*4882a593Smuzhiyun 	 * we just treat this the same as there being no more peers left. The
260*4882a593Smuzhiyun 	 * reason is that seq_nr should indicate to userspace that this isn't a
261*4882a593Smuzhiyun 	 * coherent dump anyway, so they'll try again.
262*4882a593Smuzhiyun 	 */
263*4882a593Smuzhiyun 	if (list_empty(&wg->peer_list) ||
264*4882a593Smuzhiyun 	    (ctx->next_peer && list_empty(&ctx->next_peer->peer_list))) {
265*4882a593Smuzhiyun 		nla_nest_cancel(skb, peers_nest);
266*4882a593Smuzhiyun 		goto out;
267*4882a593Smuzhiyun 	}
268*4882a593Smuzhiyun 	lockdep_assert_held(&wg->device_update_lock);
269*4882a593Smuzhiyun 	peer = list_prepare_entry(ctx->next_peer, &wg->peer_list, peer_list);
270*4882a593Smuzhiyun 	list_for_each_entry_continue(peer, &wg->peer_list, peer_list) {
271*4882a593Smuzhiyun 		if (get_peer(peer, skb, ctx)) {
272*4882a593Smuzhiyun 			done = false;
273*4882a593Smuzhiyun 			break;
274*4882a593Smuzhiyun 		}
275*4882a593Smuzhiyun 		next_peer_cursor = peer;
276*4882a593Smuzhiyun 	}
277*4882a593Smuzhiyun 	nla_nest_end(skb, peers_nest);
278*4882a593Smuzhiyun 
279*4882a593Smuzhiyun out:
280*4882a593Smuzhiyun 	if (!ret && !done && next_peer_cursor)
281*4882a593Smuzhiyun 		wg_peer_get(next_peer_cursor);
282*4882a593Smuzhiyun 	wg_peer_put(ctx->next_peer);
283*4882a593Smuzhiyun 	mutex_unlock(&wg->device_update_lock);
284*4882a593Smuzhiyun 	rtnl_unlock();
285*4882a593Smuzhiyun 
286*4882a593Smuzhiyun 	if (ret) {
287*4882a593Smuzhiyun 		genlmsg_cancel(skb, hdr);
288*4882a593Smuzhiyun 		return ret;
289*4882a593Smuzhiyun 	}
290*4882a593Smuzhiyun 	genlmsg_end(skb, hdr);
291*4882a593Smuzhiyun 	if (done) {
292*4882a593Smuzhiyun 		ctx->next_peer = NULL;
293*4882a593Smuzhiyun 		return 0;
294*4882a593Smuzhiyun 	}
295*4882a593Smuzhiyun 	ctx->next_peer = next_peer_cursor;
296*4882a593Smuzhiyun 	return skb->len;
297*4882a593Smuzhiyun 
298*4882a593Smuzhiyun 	/* At this point, we can't really deal ourselves with safely zeroing out
299*4882a593Smuzhiyun 	 * the private key material after usage. This will need an additional API
300*4882a593Smuzhiyun 	 * in the kernel for marking skbs as zero_on_free.
301*4882a593Smuzhiyun 	 */
302*4882a593Smuzhiyun }
303*4882a593Smuzhiyun 
wg_get_device_done(struct netlink_callback * cb)304*4882a593Smuzhiyun static int wg_get_device_done(struct netlink_callback *cb)
305*4882a593Smuzhiyun {
306*4882a593Smuzhiyun 	struct dump_ctx *ctx = DUMP_CTX(cb);
307*4882a593Smuzhiyun 
308*4882a593Smuzhiyun 	if (ctx->wg)
309*4882a593Smuzhiyun 		dev_put(ctx->wg->dev);
310*4882a593Smuzhiyun 	wg_peer_put(ctx->next_peer);
311*4882a593Smuzhiyun 	return 0;
312*4882a593Smuzhiyun }
313*4882a593Smuzhiyun 
set_port(struct wg_device * wg,u16 port)314*4882a593Smuzhiyun static int set_port(struct wg_device *wg, u16 port)
315*4882a593Smuzhiyun {
316*4882a593Smuzhiyun 	struct wg_peer *peer;
317*4882a593Smuzhiyun 
318*4882a593Smuzhiyun 	if (wg->incoming_port == port)
319*4882a593Smuzhiyun 		return 0;
320*4882a593Smuzhiyun 	list_for_each_entry(peer, &wg->peer_list, peer_list)
321*4882a593Smuzhiyun 		wg_socket_clear_peer_endpoint_src(peer);
322*4882a593Smuzhiyun 	if (!netif_running(wg->dev)) {
323*4882a593Smuzhiyun 		wg->incoming_port = port;
324*4882a593Smuzhiyun 		return 0;
325*4882a593Smuzhiyun 	}
326*4882a593Smuzhiyun 	return wg_socket_init(wg, port);
327*4882a593Smuzhiyun }
328*4882a593Smuzhiyun 
set_allowedip(struct wg_peer * peer,struct nlattr ** attrs)329*4882a593Smuzhiyun static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs)
330*4882a593Smuzhiyun {
331*4882a593Smuzhiyun 	int ret = -EINVAL;
332*4882a593Smuzhiyun 	u16 family;
333*4882a593Smuzhiyun 	u8 cidr;
334*4882a593Smuzhiyun 
335*4882a593Smuzhiyun 	if (!attrs[WGALLOWEDIP_A_FAMILY] || !attrs[WGALLOWEDIP_A_IPADDR] ||
336*4882a593Smuzhiyun 	    !attrs[WGALLOWEDIP_A_CIDR_MASK])
337*4882a593Smuzhiyun 		return ret;
338*4882a593Smuzhiyun 	family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]);
339*4882a593Smuzhiyun 	cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]);
340*4882a593Smuzhiyun 
341*4882a593Smuzhiyun 	if (family == AF_INET && cidr <= 32 &&
342*4882a593Smuzhiyun 	    nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr))
343*4882a593Smuzhiyun 		ret = wg_allowedips_insert_v4(
344*4882a593Smuzhiyun 			&peer->device->peer_allowedips,
345*4882a593Smuzhiyun 			nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
346*4882a593Smuzhiyun 			&peer->device->device_update_lock);
347*4882a593Smuzhiyun 	else if (family == AF_INET6 && cidr <= 128 &&
348*4882a593Smuzhiyun 		 nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr))
349*4882a593Smuzhiyun 		ret = wg_allowedips_insert_v6(
350*4882a593Smuzhiyun 			&peer->device->peer_allowedips,
351*4882a593Smuzhiyun 			nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
352*4882a593Smuzhiyun 			&peer->device->device_update_lock);
353*4882a593Smuzhiyun 
354*4882a593Smuzhiyun 	return ret;
355*4882a593Smuzhiyun }
356*4882a593Smuzhiyun 
set_peer(struct wg_device * wg,struct nlattr ** attrs)357*4882a593Smuzhiyun static int set_peer(struct wg_device *wg, struct nlattr **attrs)
358*4882a593Smuzhiyun {
359*4882a593Smuzhiyun 	u8 *public_key = NULL, *preshared_key = NULL;
360*4882a593Smuzhiyun 	struct wg_peer *peer = NULL;
361*4882a593Smuzhiyun 	u32 flags = 0;
362*4882a593Smuzhiyun 	int ret;
363*4882a593Smuzhiyun 
364*4882a593Smuzhiyun 	ret = -EINVAL;
365*4882a593Smuzhiyun 	if (attrs[WGPEER_A_PUBLIC_KEY] &&
366*4882a593Smuzhiyun 	    nla_len(attrs[WGPEER_A_PUBLIC_KEY]) == NOISE_PUBLIC_KEY_LEN)
367*4882a593Smuzhiyun 		public_key = nla_data(attrs[WGPEER_A_PUBLIC_KEY]);
368*4882a593Smuzhiyun 	else
369*4882a593Smuzhiyun 		goto out;
370*4882a593Smuzhiyun 	if (attrs[WGPEER_A_PRESHARED_KEY] &&
371*4882a593Smuzhiyun 	    nla_len(attrs[WGPEER_A_PRESHARED_KEY]) == NOISE_SYMMETRIC_KEY_LEN)
372*4882a593Smuzhiyun 		preshared_key = nla_data(attrs[WGPEER_A_PRESHARED_KEY]);
373*4882a593Smuzhiyun 
374*4882a593Smuzhiyun 	if (attrs[WGPEER_A_FLAGS])
375*4882a593Smuzhiyun 		flags = nla_get_u32(attrs[WGPEER_A_FLAGS]);
376*4882a593Smuzhiyun 	ret = -EOPNOTSUPP;
377*4882a593Smuzhiyun 	if (flags & ~__WGPEER_F_ALL)
378*4882a593Smuzhiyun 		goto out;
379*4882a593Smuzhiyun 
380*4882a593Smuzhiyun 	ret = -EPFNOSUPPORT;
381*4882a593Smuzhiyun 	if (attrs[WGPEER_A_PROTOCOL_VERSION]) {
382*4882a593Smuzhiyun 		if (nla_get_u32(attrs[WGPEER_A_PROTOCOL_VERSION]) != 1)
383*4882a593Smuzhiyun 			goto out;
384*4882a593Smuzhiyun 	}
385*4882a593Smuzhiyun 
386*4882a593Smuzhiyun 	peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable,
387*4882a593Smuzhiyun 					  nla_data(attrs[WGPEER_A_PUBLIC_KEY]));
388*4882a593Smuzhiyun 	ret = 0;
389*4882a593Smuzhiyun 	if (!peer) { /* Peer doesn't exist yet. Add a new one. */
390*4882a593Smuzhiyun 		if (flags & (WGPEER_F_REMOVE_ME | WGPEER_F_UPDATE_ONLY))
391*4882a593Smuzhiyun 			goto out;
392*4882a593Smuzhiyun 
393*4882a593Smuzhiyun 		/* The peer is new, so there aren't allowed IPs to remove. */
394*4882a593Smuzhiyun 		flags &= ~WGPEER_F_REPLACE_ALLOWEDIPS;
395*4882a593Smuzhiyun 
396*4882a593Smuzhiyun 		down_read(&wg->static_identity.lock);
397*4882a593Smuzhiyun 		if (wg->static_identity.has_identity &&
398*4882a593Smuzhiyun 		    !memcmp(nla_data(attrs[WGPEER_A_PUBLIC_KEY]),
399*4882a593Smuzhiyun 			    wg->static_identity.static_public,
400*4882a593Smuzhiyun 			    NOISE_PUBLIC_KEY_LEN)) {
401*4882a593Smuzhiyun 			/* We silently ignore peers that have the same public
402*4882a593Smuzhiyun 			 * key as the device. The reason we do it silently is
403*4882a593Smuzhiyun 			 * that we'd like for people to be able to reuse the
404*4882a593Smuzhiyun 			 * same set of API calls across peers.
405*4882a593Smuzhiyun 			 */
406*4882a593Smuzhiyun 			up_read(&wg->static_identity.lock);
407*4882a593Smuzhiyun 			ret = 0;
408*4882a593Smuzhiyun 			goto out;
409*4882a593Smuzhiyun 		}
410*4882a593Smuzhiyun 		up_read(&wg->static_identity.lock);
411*4882a593Smuzhiyun 
412*4882a593Smuzhiyun 		peer = wg_peer_create(wg, public_key, preshared_key);
413*4882a593Smuzhiyun 		if (IS_ERR(peer)) {
414*4882a593Smuzhiyun 			ret = PTR_ERR(peer);
415*4882a593Smuzhiyun 			peer = NULL;
416*4882a593Smuzhiyun 			goto out;
417*4882a593Smuzhiyun 		}
418*4882a593Smuzhiyun 		/* Take additional reference, as though we've just been
419*4882a593Smuzhiyun 		 * looked up.
420*4882a593Smuzhiyun 		 */
421*4882a593Smuzhiyun 		wg_peer_get(peer);
422*4882a593Smuzhiyun 	}
423*4882a593Smuzhiyun 
424*4882a593Smuzhiyun 	if (flags & WGPEER_F_REMOVE_ME) {
425*4882a593Smuzhiyun 		wg_peer_remove(peer);
426*4882a593Smuzhiyun 		goto out;
427*4882a593Smuzhiyun 	}
428*4882a593Smuzhiyun 
429*4882a593Smuzhiyun 	if (preshared_key) {
430*4882a593Smuzhiyun 		down_write(&peer->handshake.lock);
431*4882a593Smuzhiyun 		memcpy(&peer->handshake.preshared_key, preshared_key,
432*4882a593Smuzhiyun 		       NOISE_SYMMETRIC_KEY_LEN);
433*4882a593Smuzhiyun 		up_write(&peer->handshake.lock);
434*4882a593Smuzhiyun 	}
435*4882a593Smuzhiyun 
436*4882a593Smuzhiyun 	if (attrs[WGPEER_A_ENDPOINT]) {
437*4882a593Smuzhiyun 		struct sockaddr *addr = nla_data(attrs[WGPEER_A_ENDPOINT]);
438*4882a593Smuzhiyun 		size_t len = nla_len(attrs[WGPEER_A_ENDPOINT]);
439*4882a593Smuzhiyun 		struct endpoint endpoint = { { { 0 } } };
440*4882a593Smuzhiyun 
441*4882a593Smuzhiyun 		if (len == sizeof(struct sockaddr_in) && addr->sa_family == AF_INET) {
442*4882a593Smuzhiyun 			endpoint.addr4 = *(struct sockaddr_in *)addr;
443*4882a593Smuzhiyun 			wg_socket_set_peer_endpoint(peer, &endpoint);
444*4882a593Smuzhiyun 		} else if (len == sizeof(struct sockaddr_in6) && addr->sa_family == AF_INET6) {
445*4882a593Smuzhiyun 			endpoint.addr6 = *(struct sockaddr_in6 *)addr;
446*4882a593Smuzhiyun 			wg_socket_set_peer_endpoint(peer, &endpoint);
447*4882a593Smuzhiyun 		}
448*4882a593Smuzhiyun 	}
449*4882a593Smuzhiyun 
450*4882a593Smuzhiyun 	if (flags & WGPEER_F_REPLACE_ALLOWEDIPS)
451*4882a593Smuzhiyun 		wg_allowedips_remove_by_peer(&wg->peer_allowedips, peer,
452*4882a593Smuzhiyun 					     &wg->device_update_lock);
453*4882a593Smuzhiyun 
454*4882a593Smuzhiyun 	if (attrs[WGPEER_A_ALLOWEDIPS]) {
455*4882a593Smuzhiyun 		struct nlattr *attr, *allowedip[WGALLOWEDIP_A_MAX + 1];
456*4882a593Smuzhiyun 		int rem;
457*4882a593Smuzhiyun 
458*4882a593Smuzhiyun 		nla_for_each_nested(attr, attrs[WGPEER_A_ALLOWEDIPS], rem) {
459*4882a593Smuzhiyun 			ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX,
460*4882a593Smuzhiyun 					       attr, allowedip_policy, NULL);
461*4882a593Smuzhiyun 			if (ret < 0)
462*4882a593Smuzhiyun 				goto out;
463*4882a593Smuzhiyun 			ret = set_allowedip(peer, allowedip);
464*4882a593Smuzhiyun 			if (ret < 0)
465*4882a593Smuzhiyun 				goto out;
466*4882a593Smuzhiyun 		}
467*4882a593Smuzhiyun 	}
468*4882a593Smuzhiyun 
469*4882a593Smuzhiyun 	if (attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]) {
470*4882a593Smuzhiyun 		const u16 persistent_keepalive_interval = nla_get_u16(
471*4882a593Smuzhiyun 				attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]);
472*4882a593Smuzhiyun 		const bool send_keepalive =
473*4882a593Smuzhiyun 			!peer->persistent_keepalive_interval &&
474*4882a593Smuzhiyun 			persistent_keepalive_interval &&
475*4882a593Smuzhiyun 			netif_running(wg->dev);
476*4882a593Smuzhiyun 
477*4882a593Smuzhiyun 		peer->persistent_keepalive_interval = persistent_keepalive_interval;
478*4882a593Smuzhiyun 		if (send_keepalive)
479*4882a593Smuzhiyun 			wg_packet_send_keepalive(peer);
480*4882a593Smuzhiyun 	}
481*4882a593Smuzhiyun 
482*4882a593Smuzhiyun 	if (netif_running(wg->dev))
483*4882a593Smuzhiyun 		wg_packet_send_staged_packets(peer);
484*4882a593Smuzhiyun 
485*4882a593Smuzhiyun out:
486*4882a593Smuzhiyun 	wg_peer_put(peer);
487*4882a593Smuzhiyun 	if (attrs[WGPEER_A_PRESHARED_KEY])
488*4882a593Smuzhiyun 		memzero_explicit(nla_data(attrs[WGPEER_A_PRESHARED_KEY]),
489*4882a593Smuzhiyun 				 nla_len(attrs[WGPEER_A_PRESHARED_KEY]));
490*4882a593Smuzhiyun 	return ret;
491*4882a593Smuzhiyun }
492*4882a593Smuzhiyun 
wg_set_device(struct sk_buff * skb,struct genl_info * info)493*4882a593Smuzhiyun static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
494*4882a593Smuzhiyun {
495*4882a593Smuzhiyun 	struct wg_device *wg = lookup_interface(info->attrs, skb);
496*4882a593Smuzhiyun 	u32 flags = 0;
497*4882a593Smuzhiyun 	int ret;
498*4882a593Smuzhiyun 
499*4882a593Smuzhiyun 	if (IS_ERR(wg)) {
500*4882a593Smuzhiyun 		ret = PTR_ERR(wg);
501*4882a593Smuzhiyun 		goto out_nodev;
502*4882a593Smuzhiyun 	}
503*4882a593Smuzhiyun 
504*4882a593Smuzhiyun 	rtnl_lock();
505*4882a593Smuzhiyun 	mutex_lock(&wg->device_update_lock);
506*4882a593Smuzhiyun 
507*4882a593Smuzhiyun 	if (info->attrs[WGDEVICE_A_FLAGS])
508*4882a593Smuzhiyun 		flags = nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]);
509*4882a593Smuzhiyun 	ret = -EOPNOTSUPP;
510*4882a593Smuzhiyun 	if (flags & ~__WGDEVICE_F_ALL)
511*4882a593Smuzhiyun 		goto out;
512*4882a593Smuzhiyun 
513*4882a593Smuzhiyun 	if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
514*4882a593Smuzhiyun 		struct net *net;
515*4882a593Smuzhiyun 		rcu_read_lock();
516*4882a593Smuzhiyun 		net = rcu_dereference(wg->creating_net);
517*4882a593Smuzhiyun 		ret = !net || !ns_capable(net->user_ns, CAP_NET_ADMIN) ? -EPERM : 0;
518*4882a593Smuzhiyun 		rcu_read_unlock();
519*4882a593Smuzhiyun 		if (ret)
520*4882a593Smuzhiyun 			goto out;
521*4882a593Smuzhiyun 	}
522*4882a593Smuzhiyun 
523*4882a593Smuzhiyun 	++wg->device_update_gen;
524*4882a593Smuzhiyun 
525*4882a593Smuzhiyun 	if (info->attrs[WGDEVICE_A_FWMARK]) {
526*4882a593Smuzhiyun 		struct wg_peer *peer;
527*4882a593Smuzhiyun 
528*4882a593Smuzhiyun 		wg->fwmark = nla_get_u32(info->attrs[WGDEVICE_A_FWMARK]);
529*4882a593Smuzhiyun 		list_for_each_entry(peer, &wg->peer_list, peer_list)
530*4882a593Smuzhiyun 			wg_socket_clear_peer_endpoint_src(peer);
531*4882a593Smuzhiyun 	}
532*4882a593Smuzhiyun 
533*4882a593Smuzhiyun 	if (info->attrs[WGDEVICE_A_LISTEN_PORT]) {
534*4882a593Smuzhiyun 		ret = set_port(wg,
535*4882a593Smuzhiyun 			nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT]));
536*4882a593Smuzhiyun 		if (ret)
537*4882a593Smuzhiyun 			goto out;
538*4882a593Smuzhiyun 	}
539*4882a593Smuzhiyun 
540*4882a593Smuzhiyun 	if (flags & WGDEVICE_F_REPLACE_PEERS)
541*4882a593Smuzhiyun 		wg_peer_remove_all(wg);
542*4882a593Smuzhiyun 
543*4882a593Smuzhiyun 	if (info->attrs[WGDEVICE_A_PRIVATE_KEY] &&
544*4882a593Smuzhiyun 	    nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]) ==
545*4882a593Smuzhiyun 		    NOISE_PUBLIC_KEY_LEN) {
546*4882a593Smuzhiyun 		u8 *private_key = nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]);
547*4882a593Smuzhiyun 		u8 public_key[NOISE_PUBLIC_KEY_LEN];
548*4882a593Smuzhiyun 		struct wg_peer *peer, *temp;
549*4882a593Smuzhiyun 
550*4882a593Smuzhiyun 		if (!crypto_memneq(wg->static_identity.static_private,
551*4882a593Smuzhiyun 				   private_key, NOISE_PUBLIC_KEY_LEN))
552*4882a593Smuzhiyun 			goto skip_set_private_key;
553*4882a593Smuzhiyun 
554*4882a593Smuzhiyun 		/* We remove before setting, to prevent race, which means doing
555*4882a593Smuzhiyun 		 * two 25519-genpub ops.
556*4882a593Smuzhiyun 		 */
557*4882a593Smuzhiyun 		if (curve25519_generate_public(public_key, private_key)) {
558*4882a593Smuzhiyun 			peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable,
559*4882a593Smuzhiyun 							  public_key);
560*4882a593Smuzhiyun 			if (peer) {
561*4882a593Smuzhiyun 				wg_peer_put(peer);
562*4882a593Smuzhiyun 				wg_peer_remove(peer);
563*4882a593Smuzhiyun 			}
564*4882a593Smuzhiyun 		}
565*4882a593Smuzhiyun 
566*4882a593Smuzhiyun 		down_write(&wg->static_identity.lock);
567*4882a593Smuzhiyun 		wg_noise_set_static_identity_private_key(&wg->static_identity,
568*4882a593Smuzhiyun 							 private_key);
569*4882a593Smuzhiyun 		list_for_each_entry_safe(peer, temp, &wg->peer_list,
570*4882a593Smuzhiyun 					 peer_list) {
571*4882a593Smuzhiyun 			wg_noise_precompute_static_static(peer);
572*4882a593Smuzhiyun 			wg_noise_expire_current_peer_keypairs(peer);
573*4882a593Smuzhiyun 		}
574*4882a593Smuzhiyun 		wg_cookie_checker_precompute_device_keys(&wg->cookie_checker);
575*4882a593Smuzhiyun 		up_write(&wg->static_identity.lock);
576*4882a593Smuzhiyun 	}
577*4882a593Smuzhiyun skip_set_private_key:
578*4882a593Smuzhiyun 
579*4882a593Smuzhiyun 	if (info->attrs[WGDEVICE_A_PEERS]) {
580*4882a593Smuzhiyun 		struct nlattr *attr, *peer[WGPEER_A_MAX + 1];
581*4882a593Smuzhiyun 		int rem;
582*4882a593Smuzhiyun 
583*4882a593Smuzhiyun 		nla_for_each_nested(attr, info->attrs[WGDEVICE_A_PEERS], rem) {
584*4882a593Smuzhiyun 			ret = nla_parse_nested(peer, WGPEER_A_MAX, attr,
585*4882a593Smuzhiyun 					       peer_policy, NULL);
586*4882a593Smuzhiyun 			if (ret < 0)
587*4882a593Smuzhiyun 				goto out;
588*4882a593Smuzhiyun 			ret = set_peer(wg, peer);
589*4882a593Smuzhiyun 			if (ret < 0)
590*4882a593Smuzhiyun 				goto out;
591*4882a593Smuzhiyun 		}
592*4882a593Smuzhiyun 	}
593*4882a593Smuzhiyun 	ret = 0;
594*4882a593Smuzhiyun 
595*4882a593Smuzhiyun out:
596*4882a593Smuzhiyun 	mutex_unlock(&wg->device_update_lock);
597*4882a593Smuzhiyun 	rtnl_unlock();
598*4882a593Smuzhiyun 	dev_put(wg->dev);
599*4882a593Smuzhiyun out_nodev:
600*4882a593Smuzhiyun 	if (info->attrs[WGDEVICE_A_PRIVATE_KEY])
601*4882a593Smuzhiyun 		memzero_explicit(nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]),
602*4882a593Smuzhiyun 				 nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]));
603*4882a593Smuzhiyun 	return ret;
604*4882a593Smuzhiyun }
605*4882a593Smuzhiyun 
606*4882a593Smuzhiyun static const struct genl_ops genl_ops[] = {
607*4882a593Smuzhiyun 	{
608*4882a593Smuzhiyun 		.cmd = WG_CMD_GET_DEVICE,
609*4882a593Smuzhiyun 		.start = wg_get_device_start,
610*4882a593Smuzhiyun 		.dumpit = wg_get_device_dump,
611*4882a593Smuzhiyun 		.done = wg_get_device_done,
612*4882a593Smuzhiyun 		.flags = GENL_UNS_ADMIN_PERM
613*4882a593Smuzhiyun 	}, {
614*4882a593Smuzhiyun 		.cmd = WG_CMD_SET_DEVICE,
615*4882a593Smuzhiyun 		.doit = wg_set_device,
616*4882a593Smuzhiyun 		.flags = GENL_UNS_ADMIN_PERM
617*4882a593Smuzhiyun 	}
618*4882a593Smuzhiyun };
619*4882a593Smuzhiyun 
620*4882a593Smuzhiyun static struct genl_family genl_family __ro_after_init = {
621*4882a593Smuzhiyun 	.ops = genl_ops,
622*4882a593Smuzhiyun 	.n_ops = ARRAY_SIZE(genl_ops),
623*4882a593Smuzhiyun 	.name = WG_GENL_NAME,
624*4882a593Smuzhiyun 	.version = WG_GENL_VERSION,
625*4882a593Smuzhiyun 	.maxattr = WGDEVICE_A_MAX,
626*4882a593Smuzhiyun 	.module = THIS_MODULE,
627*4882a593Smuzhiyun 	.policy = device_policy,
628*4882a593Smuzhiyun 	.netnsok = true
629*4882a593Smuzhiyun };
630*4882a593Smuzhiyun 
wg_genetlink_init(void)631*4882a593Smuzhiyun int __init wg_genetlink_init(void)
632*4882a593Smuzhiyun {
633*4882a593Smuzhiyun 	return genl_register_family(&genl_family);
634*4882a593Smuzhiyun }
635*4882a593Smuzhiyun 
wg_genetlink_uninit(void)636*4882a593Smuzhiyun void __exit wg_genetlink_uninit(void)
637*4882a593Smuzhiyun {
638*4882a593Smuzhiyun 	genl_unregister_family(&genl_family);
639*4882a593Smuzhiyun }
640