xref: /OK3568_Linux_fs/kernel/net/core/lwt_bpf.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0-only
2*4882a593Smuzhiyun /* Copyright (c) 2016 Thomas Graf <tgraf@tgraf.ch>
3*4882a593Smuzhiyun  */
4*4882a593Smuzhiyun 
5*4882a593Smuzhiyun #include <linux/kernel.h>
6*4882a593Smuzhiyun #include <linux/module.h>
7*4882a593Smuzhiyun #include <linux/skbuff.h>
8*4882a593Smuzhiyun #include <linux/types.h>
9*4882a593Smuzhiyun #include <linux/bpf.h>
10*4882a593Smuzhiyun #include <net/lwtunnel.h>
11*4882a593Smuzhiyun #include <net/gre.h>
12*4882a593Smuzhiyun #include <net/ip6_route.h>
13*4882a593Smuzhiyun #include <net/ipv6_stubs.h>
14*4882a593Smuzhiyun 
15*4882a593Smuzhiyun struct bpf_lwt_prog {
16*4882a593Smuzhiyun 	struct bpf_prog *prog;
17*4882a593Smuzhiyun 	char *name;
18*4882a593Smuzhiyun };
19*4882a593Smuzhiyun 
20*4882a593Smuzhiyun struct bpf_lwt {
21*4882a593Smuzhiyun 	struct bpf_lwt_prog in;
22*4882a593Smuzhiyun 	struct bpf_lwt_prog out;
23*4882a593Smuzhiyun 	struct bpf_lwt_prog xmit;
24*4882a593Smuzhiyun 	int family;
25*4882a593Smuzhiyun };
26*4882a593Smuzhiyun 
27*4882a593Smuzhiyun #define MAX_PROG_NAME 256
28*4882a593Smuzhiyun 
bpf_lwt_lwtunnel(struct lwtunnel_state * lwt)29*4882a593Smuzhiyun static inline struct bpf_lwt *bpf_lwt_lwtunnel(struct lwtunnel_state *lwt)
30*4882a593Smuzhiyun {
31*4882a593Smuzhiyun 	return (struct bpf_lwt *)lwt->data;
32*4882a593Smuzhiyun }
33*4882a593Smuzhiyun 
34*4882a593Smuzhiyun #define NO_REDIRECT false
35*4882a593Smuzhiyun #define CAN_REDIRECT true
36*4882a593Smuzhiyun 
run_lwt_bpf(struct sk_buff * skb,struct bpf_lwt_prog * lwt,struct dst_entry * dst,bool can_redirect)37*4882a593Smuzhiyun static int run_lwt_bpf(struct sk_buff *skb, struct bpf_lwt_prog *lwt,
38*4882a593Smuzhiyun 		       struct dst_entry *dst, bool can_redirect)
39*4882a593Smuzhiyun {
40*4882a593Smuzhiyun 	int ret;
41*4882a593Smuzhiyun 
42*4882a593Smuzhiyun 	/* Migration disable and BH disable are needed to protect per-cpu
43*4882a593Smuzhiyun 	 * redirect_info between BPF prog and skb_do_redirect().
44*4882a593Smuzhiyun 	 */
45*4882a593Smuzhiyun 	migrate_disable();
46*4882a593Smuzhiyun 	local_bh_disable();
47*4882a593Smuzhiyun 	bpf_compute_data_pointers(skb);
48*4882a593Smuzhiyun 	ret = bpf_prog_run_save_cb(lwt->prog, skb);
49*4882a593Smuzhiyun 
50*4882a593Smuzhiyun 	switch (ret) {
51*4882a593Smuzhiyun 	case BPF_OK:
52*4882a593Smuzhiyun 	case BPF_LWT_REROUTE:
53*4882a593Smuzhiyun 		break;
54*4882a593Smuzhiyun 
55*4882a593Smuzhiyun 	case BPF_REDIRECT:
56*4882a593Smuzhiyun 		if (unlikely(!can_redirect)) {
57*4882a593Smuzhiyun 			pr_warn_once("Illegal redirect return code in prog %s\n",
58*4882a593Smuzhiyun 				     lwt->name ? : "<unknown>");
59*4882a593Smuzhiyun 			ret = BPF_OK;
60*4882a593Smuzhiyun 		} else {
61*4882a593Smuzhiyun 			skb_reset_mac_header(skb);
62*4882a593Smuzhiyun 			ret = skb_do_redirect(skb);
63*4882a593Smuzhiyun 			if (ret == 0)
64*4882a593Smuzhiyun 				ret = BPF_REDIRECT;
65*4882a593Smuzhiyun 		}
66*4882a593Smuzhiyun 		break;
67*4882a593Smuzhiyun 
68*4882a593Smuzhiyun 	case BPF_DROP:
69*4882a593Smuzhiyun 		kfree_skb(skb);
70*4882a593Smuzhiyun 		ret = -EPERM;
71*4882a593Smuzhiyun 		break;
72*4882a593Smuzhiyun 
73*4882a593Smuzhiyun 	default:
74*4882a593Smuzhiyun 		pr_warn_once("bpf-lwt: Illegal return value %u, expect packet loss\n", ret);
75*4882a593Smuzhiyun 		kfree_skb(skb);
76*4882a593Smuzhiyun 		ret = -EINVAL;
77*4882a593Smuzhiyun 		break;
78*4882a593Smuzhiyun 	}
79*4882a593Smuzhiyun 
80*4882a593Smuzhiyun 	local_bh_enable();
81*4882a593Smuzhiyun 	migrate_enable();
82*4882a593Smuzhiyun 
83*4882a593Smuzhiyun 	return ret;
84*4882a593Smuzhiyun }
85*4882a593Smuzhiyun 
bpf_lwt_input_reroute(struct sk_buff * skb)86*4882a593Smuzhiyun static int bpf_lwt_input_reroute(struct sk_buff *skb)
87*4882a593Smuzhiyun {
88*4882a593Smuzhiyun 	int err = -EINVAL;
89*4882a593Smuzhiyun 
90*4882a593Smuzhiyun 	if (skb->protocol == htons(ETH_P_IP)) {
91*4882a593Smuzhiyun 		struct net_device *dev = skb_dst(skb)->dev;
92*4882a593Smuzhiyun 		struct iphdr *iph = ip_hdr(skb);
93*4882a593Smuzhiyun 
94*4882a593Smuzhiyun 		dev_hold(dev);
95*4882a593Smuzhiyun 		skb_dst_drop(skb);
96*4882a593Smuzhiyun 		err = ip_route_input_noref(skb, iph->daddr, iph->saddr,
97*4882a593Smuzhiyun 					   iph->tos, dev);
98*4882a593Smuzhiyun 		dev_put(dev);
99*4882a593Smuzhiyun 	} else if (skb->protocol == htons(ETH_P_IPV6)) {
100*4882a593Smuzhiyun 		skb_dst_drop(skb);
101*4882a593Smuzhiyun 		err = ipv6_stub->ipv6_route_input(skb);
102*4882a593Smuzhiyun 	} else {
103*4882a593Smuzhiyun 		err = -EAFNOSUPPORT;
104*4882a593Smuzhiyun 	}
105*4882a593Smuzhiyun 
106*4882a593Smuzhiyun 	if (err)
107*4882a593Smuzhiyun 		goto err;
108*4882a593Smuzhiyun 	return dst_input(skb);
109*4882a593Smuzhiyun 
110*4882a593Smuzhiyun err:
111*4882a593Smuzhiyun 	kfree_skb(skb);
112*4882a593Smuzhiyun 	return err;
113*4882a593Smuzhiyun }
114*4882a593Smuzhiyun 
bpf_input(struct sk_buff * skb)115*4882a593Smuzhiyun static int bpf_input(struct sk_buff *skb)
116*4882a593Smuzhiyun {
117*4882a593Smuzhiyun 	struct dst_entry *dst = skb_dst(skb);
118*4882a593Smuzhiyun 	struct bpf_lwt *bpf;
119*4882a593Smuzhiyun 	int ret;
120*4882a593Smuzhiyun 
121*4882a593Smuzhiyun 	bpf = bpf_lwt_lwtunnel(dst->lwtstate);
122*4882a593Smuzhiyun 	if (bpf->in.prog) {
123*4882a593Smuzhiyun 		ret = run_lwt_bpf(skb, &bpf->in, dst, NO_REDIRECT);
124*4882a593Smuzhiyun 		if (ret < 0)
125*4882a593Smuzhiyun 			return ret;
126*4882a593Smuzhiyun 		if (ret == BPF_LWT_REROUTE)
127*4882a593Smuzhiyun 			return bpf_lwt_input_reroute(skb);
128*4882a593Smuzhiyun 	}
129*4882a593Smuzhiyun 
130*4882a593Smuzhiyun 	if (unlikely(!dst->lwtstate->orig_input)) {
131*4882a593Smuzhiyun 		kfree_skb(skb);
132*4882a593Smuzhiyun 		return -EINVAL;
133*4882a593Smuzhiyun 	}
134*4882a593Smuzhiyun 
135*4882a593Smuzhiyun 	return dst->lwtstate->orig_input(skb);
136*4882a593Smuzhiyun }
137*4882a593Smuzhiyun 
bpf_output(struct net * net,struct sock * sk,struct sk_buff * skb)138*4882a593Smuzhiyun static int bpf_output(struct net *net, struct sock *sk, struct sk_buff *skb)
139*4882a593Smuzhiyun {
140*4882a593Smuzhiyun 	struct dst_entry *dst = skb_dst(skb);
141*4882a593Smuzhiyun 	struct bpf_lwt *bpf;
142*4882a593Smuzhiyun 	int ret;
143*4882a593Smuzhiyun 
144*4882a593Smuzhiyun 	bpf = bpf_lwt_lwtunnel(dst->lwtstate);
145*4882a593Smuzhiyun 	if (bpf->out.prog) {
146*4882a593Smuzhiyun 		ret = run_lwt_bpf(skb, &bpf->out, dst, NO_REDIRECT);
147*4882a593Smuzhiyun 		if (ret < 0)
148*4882a593Smuzhiyun 			return ret;
149*4882a593Smuzhiyun 	}
150*4882a593Smuzhiyun 
151*4882a593Smuzhiyun 	if (unlikely(!dst->lwtstate->orig_output)) {
152*4882a593Smuzhiyun 		pr_warn_once("orig_output not set on dst for prog %s\n",
153*4882a593Smuzhiyun 			     bpf->out.name);
154*4882a593Smuzhiyun 		kfree_skb(skb);
155*4882a593Smuzhiyun 		return -EINVAL;
156*4882a593Smuzhiyun 	}
157*4882a593Smuzhiyun 
158*4882a593Smuzhiyun 	return dst->lwtstate->orig_output(net, sk, skb);
159*4882a593Smuzhiyun }
160*4882a593Smuzhiyun 
xmit_check_hhlen(struct sk_buff * skb,int hh_len)161*4882a593Smuzhiyun static int xmit_check_hhlen(struct sk_buff *skb, int hh_len)
162*4882a593Smuzhiyun {
163*4882a593Smuzhiyun 	if (skb_headroom(skb) < hh_len) {
164*4882a593Smuzhiyun 		int nhead = HH_DATA_ALIGN(hh_len - skb_headroom(skb));
165*4882a593Smuzhiyun 
166*4882a593Smuzhiyun 		if (pskb_expand_head(skb, nhead, 0, GFP_ATOMIC))
167*4882a593Smuzhiyun 			return -ENOMEM;
168*4882a593Smuzhiyun 	}
169*4882a593Smuzhiyun 
170*4882a593Smuzhiyun 	return 0;
171*4882a593Smuzhiyun }
172*4882a593Smuzhiyun 
bpf_lwt_xmit_reroute(struct sk_buff * skb)173*4882a593Smuzhiyun static int bpf_lwt_xmit_reroute(struct sk_buff *skb)
174*4882a593Smuzhiyun {
175*4882a593Smuzhiyun 	struct net_device *l3mdev = l3mdev_master_dev_rcu(skb_dst(skb)->dev);
176*4882a593Smuzhiyun 	int oif = l3mdev ? l3mdev->ifindex : 0;
177*4882a593Smuzhiyun 	struct dst_entry *dst = NULL;
178*4882a593Smuzhiyun 	int err = -EAFNOSUPPORT;
179*4882a593Smuzhiyun 	struct sock *sk;
180*4882a593Smuzhiyun 	struct net *net;
181*4882a593Smuzhiyun 	bool ipv4;
182*4882a593Smuzhiyun 
183*4882a593Smuzhiyun 	if (skb->protocol == htons(ETH_P_IP))
184*4882a593Smuzhiyun 		ipv4 = true;
185*4882a593Smuzhiyun 	else if (skb->protocol == htons(ETH_P_IPV6))
186*4882a593Smuzhiyun 		ipv4 = false;
187*4882a593Smuzhiyun 	else
188*4882a593Smuzhiyun 		goto err;
189*4882a593Smuzhiyun 
190*4882a593Smuzhiyun 	sk = sk_to_full_sk(skb->sk);
191*4882a593Smuzhiyun 	if (sk) {
192*4882a593Smuzhiyun 		if (sk->sk_bound_dev_if)
193*4882a593Smuzhiyun 			oif = sk->sk_bound_dev_if;
194*4882a593Smuzhiyun 		net = sock_net(sk);
195*4882a593Smuzhiyun 	} else {
196*4882a593Smuzhiyun 		net = dev_net(skb_dst(skb)->dev);
197*4882a593Smuzhiyun 	}
198*4882a593Smuzhiyun 
199*4882a593Smuzhiyun 	if (ipv4) {
200*4882a593Smuzhiyun 		struct iphdr *iph = ip_hdr(skb);
201*4882a593Smuzhiyun 		struct flowi4 fl4 = {};
202*4882a593Smuzhiyun 		struct rtable *rt;
203*4882a593Smuzhiyun 
204*4882a593Smuzhiyun 		fl4.flowi4_oif = oif;
205*4882a593Smuzhiyun 		fl4.flowi4_mark = skb->mark;
206*4882a593Smuzhiyun 		fl4.flowi4_uid = sock_net_uid(net, sk);
207*4882a593Smuzhiyun 		fl4.flowi4_tos = RT_TOS(iph->tos);
208*4882a593Smuzhiyun 		fl4.flowi4_flags = FLOWI_FLAG_ANYSRC;
209*4882a593Smuzhiyun 		fl4.flowi4_proto = iph->protocol;
210*4882a593Smuzhiyun 		fl4.daddr = iph->daddr;
211*4882a593Smuzhiyun 		fl4.saddr = iph->saddr;
212*4882a593Smuzhiyun 
213*4882a593Smuzhiyun 		rt = ip_route_output_key(net, &fl4);
214*4882a593Smuzhiyun 		if (IS_ERR(rt)) {
215*4882a593Smuzhiyun 			err = PTR_ERR(rt);
216*4882a593Smuzhiyun 			goto err;
217*4882a593Smuzhiyun 		}
218*4882a593Smuzhiyun 		dst = &rt->dst;
219*4882a593Smuzhiyun 	} else {
220*4882a593Smuzhiyun 		struct ipv6hdr *iph6 = ipv6_hdr(skb);
221*4882a593Smuzhiyun 		struct flowi6 fl6 = {};
222*4882a593Smuzhiyun 
223*4882a593Smuzhiyun 		fl6.flowi6_oif = oif;
224*4882a593Smuzhiyun 		fl6.flowi6_mark = skb->mark;
225*4882a593Smuzhiyun 		fl6.flowi6_uid = sock_net_uid(net, sk);
226*4882a593Smuzhiyun 		fl6.flowlabel = ip6_flowinfo(iph6);
227*4882a593Smuzhiyun 		fl6.flowi6_proto = iph6->nexthdr;
228*4882a593Smuzhiyun 		fl6.daddr = iph6->daddr;
229*4882a593Smuzhiyun 		fl6.saddr = iph6->saddr;
230*4882a593Smuzhiyun 
231*4882a593Smuzhiyun 		dst = ipv6_stub->ipv6_dst_lookup_flow(net, skb->sk, &fl6, NULL);
232*4882a593Smuzhiyun 		if (IS_ERR(dst)) {
233*4882a593Smuzhiyun 			err = PTR_ERR(dst);
234*4882a593Smuzhiyun 			goto err;
235*4882a593Smuzhiyun 		}
236*4882a593Smuzhiyun 	}
237*4882a593Smuzhiyun 	if (unlikely(dst->error)) {
238*4882a593Smuzhiyun 		err = dst->error;
239*4882a593Smuzhiyun 		dst_release(dst);
240*4882a593Smuzhiyun 		goto err;
241*4882a593Smuzhiyun 	}
242*4882a593Smuzhiyun 
243*4882a593Smuzhiyun 	/* Although skb header was reserved in bpf_lwt_push_ip_encap(), it
244*4882a593Smuzhiyun 	 * was done for the previous dst, so we are doing it here again, in
245*4882a593Smuzhiyun 	 * case the new dst needs much more space. The call below is a noop
246*4882a593Smuzhiyun 	 * if there is enough header space in skb.
247*4882a593Smuzhiyun 	 */
248*4882a593Smuzhiyun 	err = skb_cow_head(skb, LL_RESERVED_SPACE(dst->dev));
249*4882a593Smuzhiyun 	if (unlikely(err))
250*4882a593Smuzhiyun 		goto err;
251*4882a593Smuzhiyun 
252*4882a593Smuzhiyun 	skb_dst_drop(skb);
253*4882a593Smuzhiyun 	skb_dst_set(skb, dst);
254*4882a593Smuzhiyun 
255*4882a593Smuzhiyun 	err = dst_output(dev_net(skb_dst(skb)->dev), skb->sk, skb);
256*4882a593Smuzhiyun 	if (unlikely(err))
257*4882a593Smuzhiyun 		return err;
258*4882a593Smuzhiyun 
259*4882a593Smuzhiyun 	/* ip[6]_finish_output2 understand LWTUNNEL_XMIT_DONE */
260*4882a593Smuzhiyun 	return LWTUNNEL_XMIT_DONE;
261*4882a593Smuzhiyun 
262*4882a593Smuzhiyun err:
263*4882a593Smuzhiyun 	kfree_skb(skb);
264*4882a593Smuzhiyun 	return err;
265*4882a593Smuzhiyun }
266*4882a593Smuzhiyun 
bpf_xmit(struct sk_buff * skb)267*4882a593Smuzhiyun static int bpf_xmit(struct sk_buff *skb)
268*4882a593Smuzhiyun {
269*4882a593Smuzhiyun 	struct dst_entry *dst = skb_dst(skb);
270*4882a593Smuzhiyun 	struct bpf_lwt *bpf;
271*4882a593Smuzhiyun 
272*4882a593Smuzhiyun 	bpf = bpf_lwt_lwtunnel(dst->lwtstate);
273*4882a593Smuzhiyun 	if (bpf->xmit.prog) {
274*4882a593Smuzhiyun 		int hh_len = dst->dev->hard_header_len;
275*4882a593Smuzhiyun 		__be16 proto = skb->protocol;
276*4882a593Smuzhiyun 		int ret;
277*4882a593Smuzhiyun 
278*4882a593Smuzhiyun 		ret = run_lwt_bpf(skb, &bpf->xmit, dst, CAN_REDIRECT);
279*4882a593Smuzhiyun 		switch (ret) {
280*4882a593Smuzhiyun 		case BPF_OK:
281*4882a593Smuzhiyun 			/* If the header changed, e.g. via bpf_lwt_push_encap,
282*4882a593Smuzhiyun 			 * BPF_LWT_REROUTE below should have been used if the
283*4882a593Smuzhiyun 			 * protocol was also changed.
284*4882a593Smuzhiyun 			 */
285*4882a593Smuzhiyun 			if (skb->protocol != proto) {
286*4882a593Smuzhiyun 				kfree_skb(skb);
287*4882a593Smuzhiyun 				return -EINVAL;
288*4882a593Smuzhiyun 			}
289*4882a593Smuzhiyun 			/* If the header was expanded, headroom might be too
290*4882a593Smuzhiyun 			 * small for L2 header to come, expand as needed.
291*4882a593Smuzhiyun 			 */
292*4882a593Smuzhiyun 			ret = xmit_check_hhlen(skb, hh_len);
293*4882a593Smuzhiyun 			if (unlikely(ret))
294*4882a593Smuzhiyun 				return ret;
295*4882a593Smuzhiyun 
296*4882a593Smuzhiyun 			return LWTUNNEL_XMIT_CONTINUE;
297*4882a593Smuzhiyun 		case BPF_REDIRECT:
298*4882a593Smuzhiyun 			return LWTUNNEL_XMIT_DONE;
299*4882a593Smuzhiyun 		case BPF_LWT_REROUTE:
300*4882a593Smuzhiyun 			return bpf_lwt_xmit_reroute(skb);
301*4882a593Smuzhiyun 		default:
302*4882a593Smuzhiyun 			return ret;
303*4882a593Smuzhiyun 		}
304*4882a593Smuzhiyun 	}
305*4882a593Smuzhiyun 
306*4882a593Smuzhiyun 	return LWTUNNEL_XMIT_CONTINUE;
307*4882a593Smuzhiyun }
308*4882a593Smuzhiyun 
bpf_lwt_prog_destroy(struct bpf_lwt_prog * prog)309*4882a593Smuzhiyun static void bpf_lwt_prog_destroy(struct bpf_lwt_prog *prog)
310*4882a593Smuzhiyun {
311*4882a593Smuzhiyun 	if (prog->prog)
312*4882a593Smuzhiyun 		bpf_prog_put(prog->prog);
313*4882a593Smuzhiyun 
314*4882a593Smuzhiyun 	kfree(prog->name);
315*4882a593Smuzhiyun }
316*4882a593Smuzhiyun 
bpf_destroy_state(struct lwtunnel_state * lwt)317*4882a593Smuzhiyun static void bpf_destroy_state(struct lwtunnel_state *lwt)
318*4882a593Smuzhiyun {
319*4882a593Smuzhiyun 	struct bpf_lwt *bpf = bpf_lwt_lwtunnel(lwt);
320*4882a593Smuzhiyun 
321*4882a593Smuzhiyun 	bpf_lwt_prog_destroy(&bpf->in);
322*4882a593Smuzhiyun 	bpf_lwt_prog_destroy(&bpf->out);
323*4882a593Smuzhiyun 	bpf_lwt_prog_destroy(&bpf->xmit);
324*4882a593Smuzhiyun }
325*4882a593Smuzhiyun 
326*4882a593Smuzhiyun static const struct nla_policy bpf_prog_policy[LWT_BPF_PROG_MAX + 1] = {
327*4882a593Smuzhiyun 	[LWT_BPF_PROG_FD]   = { .type = NLA_U32, },
328*4882a593Smuzhiyun 	[LWT_BPF_PROG_NAME] = { .type = NLA_NUL_STRING,
329*4882a593Smuzhiyun 				.len = MAX_PROG_NAME },
330*4882a593Smuzhiyun };
331*4882a593Smuzhiyun 
bpf_parse_prog(struct nlattr * attr,struct bpf_lwt_prog * prog,enum bpf_prog_type type)332*4882a593Smuzhiyun static int bpf_parse_prog(struct nlattr *attr, struct bpf_lwt_prog *prog,
333*4882a593Smuzhiyun 			  enum bpf_prog_type type)
334*4882a593Smuzhiyun {
335*4882a593Smuzhiyun 	struct nlattr *tb[LWT_BPF_PROG_MAX + 1];
336*4882a593Smuzhiyun 	struct bpf_prog *p;
337*4882a593Smuzhiyun 	int ret;
338*4882a593Smuzhiyun 	u32 fd;
339*4882a593Smuzhiyun 
340*4882a593Smuzhiyun 	ret = nla_parse_nested_deprecated(tb, LWT_BPF_PROG_MAX, attr,
341*4882a593Smuzhiyun 					  bpf_prog_policy, NULL);
342*4882a593Smuzhiyun 	if (ret < 0)
343*4882a593Smuzhiyun 		return ret;
344*4882a593Smuzhiyun 
345*4882a593Smuzhiyun 	if (!tb[LWT_BPF_PROG_FD] || !tb[LWT_BPF_PROG_NAME])
346*4882a593Smuzhiyun 		return -EINVAL;
347*4882a593Smuzhiyun 
348*4882a593Smuzhiyun 	prog->name = nla_memdup(tb[LWT_BPF_PROG_NAME], GFP_ATOMIC);
349*4882a593Smuzhiyun 	if (!prog->name)
350*4882a593Smuzhiyun 		return -ENOMEM;
351*4882a593Smuzhiyun 
352*4882a593Smuzhiyun 	fd = nla_get_u32(tb[LWT_BPF_PROG_FD]);
353*4882a593Smuzhiyun 	p = bpf_prog_get_type(fd, type);
354*4882a593Smuzhiyun 	if (IS_ERR(p))
355*4882a593Smuzhiyun 		return PTR_ERR(p);
356*4882a593Smuzhiyun 
357*4882a593Smuzhiyun 	prog->prog = p;
358*4882a593Smuzhiyun 
359*4882a593Smuzhiyun 	return 0;
360*4882a593Smuzhiyun }
361*4882a593Smuzhiyun 
362*4882a593Smuzhiyun static const struct nla_policy bpf_nl_policy[LWT_BPF_MAX + 1] = {
363*4882a593Smuzhiyun 	[LWT_BPF_IN]		= { .type = NLA_NESTED, },
364*4882a593Smuzhiyun 	[LWT_BPF_OUT]		= { .type = NLA_NESTED, },
365*4882a593Smuzhiyun 	[LWT_BPF_XMIT]		= { .type = NLA_NESTED, },
366*4882a593Smuzhiyun 	[LWT_BPF_XMIT_HEADROOM]	= { .type = NLA_U32 },
367*4882a593Smuzhiyun };
368*4882a593Smuzhiyun 
bpf_build_state(struct net * net,struct nlattr * nla,unsigned int family,const void * cfg,struct lwtunnel_state ** ts,struct netlink_ext_ack * extack)369*4882a593Smuzhiyun static int bpf_build_state(struct net *net, struct nlattr *nla,
370*4882a593Smuzhiyun 			   unsigned int family, const void *cfg,
371*4882a593Smuzhiyun 			   struct lwtunnel_state **ts,
372*4882a593Smuzhiyun 			   struct netlink_ext_ack *extack)
373*4882a593Smuzhiyun {
374*4882a593Smuzhiyun 	struct nlattr *tb[LWT_BPF_MAX + 1];
375*4882a593Smuzhiyun 	struct lwtunnel_state *newts;
376*4882a593Smuzhiyun 	struct bpf_lwt *bpf;
377*4882a593Smuzhiyun 	int ret;
378*4882a593Smuzhiyun 
379*4882a593Smuzhiyun 	if (family != AF_INET && family != AF_INET6)
380*4882a593Smuzhiyun 		return -EAFNOSUPPORT;
381*4882a593Smuzhiyun 
382*4882a593Smuzhiyun 	ret = nla_parse_nested_deprecated(tb, LWT_BPF_MAX, nla, bpf_nl_policy,
383*4882a593Smuzhiyun 					  extack);
384*4882a593Smuzhiyun 	if (ret < 0)
385*4882a593Smuzhiyun 		return ret;
386*4882a593Smuzhiyun 
387*4882a593Smuzhiyun 	if (!tb[LWT_BPF_IN] && !tb[LWT_BPF_OUT] && !tb[LWT_BPF_XMIT])
388*4882a593Smuzhiyun 		return -EINVAL;
389*4882a593Smuzhiyun 
390*4882a593Smuzhiyun 	newts = lwtunnel_state_alloc(sizeof(*bpf));
391*4882a593Smuzhiyun 	if (!newts)
392*4882a593Smuzhiyun 		return -ENOMEM;
393*4882a593Smuzhiyun 
394*4882a593Smuzhiyun 	newts->type = LWTUNNEL_ENCAP_BPF;
395*4882a593Smuzhiyun 	bpf = bpf_lwt_lwtunnel(newts);
396*4882a593Smuzhiyun 
397*4882a593Smuzhiyun 	if (tb[LWT_BPF_IN]) {
398*4882a593Smuzhiyun 		newts->flags |= LWTUNNEL_STATE_INPUT_REDIRECT;
399*4882a593Smuzhiyun 		ret = bpf_parse_prog(tb[LWT_BPF_IN], &bpf->in,
400*4882a593Smuzhiyun 				     BPF_PROG_TYPE_LWT_IN);
401*4882a593Smuzhiyun 		if (ret  < 0)
402*4882a593Smuzhiyun 			goto errout;
403*4882a593Smuzhiyun 	}
404*4882a593Smuzhiyun 
405*4882a593Smuzhiyun 	if (tb[LWT_BPF_OUT]) {
406*4882a593Smuzhiyun 		newts->flags |= LWTUNNEL_STATE_OUTPUT_REDIRECT;
407*4882a593Smuzhiyun 		ret = bpf_parse_prog(tb[LWT_BPF_OUT], &bpf->out,
408*4882a593Smuzhiyun 				     BPF_PROG_TYPE_LWT_OUT);
409*4882a593Smuzhiyun 		if (ret < 0)
410*4882a593Smuzhiyun 			goto errout;
411*4882a593Smuzhiyun 	}
412*4882a593Smuzhiyun 
413*4882a593Smuzhiyun 	if (tb[LWT_BPF_XMIT]) {
414*4882a593Smuzhiyun 		newts->flags |= LWTUNNEL_STATE_XMIT_REDIRECT;
415*4882a593Smuzhiyun 		ret = bpf_parse_prog(tb[LWT_BPF_XMIT], &bpf->xmit,
416*4882a593Smuzhiyun 				     BPF_PROG_TYPE_LWT_XMIT);
417*4882a593Smuzhiyun 		if (ret < 0)
418*4882a593Smuzhiyun 			goto errout;
419*4882a593Smuzhiyun 	}
420*4882a593Smuzhiyun 
421*4882a593Smuzhiyun 	if (tb[LWT_BPF_XMIT_HEADROOM]) {
422*4882a593Smuzhiyun 		u32 headroom = nla_get_u32(tb[LWT_BPF_XMIT_HEADROOM]);
423*4882a593Smuzhiyun 
424*4882a593Smuzhiyun 		if (headroom > LWT_BPF_MAX_HEADROOM) {
425*4882a593Smuzhiyun 			ret = -ERANGE;
426*4882a593Smuzhiyun 			goto errout;
427*4882a593Smuzhiyun 		}
428*4882a593Smuzhiyun 
429*4882a593Smuzhiyun 		newts->headroom = headroom;
430*4882a593Smuzhiyun 	}
431*4882a593Smuzhiyun 
432*4882a593Smuzhiyun 	bpf->family = family;
433*4882a593Smuzhiyun 	*ts = newts;
434*4882a593Smuzhiyun 
435*4882a593Smuzhiyun 	return 0;
436*4882a593Smuzhiyun 
437*4882a593Smuzhiyun errout:
438*4882a593Smuzhiyun 	bpf_destroy_state(newts);
439*4882a593Smuzhiyun 	kfree(newts);
440*4882a593Smuzhiyun 	return ret;
441*4882a593Smuzhiyun }
442*4882a593Smuzhiyun 
bpf_fill_lwt_prog(struct sk_buff * skb,int attr,struct bpf_lwt_prog * prog)443*4882a593Smuzhiyun static int bpf_fill_lwt_prog(struct sk_buff *skb, int attr,
444*4882a593Smuzhiyun 			     struct bpf_lwt_prog *prog)
445*4882a593Smuzhiyun {
446*4882a593Smuzhiyun 	struct nlattr *nest;
447*4882a593Smuzhiyun 
448*4882a593Smuzhiyun 	if (!prog->prog)
449*4882a593Smuzhiyun 		return 0;
450*4882a593Smuzhiyun 
451*4882a593Smuzhiyun 	nest = nla_nest_start_noflag(skb, attr);
452*4882a593Smuzhiyun 	if (!nest)
453*4882a593Smuzhiyun 		return -EMSGSIZE;
454*4882a593Smuzhiyun 
455*4882a593Smuzhiyun 	if (prog->name &&
456*4882a593Smuzhiyun 	    nla_put_string(skb, LWT_BPF_PROG_NAME, prog->name))
457*4882a593Smuzhiyun 		return -EMSGSIZE;
458*4882a593Smuzhiyun 
459*4882a593Smuzhiyun 	return nla_nest_end(skb, nest);
460*4882a593Smuzhiyun }
461*4882a593Smuzhiyun 
bpf_fill_encap_info(struct sk_buff * skb,struct lwtunnel_state * lwt)462*4882a593Smuzhiyun static int bpf_fill_encap_info(struct sk_buff *skb, struct lwtunnel_state *lwt)
463*4882a593Smuzhiyun {
464*4882a593Smuzhiyun 	struct bpf_lwt *bpf = bpf_lwt_lwtunnel(lwt);
465*4882a593Smuzhiyun 
466*4882a593Smuzhiyun 	if (bpf_fill_lwt_prog(skb, LWT_BPF_IN, &bpf->in) < 0 ||
467*4882a593Smuzhiyun 	    bpf_fill_lwt_prog(skb, LWT_BPF_OUT, &bpf->out) < 0 ||
468*4882a593Smuzhiyun 	    bpf_fill_lwt_prog(skb, LWT_BPF_XMIT, &bpf->xmit) < 0)
469*4882a593Smuzhiyun 		return -EMSGSIZE;
470*4882a593Smuzhiyun 
471*4882a593Smuzhiyun 	return 0;
472*4882a593Smuzhiyun }
473*4882a593Smuzhiyun 
bpf_encap_nlsize(struct lwtunnel_state * lwtstate)474*4882a593Smuzhiyun static int bpf_encap_nlsize(struct lwtunnel_state *lwtstate)
475*4882a593Smuzhiyun {
476*4882a593Smuzhiyun 	int nest_len = nla_total_size(sizeof(struct nlattr)) +
477*4882a593Smuzhiyun 		       nla_total_size(MAX_PROG_NAME) + /* LWT_BPF_PROG_NAME */
478*4882a593Smuzhiyun 		       0;
479*4882a593Smuzhiyun 
480*4882a593Smuzhiyun 	return nest_len + /* LWT_BPF_IN */
481*4882a593Smuzhiyun 	       nest_len + /* LWT_BPF_OUT */
482*4882a593Smuzhiyun 	       nest_len + /* LWT_BPF_XMIT */
483*4882a593Smuzhiyun 	       0;
484*4882a593Smuzhiyun }
485*4882a593Smuzhiyun 
bpf_lwt_prog_cmp(struct bpf_lwt_prog * a,struct bpf_lwt_prog * b)486*4882a593Smuzhiyun static int bpf_lwt_prog_cmp(struct bpf_lwt_prog *a, struct bpf_lwt_prog *b)
487*4882a593Smuzhiyun {
488*4882a593Smuzhiyun 	/* FIXME:
489*4882a593Smuzhiyun 	 * The LWT state is currently rebuilt for delete requests which
490*4882a593Smuzhiyun 	 * results in a new bpf_prog instance. Comparing names for now.
491*4882a593Smuzhiyun 	 */
492*4882a593Smuzhiyun 	if (!a->name && !b->name)
493*4882a593Smuzhiyun 		return 0;
494*4882a593Smuzhiyun 
495*4882a593Smuzhiyun 	if (!a->name || !b->name)
496*4882a593Smuzhiyun 		return 1;
497*4882a593Smuzhiyun 
498*4882a593Smuzhiyun 	return strcmp(a->name, b->name);
499*4882a593Smuzhiyun }
500*4882a593Smuzhiyun 
bpf_encap_cmp(struct lwtunnel_state * a,struct lwtunnel_state * b)501*4882a593Smuzhiyun static int bpf_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b)
502*4882a593Smuzhiyun {
503*4882a593Smuzhiyun 	struct bpf_lwt *a_bpf = bpf_lwt_lwtunnel(a);
504*4882a593Smuzhiyun 	struct bpf_lwt *b_bpf = bpf_lwt_lwtunnel(b);
505*4882a593Smuzhiyun 
506*4882a593Smuzhiyun 	return bpf_lwt_prog_cmp(&a_bpf->in, &b_bpf->in) ||
507*4882a593Smuzhiyun 	       bpf_lwt_prog_cmp(&a_bpf->out, &b_bpf->out) ||
508*4882a593Smuzhiyun 	       bpf_lwt_prog_cmp(&a_bpf->xmit, &b_bpf->xmit);
509*4882a593Smuzhiyun }
510*4882a593Smuzhiyun 
511*4882a593Smuzhiyun static const struct lwtunnel_encap_ops bpf_encap_ops = {
512*4882a593Smuzhiyun 	.build_state	= bpf_build_state,
513*4882a593Smuzhiyun 	.destroy_state	= bpf_destroy_state,
514*4882a593Smuzhiyun 	.input		= bpf_input,
515*4882a593Smuzhiyun 	.output		= bpf_output,
516*4882a593Smuzhiyun 	.xmit		= bpf_xmit,
517*4882a593Smuzhiyun 	.fill_encap	= bpf_fill_encap_info,
518*4882a593Smuzhiyun 	.get_encap_size = bpf_encap_nlsize,
519*4882a593Smuzhiyun 	.cmp_encap	= bpf_encap_cmp,
520*4882a593Smuzhiyun 	.owner		= THIS_MODULE,
521*4882a593Smuzhiyun };
522*4882a593Smuzhiyun 
handle_gso_type(struct sk_buff * skb,unsigned int gso_type,int encap_len)523*4882a593Smuzhiyun static int handle_gso_type(struct sk_buff *skb, unsigned int gso_type,
524*4882a593Smuzhiyun 			   int encap_len)
525*4882a593Smuzhiyun {
526*4882a593Smuzhiyun 	struct skb_shared_info *shinfo = skb_shinfo(skb);
527*4882a593Smuzhiyun 
528*4882a593Smuzhiyun 	gso_type |= SKB_GSO_DODGY;
529*4882a593Smuzhiyun 	shinfo->gso_type |= gso_type;
530*4882a593Smuzhiyun 	skb_decrease_gso_size(shinfo, encap_len);
531*4882a593Smuzhiyun 	shinfo->gso_segs = 0;
532*4882a593Smuzhiyun 	return 0;
533*4882a593Smuzhiyun }
534*4882a593Smuzhiyun 
handle_gso_encap(struct sk_buff * skb,bool ipv4,int encap_len)535*4882a593Smuzhiyun static int handle_gso_encap(struct sk_buff *skb, bool ipv4, int encap_len)
536*4882a593Smuzhiyun {
537*4882a593Smuzhiyun 	int next_hdr_offset;
538*4882a593Smuzhiyun 	void *next_hdr;
539*4882a593Smuzhiyun 	__u8 protocol;
540*4882a593Smuzhiyun 
541*4882a593Smuzhiyun 	/* SCTP and UDP_L4 gso need more nuanced handling than what
542*4882a593Smuzhiyun 	 * handle_gso_type() does above: skb_decrease_gso_size() is not enough.
543*4882a593Smuzhiyun 	 * So at the moment only TCP GSO packets are let through.
544*4882a593Smuzhiyun 	 */
545*4882a593Smuzhiyun 	if (!(skb_shinfo(skb)->gso_type & (SKB_GSO_TCPV4 | SKB_GSO_TCPV6)))
546*4882a593Smuzhiyun 		return -ENOTSUPP;
547*4882a593Smuzhiyun 
548*4882a593Smuzhiyun 	if (ipv4) {
549*4882a593Smuzhiyun 		protocol = ip_hdr(skb)->protocol;
550*4882a593Smuzhiyun 		next_hdr_offset = sizeof(struct iphdr);
551*4882a593Smuzhiyun 		next_hdr = skb_network_header(skb) + next_hdr_offset;
552*4882a593Smuzhiyun 	} else {
553*4882a593Smuzhiyun 		protocol = ipv6_hdr(skb)->nexthdr;
554*4882a593Smuzhiyun 		next_hdr_offset = sizeof(struct ipv6hdr);
555*4882a593Smuzhiyun 		next_hdr = skb_network_header(skb) + next_hdr_offset;
556*4882a593Smuzhiyun 	}
557*4882a593Smuzhiyun 
558*4882a593Smuzhiyun 	switch (protocol) {
559*4882a593Smuzhiyun 	case IPPROTO_GRE:
560*4882a593Smuzhiyun 		next_hdr_offset += sizeof(struct gre_base_hdr);
561*4882a593Smuzhiyun 		if (next_hdr_offset > encap_len)
562*4882a593Smuzhiyun 			return -EINVAL;
563*4882a593Smuzhiyun 
564*4882a593Smuzhiyun 		if (((struct gre_base_hdr *)next_hdr)->flags & GRE_CSUM)
565*4882a593Smuzhiyun 			return handle_gso_type(skb, SKB_GSO_GRE_CSUM,
566*4882a593Smuzhiyun 					       encap_len);
567*4882a593Smuzhiyun 		return handle_gso_type(skb, SKB_GSO_GRE, encap_len);
568*4882a593Smuzhiyun 
569*4882a593Smuzhiyun 	case IPPROTO_UDP:
570*4882a593Smuzhiyun 		next_hdr_offset += sizeof(struct udphdr);
571*4882a593Smuzhiyun 		if (next_hdr_offset > encap_len)
572*4882a593Smuzhiyun 			return -EINVAL;
573*4882a593Smuzhiyun 
574*4882a593Smuzhiyun 		if (((struct udphdr *)next_hdr)->check)
575*4882a593Smuzhiyun 			return handle_gso_type(skb, SKB_GSO_UDP_TUNNEL_CSUM,
576*4882a593Smuzhiyun 					       encap_len);
577*4882a593Smuzhiyun 		return handle_gso_type(skb, SKB_GSO_UDP_TUNNEL, encap_len);
578*4882a593Smuzhiyun 
579*4882a593Smuzhiyun 	case IPPROTO_IP:
580*4882a593Smuzhiyun 	case IPPROTO_IPV6:
581*4882a593Smuzhiyun 		if (ipv4)
582*4882a593Smuzhiyun 			return handle_gso_type(skb, SKB_GSO_IPXIP4, encap_len);
583*4882a593Smuzhiyun 		else
584*4882a593Smuzhiyun 			return handle_gso_type(skb, SKB_GSO_IPXIP6, encap_len);
585*4882a593Smuzhiyun 
586*4882a593Smuzhiyun 	default:
587*4882a593Smuzhiyun 		return -EPROTONOSUPPORT;
588*4882a593Smuzhiyun 	}
589*4882a593Smuzhiyun }
590*4882a593Smuzhiyun 
bpf_lwt_push_ip_encap(struct sk_buff * skb,void * hdr,u32 len,bool ingress)591*4882a593Smuzhiyun int bpf_lwt_push_ip_encap(struct sk_buff *skb, void *hdr, u32 len, bool ingress)
592*4882a593Smuzhiyun {
593*4882a593Smuzhiyun 	struct iphdr *iph;
594*4882a593Smuzhiyun 	bool ipv4;
595*4882a593Smuzhiyun 	int err;
596*4882a593Smuzhiyun 
597*4882a593Smuzhiyun 	if (unlikely(len < sizeof(struct iphdr) || len > LWT_BPF_MAX_HEADROOM))
598*4882a593Smuzhiyun 		return -EINVAL;
599*4882a593Smuzhiyun 
600*4882a593Smuzhiyun 	/* validate protocol and length */
601*4882a593Smuzhiyun 	iph = (struct iphdr *)hdr;
602*4882a593Smuzhiyun 	if (iph->version == 4) {
603*4882a593Smuzhiyun 		ipv4 = true;
604*4882a593Smuzhiyun 		if (unlikely(len < iph->ihl * 4))
605*4882a593Smuzhiyun 			return -EINVAL;
606*4882a593Smuzhiyun 	} else if (iph->version == 6) {
607*4882a593Smuzhiyun 		ipv4 = false;
608*4882a593Smuzhiyun 		if (unlikely(len < sizeof(struct ipv6hdr)))
609*4882a593Smuzhiyun 			return -EINVAL;
610*4882a593Smuzhiyun 	} else {
611*4882a593Smuzhiyun 		return -EINVAL;
612*4882a593Smuzhiyun 	}
613*4882a593Smuzhiyun 
614*4882a593Smuzhiyun 	if (ingress)
615*4882a593Smuzhiyun 		err = skb_cow_head(skb, len + skb->mac_len);
616*4882a593Smuzhiyun 	else
617*4882a593Smuzhiyun 		err = skb_cow_head(skb,
618*4882a593Smuzhiyun 				   len + LL_RESERVED_SPACE(skb_dst(skb)->dev));
619*4882a593Smuzhiyun 	if (unlikely(err))
620*4882a593Smuzhiyun 		return err;
621*4882a593Smuzhiyun 
622*4882a593Smuzhiyun 	/* push the encap headers and fix pointers */
623*4882a593Smuzhiyun 	skb_reset_inner_headers(skb);
624*4882a593Smuzhiyun 	skb_reset_inner_mac_header(skb);  /* mac header is not yet set */
625*4882a593Smuzhiyun 	skb_set_inner_protocol(skb, skb->protocol);
626*4882a593Smuzhiyun 	skb->encapsulation = 1;
627*4882a593Smuzhiyun 	skb_push(skb, len);
628*4882a593Smuzhiyun 	if (ingress)
629*4882a593Smuzhiyun 		skb_postpush_rcsum(skb, iph, len);
630*4882a593Smuzhiyun 	skb_reset_network_header(skb);
631*4882a593Smuzhiyun 	memcpy(skb_network_header(skb), hdr, len);
632*4882a593Smuzhiyun 	bpf_compute_data_pointers(skb);
633*4882a593Smuzhiyun 	skb_clear_hash(skb);
634*4882a593Smuzhiyun 
635*4882a593Smuzhiyun 	if (ipv4) {
636*4882a593Smuzhiyun 		skb->protocol = htons(ETH_P_IP);
637*4882a593Smuzhiyun 		iph = ip_hdr(skb);
638*4882a593Smuzhiyun 
639*4882a593Smuzhiyun 		if (!iph->check)
640*4882a593Smuzhiyun 			iph->check = ip_fast_csum((unsigned char *)iph,
641*4882a593Smuzhiyun 						  iph->ihl);
642*4882a593Smuzhiyun 	} else {
643*4882a593Smuzhiyun 		skb->protocol = htons(ETH_P_IPV6);
644*4882a593Smuzhiyun 	}
645*4882a593Smuzhiyun 
646*4882a593Smuzhiyun 	if (skb_is_gso(skb))
647*4882a593Smuzhiyun 		return handle_gso_encap(skb, ipv4, len);
648*4882a593Smuzhiyun 
649*4882a593Smuzhiyun 	return 0;
650*4882a593Smuzhiyun }
651*4882a593Smuzhiyun 
bpf_lwt_init(void)652*4882a593Smuzhiyun static int __init bpf_lwt_init(void)
653*4882a593Smuzhiyun {
654*4882a593Smuzhiyun 	return lwtunnel_encap_add_ops(&bpf_encap_ops, LWTUNNEL_ENCAP_BPF);
655*4882a593Smuzhiyun }
656*4882a593Smuzhiyun 
657*4882a593Smuzhiyun subsys_initcall(bpf_lwt_init)
658