xref: /OK3568_Linux_fs/kernel/drivers/net/wireguard/ratelimiter.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0
2*4882a593Smuzhiyun /*
3*4882a593Smuzhiyun  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4*4882a593Smuzhiyun  */
5*4882a593Smuzhiyun 
6*4882a593Smuzhiyun #include "ratelimiter.h"
7*4882a593Smuzhiyun #include <linux/siphash.h>
8*4882a593Smuzhiyun #include <linux/mm.h>
9*4882a593Smuzhiyun #include <linux/slab.h>
10*4882a593Smuzhiyun #include <net/ip.h>
11*4882a593Smuzhiyun 
12*4882a593Smuzhiyun static struct kmem_cache *entry_cache;
13*4882a593Smuzhiyun static hsiphash_key_t key;
14*4882a593Smuzhiyun static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
15*4882a593Smuzhiyun static DEFINE_MUTEX(init_lock);
16*4882a593Smuzhiyun static u64 init_refcnt; /* Protected by init_lock, hence not atomic. */
17*4882a593Smuzhiyun static atomic_t total_entries = ATOMIC_INIT(0);
18*4882a593Smuzhiyun static unsigned int max_entries, table_size;
19*4882a593Smuzhiyun static void wg_ratelimiter_gc_entries(struct work_struct *);
20*4882a593Smuzhiyun static DECLARE_DEFERRABLE_WORK(gc_work, wg_ratelimiter_gc_entries);
21*4882a593Smuzhiyun static struct hlist_head *table_v4;
22*4882a593Smuzhiyun #if IS_ENABLED(CONFIG_IPV6)
23*4882a593Smuzhiyun static struct hlist_head *table_v6;
24*4882a593Smuzhiyun #endif
25*4882a593Smuzhiyun 
26*4882a593Smuzhiyun struct ratelimiter_entry {
27*4882a593Smuzhiyun 	u64 last_time_ns, tokens, ip;
28*4882a593Smuzhiyun 	void *net;
29*4882a593Smuzhiyun 	spinlock_t lock;
30*4882a593Smuzhiyun 	struct hlist_node hash;
31*4882a593Smuzhiyun 	struct rcu_head rcu;
32*4882a593Smuzhiyun };
33*4882a593Smuzhiyun 
34*4882a593Smuzhiyun enum {
35*4882a593Smuzhiyun 	PACKETS_PER_SECOND = 20,
36*4882a593Smuzhiyun 	PACKETS_BURSTABLE = 5,
37*4882a593Smuzhiyun 	PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND,
38*4882a593Smuzhiyun 	TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
39*4882a593Smuzhiyun };
40*4882a593Smuzhiyun 
entry_free(struct rcu_head * rcu)41*4882a593Smuzhiyun static void entry_free(struct rcu_head *rcu)
42*4882a593Smuzhiyun {
43*4882a593Smuzhiyun 	kmem_cache_free(entry_cache,
44*4882a593Smuzhiyun 			container_of(rcu, struct ratelimiter_entry, rcu));
45*4882a593Smuzhiyun 	atomic_dec(&total_entries);
46*4882a593Smuzhiyun }
47*4882a593Smuzhiyun 
entry_uninit(struct ratelimiter_entry * entry)48*4882a593Smuzhiyun static void entry_uninit(struct ratelimiter_entry *entry)
49*4882a593Smuzhiyun {
50*4882a593Smuzhiyun 	hlist_del_rcu(&entry->hash);
51*4882a593Smuzhiyun 	call_rcu(&entry->rcu, entry_free);
52*4882a593Smuzhiyun }
53*4882a593Smuzhiyun 
54*4882a593Smuzhiyun /* Calling this function with a NULL work uninits all entries. */
wg_ratelimiter_gc_entries(struct work_struct * work)55*4882a593Smuzhiyun static void wg_ratelimiter_gc_entries(struct work_struct *work)
56*4882a593Smuzhiyun {
57*4882a593Smuzhiyun 	const u64 now = ktime_get_coarse_boottime_ns();
58*4882a593Smuzhiyun 	struct ratelimiter_entry *entry;
59*4882a593Smuzhiyun 	struct hlist_node *temp;
60*4882a593Smuzhiyun 	unsigned int i;
61*4882a593Smuzhiyun 
62*4882a593Smuzhiyun 	for (i = 0; i < table_size; ++i) {
63*4882a593Smuzhiyun 		spin_lock(&table_lock);
64*4882a593Smuzhiyun 		hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) {
65*4882a593Smuzhiyun 			if (unlikely(!work) ||
66*4882a593Smuzhiyun 			    now - entry->last_time_ns > NSEC_PER_SEC)
67*4882a593Smuzhiyun 				entry_uninit(entry);
68*4882a593Smuzhiyun 		}
69*4882a593Smuzhiyun #if IS_ENABLED(CONFIG_IPV6)
70*4882a593Smuzhiyun 		hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) {
71*4882a593Smuzhiyun 			if (unlikely(!work) ||
72*4882a593Smuzhiyun 			    now - entry->last_time_ns > NSEC_PER_SEC)
73*4882a593Smuzhiyun 				entry_uninit(entry);
74*4882a593Smuzhiyun 		}
75*4882a593Smuzhiyun #endif
76*4882a593Smuzhiyun 		spin_unlock(&table_lock);
77*4882a593Smuzhiyun 		if (likely(work))
78*4882a593Smuzhiyun 			cond_resched();
79*4882a593Smuzhiyun 	}
80*4882a593Smuzhiyun 	if (likely(work))
81*4882a593Smuzhiyun 		queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
82*4882a593Smuzhiyun }
83*4882a593Smuzhiyun 
wg_ratelimiter_allow(struct sk_buff * skb,struct net * net)84*4882a593Smuzhiyun bool wg_ratelimiter_allow(struct sk_buff *skb, struct net *net)
85*4882a593Smuzhiyun {
86*4882a593Smuzhiyun 	/* We only take the bottom half of the net pointer, so that we can hash
87*4882a593Smuzhiyun 	 * 3 words in the end. This way, siphash's len param fits into the final
88*4882a593Smuzhiyun 	 * u32, and we don't incur an extra round.
89*4882a593Smuzhiyun 	 */
90*4882a593Smuzhiyun 	const u32 net_word = (unsigned long)net;
91*4882a593Smuzhiyun 	struct ratelimiter_entry *entry;
92*4882a593Smuzhiyun 	struct hlist_head *bucket;
93*4882a593Smuzhiyun 	u64 ip;
94*4882a593Smuzhiyun 
95*4882a593Smuzhiyun 	if (skb->protocol == htons(ETH_P_IP)) {
96*4882a593Smuzhiyun 		ip = (u64 __force)ip_hdr(skb)->saddr;
97*4882a593Smuzhiyun 		bucket = &table_v4[hsiphash_2u32(net_word, ip, &key) &
98*4882a593Smuzhiyun 				   (table_size - 1)];
99*4882a593Smuzhiyun 	}
100*4882a593Smuzhiyun #if IS_ENABLED(CONFIG_IPV6)
101*4882a593Smuzhiyun 	else if (skb->protocol == htons(ETH_P_IPV6)) {
102*4882a593Smuzhiyun 		/* Only use 64 bits, so as to ratelimit the whole /64. */
103*4882a593Smuzhiyun 		memcpy(&ip, &ipv6_hdr(skb)->saddr, sizeof(ip));
104*4882a593Smuzhiyun 		bucket = &table_v6[hsiphash_3u32(net_word, ip >> 32, ip, &key) &
105*4882a593Smuzhiyun 				   (table_size - 1)];
106*4882a593Smuzhiyun 	}
107*4882a593Smuzhiyun #endif
108*4882a593Smuzhiyun 	else
109*4882a593Smuzhiyun 		return false;
110*4882a593Smuzhiyun 	rcu_read_lock();
111*4882a593Smuzhiyun 	hlist_for_each_entry_rcu(entry, bucket, hash) {
112*4882a593Smuzhiyun 		if (entry->net == net && entry->ip == ip) {
113*4882a593Smuzhiyun 			u64 now, tokens;
114*4882a593Smuzhiyun 			bool ret;
115*4882a593Smuzhiyun 			/* Quasi-inspired by nft_limit.c, but this is actually a
116*4882a593Smuzhiyun 			 * slightly different algorithm. Namely, we incorporate
117*4882a593Smuzhiyun 			 * the burst as part of the maximum tokens, rather than
118*4882a593Smuzhiyun 			 * as part of the rate.
119*4882a593Smuzhiyun 			 */
120*4882a593Smuzhiyun 			spin_lock(&entry->lock);
121*4882a593Smuzhiyun 			now = ktime_get_coarse_boottime_ns();
122*4882a593Smuzhiyun 			tokens = min_t(u64, TOKEN_MAX,
123*4882a593Smuzhiyun 				       entry->tokens + now -
124*4882a593Smuzhiyun 					       entry->last_time_ns);
125*4882a593Smuzhiyun 			entry->last_time_ns = now;
126*4882a593Smuzhiyun 			ret = tokens >= PACKET_COST;
127*4882a593Smuzhiyun 			entry->tokens = ret ? tokens - PACKET_COST : tokens;
128*4882a593Smuzhiyun 			spin_unlock(&entry->lock);
129*4882a593Smuzhiyun 			rcu_read_unlock();
130*4882a593Smuzhiyun 			return ret;
131*4882a593Smuzhiyun 		}
132*4882a593Smuzhiyun 	}
133*4882a593Smuzhiyun 	rcu_read_unlock();
134*4882a593Smuzhiyun 
135*4882a593Smuzhiyun 	if (atomic_inc_return(&total_entries) > max_entries)
136*4882a593Smuzhiyun 		goto err_oom;
137*4882a593Smuzhiyun 
138*4882a593Smuzhiyun 	entry = kmem_cache_alloc(entry_cache, GFP_KERNEL);
139*4882a593Smuzhiyun 	if (unlikely(!entry))
140*4882a593Smuzhiyun 		goto err_oom;
141*4882a593Smuzhiyun 
142*4882a593Smuzhiyun 	entry->net = net;
143*4882a593Smuzhiyun 	entry->ip = ip;
144*4882a593Smuzhiyun 	INIT_HLIST_NODE(&entry->hash);
145*4882a593Smuzhiyun 	spin_lock_init(&entry->lock);
146*4882a593Smuzhiyun 	entry->last_time_ns = ktime_get_coarse_boottime_ns();
147*4882a593Smuzhiyun 	entry->tokens = TOKEN_MAX - PACKET_COST;
148*4882a593Smuzhiyun 	spin_lock(&table_lock);
149*4882a593Smuzhiyun 	hlist_add_head_rcu(&entry->hash, bucket);
150*4882a593Smuzhiyun 	spin_unlock(&table_lock);
151*4882a593Smuzhiyun 	return true;
152*4882a593Smuzhiyun 
153*4882a593Smuzhiyun err_oom:
154*4882a593Smuzhiyun 	atomic_dec(&total_entries);
155*4882a593Smuzhiyun 	return false;
156*4882a593Smuzhiyun }
157*4882a593Smuzhiyun 
wg_ratelimiter_init(void)158*4882a593Smuzhiyun int wg_ratelimiter_init(void)
159*4882a593Smuzhiyun {
160*4882a593Smuzhiyun 	mutex_lock(&init_lock);
161*4882a593Smuzhiyun 	if (++init_refcnt != 1)
162*4882a593Smuzhiyun 		goto out;
163*4882a593Smuzhiyun 
164*4882a593Smuzhiyun 	entry_cache = KMEM_CACHE(ratelimiter_entry, 0);
165*4882a593Smuzhiyun 	if (!entry_cache)
166*4882a593Smuzhiyun 		goto err;
167*4882a593Smuzhiyun 
168*4882a593Smuzhiyun 	/* xt_hashlimit.c uses a slightly different algorithm for ratelimiting,
169*4882a593Smuzhiyun 	 * but what it shares in common is that it uses a massive hashtable. So,
170*4882a593Smuzhiyun 	 * we borrow their wisdom about good table sizes on different systems
171*4882a593Smuzhiyun 	 * dependent on RAM. This calculation here comes from there.
172*4882a593Smuzhiyun 	 */
173*4882a593Smuzhiyun 	table_size = (totalram_pages() > (1U << 30) / PAGE_SIZE) ? 8192 :
174*4882a593Smuzhiyun 		max_t(unsigned long, 16, roundup_pow_of_two(
175*4882a593Smuzhiyun 			(totalram_pages() << PAGE_SHIFT) /
176*4882a593Smuzhiyun 			(1U << 14) / sizeof(struct hlist_head)));
177*4882a593Smuzhiyun 	max_entries = table_size * 8;
178*4882a593Smuzhiyun 
179*4882a593Smuzhiyun 	table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL);
180*4882a593Smuzhiyun 	if (unlikely(!table_v4))
181*4882a593Smuzhiyun 		goto err_kmemcache;
182*4882a593Smuzhiyun 
183*4882a593Smuzhiyun #if IS_ENABLED(CONFIG_IPV6)
184*4882a593Smuzhiyun 	table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL);
185*4882a593Smuzhiyun 	if (unlikely(!table_v6)) {
186*4882a593Smuzhiyun 		kvfree(table_v4);
187*4882a593Smuzhiyun 		goto err_kmemcache;
188*4882a593Smuzhiyun 	}
189*4882a593Smuzhiyun #endif
190*4882a593Smuzhiyun 
191*4882a593Smuzhiyun 	queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
192*4882a593Smuzhiyun 	get_random_bytes(&key, sizeof(key));
193*4882a593Smuzhiyun out:
194*4882a593Smuzhiyun 	mutex_unlock(&init_lock);
195*4882a593Smuzhiyun 	return 0;
196*4882a593Smuzhiyun 
197*4882a593Smuzhiyun err_kmemcache:
198*4882a593Smuzhiyun 	kmem_cache_destroy(entry_cache);
199*4882a593Smuzhiyun err:
200*4882a593Smuzhiyun 	--init_refcnt;
201*4882a593Smuzhiyun 	mutex_unlock(&init_lock);
202*4882a593Smuzhiyun 	return -ENOMEM;
203*4882a593Smuzhiyun }
204*4882a593Smuzhiyun 
wg_ratelimiter_uninit(void)205*4882a593Smuzhiyun void wg_ratelimiter_uninit(void)
206*4882a593Smuzhiyun {
207*4882a593Smuzhiyun 	mutex_lock(&init_lock);
208*4882a593Smuzhiyun 	if (!init_refcnt || --init_refcnt)
209*4882a593Smuzhiyun 		goto out;
210*4882a593Smuzhiyun 
211*4882a593Smuzhiyun 	cancel_delayed_work_sync(&gc_work);
212*4882a593Smuzhiyun 	wg_ratelimiter_gc_entries(NULL);
213*4882a593Smuzhiyun 	rcu_barrier();
214*4882a593Smuzhiyun 	kvfree(table_v4);
215*4882a593Smuzhiyun #if IS_ENABLED(CONFIG_IPV6)
216*4882a593Smuzhiyun 	kvfree(table_v6);
217*4882a593Smuzhiyun #endif
218*4882a593Smuzhiyun 	kmem_cache_destroy(entry_cache);
219*4882a593Smuzhiyun out:
220*4882a593Smuzhiyun 	mutex_unlock(&init_lock);
221*4882a593Smuzhiyun }
222*4882a593Smuzhiyun 
223*4882a593Smuzhiyun #include "selftest/ratelimiter.c"
224