1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0
2*4882a593Smuzhiyun /*
3*4882a593Smuzhiyun * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4*4882a593Smuzhiyun *
5*4882a593Smuzhiyun * This contains some basic static unit tests for the allowedips data structure.
6*4882a593Smuzhiyun * It also has two additional modes that are disabled and meant to be used by
7*4882a593Smuzhiyun * folks directly playing with this file. If you define the macro
8*4882a593Smuzhiyun * DEBUG_PRINT_TRIE_GRAPHVIZ to be 1, then every time there's a full tree in
9*4882a593Smuzhiyun * memory, it will be printed out as KERN_DEBUG in a format that can be passed
10*4882a593Smuzhiyun * to graphviz (the dot command) to visualize it. If you define the macro
11*4882a593Smuzhiyun * DEBUG_RANDOM_TRIE to be 1, then there will be an extremely costly set of
12*4882a593Smuzhiyun * randomized tests done against a trivial implementation, which may take
13*4882a593Smuzhiyun * upwards of a half-hour to complete. There's no set of users who should be
14*4882a593Smuzhiyun * enabling these, and the only developers that should go anywhere near these
15*4882a593Smuzhiyun * nobs are the ones who are reading this comment.
16*4882a593Smuzhiyun */
17*4882a593Smuzhiyun
18*4882a593Smuzhiyun #ifdef DEBUG
19*4882a593Smuzhiyun
20*4882a593Smuzhiyun #include <linux/siphash.h>
21*4882a593Smuzhiyun
print_node(struct allowedips_node * node,u8 bits)22*4882a593Smuzhiyun static __init void print_node(struct allowedips_node *node, u8 bits)
23*4882a593Smuzhiyun {
24*4882a593Smuzhiyun char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n";
25*4882a593Smuzhiyun char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
26*4882a593Smuzhiyun u8 ip1[16], ip2[16], cidr1, cidr2;
27*4882a593Smuzhiyun char *style = "dotted";
28*4882a593Smuzhiyun u32 color = 0;
29*4882a593Smuzhiyun
30*4882a593Smuzhiyun if (node == NULL)
31*4882a593Smuzhiyun return;
32*4882a593Smuzhiyun if (bits == 32) {
33*4882a593Smuzhiyun fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n";
34*4882a593Smuzhiyun fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
35*4882a593Smuzhiyun } else if (bits == 128) {
36*4882a593Smuzhiyun fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n";
37*4882a593Smuzhiyun fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
38*4882a593Smuzhiyun }
39*4882a593Smuzhiyun if (node->peer) {
40*4882a593Smuzhiyun hsiphash_key_t key = { { 0 } };
41*4882a593Smuzhiyun
42*4882a593Smuzhiyun memcpy(&key, &node->peer, sizeof(node->peer));
43*4882a593Smuzhiyun color = hsiphash_1u32(0xdeadbeef, &key) % 200 << 16 |
44*4882a593Smuzhiyun hsiphash_1u32(0xbabecafe, &key) % 200 << 8 |
45*4882a593Smuzhiyun hsiphash_1u32(0xabad1dea, &key) % 200;
46*4882a593Smuzhiyun style = "bold";
47*4882a593Smuzhiyun }
48*4882a593Smuzhiyun wg_allowedips_read_node(node, ip1, &cidr1);
49*4882a593Smuzhiyun printk(fmt_declaration, ip1, cidr1, style, color);
50*4882a593Smuzhiyun if (node->bit[0]) {
51*4882a593Smuzhiyun wg_allowedips_read_node(rcu_dereference_raw(node->bit[0]), ip2, &cidr2);
52*4882a593Smuzhiyun printk(fmt_connection, ip1, cidr1, ip2, cidr2);
53*4882a593Smuzhiyun }
54*4882a593Smuzhiyun if (node->bit[1]) {
55*4882a593Smuzhiyun wg_allowedips_read_node(rcu_dereference_raw(node->bit[1]), ip2, &cidr2);
56*4882a593Smuzhiyun printk(fmt_connection, ip1, cidr1, ip2, cidr2);
57*4882a593Smuzhiyun }
58*4882a593Smuzhiyun if (node->bit[0])
59*4882a593Smuzhiyun print_node(rcu_dereference_raw(node->bit[0]), bits);
60*4882a593Smuzhiyun if (node->bit[1])
61*4882a593Smuzhiyun print_node(rcu_dereference_raw(node->bit[1]), bits);
62*4882a593Smuzhiyun }
63*4882a593Smuzhiyun
print_tree(struct allowedips_node __rcu * top,u8 bits)64*4882a593Smuzhiyun static __init void print_tree(struct allowedips_node __rcu *top, u8 bits)
65*4882a593Smuzhiyun {
66*4882a593Smuzhiyun printk(KERN_DEBUG "digraph trie {\n");
67*4882a593Smuzhiyun print_node(rcu_dereference_raw(top), bits);
68*4882a593Smuzhiyun printk(KERN_DEBUG "}\n");
69*4882a593Smuzhiyun }
70*4882a593Smuzhiyun
71*4882a593Smuzhiyun enum {
72*4882a593Smuzhiyun NUM_PEERS = 2000,
73*4882a593Smuzhiyun NUM_RAND_ROUTES = 400,
74*4882a593Smuzhiyun NUM_MUTATED_ROUTES = 100,
75*4882a593Smuzhiyun NUM_QUERIES = NUM_RAND_ROUTES * NUM_MUTATED_ROUTES * 30
76*4882a593Smuzhiyun };
77*4882a593Smuzhiyun
78*4882a593Smuzhiyun struct horrible_allowedips {
79*4882a593Smuzhiyun struct hlist_head head;
80*4882a593Smuzhiyun };
81*4882a593Smuzhiyun
82*4882a593Smuzhiyun struct horrible_allowedips_node {
83*4882a593Smuzhiyun struct hlist_node table;
84*4882a593Smuzhiyun union nf_inet_addr ip;
85*4882a593Smuzhiyun union nf_inet_addr mask;
86*4882a593Smuzhiyun u8 ip_version;
87*4882a593Smuzhiyun void *value;
88*4882a593Smuzhiyun };
89*4882a593Smuzhiyun
horrible_allowedips_init(struct horrible_allowedips * table)90*4882a593Smuzhiyun static __init void horrible_allowedips_init(struct horrible_allowedips *table)
91*4882a593Smuzhiyun {
92*4882a593Smuzhiyun INIT_HLIST_HEAD(&table->head);
93*4882a593Smuzhiyun }
94*4882a593Smuzhiyun
horrible_allowedips_free(struct horrible_allowedips * table)95*4882a593Smuzhiyun static __init void horrible_allowedips_free(struct horrible_allowedips *table)
96*4882a593Smuzhiyun {
97*4882a593Smuzhiyun struct horrible_allowedips_node *node;
98*4882a593Smuzhiyun struct hlist_node *h;
99*4882a593Smuzhiyun
100*4882a593Smuzhiyun hlist_for_each_entry_safe(node, h, &table->head, table) {
101*4882a593Smuzhiyun hlist_del(&node->table);
102*4882a593Smuzhiyun kfree(node);
103*4882a593Smuzhiyun }
104*4882a593Smuzhiyun }
105*4882a593Smuzhiyun
horrible_cidr_to_mask(u8 cidr)106*4882a593Smuzhiyun static __init inline union nf_inet_addr horrible_cidr_to_mask(u8 cidr)
107*4882a593Smuzhiyun {
108*4882a593Smuzhiyun union nf_inet_addr mask;
109*4882a593Smuzhiyun
110*4882a593Smuzhiyun memset(&mask, 0, sizeof(mask));
111*4882a593Smuzhiyun memset(&mask.all, 0xff, cidr / 8);
112*4882a593Smuzhiyun if (cidr % 32)
113*4882a593Smuzhiyun mask.all[cidr / 32] = (__force u32)htonl(
114*4882a593Smuzhiyun (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL);
115*4882a593Smuzhiyun return mask;
116*4882a593Smuzhiyun }
117*4882a593Smuzhiyun
horrible_mask_to_cidr(union nf_inet_addr subnet)118*4882a593Smuzhiyun static __init inline u8 horrible_mask_to_cidr(union nf_inet_addr subnet)
119*4882a593Smuzhiyun {
120*4882a593Smuzhiyun return hweight32(subnet.all[0]) + hweight32(subnet.all[1]) +
121*4882a593Smuzhiyun hweight32(subnet.all[2]) + hweight32(subnet.all[3]);
122*4882a593Smuzhiyun }
123*4882a593Smuzhiyun
124*4882a593Smuzhiyun static __init inline void
horrible_mask_self(struct horrible_allowedips_node * node)125*4882a593Smuzhiyun horrible_mask_self(struct horrible_allowedips_node *node)
126*4882a593Smuzhiyun {
127*4882a593Smuzhiyun if (node->ip_version == 4) {
128*4882a593Smuzhiyun node->ip.ip &= node->mask.ip;
129*4882a593Smuzhiyun } else if (node->ip_version == 6) {
130*4882a593Smuzhiyun node->ip.ip6[0] &= node->mask.ip6[0];
131*4882a593Smuzhiyun node->ip.ip6[1] &= node->mask.ip6[1];
132*4882a593Smuzhiyun node->ip.ip6[2] &= node->mask.ip6[2];
133*4882a593Smuzhiyun node->ip.ip6[3] &= node->mask.ip6[3];
134*4882a593Smuzhiyun }
135*4882a593Smuzhiyun }
136*4882a593Smuzhiyun
137*4882a593Smuzhiyun static __init inline bool
horrible_match_v4(const struct horrible_allowedips_node * node,struct in_addr * ip)138*4882a593Smuzhiyun horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr *ip)
139*4882a593Smuzhiyun {
140*4882a593Smuzhiyun return (ip->s_addr & node->mask.ip) == node->ip.ip;
141*4882a593Smuzhiyun }
142*4882a593Smuzhiyun
143*4882a593Smuzhiyun static __init inline bool
horrible_match_v6(const struct horrible_allowedips_node * node,struct in6_addr * ip)144*4882a593Smuzhiyun horrible_match_v6(const struct horrible_allowedips_node *node, struct in6_addr *ip)
145*4882a593Smuzhiyun {
146*4882a593Smuzhiyun return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] &&
147*4882a593Smuzhiyun (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] &&
148*4882a593Smuzhiyun (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] &&
149*4882a593Smuzhiyun (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3];
150*4882a593Smuzhiyun }
151*4882a593Smuzhiyun
152*4882a593Smuzhiyun static __init void
horrible_insert_ordered(struct horrible_allowedips * table,struct horrible_allowedips_node * node)153*4882a593Smuzhiyun horrible_insert_ordered(struct horrible_allowedips *table, struct horrible_allowedips_node *node)
154*4882a593Smuzhiyun {
155*4882a593Smuzhiyun struct horrible_allowedips_node *other = NULL, *where = NULL;
156*4882a593Smuzhiyun u8 my_cidr = horrible_mask_to_cidr(node->mask);
157*4882a593Smuzhiyun
158*4882a593Smuzhiyun hlist_for_each_entry(other, &table->head, table) {
159*4882a593Smuzhiyun if (other->ip_version == node->ip_version &&
160*4882a593Smuzhiyun !memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) &&
161*4882a593Smuzhiyun !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr))) {
162*4882a593Smuzhiyun other->value = node->value;
163*4882a593Smuzhiyun kfree(node);
164*4882a593Smuzhiyun return;
165*4882a593Smuzhiyun }
166*4882a593Smuzhiyun }
167*4882a593Smuzhiyun hlist_for_each_entry(other, &table->head, table) {
168*4882a593Smuzhiyun where = other;
169*4882a593Smuzhiyun if (horrible_mask_to_cidr(other->mask) <= my_cidr)
170*4882a593Smuzhiyun break;
171*4882a593Smuzhiyun }
172*4882a593Smuzhiyun if (!other && !where)
173*4882a593Smuzhiyun hlist_add_head(&node->table, &table->head);
174*4882a593Smuzhiyun else if (!other)
175*4882a593Smuzhiyun hlist_add_behind(&node->table, &where->table);
176*4882a593Smuzhiyun else
177*4882a593Smuzhiyun hlist_add_before(&node->table, &where->table);
178*4882a593Smuzhiyun }
179*4882a593Smuzhiyun
180*4882a593Smuzhiyun static __init int
horrible_allowedips_insert_v4(struct horrible_allowedips * table,struct in_addr * ip,u8 cidr,void * value)181*4882a593Smuzhiyun horrible_allowedips_insert_v4(struct horrible_allowedips *table,
182*4882a593Smuzhiyun struct in_addr *ip, u8 cidr, void *value)
183*4882a593Smuzhiyun {
184*4882a593Smuzhiyun struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
185*4882a593Smuzhiyun
186*4882a593Smuzhiyun if (unlikely(!node))
187*4882a593Smuzhiyun return -ENOMEM;
188*4882a593Smuzhiyun node->ip.in = *ip;
189*4882a593Smuzhiyun node->mask = horrible_cidr_to_mask(cidr);
190*4882a593Smuzhiyun node->ip_version = 4;
191*4882a593Smuzhiyun node->value = value;
192*4882a593Smuzhiyun horrible_mask_self(node);
193*4882a593Smuzhiyun horrible_insert_ordered(table, node);
194*4882a593Smuzhiyun return 0;
195*4882a593Smuzhiyun }
196*4882a593Smuzhiyun
197*4882a593Smuzhiyun static __init int
horrible_allowedips_insert_v6(struct horrible_allowedips * table,struct in6_addr * ip,u8 cidr,void * value)198*4882a593Smuzhiyun horrible_allowedips_insert_v6(struct horrible_allowedips *table,
199*4882a593Smuzhiyun struct in6_addr *ip, u8 cidr, void *value)
200*4882a593Smuzhiyun {
201*4882a593Smuzhiyun struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
202*4882a593Smuzhiyun
203*4882a593Smuzhiyun if (unlikely(!node))
204*4882a593Smuzhiyun return -ENOMEM;
205*4882a593Smuzhiyun node->ip.in6 = *ip;
206*4882a593Smuzhiyun node->mask = horrible_cidr_to_mask(cidr);
207*4882a593Smuzhiyun node->ip_version = 6;
208*4882a593Smuzhiyun node->value = value;
209*4882a593Smuzhiyun horrible_mask_self(node);
210*4882a593Smuzhiyun horrible_insert_ordered(table, node);
211*4882a593Smuzhiyun return 0;
212*4882a593Smuzhiyun }
213*4882a593Smuzhiyun
214*4882a593Smuzhiyun static __init void *
horrible_allowedips_lookup_v4(struct horrible_allowedips * table,struct in_addr * ip)215*4882a593Smuzhiyun horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct in_addr *ip)
216*4882a593Smuzhiyun {
217*4882a593Smuzhiyun struct horrible_allowedips_node *node;
218*4882a593Smuzhiyun
219*4882a593Smuzhiyun hlist_for_each_entry(node, &table->head, table) {
220*4882a593Smuzhiyun if (node->ip_version == 4 && horrible_match_v4(node, ip))
221*4882a593Smuzhiyun return node->value;
222*4882a593Smuzhiyun }
223*4882a593Smuzhiyun return NULL;
224*4882a593Smuzhiyun }
225*4882a593Smuzhiyun
226*4882a593Smuzhiyun static __init void *
horrible_allowedips_lookup_v6(struct horrible_allowedips * table,struct in6_addr * ip)227*4882a593Smuzhiyun horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct in6_addr *ip)
228*4882a593Smuzhiyun {
229*4882a593Smuzhiyun struct horrible_allowedips_node *node;
230*4882a593Smuzhiyun
231*4882a593Smuzhiyun hlist_for_each_entry(node, &table->head, table) {
232*4882a593Smuzhiyun if (node->ip_version == 6 && horrible_match_v6(node, ip))
233*4882a593Smuzhiyun return node->value;
234*4882a593Smuzhiyun }
235*4882a593Smuzhiyun return NULL;
236*4882a593Smuzhiyun }
237*4882a593Smuzhiyun
238*4882a593Smuzhiyun
239*4882a593Smuzhiyun static __init void
horrible_allowedips_remove_by_value(struct horrible_allowedips * table,void * value)240*4882a593Smuzhiyun horrible_allowedips_remove_by_value(struct horrible_allowedips *table, void *value)
241*4882a593Smuzhiyun {
242*4882a593Smuzhiyun struct horrible_allowedips_node *node;
243*4882a593Smuzhiyun struct hlist_node *h;
244*4882a593Smuzhiyun
245*4882a593Smuzhiyun hlist_for_each_entry_safe(node, h, &table->head, table) {
246*4882a593Smuzhiyun if (node->value != value)
247*4882a593Smuzhiyun continue;
248*4882a593Smuzhiyun hlist_del(&node->table);
249*4882a593Smuzhiyun kfree(node);
250*4882a593Smuzhiyun }
251*4882a593Smuzhiyun
252*4882a593Smuzhiyun }
253*4882a593Smuzhiyun
randomized_test(void)254*4882a593Smuzhiyun static __init bool randomized_test(void)
255*4882a593Smuzhiyun {
256*4882a593Smuzhiyun unsigned int i, j, k, mutate_amount, cidr;
257*4882a593Smuzhiyun u8 ip[16], mutate_mask[16], mutated[16];
258*4882a593Smuzhiyun struct wg_peer **peers, *peer;
259*4882a593Smuzhiyun struct horrible_allowedips h;
260*4882a593Smuzhiyun DEFINE_MUTEX(mutex);
261*4882a593Smuzhiyun struct allowedips t;
262*4882a593Smuzhiyun bool ret = false;
263*4882a593Smuzhiyun
264*4882a593Smuzhiyun mutex_init(&mutex);
265*4882a593Smuzhiyun
266*4882a593Smuzhiyun wg_allowedips_init(&t);
267*4882a593Smuzhiyun horrible_allowedips_init(&h);
268*4882a593Smuzhiyun
269*4882a593Smuzhiyun peers = kcalloc(NUM_PEERS, sizeof(*peers), GFP_KERNEL);
270*4882a593Smuzhiyun if (unlikely(!peers)) {
271*4882a593Smuzhiyun pr_err("allowedips random self-test malloc: FAIL\n");
272*4882a593Smuzhiyun goto free;
273*4882a593Smuzhiyun }
274*4882a593Smuzhiyun for (i = 0; i < NUM_PEERS; ++i) {
275*4882a593Smuzhiyun peers[i] = kzalloc(sizeof(*peers[i]), GFP_KERNEL);
276*4882a593Smuzhiyun if (unlikely(!peers[i])) {
277*4882a593Smuzhiyun pr_err("allowedips random self-test malloc: FAIL\n");
278*4882a593Smuzhiyun goto free;
279*4882a593Smuzhiyun }
280*4882a593Smuzhiyun kref_init(&peers[i]->refcount);
281*4882a593Smuzhiyun INIT_LIST_HEAD(&peers[i]->allowedips_list);
282*4882a593Smuzhiyun }
283*4882a593Smuzhiyun
284*4882a593Smuzhiyun mutex_lock(&mutex);
285*4882a593Smuzhiyun
286*4882a593Smuzhiyun for (i = 0; i < NUM_RAND_ROUTES; ++i) {
287*4882a593Smuzhiyun prandom_bytes(ip, 4);
288*4882a593Smuzhiyun cidr = prandom_u32_max(32) + 1;
289*4882a593Smuzhiyun peer = peers[prandom_u32_max(NUM_PEERS)];
290*4882a593Smuzhiyun if (wg_allowedips_insert_v4(&t, (struct in_addr *)ip, cidr,
291*4882a593Smuzhiyun peer, &mutex) < 0) {
292*4882a593Smuzhiyun pr_err("allowedips random self-test malloc: FAIL\n");
293*4882a593Smuzhiyun goto free_locked;
294*4882a593Smuzhiyun }
295*4882a593Smuzhiyun if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip,
296*4882a593Smuzhiyun cidr, peer) < 0) {
297*4882a593Smuzhiyun pr_err("allowedips random self-test malloc: FAIL\n");
298*4882a593Smuzhiyun goto free_locked;
299*4882a593Smuzhiyun }
300*4882a593Smuzhiyun for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
301*4882a593Smuzhiyun memcpy(mutated, ip, 4);
302*4882a593Smuzhiyun prandom_bytes(mutate_mask, 4);
303*4882a593Smuzhiyun mutate_amount = prandom_u32_max(32);
304*4882a593Smuzhiyun for (k = 0; k < mutate_amount / 8; ++k)
305*4882a593Smuzhiyun mutate_mask[k] = 0xff;
306*4882a593Smuzhiyun mutate_mask[k] = 0xff
307*4882a593Smuzhiyun << ((8 - (mutate_amount % 8)) % 8);
308*4882a593Smuzhiyun for (; k < 4; ++k)
309*4882a593Smuzhiyun mutate_mask[k] = 0;
310*4882a593Smuzhiyun for (k = 0; k < 4; ++k)
311*4882a593Smuzhiyun mutated[k] = (mutated[k] & mutate_mask[k]) |
312*4882a593Smuzhiyun (~mutate_mask[k] &
313*4882a593Smuzhiyun prandom_u32_max(256));
314*4882a593Smuzhiyun cidr = prandom_u32_max(32) + 1;
315*4882a593Smuzhiyun peer = peers[prandom_u32_max(NUM_PEERS)];
316*4882a593Smuzhiyun if (wg_allowedips_insert_v4(&t,
317*4882a593Smuzhiyun (struct in_addr *)mutated,
318*4882a593Smuzhiyun cidr, peer, &mutex) < 0) {
319*4882a593Smuzhiyun pr_err("allowedips random self-test malloc: FAIL\n");
320*4882a593Smuzhiyun goto free_locked;
321*4882a593Smuzhiyun }
322*4882a593Smuzhiyun if (horrible_allowedips_insert_v4(&h,
323*4882a593Smuzhiyun (struct in_addr *)mutated, cidr, peer)) {
324*4882a593Smuzhiyun pr_err("allowedips random self-test malloc: FAIL\n");
325*4882a593Smuzhiyun goto free_locked;
326*4882a593Smuzhiyun }
327*4882a593Smuzhiyun }
328*4882a593Smuzhiyun }
329*4882a593Smuzhiyun
330*4882a593Smuzhiyun for (i = 0; i < NUM_RAND_ROUTES; ++i) {
331*4882a593Smuzhiyun prandom_bytes(ip, 16);
332*4882a593Smuzhiyun cidr = prandom_u32_max(128) + 1;
333*4882a593Smuzhiyun peer = peers[prandom_u32_max(NUM_PEERS)];
334*4882a593Smuzhiyun if (wg_allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr,
335*4882a593Smuzhiyun peer, &mutex) < 0) {
336*4882a593Smuzhiyun pr_err("allowedips random self-test malloc: FAIL\n");
337*4882a593Smuzhiyun goto free_locked;
338*4882a593Smuzhiyun }
339*4882a593Smuzhiyun if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip,
340*4882a593Smuzhiyun cidr, peer) < 0) {
341*4882a593Smuzhiyun pr_err("allowedips random self-test malloc: FAIL\n");
342*4882a593Smuzhiyun goto free_locked;
343*4882a593Smuzhiyun }
344*4882a593Smuzhiyun for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
345*4882a593Smuzhiyun memcpy(mutated, ip, 16);
346*4882a593Smuzhiyun prandom_bytes(mutate_mask, 16);
347*4882a593Smuzhiyun mutate_amount = prandom_u32_max(128);
348*4882a593Smuzhiyun for (k = 0; k < mutate_amount / 8; ++k)
349*4882a593Smuzhiyun mutate_mask[k] = 0xff;
350*4882a593Smuzhiyun mutate_mask[k] = 0xff
351*4882a593Smuzhiyun << ((8 - (mutate_amount % 8)) % 8);
352*4882a593Smuzhiyun for (; k < 4; ++k)
353*4882a593Smuzhiyun mutate_mask[k] = 0;
354*4882a593Smuzhiyun for (k = 0; k < 4; ++k)
355*4882a593Smuzhiyun mutated[k] = (mutated[k] & mutate_mask[k]) |
356*4882a593Smuzhiyun (~mutate_mask[k] &
357*4882a593Smuzhiyun prandom_u32_max(256));
358*4882a593Smuzhiyun cidr = prandom_u32_max(128) + 1;
359*4882a593Smuzhiyun peer = peers[prandom_u32_max(NUM_PEERS)];
360*4882a593Smuzhiyun if (wg_allowedips_insert_v6(&t,
361*4882a593Smuzhiyun (struct in6_addr *)mutated,
362*4882a593Smuzhiyun cidr, peer, &mutex) < 0) {
363*4882a593Smuzhiyun pr_err("allowedips random self-test malloc: FAIL\n");
364*4882a593Smuzhiyun goto free_locked;
365*4882a593Smuzhiyun }
366*4882a593Smuzhiyun if (horrible_allowedips_insert_v6(
367*4882a593Smuzhiyun &h, (struct in6_addr *)mutated, cidr,
368*4882a593Smuzhiyun peer)) {
369*4882a593Smuzhiyun pr_err("allowedips random self-test malloc: FAIL\n");
370*4882a593Smuzhiyun goto free_locked;
371*4882a593Smuzhiyun }
372*4882a593Smuzhiyun }
373*4882a593Smuzhiyun }
374*4882a593Smuzhiyun
375*4882a593Smuzhiyun mutex_unlock(&mutex);
376*4882a593Smuzhiyun
377*4882a593Smuzhiyun if (IS_ENABLED(DEBUG_PRINT_TRIE_GRAPHVIZ)) {
378*4882a593Smuzhiyun print_tree(t.root4, 32);
379*4882a593Smuzhiyun print_tree(t.root6, 128);
380*4882a593Smuzhiyun }
381*4882a593Smuzhiyun
382*4882a593Smuzhiyun for (j = 0;; ++j) {
383*4882a593Smuzhiyun for (i = 0; i < NUM_QUERIES; ++i) {
384*4882a593Smuzhiyun prandom_bytes(ip, 4);
385*4882a593Smuzhiyun if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
386*4882a593Smuzhiyun horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip);
387*4882a593Smuzhiyun pr_err("allowedips random v4 self-test: FAIL\n");
388*4882a593Smuzhiyun goto free;
389*4882a593Smuzhiyun }
390*4882a593Smuzhiyun prandom_bytes(ip, 16);
391*4882a593Smuzhiyun if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
392*4882a593Smuzhiyun pr_err("allowedips random v6 self-test: FAIL\n");
393*4882a593Smuzhiyun goto free;
394*4882a593Smuzhiyun }
395*4882a593Smuzhiyun }
396*4882a593Smuzhiyun if (j >= NUM_PEERS)
397*4882a593Smuzhiyun break;
398*4882a593Smuzhiyun mutex_lock(&mutex);
399*4882a593Smuzhiyun wg_allowedips_remove_by_peer(&t, peers[j], &mutex);
400*4882a593Smuzhiyun mutex_unlock(&mutex);
401*4882a593Smuzhiyun horrible_allowedips_remove_by_value(&h, peers[j]);
402*4882a593Smuzhiyun }
403*4882a593Smuzhiyun
404*4882a593Smuzhiyun if (t.root4 || t.root6) {
405*4882a593Smuzhiyun pr_err("allowedips random self-test removal: FAIL\n");
406*4882a593Smuzhiyun goto free;
407*4882a593Smuzhiyun }
408*4882a593Smuzhiyun
409*4882a593Smuzhiyun ret = true;
410*4882a593Smuzhiyun
411*4882a593Smuzhiyun free:
412*4882a593Smuzhiyun mutex_lock(&mutex);
413*4882a593Smuzhiyun free_locked:
414*4882a593Smuzhiyun wg_allowedips_free(&t, &mutex);
415*4882a593Smuzhiyun mutex_unlock(&mutex);
416*4882a593Smuzhiyun horrible_allowedips_free(&h);
417*4882a593Smuzhiyun if (peers) {
418*4882a593Smuzhiyun for (i = 0; i < NUM_PEERS; ++i)
419*4882a593Smuzhiyun kfree(peers[i]);
420*4882a593Smuzhiyun }
421*4882a593Smuzhiyun kfree(peers);
422*4882a593Smuzhiyun return ret;
423*4882a593Smuzhiyun }
424*4882a593Smuzhiyun
ip4(u8 a,u8 b,u8 c,u8 d)425*4882a593Smuzhiyun static __init inline struct in_addr *ip4(u8 a, u8 b, u8 c, u8 d)
426*4882a593Smuzhiyun {
427*4882a593Smuzhiyun static struct in_addr ip;
428*4882a593Smuzhiyun u8 *split = (u8 *)&ip;
429*4882a593Smuzhiyun
430*4882a593Smuzhiyun split[0] = a;
431*4882a593Smuzhiyun split[1] = b;
432*4882a593Smuzhiyun split[2] = c;
433*4882a593Smuzhiyun split[3] = d;
434*4882a593Smuzhiyun return &ip;
435*4882a593Smuzhiyun }
436*4882a593Smuzhiyun
ip6(u32 a,u32 b,u32 c,u32 d)437*4882a593Smuzhiyun static __init inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d)
438*4882a593Smuzhiyun {
439*4882a593Smuzhiyun static struct in6_addr ip;
440*4882a593Smuzhiyun __be32 *split = (__be32 *)&ip;
441*4882a593Smuzhiyun
442*4882a593Smuzhiyun split[0] = cpu_to_be32(a);
443*4882a593Smuzhiyun split[1] = cpu_to_be32(b);
444*4882a593Smuzhiyun split[2] = cpu_to_be32(c);
445*4882a593Smuzhiyun split[3] = cpu_to_be32(d);
446*4882a593Smuzhiyun return &ip;
447*4882a593Smuzhiyun }
448*4882a593Smuzhiyun
init_peer(void)449*4882a593Smuzhiyun static __init struct wg_peer *init_peer(void)
450*4882a593Smuzhiyun {
451*4882a593Smuzhiyun struct wg_peer *peer = kzalloc(sizeof(*peer), GFP_KERNEL);
452*4882a593Smuzhiyun
453*4882a593Smuzhiyun if (!peer)
454*4882a593Smuzhiyun return NULL;
455*4882a593Smuzhiyun kref_init(&peer->refcount);
456*4882a593Smuzhiyun INIT_LIST_HEAD(&peer->allowedips_list);
457*4882a593Smuzhiyun return peer;
458*4882a593Smuzhiyun }
459*4882a593Smuzhiyun
460*4882a593Smuzhiyun #define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \
461*4882a593Smuzhiyun wg_allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \
462*4882a593Smuzhiyun cidr, mem, &mutex)
463*4882a593Smuzhiyun
464*4882a593Smuzhiyun #define maybe_fail() do { \
465*4882a593Smuzhiyun ++i; \
466*4882a593Smuzhiyun if (!_s) { \
467*4882a593Smuzhiyun pr_info("allowedips self-test %zu: FAIL\n", i); \
468*4882a593Smuzhiyun success = false; \
469*4882a593Smuzhiyun } \
470*4882a593Smuzhiyun } while (0)
471*4882a593Smuzhiyun
472*4882a593Smuzhiyun #define test(version, mem, ipa, ipb, ipc, ipd) do { \
473*4882a593Smuzhiyun bool _s = lookup(t.root##version, (version) == 4 ? 32 : 128, \
474*4882a593Smuzhiyun ip##version(ipa, ipb, ipc, ipd)) == (mem); \
475*4882a593Smuzhiyun maybe_fail(); \
476*4882a593Smuzhiyun } while (0)
477*4882a593Smuzhiyun
478*4882a593Smuzhiyun #define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \
479*4882a593Smuzhiyun bool _s = lookup(t.root##version, (version) == 4 ? 32 : 128, \
480*4882a593Smuzhiyun ip##version(ipa, ipb, ipc, ipd)) != (mem); \
481*4882a593Smuzhiyun maybe_fail(); \
482*4882a593Smuzhiyun } while (0)
483*4882a593Smuzhiyun
484*4882a593Smuzhiyun #define test_boolean(cond) do { \
485*4882a593Smuzhiyun bool _s = (cond); \
486*4882a593Smuzhiyun maybe_fail(); \
487*4882a593Smuzhiyun } while (0)
488*4882a593Smuzhiyun
wg_allowedips_selftest(void)489*4882a593Smuzhiyun bool __init wg_allowedips_selftest(void)
490*4882a593Smuzhiyun {
491*4882a593Smuzhiyun bool found_a = false, found_b = false, found_c = false, found_d = false,
492*4882a593Smuzhiyun found_e = false, found_other = false;
493*4882a593Smuzhiyun struct wg_peer *a = init_peer(), *b = init_peer(), *c = init_peer(),
494*4882a593Smuzhiyun *d = init_peer(), *e = init_peer(), *f = init_peer(),
495*4882a593Smuzhiyun *g = init_peer(), *h = init_peer();
496*4882a593Smuzhiyun struct allowedips_node *iter_node;
497*4882a593Smuzhiyun bool success = false;
498*4882a593Smuzhiyun struct allowedips t;
499*4882a593Smuzhiyun DEFINE_MUTEX(mutex);
500*4882a593Smuzhiyun struct in6_addr ip;
501*4882a593Smuzhiyun size_t i = 0, count = 0;
502*4882a593Smuzhiyun __be64 part;
503*4882a593Smuzhiyun
504*4882a593Smuzhiyun mutex_init(&mutex);
505*4882a593Smuzhiyun mutex_lock(&mutex);
506*4882a593Smuzhiyun wg_allowedips_init(&t);
507*4882a593Smuzhiyun
508*4882a593Smuzhiyun if (!a || !b || !c || !d || !e || !f || !g || !h) {
509*4882a593Smuzhiyun pr_err("allowedips self-test malloc: FAIL\n");
510*4882a593Smuzhiyun goto free;
511*4882a593Smuzhiyun }
512*4882a593Smuzhiyun
513*4882a593Smuzhiyun insert(4, a, 192, 168, 4, 0, 24);
514*4882a593Smuzhiyun insert(4, b, 192, 168, 4, 4, 32);
515*4882a593Smuzhiyun insert(4, c, 192, 168, 0, 0, 16);
516*4882a593Smuzhiyun insert(4, d, 192, 95, 5, 64, 27);
517*4882a593Smuzhiyun /* replaces previous entry, and maskself is required */
518*4882a593Smuzhiyun insert(4, c, 192, 95, 5, 65, 27);
519*4882a593Smuzhiyun insert(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
520*4882a593Smuzhiyun insert(6, c, 0x26075300, 0x60006b00, 0, 0, 64);
521*4882a593Smuzhiyun insert(4, e, 0, 0, 0, 0, 0);
522*4882a593Smuzhiyun insert(6, e, 0, 0, 0, 0, 0);
523*4882a593Smuzhiyun /* replaces previous entry */
524*4882a593Smuzhiyun insert(6, f, 0, 0, 0, 0, 0);
525*4882a593Smuzhiyun insert(6, g, 0x24046800, 0, 0, 0, 32);
526*4882a593Smuzhiyun /* maskself is required */
527*4882a593Smuzhiyun insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64);
528*4882a593Smuzhiyun insert(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128);
529*4882a593Smuzhiyun insert(6, c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128);
530*4882a593Smuzhiyun insert(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
531*4882a593Smuzhiyun insert(4, g, 64, 15, 112, 0, 20);
532*4882a593Smuzhiyun /* maskself is required */
533*4882a593Smuzhiyun insert(4, h, 64, 15, 123, 211, 25);
534*4882a593Smuzhiyun insert(4, a, 10, 0, 0, 0, 25);
535*4882a593Smuzhiyun insert(4, b, 10, 0, 0, 128, 25);
536*4882a593Smuzhiyun insert(4, a, 10, 1, 0, 0, 30);
537*4882a593Smuzhiyun insert(4, b, 10, 1, 0, 4, 30);
538*4882a593Smuzhiyun insert(4, c, 10, 1, 0, 8, 29);
539*4882a593Smuzhiyun insert(4, d, 10, 1, 0, 16, 29);
540*4882a593Smuzhiyun
541*4882a593Smuzhiyun if (IS_ENABLED(DEBUG_PRINT_TRIE_GRAPHVIZ)) {
542*4882a593Smuzhiyun print_tree(t.root4, 32);
543*4882a593Smuzhiyun print_tree(t.root6, 128);
544*4882a593Smuzhiyun }
545*4882a593Smuzhiyun
546*4882a593Smuzhiyun success = true;
547*4882a593Smuzhiyun
548*4882a593Smuzhiyun test(4, a, 192, 168, 4, 20);
549*4882a593Smuzhiyun test(4, a, 192, 168, 4, 0);
550*4882a593Smuzhiyun test(4, b, 192, 168, 4, 4);
551*4882a593Smuzhiyun test(4, c, 192, 168, 200, 182);
552*4882a593Smuzhiyun test(4, c, 192, 95, 5, 68);
553*4882a593Smuzhiyun test(4, e, 192, 95, 5, 96);
554*4882a593Smuzhiyun test(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543);
555*4882a593Smuzhiyun test(6, c, 0x26075300, 0x60006b00, 0, 0xc02e01ee);
556*4882a593Smuzhiyun test(6, f, 0x26075300, 0x60006b01, 0, 0);
557*4882a593Smuzhiyun test(6, g, 0x24046800, 0x40040806, 0, 0x1006);
558*4882a593Smuzhiyun test(6, g, 0x24046800, 0x40040806, 0x1234, 0x5678);
559*4882a593Smuzhiyun test(6, f, 0x240467ff, 0x40040806, 0x1234, 0x5678);
560*4882a593Smuzhiyun test(6, f, 0x24046801, 0x40040806, 0x1234, 0x5678);
561*4882a593Smuzhiyun test(6, h, 0x24046800, 0x40040800, 0x1234, 0x5678);
562*4882a593Smuzhiyun test(6, h, 0x24046800, 0x40040800, 0, 0);
563*4882a593Smuzhiyun test(6, h, 0x24046800, 0x40040800, 0x10101010, 0x10101010);
564*4882a593Smuzhiyun test(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef);
565*4882a593Smuzhiyun test(4, g, 64, 15, 116, 26);
566*4882a593Smuzhiyun test(4, g, 64, 15, 127, 3);
567*4882a593Smuzhiyun test(4, g, 64, 15, 123, 1);
568*4882a593Smuzhiyun test(4, h, 64, 15, 123, 128);
569*4882a593Smuzhiyun test(4, h, 64, 15, 123, 129);
570*4882a593Smuzhiyun test(4, a, 10, 0, 0, 52);
571*4882a593Smuzhiyun test(4, b, 10, 0, 0, 220);
572*4882a593Smuzhiyun test(4, a, 10, 1, 0, 2);
573*4882a593Smuzhiyun test(4, b, 10, 1, 0, 6);
574*4882a593Smuzhiyun test(4, c, 10, 1, 0, 10);
575*4882a593Smuzhiyun test(4, d, 10, 1, 0, 20);
576*4882a593Smuzhiyun
577*4882a593Smuzhiyun insert(4, a, 1, 0, 0, 0, 32);
578*4882a593Smuzhiyun insert(4, a, 64, 0, 0, 0, 32);
579*4882a593Smuzhiyun insert(4, a, 128, 0, 0, 0, 32);
580*4882a593Smuzhiyun insert(4, a, 192, 0, 0, 0, 32);
581*4882a593Smuzhiyun insert(4, a, 255, 0, 0, 0, 32);
582*4882a593Smuzhiyun wg_allowedips_remove_by_peer(&t, a, &mutex);
583*4882a593Smuzhiyun test_negative(4, a, 1, 0, 0, 0);
584*4882a593Smuzhiyun test_negative(4, a, 64, 0, 0, 0);
585*4882a593Smuzhiyun test_negative(4, a, 128, 0, 0, 0);
586*4882a593Smuzhiyun test_negative(4, a, 192, 0, 0, 0);
587*4882a593Smuzhiyun test_negative(4, a, 255, 0, 0, 0);
588*4882a593Smuzhiyun
589*4882a593Smuzhiyun wg_allowedips_free(&t, &mutex);
590*4882a593Smuzhiyun wg_allowedips_init(&t);
591*4882a593Smuzhiyun insert(4, a, 192, 168, 0, 0, 16);
592*4882a593Smuzhiyun insert(4, a, 192, 168, 0, 0, 24);
593*4882a593Smuzhiyun wg_allowedips_remove_by_peer(&t, a, &mutex);
594*4882a593Smuzhiyun test_negative(4, a, 192, 168, 0, 1);
595*4882a593Smuzhiyun
596*4882a593Smuzhiyun /* These will hit the WARN_ON(len >= MAX_ALLOWEDIPS_BITS) in free_node
597*4882a593Smuzhiyun * if something goes wrong.
598*4882a593Smuzhiyun */
599*4882a593Smuzhiyun for (i = 0; i < MAX_ALLOWEDIPS_BITS; ++i) {
600*4882a593Smuzhiyun part = cpu_to_be64(~(1LLU << (i % 64)));
601*4882a593Smuzhiyun memset(&ip, 0xff, 16);
602*4882a593Smuzhiyun memcpy((u8 *)&ip + (i < 64) * 8, &part, 8);
603*4882a593Smuzhiyun wg_allowedips_insert_v6(&t, &ip, 128, a, &mutex);
604*4882a593Smuzhiyun }
605*4882a593Smuzhiyun
606*4882a593Smuzhiyun wg_allowedips_free(&t, &mutex);
607*4882a593Smuzhiyun
608*4882a593Smuzhiyun wg_allowedips_init(&t);
609*4882a593Smuzhiyun insert(4, a, 192, 95, 5, 93, 27);
610*4882a593Smuzhiyun insert(6, a, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
611*4882a593Smuzhiyun insert(4, a, 10, 1, 0, 20, 29);
612*4882a593Smuzhiyun insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 83);
613*4882a593Smuzhiyun insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 21);
614*4882a593Smuzhiyun list_for_each_entry(iter_node, &a->allowedips_list, peer_list) {
615*4882a593Smuzhiyun u8 cidr, ip[16] __aligned(__alignof(u64));
616*4882a593Smuzhiyun int family = wg_allowedips_read_node(iter_node, ip, &cidr);
617*4882a593Smuzhiyun
618*4882a593Smuzhiyun count++;
619*4882a593Smuzhiyun
620*4882a593Smuzhiyun if (cidr == 27 && family == AF_INET &&
621*4882a593Smuzhiyun !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr)))
622*4882a593Smuzhiyun found_a = true;
623*4882a593Smuzhiyun else if (cidr == 128 && family == AF_INET6 &&
624*4882a593Smuzhiyun !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543),
625*4882a593Smuzhiyun sizeof(struct in6_addr)))
626*4882a593Smuzhiyun found_b = true;
627*4882a593Smuzhiyun else if (cidr == 29 && family == AF_INET &&
628*4882a593Smuzhiyun !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr)))
629*4882a593Smuzhiyun found_c = true;
630*4882a593Smuzhiyun else if (cidr == 83 && family == AF_INET6 &&
631*4882a593Smuzhiyun !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0),
632*4882a593Smuzhiyun sizeof(struct in6_addr)))
633*4882a593Smuzhiyun found_d = true;
634*4882a593Smuzhiyun else if (cidr == 21 && family == AF_INET6 &&
635*4882a593Smuzhiyun !memcmp(ip, ip6(0x26075000, 0, 0, 0),
636*4882a593Smuzhiyun sizeof(struct in6_addr)))
637*4882a593Smuzhiyun found_e = true;
638*4882a593Smuzhiyun else
639*4882a593Smuzhiyun found_other = true;
640*4882a593Smuzhiyun }
641*4882a593Smuzhiyun test_boolean(count == 5);
642*4882a593Smuzhiyun test_boolean(found_a);
643*4882a593Smuzhiyun test_boolean(found_b);
644*4882a593Smuzhiyun test_boolean(found_c);
645*4882a593Smuzhiyun test_boolean(found_d);
646*4882a593Smuzhiyun test_boolean(found_e);
647*4882a593Smuzhiyun test_boolean(!found_other);
648*4882a593Smuzhiyun
649*4882a593Smuzhiyun if (IS_ENABLED(DEBUG_RANDOM_TRIE) && success)
650*4882a593Smuzhiyun success = randomized_test();
651*4882a593Smuzhiyun
652*4882a593Smuzhiyun if (success)
653*4882a593Smuzhiyun pr_info("allowedips self-tests: pass\n");
654*4882a593Smuzhiyun
655*4882a593Smuzhiyun free:
656*4882a593Smuzhiyun wg_allowedips_free(&t, &mutex);
657*4882a593Smuzhiyun kfree(a);
658*4882a593Smuzhiyun kfree(b);
659*4882a593Smuzhiyun kfree(c);
660*4882a593Smuzhiyun kfree(d);
661*4882a593Smuzhiyun kfree(e);
662*4882a593Smuzhiyun kfree(f);
663*4882a593Smuzhiyun kfree(g);
664*4882a593Smuzhiyun kfree(h);
665*4882a593Smuzhiyun mutex_unlock(&mutex);
666*4882a593Smuzhiyun
667*4882a593Smuzhiyun return success;
668*4882a593Smuzhiyun }
669*4882a593Smuzhiyun
670*4882a593Smuzhiyun #undef test_negative
671*4882a593Smuzhiyun #undef test
672*4882a593Smuzhiyun #undef remove
673*4882a593Smuzhiyun #undef insert
674*4882a593Smuzhiyun #undef init_peer
675*4882a593Smuzhiyun
676*4882a593Smuzhiyun #endif
677