xref: /OK3568_Linux_fs/kernel/net/ipv4/bpf_tcp_ca.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0
2*4882a593Smuzhiyun /* Copyright (c) 2019 Facebook  */
3*4882a593Smuzhiyun 
4*4882a593Smuzhiyun #include <linux/types.h>
5*4882a593Smuzhiyun #include <linux/bpf_verifier.h>
6*4882a593Smuzhiyun #include <linux/bpf.h>
7*4882a593Smuzhiyun #include <linux/btf.h>
8*4882a593Smuzhiyun #include <linux/filter.h>
9*4882a593Smuzhiyun #include <net/tcp.h>
10*4882a593Smuzhiyun #include <net/bpf_sk_storage.h>
11*4882a593Smuzhiyun 
12*4882a593Smuzhiyun static u32 optional_ops[] = {
13*4882a593Smuzhiyun 	offsetof(struct tcp_congestion_ops, init),
14*4882a593Smuzhiyun 	offsetof(struct tcp_congestion_ops, release),
15*4882a593Smuzhiyun 	offsetof(struct tcp_congestion_ops, set_state),
16*4882a593Smuzhiyun 	offsetof(struct tcp_congestion_ops, cwnd_event),
17*4882a593Smuzhiyun 	offsetof(struct tcp_congestion_ops, in_ack_event),
18*4882a593Smuzhiyun 	offsetof(struct tcp_congestion_ops, pkts_acked),
19*4882a593Smuzhiyun 	offsetof(struct tcp_congestion_ops, min_tso_segs),
20*4882a593Smuzhiyun 	offsetof(struct tcp_congestion_ops, sndbuf_expand),
21*4882a593Smuzhiyun 	offsetof(struct tcp_congestion_ops, cong_control),
22*4882a593Smuzhiyun };
23*4882a593Smuzhiyun 
24*4882a593Smuzhiyun static u32 unsupported_ops[] = {
25*4882a593Smuzhiyun 	offsetof(struct tcp_congestion_ops, get_info),
26*4882a593Smuzhiyun };
27*4882a593Smuzhiyun 
28*4882a593Smuzhiyun static const struct btf_type *tcp_sock_type;
29*4882a593Smuzhiyun static u32 tcp_sock_id, sock_id;
30*4882a593Smuzhiyun 
bpf_tcp_ca_init(struct btf * btf)31*4882a593Smuzhiyun static int bpf_tcp_ca_init(struct btf *btf)
32*4882a593Smuzhiyun {
33*4882a593Smuzhiyun 	s32 type_id;
34*4882a593Smuzhiyun 
35*4882a593Smuzhiyun 	type_id = btf_find_by_name_kind(btf, "sock", BTF_KIND_STRUCT);
36*4882a593Smuzhiyun 	if (type_id < 0)
37*4882a593Smuzhiyun 		return -EINVAL;
38*4882a593Smuzhiyun 	sock_id = type_id;
39*4882a593Smuzhiyun 
40*4882a593Smuzhiyun 	type_id = btf_find_by_name_kind(btf, "tcp_sock", BTF_KIND_STRUCT);
41*4882a593Smuzhiyun 	if (type_id < 0)
42*4882a593Smuzhiyun 		return -EINVAL;
43*4882a593Smuzhiyun 	tcp_sock_id = type_id;
44*4882a593Smuzhiyun 	tcp_sock_type = btf_type_by_id(btf, tcp_sock_id);
45*4882a593Smuzhiyun 
46*4882a593Smuzhiyun 	return 0;
47*4882a593Smuzhiyun }
48*4882a593Smuzhiyun 
is_optional(u32 member_offset)49*4882a593Smuzhiyun static bool is_optional(u32 member_offset)
50*4882a593Smuzhiyun {
51*4882a593Smuzhiyun 	unsigned int i;
52*4882a593Smuzhiyun 
53*4882a593Smuzhiyun 	for (i = 0; i < ARRAY_SIZE(optional_ops); i++) {
54*4882a593Smuzhiyun 		if (member_offset == optional_ops[i])
55*4882a593Smuzhiyun 			return true;
56*4882a593Smuzhiyun 	}
57*4882a593Smuzhiyun 
58*4882a593Smuzhiyun 	return false;
59*4882a593Smuzhiyun }
60*4882a593Smuzhiyun 
is_unsupported(u32 member_offset)61*4882a593Smuzhiyun static bool is_unsupported(u32 member_offset)
62*4882a593Smuzhiyun {
63*4882a593Smuzhiyun 	unsigned int i;
64*4882a593Smuzhiyun 
65*4882a593Smuzhiyun 	for (i = 0; i < ARRAY_SIZE(unsupported_ops); i++) {
66*4882a593Smuzhiyun 		if (member_offset == unsupported_ops[i])
67*4882a593Smuzhiyun 			return true;
68*4882a593Smuzhiyun 	}
69*4882a593Smuzhiyun 
70*4882a593Smuzhiyun 	return false;
71*4882a593Smuzhiyun }
72*4882a593Smuzhiyun 
73*4882a593Smuzhiyun extern struct btf *btf_vmlinux;
74*4882a593Smuzhiyun 
bpf_tcp_ca_is_valid_access(int off,int size,enum bpf_access_type type,const struct bpf_prog * prog,struct bpf_insn_access_aux * info)75*4882a593Smuzhiyun static bool bpf_tcp_ca_is_valid_access(int off, int size,
76*4882a593Smuzhiyun 				       enum bpf_access_type type,
77*4882a593Smuzhiyun 				       const struct bpf_prog *prog,
78*4882a593Smuzhiyun 				       struct bpf_insn_access_aux *info)
79*4882a593Smuzhiyun {
80*4882a593Smuzhiyun 	if (off < 0 || off >= sizeof(__u64) * MAX_BPF_FUNC_ARGS)
81*4882a593Smuzhiyun 		return false;
82*4882a593Smuzhiyun 	if (type != BPF_READ)
83*4882a593Smuzhiyun 		return false;
84*4882a593Smuzhiyun 	if (off % size != 0)
85*4882a593Smuzhiyun 		return false;
86*4882a593Smuzhiyun 
87*4882a593Smuzhiyun 	if (!btf_ctx_access(off, size, type, prog, info))
88*4882a593Smuzhiyun 		return false;
89*4882a593Smuzhiyun 
90*4882a593Smuzhiyun 	if (info->reg_type == PTR_TO_BTF_ID && info->btf_id == sock_id)
91*4882a593Smuzhiyun 		/* promote it to tcp_sock */
92*4882a593Smuzhiyun 		info->btf_id = tcp_sock_id;
93*4882a593Smuzhiyun 
94*4882a593Smuzhiyun 	return true;
95*4882a593Smuzhiyun }
96*4882a593Smuzhiyun 
bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log * log,const struct btf_type * t,int off,int size,enum bpf_access_type atype,u32 * next_btf_id)97*4882a593Smuzhiyun static int bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log *log,
98*4882a593Smuzhiyun 					const struct btf_type *t, int off,
99*4882a593Smuzhiyun 					int size, enum bpf_access_type atype,
100*4882a593Smuzhiyun 					u32 *next_btf_id)
101*4882a593Smuzhiyun {
102*4882a593Smuzhiyun 	size_t end;
103*4882a593Smuzhiyun 
104*4882a593Smuzhiyun 	if (atype == BPF_READ)
105*4882a593Smuzhiyun 		return btf_struct_access(log, t, off, size, atype, next_btf_id);
106*4882a593Smuzhiyun 
107*4882a593Smuzhiyun 	if (t != tcp_sock_type) {
108*4882a593Smuzhiyun 		bpf_log(log, "only read is supported\n");
109*4882a593Smuzhiyun 		return -EACCES;
110*4882a593Smuzhiyun 	}
111*4882a593Smuzhiyun 
112*4882a593Smuzhiyun 	switch (off) {
113*4882a593Smuzhiyun 	case bpf_ctx_range(struct inet_connection_sock, icsk_ca_priv):
114*4882a593Smuzhiyun 		end = offsetofend(struct inet_connection_sock, icsk_ca_priv);
115*4882a593Smuzhiyun 		break;
116*4882a593Smuzhiyun 	case offsetof(struct inet_connection_sock, icsk_ack.pending):
117*4882a593Smuzhiyun 		end = offsetofend(struct inet_connection_sock,
118*4882a593Smuzhiyun 				  icsk_ack.pending);
119*4882a593Smuzhiyun 		break;
120*4882a593Smuzhiyun 	case offsetof(struct tcp_sock, snd_cwnd):
121*4882a593Smuzhiyun 		end = offsetofend(struct tcp_sock, snd_cwnd);
122*4882a593Smuzhiyun 		break;
123*4882a593Smuzhiyun 	case offsetof(struct tcp_sock, snd_cwnd_cnt):
124*4882a593Smuzhiyun 		end = offsetofend(struct tcp_sock, snd_cwnd_cnt);
125*4882a593Smuzhiyun 		break;
126*4882a593Smuzhiyun 	case offsetof(struct tcp_sock, snd_ssthresh):
127*4882a593Smuzhiyun 		end = offsetofend(struct tcp_sock, snd_ssthresh);
128*4882a593Smuzhiyun 		break;
129*4882a593Smuzhiyun 	case offsetof(struct tcp_sock, ecn_flags):
130*4882a593Smuzhiyun 		end = offsetofend(struct tcp_sock, ecn_flags);
131*4882a593Smuzhiyun 		break;
132*4882a593Smuzhiyun 	default:
133*4882a593Smuzhiyun 		bpf_log(log, "no write support to tcp_sock at off %d\n", off);
134*4882a593Smuzhiyun 		return -EACCES;
135*4882a593Smuzhiyun 	}
136*4882a593Smuzhiyun 
137*4882a593Smuzhiyun 	if (off + size > end) {
138*4882a593Smuzhiyun 		bpf_log(log,
139*4882a593Smuzhiyun 			"write access at off %d with size %d beyond the member of tcp_sock ended at %zu\n",
140*4882a593Smuzhiyun 			off, size, end);
141*4882a593Smuzhiyun 		return -EACCES;
142*4882a593Smuzhiyun 	}
143*4882a593Smuzhiyun 
144*4882a593Smuzhiyun 	return NOT_INIT;
145*4882a593Smuzhiyun }
146*4882a593Smuzhiyun 
BPF_CALL_2(bpf_tcp_send_ack,struct tcp_sock *,tp,u32,rcv_nxt)147*4882a593Smuzhiyun BPF_CALL_2(bpf_tcp_send_ack, struct tcp_sock *, tp, u32, rcv_nxt)
148*4882a593Smuzhiyun {
149*4882a593Smuzhiyun 	/* bpf_tcp_ca prog cannot have NULL tp */
150*4882a593Smuzhiyun 	__tcp_send_ack((struct sock *)tp, rcv_nxt);
151*4882a593Smuzhiyun 	return 0;
152*4882a593Smuzhiyun }
153*4882a593Smuzhiyun 
154*4882a593Smuzhiyun static const struct bpf_func_proto bpf_tcp_send_ack_proto = {
155*4882a593Smuzhiyun 	.func		= bpf_tcp_send_ack,
156*4882a593Smuzhiyun 	.gpl_only	= false,
157*4882a593Smuzhiyun 	/* In case we want to report error later */
158*4882a593Smuzhiyun 	.ret_type	= RET_INTEGER,
159*4882a593Smuzhiyun 	.arg1_type	= ARG_PTR_TO_BTF_ID,
160*4882a593Smuzhiyun 	.arg1_btf_id	= &tcp_sock_id,
161*4882a593Smuzhiyun 	.arg2_type	= ARG_ANYTHING,
162*4882a593Smuzhiyun };
163*4882a593Smuzhiyun 
164*4882a593Smuzhiyun static const struct bpf_func_proto *
bpf_tcp_ca_get_func_proto(enum bpf_func_id func_id,const struct bpf_prog * prog)165*4882a593Smuzhiyun bpf_tcp_ca_get_func_proto(enum bpf_func_id func_id,
166*4882a593Smuzhiyun 			  const struct bpf_prog *prog)
167*4882a593Smuzhiyun {
168*4882a593Smuzhiyun 	switch (func_id) {
169*4882a593Smuzhiyun 	case BPF_FUNC_tcp_send_ack:
170*4882a593Smuzhiyun 		return &bpf_tcp_send_ack_proto;
171*4882a593Smuzhiyun 	case BPF_FUNC_sk_storage_get:
172*4882a593Smuzhiyun 		return &bpf_sk_storage_get_proto;
173*4882a593Smuzhiyun 	case BPF_FUNC_sk_storage_delete:
174*4882a593Smuzhiyun 		return &bpf_sk_storage_delete_proto;
175*4882a593Smuzhiyun 	default:
176*4882a593Smuzhiyun 		return bpf_base_func_proto(func_id);
177*4882a593Smuzhiyun 	}
178*4882a593Smuzhiyun }
179*4882a593Smuzhiyun 
180*4882a593Smuzhiyun static const struct bpf_verifier_ops bpf_tcp_ca_verifier_ops = {
181*4882a593Smuzhiyun 	.get_func_proto		= bpf_tcp_ca_get_func_proto,
182*4882a593Smuzhiyun 	.is_valid_access	= bpf_tcp_ca_is_valid_access,
183*4882a593Smuzhiyun 	.btf_struct_access	= bpf_tcp_ca_btf_struct_access,
184*4882a593Smuzhiyun };
185*4882a593Smuzhiyun 
bpf_tcp_ca_init_member(const struct btf_type * t,const struct btf_member * member,void * kdata,const void * udata)186*4882a593Smuzhiyun static int bpf_tcp_ca_init_member(const struct btf_type *t,
187*4882a593Smuzhiyun 				  const struct btf_member *member,
188*4882a593Smuzhiyun 				  void *kdata, const void *udata)
189*4882a593Smuzhiyun {
190*4882a593Smuzhiyun 	const struct tcp_congestion_ops *utcp_ca;
191*4882a593Smuzhiyun 	struct tcp_congestion_ops *tcp_ca;
192*4882a593Smuzhiyun 	int prog_fd;
193*4882a593Smuzhiyun 	u32 moff;
194*4882a593Smuzhiyun 
195*4882a593Smuzhiyun 	utcp_ca = (const struct tcp_congestion_ops *)udata;
196*4882a593Smuzhiyun 	tcp_ca = (struct tcp_congestion_ops *)kdata;
197*4882a593Smuzhiyun 
198*4882a593Smuzhiyun 	moff = btf_member_bit_offset(t, member) / 8;
199*4882a593Smuzhiyun 	switch (moff) {
200*4882a593Smuzhiyun 	case offsetof(struct tcp_congestion_ops, flags):
201*4882a593Smuzhiyun 		if (utcp_ca->flags & ~TCP_CONG_MASK)
202*4882a593Smuzhiyun 			return -EINVAL;
203*4882a593Smuzhiyun 		tcp_ca->flags = utcp_ca->flags;
204*4882a593Smuzhiyun 		return 1;
205*4882a593Smuzhiyun 	case offsetof(struct tcp_congestion_ops, name):
206*4882a593Smuzhiyun 		if (bpf_obj_name_cpy(tcp_ca->name, utcp_ca->name,
207*4882a593Smuzhiyun 				     sizeof(tcp_ca->name)) <= 0)
208*4882a593Smuzhiyun 			return -EINVAL;
209*4882a593Smuzhiyun 		if (tcp_ca_find(utcp_ca->name))
210*4882a593Smuzhiyun 			return -EEXIST;
211*4882a593Smuzhiyun 		return 1;
212*4882a593Smuzhiyun 	}
213*4882a593Smuzhiyun 
214*4882a593Smuzhiyun 	if (!btf_type_resolve_func_ptr(btf_vmlinux, member->type, NULL))
215*4882a593Smuzhiyun 		return 0;
216*4882a593Smuzhiyun 
217*4882a593Smuzhiyun 	/* Ensure bpf_prog is provided for compulsory func ptr */
218*4882a593Smuzhiyun 	prog_fd = (int)(*(unsigned long *)(udata + moff));
219*4882a593Smuzhiyun 	if (!prog_fd && !is_optional(moff) && !is_unsupported(moff))
220*4882a593Smuzhiyun 		return -EINVAL;
221*4882a593Smuzhiyun 
222*4882a593Smuzhiyun 	return 0;
223*4882a593Smuzhiyun }
224*4882a593Smuzhiyun 
bpf_tcp_ca_check_member(const struct btf_type * t,const struct btf_member * member)225*4882a593Smuzhiyun static int bpf_tcp_ca_check_member(const struct btf_type *t,
226*4882a593Smuzhiyun 				   const struct btf_member *member)
227*4882a593Smuzhiyun {
228*4882a593Smuzhiyun 	if (is_unsupported(btf_member_bit_offset(t, member) / 8))
229*4882a593Smuzhiyun 		return -ENOTSUPP;
230*4882a593Smuzhiyun 	return 0;
231*4882a593Smuzhiyun }
232*4882a593Smuzhiyun 
bpf_tcp_ca_reg(void * kdata)233*4882a593Smuzhiyun static int bpf_tcp_ca_reg(void *kdata)
234*4882a593Smuzhiyun {
235*4882a593Smuzhiyun 	return tcp_register_congestion_control(kdata);
236*4882a593Smuzhiyun }
237*4882a593Smuzhiyun 
bpf_tcp_ca_unreg(void * kdata)238*4882a593Smuzhiyun static void bpf_tcp_ca_unreg(void *kdata)
239*4882a593Smuzhiyun {
240*4882a593Smuzhiyun 	tcp_unregister_congestion_control(kdata);
241*4882a593Smuzhiyun }
242*4882a593Smuzhiyun 
243*4882a593Smuzhiyun /* Avoid sparse warning.  It is only used in bpf_struct_ops.c. */
244*4882a593Smuzhiyun extern struct bpf_struct_ops bpf_tcp_congestion_ops;
245*4882a593Smuzhiyun 
246*4882a593Smuzhiyun struct bpf_struct_ops bpf_tcp_congestion_ops = {
247*4882a593Smuzhiyun 	.verifier_ops = &bpf_tcp_ca_verifier_ops,
248*4882a593Smuzhiyun 	.reg = bpf_tcp_ca_reg,
249*4882a593Smuzhiyun 	.unreg = bpf_tcp_ca_unreg,
250*4882a593Smuzhiyun 	.check_member = bpf_tcp_ca_check_member,
251*4882a593Smuzhiyun 	.init_member = bpf_tcp_ca_init_member,
252*4882a593Smuzhiyun 	.init = bpf_tcp_ca_init,
253*4882a593Smuzhiyun 	.name = "tcp_congestion_ops",
254*4882a593Smuzhiyun };
255