xref: /OK3568_Linux_fs/kernel/net/mptcp/token.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0
2*4882a593Smuzhiyun /* Multipath TCP token management
3*4882a593Smuzhiyun  * Copyright (c) 2017 - 2019, Intel Corporation.
4*4882a593Smuzhiyun  *
5*4882a593Smuzhiyun  * Note: This code is based on mptcp_ctrl.c from multipath-tcp.org,
6*4882a593Smuzhiyun  *       authored by:
7*4882a593Smuzhiyun  *
8*4882a593Smuzhiyun  *       Sébastien Barré <sebastien.barre@uclouvain.be>
9*4882a593Smuzhiyun  *       Christoph Paasch <christoph.paasch@uclouvain.be>
10*4882a593Smuzhiyun  *       Jaakko Korkeaniemi <jaakko.korkeaniemi@aalto.fi>
11*4882a593Smuzhiyun  *       Gregory Detal <gregory.detal@uclouvain.be>
12*4882a593Smuzhiyun  *       Fabien Duchêne <fabien.duchene@uclouvain.be>
13*4882a593Smuzhiyun  *       Andreas Seelinger <Andreas.Seelinger@rwth-aachen.de>
14*4882a593Smuzhiyun  *       Lavkesh Lahngir <lavkesh51@gmail.com>
15*4882a593Smuzhiyun  *       Andreas Ripke <ripke@neclab.eu>
16*4882a593Smuzhiyun  *       Vlad Dogaru <vlad.dogaru@intel.com>
17*4882a593Smuzhiyun  *       Octavian Purdila <octavian.purdila@intel.com>
18*4882a593Smuzhiyun  *       John Ronan <jronan@tssg.org>
19*4882a593Smuzhiyun  *       Catalin Nicutar <catalin.nicutar@gmail.com>
20*4882a593Smuzhiyun  *       Brandon Heller <brandonh@stanford.edu>
21*4882a593Smuzhiyun  */
22*4882a593Smuzhiyun 
23*4882a593Smuzhiyun #define pr_fmt(fmt) "MPTCP: " fmt
24*4882a593Smuzhiyun 
25*4882a593Smuzhiyun #include <linux/kernel.h>
26*4882a593Smuzhiyun #include <linux/module.h>
27*4882a593Smuzhiyun #include <linux/memblock.h>
28*4882a593Smuzhiyun #include <linux/ip.h>
29*4882a593Smuzhiyun #include <linux/tcp.h>
30*4882a593Smuzhiyun #include <net/sock.h>
31*4882a593Smuzhiyun #include <net/inet_common.h>
32*4882a593Smuzhiyun #include <net/protocol.h>
33*4882a593Smuzhiyun #include <net/mptcp.h>
34*4882a593Smuzhiyun #include "protocol.h"
35*4882a593Smuzhiyun 
36*4882a593Smuzhiyun #define TOKEN_MAX_RETRIES	4
37*4882a593Smuzhiyun #define TOKEN_MAX_CHAIN_LEN	4
38*4882a593Smuzhiyun 
39*4882a593Smuzhiyun struct token_bucket {
40*4882a593Smuzhiyun 	spinlock_t		lock;
41*4882a593Smuzhiyun 	int			chain_len;
42*4882a593Smuzhiyun 	struct hlist_nulls_head	req_chain;
43*4882a593Smuzhiyun 	struct hlist_nulls_head	msk_chain;
44*4882a593Smuzhiyun };
45*4882a593Smuzhiyun 
46*4882a593Smuzhiyun static struct token_bucket *token_hash __read_mostly;
47*4882a593Smuzhiyun static unsigned int token_mask __read_mostly;
48*4882a593Smuzhiyun 
token_bucket(u32 token)49*4882a593Smuzhiyun static struct token_bucket *token_bucket(u32 token)
50*4882a593Smuzhiyun {
51*4882a593Smuzhiyun 	return &token_hash[token & token_mask];
52*4882a593Smuzhiyun }
53*4882a593Smuzhiyun 
54*4882a593Smuzhiyun /* called with bucket lock held */
55*4882a593Smuzhiyun static struct mptcp_subflow_request_sock *
__token_lookup_req(struct token_bucket * t,u32 token)56*4882a593Smuzhiyun __token_lookup_req(struct token_bucket *t, u32 token)
57*4882a593Smuzhiyun {
58*4882a593Smuzhiyun 	struct mptcp_subflow_request_sock *req;
59*4882a593Smuzhiyun 	struct hlist_nulls_node *pos;
60*4882a593Smuzhiyun 
61*4882a593Smuzhiyun 	hlist_nulls_for_each_entry_rcu(req, pos, &t->req_chain, token_node)
62*4882a593Smuzhiyun 		if (req->token == token)
63*4882a593Smuzhiyun 			return req;
64*4882a593Smuzhiyun 	return NULL;
65*4882a593Smuzhiyun }
66*4882a593Smuzhiyun 
67*4882a593Smuzhiyun /* called with bucket lock held */
68*4882a593Smuzhiyun static struct mptcp_sock *
__token_lookup_msk(struct token_bucket * t,u32 token)69*4882a593Smuzhiyun __token_lookup_msk(struct token_bucket *t, u32 token)
70*4882a593Smuzhiyun {
71*4882a593Smuzhiyun 	struct hlist_nulls_node *pos;
72*4882a593Smuzhiyun 	struct sock *sk;
73*4882a593Smuzhiyun 
74*4882a593Smuzhiyun 	sk_nulls_for_each_rcu(sk, pos, &t->msk_chain)
75*4882a593Smuzhiyun 		if (mptcp_sk(sk)->token == token)
76*4882a593Smuzhiyun 			return mptcp_sk(sk);
77*4882a593Smuzhiyun 	return NULL;
78*4882a593Smuzhiyun }
79*4882a593Smuzhiyun 
__token_bucket_busy(struct token_bucket * t,u32 token)80*4882a593Smuzhiyun static bool __token_bucket_busy(struct token_bucket *t, u32 token)
81*4882a593Smuzhiyun {
82*4882a593Smuzhiyun 	return !token || t->chain_len >= TOKEN_MAX_CHAIN_LEN ||
83*4882a593Smuzhiyun 	       __token_lookup_req(t, token) || __token_lookup_msk(t, token);
84*4882a593Smuzhiyun }
85*4882a593Smuzhiyun 
mptcp_crypto_key_gen_sha(u64 * key,u32 * token,u64 * idsn)86*4882a593Smuzhiyun static void mptcp_crypto_key_gen_sha(u64 *key, u32 *token, u64 *idsn)
87*4882a593Smuzhiyun {
88*4882a593Smuzhiyun 	/* we might consider a faster version that computes the key as a
89*4882a593Smuzhiyun 	 * hash of some information available in the MPTCP socket. Use
90*4882a593Smuzhiyun 	 * random data at the moment, as it's probably the safest option
91*4882a593Smuzhiyun 	 * in case multiple sockets are opened in different namespaces at
92*4882a593Smuzhiyun 	 * the same time.
93*4882a593Smuzhiyun 	 */
94*4882a593Smuzhiyun 	get_random_bytes(key, sizeof(u64));
95*4882a593Smuzhiyun 	mptcp_crypto_key_sha(*key, token, idsn);
96*4882a593Smuzhiyun }
97*4882a593Smuzhiyun 
98*4882a593Smuzhiyun /**
99*4882a593Smuzhiyun  * mptcp_token_new_request - create new key/idsn/token for subflow_request
100*4882a593Smuzhiyun  * @req: the request socket
101*4882a593Smuzhiyun  *
102*4882a593Smuzhiyun  * This function is called when a new mptcp connection is coming in.
103*4882a593Smuzhiyun  *
104*4882a593Smuzhiyun  * It creates a unique token to identify the new mptcp connection,
105*4882a593Smuzhiyun  * a secret local key and the initial data sequence number (idsn).
106*4882a593Smuzhiyun  *
107*4882a593Smuzhiyun  * Returns 0 on success.
108*4882a593Smuzhiyun  */
mptcp_token_new_request(struct request_sock * req)109*4882a593Smuzhiyun int mptcp_token_new_request(struct request_sock *req)
110*4882a593Smuzhiyun {
111*4882a593Smuzhiyun 	struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
112*4882a593Smuzhiyun 	struct token_bucket *bucket;
113*4882a593Smuzhiyun 	u32 token;
114*4882a593Smuzhiyun 
115*4882a593Smuzhiyun 	mptcp_crypto_key_sha(subflow_req->local_key,
116*4882a593Smuzhiyun 			     &subflow_req->token,
117*4882a593Smuzhiyun 			     &subflow_req->idsn);
118*4882a593Smuzhiyun 	pr_debug("req=%p local_key=%llu, token=%u, idsn=%llu\n",
119*4882a593Smuzhiyun 		 req, subflow_req->local_key, subflow_req->token,
120*4882a593Smuzhiyun 		 subflow_req->idsn);
121*4882a593Smuzhiyun 
122*4882a593Smuzhiyun 	token = subflow_req->token;
123*4882a593Smuzhiyun 	bucket = token_bucket(token);
124*4882a593Smuzhiyun 	spin_lock_bh(&bucket->lock);
125*4882a593Smuzhiyun 	if (__token_bucket_busy(bucket, token)) {
126*4882a593Smuzhiyun 		spin_unlock_bh(&bucket->lock);
127*4882a593Smuzhiyun 		return -EBUSY;
128*4882a593Smuzhiyun 	}
129*4882a593Smuzhiyun 
130*4882a593Smuzhiyun 	hlist_nulls_add_head_rcu(&subflow_req->token_node, &bucket->req_chain);
131*4882a593Smuzhiyun 	bucket->chain_len++;
132*4882a593Smuzhiyun 	spin_unlock_bh(&bucket->lock);
133*4882a593Smuzhiyun 	return 0;
134*4882a593Smuzhiyun }
135*4882a593Smuzhiyun 
136*4882a593Smuzhiyun /**
137*4882a593Smuzhiyun  * mptcp_token_new_connect - create new key/idsn/token for subflow
138*4882a593Smuzhiyun  * @sk: the socket that will initiate a connection
139*4882a593Smuzhiyun  *
140*4882a593Smuzhiyun  * This function is called when a new outgoing mptcp connection is
141*4882a593Smuzhiyun  * initiated.
142*4882a593Smuzhiyun  *
143*4882a593Smuzhiyun  * It creates a unique token to identify the new mptcp connection,
144*4882a593Smuzhiyun  * a secret local key and the initial data sequence number (idsn).
145*4882a593Smuzhiyun  *
146*4882a593Smuzhiyun  * On success, the mptcp connection can be found again using
147*4882a593Smuzhiyun  * the computed token at a later time, this is needed to process
148*4882a593Smuzhiyun  * join requests.
149*4882a593Smuzhiyun  *
150*4882a593Smuzhiyun  * returns 0 on success.
151*4882a593Smuzhiyun  */
mptcp_token_new_connect(struct sock * sk)152*4882a593Smuzhiyun int mptcp_token_new_connect(struct sock *sk)
153*4882a593Smuzhiyun {
154*4882a593Smuzhiyun 	struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
155*4882a593Smuzhiyun 	struct mptcp_sock *msk = mptcp_sk(subflow->conn);
156*4882a593Smuzhiyun 	int retries = TOKEN_MAX_RETRIES;
157*4882a593Smuzhiyun 	struct token_bucket *bucket;
158*4882a593Smuzhiyun 
159*4882a593Smuzhiyun again:
160*4882a593Smuzhiyun 	mptcp_crypto_key_gen_sha(&subflow->local_key, &subflow->token,
161*4882a593Smuzhiyun 				 &subflow->idsn);
162*4882a593Smuzhiyun 
163*4882a593Smuzhiyun 	bucket = token_bucket(subflow->token);
164*4882a593Smuzhiyun 	spin_lock_bh(&bucket->lock);
165*4882a593Smuzhiyun 	if (__token_bucket_busy(bucket, subflow->token)) {
166*4882a593Smuzhiyun 		spin_unlock_bh(&bucket->lock);
167*4882a593Smuzhiyun 		if (!--retries)
168*4882a593Smuzhiyun 			return -EBUSY;
169*4882a593Smuzhiyun 		goto again;
170*4882a593Smuzhiyun 	}
171*4882a593Smuzhiyun 
172*4882a593Smuzhiyun 	pr_debug("ssk=%p, local_key=%llu, token=%u, idsn=%llu\n",
173*4882a593Smuzhiyun 		 sk, subflow->local_key, subflow->token, subflow->idsn);
174*4882a593Smuzhiyun 
175*4882a593Smuzhiyun 	WRITE_ONCE(msk->token, subflow->token);
176*4882a593Smuzhiyun 	__sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain);
177*4882a593Smuzhiyun 	bucket->chain_len++;
178*4882a593Smuzhiyun 	spin_unlock_bh(&bucket->lock);
179*4882a593Smuzhiyun 	return 0;
180*4882a593Smuzhiyun }
181*4882a593Smuzhiyun 
182*4882a593Smuzhiyun /**
183*4882a593Smuzhiyun  * mptcp_token_accept - replace a req sk with full sock in token hash
184*4882a593Smuzhiyun  * @req: the request socket to be removed
185*4882a593Smuzhiyun  * @msk: the just cloned socket linked to the new connection
186*4882a593Smuzhiyun  *
187*4882a593Smuzhiyun  * Called when a SYN packet creates a new logical connection, i.e.
188*4882a593Smuzhiyun  * is not a join request.
189*4882a593Smuzhiyun  */
mptcp_token_accept(struct mptcp_subflow_request_sock * req,struct mptcp_sock * msk)190*4882a593Smuzhiyun void mptcp_token_accept(struct mptcp_subflow_request_sock *req,
191*4882a593Smuzhiyun 			struct mptcp_sock *msk)
192*4882a593Smuzhiyun {
193*4882a593Smuzhiyun 	struct mptcp_subflow_request_sock *pos;
194*4882a593Smuzhiyun 	struct token_bucket *bucket;
195*4882a593Smuzhiyun 
196*4882a593Smuzhiyun 	bucket = token_bucket(req->token);
197*4882a593Smuzhiyun 	spin_lock_bh(&bucket->lock);
198*4882a593Smuzhiyun 
199*4882a593Smuzhiyun 	/* pedantic lookup check for the moved token */
200*4882a593Smuzhiyun 	pos = __token_lookup_req(bucket, req->token);
201*4882a593Smuzhiyun 	if (!WARN_ON_ONCE(pos != req))
202*4882a593Smuzhiyun 		hlist_nulls_del_init_rcu(&req->token_node);
203*4882a593Smuzhiyun 	__sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain);
204*4882a593Smuzhiyun 	spin_unlock_bh(&bucket->lock);
205*4882a593Smuzhiyun }
206*4882a593Smuzhiyun 
mptcp_token_exists(u32 token)207*4882a593Smuzhiyun bool mptcp_token_exists(u32 token)
208*4882a593Smuzhiyun {
209*4882a593Smuzhiyun 	struct hlist_nulls_node *pos;
210*4882a593Smuzhiyun 	struct token_bucket *bucket;
211*4882a593Smuzhiyun 	struct mptcp_sock *msk;
212*4882a593Smuzhiyun 	struct sock *sk;
213*4882a593Smuzhiyun 
214*4882a593Smuzhiyun 	rcu_read_lock();
215*4882a593Smuzhiyun 	bucket = token_bucket(token);
216*4882a593Smuzhiyun 
217*4882a593Smuzhiyun again:
218*4882a593Smuzhiyun 	sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
219*4882a593Smuzhiyun 		msk = mptcp_sk(sk);
220*4882a593Smuzhiyun 		if (READ_ONCE(msk->token) == token)
221*4882a593Smuzhiyun 			goto found;
222*4882a593Smuzhiyun 	}
223*4882a593Smuzhiyun 	if (get_nulls_value(pos) != (token & token_mask))
224*4882a593Smuzhiyun 		goto again;
225*4882a593Smuzhiyun 
226*4882a593Smuzhiyun 	rcu_read_unlock();
227*4882a593Smuzhiyun 	return false;
228*4882a593Smuzhiyun found:
229*4882a593Smuzhiyun 	rcu_read_unlock();
230*4882a593Smuzhiyun 	return true;
231*4882a593Smuzhiyun }
232*4882a593Smuzhiyun 
233*4882a593Smuzhiyun /**
234*4882a593Smuzhiyun  * mptcp_token_get_sock - retrieve mptcp connection sock using its token
235*4882a593Smuzhiyun  * @net: restrict to this namespace
236*4882a593Smuzhiyun  * @token: token of the mptcp connection to retrieve
237*4882a593Smuzhiyun  *
238*4882a593Smuzhiyun  * This function returns the mptcp connection structure with the given token.
239*4882a593Smuzhiyun  * A reference count on the mptcp socket returned is taken.
240*4882a593Smuzhiyun  *
241*4882a593Smuzhiyun  * returns NULL if no connection with the given token value exists.
242*4882a593Smuzhiyun  */
mptcp_token_get_sock(struct net * net,u32 token)243*4882a593Smuzhiyun struct mptcp_sock *mptcp_token_get_sock(struct net *net, u32 token)
244*4882a593Smuzhiyun {
245*4882a593Smuzhiyun 	struct hlist_nulls_node *pos;
246*4882a593Smuzhiyun 	struct token_bucket *bucket;
247*4882a593Smuzhiyun 	struct mptcp_sock *msk;
248*4882a593Smuzhiyun 	struct sock *sk;
249*4882a593Smuzhiyun 
250*4882a593Smuzhiyun 	rcu_read_lock();
251*4882a593Smuzhiyun 	bucket = token_bucket(token);
252*4882a593Smuzhiyun 
253*4882a593Smuzhiyun again:
254*4882a593Smuzhiyun 	sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
255*4882a593Smuzhiyun 		msk = mptcp_sk(sk);
256*4882a593Smuzhiyun 		if (READ_ONCE(msk->token) != token ||
257*4882a593Smuzhiyun 		    !net_eq(sock_net(sk), net))
258*4882a593Smuzhiyun 			continue;
259*4882a593Smuzhiyun 
260*4882a593Smuzhiyun 		if (!refcount_inc_not_zero(&sk->sk_refcnt))
261*4882a593Smuzhiyun 			goto not_found;
262*4882a593Smuzhiyun 
263*4882a593Smuzhiyun 		if (READ_ONCE(msk->token) != token ||
264*4882a593Smuzhiyun 		    !net_eq(sock_net(sk), net)) {
265*4882a593Smuzhiyun 			sock_put(sk);
266*4882a593Smuzhiyun 			goto again;
267*4882a593Smuzhiyun 		}
268*4882a593Smuzhiyun 		goto found;
269*4882a593Smuzhiyun 	}
270*4882a593Smuzhiyun 	if (get_nulls_value(pos) != (token & token_mask))
271*4882a593Smuzhiyun 		goto again;
272*4882a593Smuzhiyun 
273*4882a593Smuzhiyun not_found:
274*4882a593Smuzhiyun 	msk = NULL;
275*4882a593Smuzhiyun 
276*4882a593Smuzhiyun found:
277*4882a593Smuzhiyun 	rcu_read_unlock();
278*4882a593Smuzhiyun 	return msk;
279*4882a593Smuzhiyun }
280*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(mptcp_token_get_sock);
281*4882a593Smuzhiyun 
282*4882a593Smuzhiyun /**
283*4882a593Smuzhiyun  * mptcp_token_iter_next - iterate over the token container from given pos
284*4882a593Smuzhiyun  * @net: namespace to be iterated
285*4882a593Smuzhiyun  * @s_slot: start slot number
286*4882a593Smuzhiyun  * @s_num: start number inside the given lock
287*4882a593Smuzhiyun  *
288*4882a593Smuzhiyun  * This function returns the first mptcp connection structure found inside the
289*4882a593Smuzhiyun  * token container starting from the specified position, or NULL.
290*4882a593Smuzhiyun  *
291*4882a593Smuzhiyun  * On successful iteration, the iterator is move to the next position and the
292*4882a593Smuzhiyun  * the acquires a reference to the returned socket.
293*4882a593Smuzhiyun  */
mptcp_token_iter_next(const struct net * net,long * s_slot,long * s_num)294*4882a593Smuzhiyun struct mptcp_sock *mptcp_token_iter_next(const struct net *net, long *s_slot,
295*4882a593Smuzhiyun 					 long *s_num)
296*4882a593Smuzhiyun {
297*4882a593Smuzhiyun 	struct mptcp_sock *ret = NULL;
298*4882a593Smuzhiyun 	struct hlist_nulls_node *pos;
299*4882a593Smuzhiyun 	int slot, num = 0;
300*4882a593Smuzhiyun 
301*4882a593Smuzhiyun 	for (slot = *s_slot; slot <= token_mask; *s_num = 0, slot++) {
302*4882a593Smuzhiyun 		struct token_bucket *bucket = &token_hash[slot];
303*4882a593Smuzhiyun 		struct sock *sk;
304*4882a593Smuzhiyun 
305*4882a593Smuzhiyun 		num = 0;
306*4882a593Smuzhiyun 
307*4882a593Smuzhiyun 		if (hlist_nulls_empty(&bucket->msk_chain))
308*4882a593Smuzhiyun 			continue;
309*4882a593Smuzhiyun 
310*4882a593Smuzhiyun 		rcu_read_lock();
311*4882a593Smuzhiyun 		sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
312*4882a593Smuzhiyun 			++num;
313*4882a593Smuzhiyun 			if (!net_eq(sock_net(sk), net))
314*4882a593Smuzhiyun 				continue;
315*4882a593Smuzhiyun 
316*4882a593Smuzhiyun 			if (num <= *s_num)
317*4882a593Smuzhiyun 				continue;
318*4882a593Smuzhiyun 
319*4882a593Smuzhiyun 			if (!refcount_inc_not_zero(&sk->sk_refcnt))
320*4882a593Smuzhiyun 				continue;
321*4882a593Smuzhiyun 
322*4882a593Smuzhiyun 			if (!net_eq(sock_net(sk), net)) {
323*4882a593Smuzhiyun 				sock_put(sk);
324*4882a593Smuzhiyun 				continue;
325*4882a593Smuzhiyun 			}
326*4882a593Smuzhiyun 
327*4882a593Smuzhiyun 			ret = mptcp_sk(sk);
328*4882a593Smuzhiyun 			rcu_read_unlock();
329*4882a593Smuzhiyun 			goto out;
330*4882a593Smuzhiyun 		}
331*4882a593Smuzhiyun 		rcu_read_unlock();
332*4882a593Smuzhiyun 	}
333*4882a593Smuzhiyun 
334*4882a593Smuzhiyun out:
335*4882a593Smuzhiyun 	*s_slot = slot;
336*4882a593Smuzhiyun 	*s_num = num;
337*4882a593Smuzhiyun 	return ret;
338*4882a593Smuzhiyun }
339*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(mptcp_token_iter_next);
340*4882a593Smuzhiyun 
341*4882a593Smuzhiyun /**
342*4882a593Smuzhiyun  * mptcp_token_destroy_request - remove mptcp connection/token
343*4882a593Smuzhiyun  * @req: mptcp request socket dropping the token
344*4882a593Smuzhiyun  *
345*4882a593Smuzhiyun  * Remove the token associated to @req.
346*4882a593Smuzhiyun  */
mptcp_token_destroy_request(struct request_sock * req)347*4882a593Smuzhiyun void mptcp_token_destroy_request(struct request_sock *req)
348*4882a593Smuzhiyun {
349*4882a593Smuzhiyun 	struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
350*4882a593Smuzhiyun 	struct mptcp_subflow_request_sock *pos;
351*4882a593Smuzhiyun 	struct token_bucket *bucket;
352*4882a593Smuzhiyun 
353*4882a593Smuzhiyun 	if (hlist_nulls_unhashed(&subflow_req->token_node))
354*4882a593Smuzhiyun 		return;
355*4882a593Smuzhiyun 
356*4882a593Smuzhiyun 	bucket = token_bucket(subflow_req->token);
357*4882a593Smuzhiyun 	spin_lock_bh(&bucket->lock);
358*4882a593Smuzhiyun 	pos = __token_lookup_req(bucket, subflow_req->token);
359*4882a593Smuzhiyun 	if (!WARN_ON_ONCE(pos != subflow_req)) {
360*4882a593Smuzhiyun 		hlist_nulls_del_init_rcu(&pos->token_node);
361*4882a593Smuzhiyun 		bucket->chain_len--;
362*4882a593Smuzhiyun 	}
363*4882a593Smuzhiyun 	spin_unlock_bh(&bucket->lock);
364*4882a593Smuzhiyun }
365*4882a593Smuzhiyun 
366*4882a593Smuzhiyun /**
367*4882a593Smuzhiyun  * mptcp_token_destroy - remove mptcp connection/token
368*4882a593Smuzhiyun  * @msk: mptcp connection dropping the token
369*4882a593Smuzhiyun  *
370*4882a593Smuzhiyun  * Remove the token associated to @msk
371*4882a593Smuzhiyun  */
mptcp_token_destroy(struct mptcp_sock * msk)372*4882a593Smuzhiyun void mptcp_token_destroy(struct mptcp_sock *msk)
373*4882a593Smuzhiyun {
374*4882a593Smuzhiyun 	struct token_bucket *bucket;
375*4882a593Smuzhiyun 	struct mptcp_sock *pos;
376*4882a593Smuzhiyun 
377*4882a593Smuzhiyun 	if (sk_unhashed((struct sock *)msk))
378*4882a593Smuzhiyun 		return;
379*4882a593Smuzhiyun 
380*4882a593Smuzhiyun 	bucket = token_bucket(msk->token);
381*4882a593Smuzhiyun 	spin_lock_bh(&bucket->lock);
382*4882a593Smuzhiyun 	pos = __token_lookup_msk(bucket, msk->token);
383*4882a593Smuzhiyun 	if (!WARN_ON_ONCE(pos != msk)) {
384*4882a593Smuzhiyun 		__sk_nulls_del_node_init_rcu((struct sock *)pos);
385*4882a593Smuzhiyun 		bucket->chain_len--;
386*4882a593Smuzhiyun 	}
387*4882a593Smuzhiyun 	spin_unlock_bh(&bucket->lock);
388*4882a593Smuzhiyun }
389*4882a593Smuzhiyun 
mptcp_token_init(void)390*4882a593Smuzhiyun void __init mptcp_token_init(void)
391*4882a593Smuzhiyun {
392*4882a593Smuzhiyun 	int i;
393*4882a593Smuzhiyun 
394*4882a593Smuzhiyun 	token_hash = alloc_large_system_hash("MPTCP token",
395*4882a593Smuzhiyun 					     sizeof(struct token_bucket),
396*4882a593Smuzhiyun 					     0,
397*4882a593Smuzhiyun 					     20,/* one slot per 1MB of memory */
398*4882a593Smuzhiyun 					     HASH_ZERO,
399*4882a593Smuzhiyun 					     NULL,
400*4882a593Smuzhiyun 					     &token_mask,
401*4882a593Smuzhiyun 					     0,
402*4882a593Smuzhiyun 					     64 * 1024);
403*4882a593Smuzhiyun 	for (i = 0; i < token_mask + 1; ++i) {
404*4882a593Smuzhiyun 		INIT_HLIST_NULLS_HEAD(&token_hash[i].req_chain, i);
405*4882a593Smuzhiyun 		INIT_HLIST_NULLS_HEAD(&token_hash[i].msk_chain, i);
406*4882a593Smuzhiyun 		spin_lock_init(&token_hash[i].lock);
407*4882a593Smuzhiyun 	}
408*4882a593Smuzhiyun }
409*4882a593Smuzhiyun 
410*4882a593Smuzhiyun #if IS_MODULE(CONFIG_MPTCP_KUNIT_TESTS)
411*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(mptcp_token_new_request);
412*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(mptcp_token_new_connect);
413*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(mptcp_token_accept);
414*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(mptcp_token_destroy_request);
415*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(mptcp_token_destroy);
416*4882a593Smuzhiyun #endif
417