1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0
2*4882a593Smuzhiyun /* Copyright (c) 2019 Facebook */
3*4882a593Smuzhiyun
4*4882a593Smuzhiyun #include <linux/types.h>
5*4882a593Smuzhiyun #include <linux/bpf_verifier.h>
6*4882a593Smuzhiyun #include <linux/bpf.h>
7*4882a593Smuzhiyun #include <linux/btf.h>
8*4882a593Smuzhiyun #include <linux/filter.h>
9*4882a593Smuzhiyun #include <net/tcp.h>
10*4882a593Smuzhiyun #include <net/bpf_sk_storage.h>
11*4882a593Smuzhiyun
12*4882a593Smuzhiyun static u32 optional_ops[] = {
13*4882a593Smuzhiyun offsetof(struct tcp_congestion_ops, init),
14*4882a593Smuzhiyun offsetof(struct tcp_congestion_ops, release),
15*4882a593Smuzhiyun offsetof(struct tcp_congestion_ops, set_state),
16*4882a593Smuzhiyun offsetof(struct tcp_congestion_ops, cwnd_event),
17*4882a593Smuzhiyun offsetof(struct tcp_congestion_ops, in_ack_event),
18*4882a593Smuzhiyun offsetof(struct tcp_congestion_ops, pkts_acked),
19*4882a593Smuzhiyun offsetof(struct tcp_congestion_ops, min_tso_segs),
20*4882a593Smuzhiyun offsetof(struct tcp_congestion_ops, sndbuf_expand),
21*4882a593Smuzhiyun offsetof(struct tcp_congestion_ops, cong_control),
22*4882a593Smuzhiyun };
23*4882a593Smuzhiyun
24*4882a593Smuzhiyun static u32 unsupported_ops[] = {
25*4882a593Smuzhiyun offsetof(struct tcp_congestion_ops, get_info),
26*4882a593Smuzhiyun };
27*4882a593Smuzhiyun
28*4882a593Smuzhiyun static const struct btf_type *tcp_sock_type;
29*4882a593Smuzhiyun static u32 tcp_sock_id, sock_id;
30*4882a593Smuzhiyun
bpf_tcp_ca_init(struct btf * btf)31*4882a593Smuzhiyun static int bpf_tcp_ca_init(struct btf *btf)
32*4882a593Smuzhiyun {
33*4882a593Smuzhiyun s32 type_id;
34*4882a593Smuzhiyun
35*4882a593Smuzhiyun type_id = btf_find_by_name_kind(btf, "sock", BTF_KIND_STRUCT);
36*4882a593Smuzhiyun if (type_id < 0)
37*4882a593Smuzhiyun return -EINVAL;
38*4882a593Smuzhiyun sock_id = type_id;
39*4882a593Smuzhiyun
40*4882a593Smuzhiyun type_id = btf_find_by_name_kind(btf, "tcp_sock", BTF_KIND_STRUCT);
41*4882a593Smuzhiyun if (type_id < 0)
42*4882a593Smuzhiyun return -EINVAL;
43*4882a593Smuzhiyun tcp_sock_id = type_id;
44*4882a593Smuzhiyun tcp_sock_type = btf_type_by_id(btf, tcp_sock_id);
45*4882a593Smuzhiyun
46*4882a593Smuzhiyun return 0;
47*4882a593Smuzhiyun }
48*4882a593Smuzhiyun
is_optional(u32 member_offset)49*4882a593Smuzhiyun static bool is_optional(u32 member_offset)
50*4882a593Smuzhiyun {
51*4882a593Smuzhiyun unsigned int i;
52*4882a593Smuzhiyun
53*4882a593Smuzhiyun for (i = 0; i < ARRAY_SIZE(optional_ops); i++) {
54*4882a593Smuzhiyun if (member_offset == optional_ops[i])
55*4882a593Smuzhiyun return true;
56*4882a593Smuzhiyun }
57*4882a593Smuzhiyun
58*4882a593Smuzhiyun return false;
59*4882a593Smuzhiyun }
60*4882a593Smuzhiyun
is_unsupported(u32 member_offset)61*4882a593Smuzhiyun static bool is_unsupported(u32 member_offset)
62*4882a593Smuzhiyun {
63*4882a593Smuzhiyun unsigned int i;
64*4882a593Smuzhiyun
65*4882a593Smuzhiyun for (i = 0; i < ARRAY_SIZE(unsupported_ops); i++) {
66*4882a593Smuzhiyun if (member_offset == unsupported_ops[i])
67*4882a593Smuzhiyun return true;
68*4882a593Smuzhiyun }
69*4882a593Smuzhiyun
70*4882a593Smuzhiyun return false;
71*4882a593Smuzhiyun }
72*4882a593Smuzhiyun
73*4882a593Smuzhiyun extern struct btf *btf_vmlinux;
74*4882a593Smuzhiyun
bpf_tcp_ca_is_valid_access(int off,int size,enum bpf_access_type type,const struct bpf_prog * prog,struct bpf_insn_access_aux * info)75*4882a593Smuzhiyun static bool bpf_tcp_ca_is_valid_access(int off, int size,
76*4882a593Smuzhiyun enum bpf_access_type type,
77*4882a593Smuzhiyun const struct bpf_prog *prog,
78*4882a593Smuzhiyun struct bpf_insn_access_aux *info)
79*4882a593Smuzhiyun {
80*4882a593Smuzhiyun if (off < 0 || off >= sizeof(__u64) * MAX_BPF_FUNC_ARGS)
81*4882a593Smuzhiyun return false;
82*4882a593Smuzhiyun if (type != BPF_READ)
83*4882a593Smuzhiyun return false;
84*4882a593Smuzhiyun if (off % size != 0)
85*4882a593Smuzhiyun return false;
86*4882a593Smuzhiyun
87*4882a593Smuzhiyun if (!btf_ctx_access(off, size, type, prog, info))
88*4882a593Smuzhiyun return false;
89*4882a593Smuzhiyun
90*4882a593Smuzhiyun if (info->reg_type == PTR_TO_BTF_ID && info->btf_id == sock_id)
91*4882a593Smuzhiyun /* promote it to tcp_sock */
92*4882a593Smuzhiyun info->btf_id = tcp_sock_id;
93*4882a593Smuzhiyun
94*4882a593Smuzhiyun return true;
95*4882a593Smuzhiyun }
96*4882a593Smuzhiyun
bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log * log,const struct btf_type * t,int off,int size,enum bpf_access_type atype,u32 * next_btf_id)97*4882a593Smuzhiyun static int bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log *log,
98*4882a593Smuzhiyun const struct btf_type *t, int off,
99*4882a593Smuzhiyun int size, enum bpf_access_type atype,
100*4882a593Smuzhiyun u32 *next_btf_id)
101*4882a593Smuzhiyun {
102*4882a593Smuzhiyun size_t end;
103*4882a593Smuzhiyun
104*4882a593Smuzhiyun if (atype == BPF_READ)
105*4882a593Smuzhiyun return btf_struct_access(log, t, off, size, atype, next_btf_id);
106*4882a593Smuzhiyun
107*4882a593Smuzhiyun if (t != tcp_sock_type) {
108*4882a593Smuzhiyun bpf_log(log, "only read is supported\n");
109*4882a593Smuzhiyun return -EACCES;
110*4882a593Smuzhiyun }
111*4882a593Smuzhiyun
112*4882a593Smuzhiyun switch (off) {
113*4882a593Smuzhiyun case bpf_ctx_range(struct inet_connection_sock, icsk_ca_priv):
114*4882a593Smuzhiyun end = offsetofend(struct inet_connection_sock, icsk_ca_priv);
115*4882a593Smuzhiyun break;
116*4882a593Smuzhiyun case offsetof(struct inet_connection_sock, icsk_ack.pending):
117*4882a593Smuzhiyun end = offsetofend(struct inet_connection_sock,
118*4882a593Smuzhiyun icsk_ack.pending);
119*4882a593Smuzhiyun break;
120*4882a593Smuzhiyun case offsetof(struct tcp_sock, snd_cwnd):
121*4882a593Smuzhiyun end = offsetofend(struct tcp_sock, snd_cwnd);
122*4882a593Smuzhiyun break;
123*4882a593Smuzhiyun case offsetof(struct tcp_sock, snd_cwnd_cnt):
124*4882a593Smuzhiyun end = offsetofend(struct tcp_sock, snd_cwnd_cnt);
125*4882a593Smuzhiyun break;
126*4882a593Smuzhiyun case offsetof(struct tcp_sock, snd_ssthresh):
127*4882a593Smuzhiyun end = offsetofend(struct tcp_sock, snd_ssthresh);
128*4882a593Smuzhiyun break;
129*4882a593Smuzhiyun case offsetof(struct tcp_sock, ecn_flags):
130*4882a593Smuzhiyun end = offsetofend(struct tcp_sock, ecn_flags);
131*4882a593Smuzhiyun break;
132*4882a593Smuzhiyun default:
133*4882a593Smuzhiyun bpf_log(log, "no write support to tcp_sock at off %d\n", off);
134*4882a593Smuzhiyun return -EACCES;
135*4882a593Smuzhiyun }
136*4882a593Smuzhiyun
137*4882a593Smuzhiyun if (off + size > end) {
138*4882a593Smuzhiyun bpf_log(log,
139*4882a593Smuzhiyun "write access at off %d with size %d beyond the member of tcp_sock ended at %zu\n",
140*4882a593Smuzhiyun off, size, end);
141*4882a593Smuzhiyun return -EACCES;
142*4882a593Smuzhiyun }
143*4882a593Smuzhiyun
144*4882a593Smuzhiyun return NOT_INIT;
145*4882a593Smuzhiyun }
146*4882a593Smuzhiyun
BPF_CALL_2(bpf_tcp_send_ack,struct tcp_sock *,tp,u32,rcv_nxt)147*4882a593Smuzhiyun BPF_CALL_2(bpf_tcp_send_ack, struct tcp_sock *, tp, u32, rcv_nxt)
148*4882a593Smuzhiyun {
149*4882a593Smuzhiyun /* bpf_tcp_ca prog cannot have NULL tp */
150*4882a593Smuzhiyun __tcp_send_ack((struct sock *)tp, rcv_nxt);
151*4882a593Smuzhiyun return 0;
152*4882a593Smuzhiyun }
153*4882a593Smuzhiyun
154*4882a593Smuzhiyun static const struct bpf_func_proto bpf_tcp_send_ack_proto = {
155*4882a593Smuzhiyun .func = bpf_tcp_send_ack,
156*4882a593Smuzhiyun .gpl_only = false,
157*4882a593Smuzhiyun /* In case we want to report error later */
158*4882a593Smuzhiyun .ret_type = RET_INTEGER,
159*4882a593Smuzhiyun .arg1_type = ARG_PTR_TO_BTF_ID,
160*4882a593Smuzhiyun .arg1_btf_id = &tcp_sock_id,
161*4882a593Smuzhiyun .arg2_type = ARG_ANYTHING,
162*4882a593Smuzhiyun };
163*4882a593Smuzhiyun
164*4882a593Smuzhiyun static const struct bpf_func_proto *
bpf_tcp_ca_get_func_proto(enum bpf_func_id func_id,const struct bpf_prog * prog)165*4882a593Smuzhiyun bpf_tcp_ca_get_func_proto(enum bpf_func_id func_id,
166*4882a593Smuzhiyun const struct bpf_prog *prog)
167*4882a593Smuzhiyun {
168*4882a593Smuzhiyun switch (func_id) {
169*4882a593Smuzhiyun case BPF_FUNC_tcp_send_ack:
170*4882a593Smuzhiyun return &bpf_tcp_send_ack_proto;
171*4882a593Smuzhiyun case BPF_FUNC_sk_storage_get:
172*4882a593Smuzhiyun return &bpf_sk_storage_get_proto;
173*4882a593Smuzhiyun case BPF_FUNC_sk_storage_delete:
174*4882a593Smuzhiyun return &bpf_sk_storage_delete_proto;
175*4882a593Smuzhiyun default:
176*4882a593Smuzhiyun return bpf_base_func_proto(func_id);
177*4882a593Smuzhiyun }
178*4882a593Smuzhiyun }
179*4882a593Smuzhiyun
180*4882a593Smuzhiyun static const struct bpf_verifier_ops bpf_tcp_ca_verifier_ops = {
181*4882a593Smuzhiyun .get_func_proto = bpf_tcp_ca_get_func_proto,
182*4882a593Smuzhiyun .is_valid_access = bpf_tcp_ca_is_valid_access,
183*4882a593Smuzhiyun .btf_struct_access = bpf_tcp_ca_btf_struct_access,
184*4882a593Smuzhiyun };
185*4882a593Smuzhiyun
bpf_tcp_ca_init_member(const struct btf_type * t,const struct btf_member * member,void * kdata,const void * udata)186*4882a593Smuzhiyun static int bpf_tcp_ca_init_member(const struct btf_type *t,
187*4882a593Smuzhiyun const struct btf_member *member,
188*4882a593Smuzhiyun void *kdata, const void *udata)
189*4882a593Smuzhiyun {
190*4882a593Smuzhiyun const struct tcp_congestion_ops *utcp_ca;
191*4882a593Smuzhiyun struct tcp_congestion_ops *tcp_ca;
192*4882a593Smuzhiyun int prog_fd;
193*4882a593Smuzhiyun u32 moff;
194*4882a593Smuzhiyun
195*4882a593Smuzhiyun utcp_ca = (const struct tcp_congestion_ops *)udata;
196*4882a593Smuzhiyun tcp_ca = (struct tcp_congestion_ops *)kdata;
197*4882a593Smuzhiyun
198*4882a593Smuzhiyun moff = btf_member_bit_offset(t, member) / 8;
199*4882a593Smuzhiyun switch (moff) {
200*4882a593Smuzhiyun case offsetof(struct tcp_congestion_ops, flags):
201*4882a593Smuzhiyun if (utcp_ca->flags & ~TCP_CONG_MASK)
202*4882a593Smuzhiyun return -EINVAL;
203*4882a593Smuzhiyun tcp_ca->flags = utcp_ca->flags;
204*4882a593Smuzhiyun return 1;
205*4882a593Smuzhiyun case offsetof(struct tcp_congestion_ops, name):
206*4882a593Smuzhiyun if (bpf_obj_name_cpy(tcp_ca->name, utcp_ca->name,
207*4882a593Smuzhiyun sizeof(tcp_ca->name)) <= 0)
208*4882a593Smuzhiyun return -EINVAL;
209*4882a593Smuzhiyun if (tcp_ca_find(utcp_ca->name))
210*4882a593Smuzhiyun return -EEXIST;
211*4882a593Smuzhiyun return 1;
212*4882a593Smuzhiyun }
213*4882a593Smuzhiyun
214*4882a593Smuzhiyun if (!btf_type_resolve_func_ptr(btf_vmlinux, member->type, NULL))
215*4882a593Smuzhiyun return 0;
216*4882a593Smuzhiyun
217*4882a593Smuzhiyun /* Ensure bpf_prog is provided for compulsory func ptr */
218*4882a593Smuzhiyun prog_fd = (int)(*(unsigned long *)(udata + moff));
219*4882a593Smuzhiyun if (!prog_fd && !is_optional(moff) && !is_unsupported(moff))
220*4882a593Smuzhiyun return -EINVAL;
221*4882a593Smuzhiyun
222*4882a593Smuzhiyun return 0;
223*4882a593Smuzhiyun }
224*4882a593Smuzhiyun
bpf_tcp_ca_check_member(const struct btf_type * t,const struct btf_member * member)225*4882a593Smuzhiyun static int bpf_tcp_ca_check_member(const struct btf_type *t,
226*4882a593Smuzhiyun const struct btf_member *member)
227*4882a593Smuzhiyun {
228*4882a593Smuzhiyun if (is_unsupported(btf_member_bit_offset(t, member) / 8))
229*4882a593Smuzhiyun return -ENOTSUPP;
230*4882a593Smuzhiyun return 0;
231*4882a593Smuzhiyun }
232*4882a593Smuzhiyun
bpf_tcp_ca_reg(void * kdata)233*4882a593Smuzhiyun static int bpf_tcp_ca_reg(void *kdata)
234*4882a593Smuzhiyun {
235*4882a593Smuzhiyun return tcp_register_congestion_control(kdata);
236*4882a593Smuzhiyun }
237*4882a593Smuzhiyun
bpf_tcp_ca_unreg(void * kdata)238*4882a593Smuzhiyun static void bpf_tcp_ca_unreg(void *kdata)
239*4882a593Smuzhiyun {
240*4882a593Smuzhiyun tcp_unregister_congestion_control(kdata);
241*4882a593Smuzhiyun }
242*4882a593Smuzhiyun
243*4882a593Smuzhiyun /* Avoid sparse warning. It is only used in bpf_struct_ops.c. */
244*4882a593Smuzhiyun extern struct bpf_struct_ops bpf_tcp_congestion_ops;
245*4882a593Smuzhiyun
246*4882a593Smuzhiyun struct bpf_struct_ops bpf_tcp_congestion_ops = {
247*4882a593Smuzhiyun .verifier_ops = &bpf_tcp_ca_verifier_ops,
248*4882a593Smuzhiyun .reg = bpf_tcp_ca_reg,
249*4882a593Smuzhiyun .unreg = bpf_tcp_ca_unreg,
250*4882a593Smuzhiyun .check_member = bpf_tcp_ca_check_member,
251*4882a593Smuzhiyun .init_member = bpf_tcp_ca_init_member,
252*4882a593Smuzhiyun .init = bpf_tcp_ca_init,
253*4882a593Smuzhiyun .name = "tcp_congestion_ops",
254*4882a593Smuzhiyun };
255