1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0
2*4882a593Smuzhiyun /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3*4882a593Smuzhiyun
4*4882a593Smuzhiyun #include <linux/skmsg.h>
5*4882a593Smuzhiyun #include <linux/skbuff.h>
6*4882a593Smuzhiyun #include <linux/scatterlist.h>
7*4882a593Smuzhiyun
8*4882a593Smuzhiyun #include <net/sock.h>
9*4882a593Smuzhiyun #include <net/tcp.h>
10*4882a593Smuzhiyun #include <net/tls.h>
11*4882a593Smuzhiyun
sk_msg_try_coalesce_ok(struct sk_msg * msg,int elem_first_coalesce)12*4882a593Smuzhiyun static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
13*4882a593Smuzhiyun {
14*4882a593Smuzhiyun if (msg->sg.end > msg->sg.start &&
15*4882a593Smuzhiyun elem_first_coalesce < msg->sg.end)
16*4882a593Smuzhiyun return true;
17*4882a593Smuzhiyun
18*4882a593Smuzhiyun if (msg->sg.end < msg->sg.start &&
19*4882a593Smuzhiyun (elem_first_coalesce > msg->sg.start ||
20*4882a593Smuzhiyun elem_first_coalesce < msg->sg.end))
21*4882a593Smuzhiyun return true;
22*4882a593Smuzhiyun
23*4882a593Smuzhiyun return false;
24*4882a593Smuzhiyun }
25*4882a593Smuzhiyun
sk_msg_alloc(struct sock * sk,struct sk_msg * msg,int len,int elem_first_coalesce)26*4882a593Smuzhiyun int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
27*4882a593Smuzhiyun int elem_first_coalesce)
28*4882a593Smuzhiyun {
29*4882a593Smuzhiyun struct page_frag *pfrag = sk_page_frag(sk);
30*4882a593Smuzhiyun u32 osize = msg->sg.size;
31*4882a593Smuzhiyun int ret = 0;
32*4882a593Smuzhiyun
33*4882a593Smuzhiyun len -= msg->sg.size;
34*4882a593Smuzhiyun while (len > 0) {
35*4882a593Smuzhiyun struct scatterlist *sge;
36*4882a593Smuzhiyun u32 orig_offset;
37*4882a593Smuzhiyun int use, i;
38*4882a593Smuzhiyun
39*4882a593Smuzhiyun if (!sk_page_frag_refill(sk, pfrag)) {
40*4882a593Smuzhiyun ret = -ENOMEM;
41*4882a593Smuzhiyun goto msg_trim;
42*4882a593Smuzhiyun }
43*4882a593Smuzhiyun
44*4882a593Smuzhiyun orig_offset = pfrag->offset;
45*4882a593Smuzhiyun use = min_t(int, len, pfrag->size - orig_offset);
46*4882a593Smuzhiyun if (!sk_wmem_schedule(sk, use)) {
47*4882a593Smuzhiyun ret = -ENOMEM;
48*4882a593Smuzhiyun goto msg_trim;
49*4882a593Smuzhiyun }
50*4882a593Smuzhiyun
51*4882a593Smuzhiyun i = msg->sg.end;
52*4882a593Smuzhiyun sk_msg_iter_var_prev(i);
53*4882a593Smuzhiyun sge = &msg->sg.data[i];
54*4882a593Smuzhiyun
55*4882a593Smuzhiyun if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
56*4882a593Smuzhiyun sg_page(sge) == pfrag->page &&
57*4882a593Smuzhiyun sge->offset + sge->length == orig_offset) {
58*4882a593Smuzhiyun sge->length += use;
59*4882a593Smuzhiyun } else {
60*4882a593Smuzhiyun if (sk_msg_full(msg)) {
61*4882a593Smuzhiyun ret = -ENOSPC;
62*4882a593Smuzhiyun break;
63*4882a593Smuzhiyun }
64*4882a593Smuzhiyun
65*4882a593Smuzhiyun sge = &msg->sg.data[msg->sg.end];
66*4882a593Smuzhiyun sg_unmark_end(sge);
67*4882a593Smuzhiyun sg_set_page(sge, pfrag->page, use, orig_offset);
68*4882a593Smuzhiyun get_page(pfrag->page);
69*4882a593Smuzhiyun sk_msg_iter_next(msg, end);
70*4882a593Smuzhiyun }
71*4882a593Smuzhiyun
72*4882a593Smuzhiyun sk_mem_charge(sk, use);
73*4882a593Smuzhiyun msg->sg.size += use;
74*4882a593Smuzhiyun pfrag->offset += use;
75*4882a593Smuzhiyun len -= use;
76*4882a593Smuzhiyun }
77*4882a593Smuzhiyun
78*4882a593Smuzhiyun return ret;
79*4882a593Smuzhiyun
80*4882a593Smuzhiyun msg_trim:
81*4882a593Smuzhiyun sk_msg_trim(sk, msg, osize);
82*4882a593Smuzhiyun return ret;
83*4882a593Smuzhiyun }
84*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_msg_alloc);
85*4882a593Smuzhiyun
sk_msg_clone(struct sock * sk,struct sk_msg * dst,struct sk_msg * src,u32 off,u32 len)86*4882a593Smuzhiyun int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
87*4882a593Smuzhiyun u32 off, u32 len)
88*4882a593Smuzhiyun {
89*4882a593Smuzhiyun int i = src->sg.start;
90*4882a593Smuzhiyun struct scatterlist *sge = sk_msg_elem(src, i);
91*4882a593Smuzhiyun struct scatterlist *sgd = NULL;
92*4882a593Smuzhiyun u32 sge_len, sge_off;
93*4882a593Smuzhiyun
94*4882a593Smuzhiyun while (off) {
95*4882a593Smuzhiyun if (sge->length > off)
96*4882a593Smuzhiyun break;
97*4882a593Smuzhiyun off -= sge->length;
98*4882a593Smuzhiyun sk_msg_iter_var_next(i);
99*4882a593Smuzhiyun if (i == src->sg.end && off)
100*4882a593Smuzhiyun return -ENOSPC;
101*4882a593Smuzhiyun sge = sk_msg_elem(src, i);
102*4882a593Smuzhiyun }
103*4882a593Smuzhiyun
104*4882a593Smuzhiyun while (len) {
105*4882a593Smuzhiyun sge_len = sge->length - off;
106*4882a593Smuzhiyun if (sge_len > len)
107*4882a593Smuzhiyun sge_len = len;
108*4882a593Smuzhiyun
109*4882a593Smuzhiyun if (dst->sg.end)
110*4882a593Smuzhiyun sgd = sk_msg_elem(dst, dst->sg.end - 1);
111*4882a593Smuzhiyun
112*4882a593Smuzhiyun if (sgd &&
113*4882a593Smuzhiyun (sg_page(sge) == sg_page(sgd)) &&
114*4882a593Smuzhiyun (sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) {
115*4882a593Smuzhiyun sgd->length += sge_len;
116*4882a593Smuzhiyun dst->sg.size += sge_len;
117*4882a593Smuzhiyun } else if (!sk_msg_full(dst)) {
118*4882a593Smuzhiyun sge_off = sge->offset + off;
119*4882a593Smuzhiyun sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off);
120*4882a593Smuzhiyun } else {
121*4882a593Smuzhiyun return -ENOSPC;
122*4882a593Smuzhiyun }
123*4882a593Smuzhiyun
124*4882a593Smuzhiyun off = 0;
125*4882a593Smuzhiyun len -= sge_len;
126*4882a593Smuzhiyun sk_mem_charge(sk, sge_len);
127*4882a593Smuzhiyun sk_msg_iter_var_next(i);
128*4882a593Smuzhiyun if (i == src->sg.end && len)
129*4882a593Smuzhiyun return -ENOSPC;
130*4882a593Smuzhiyun sge = sk_msg_elem(src, i);
131*4882a593Smuzhiyun }
132*4882a593Smuzhiyun
133*4882a593Smuzhiyun return 0;
134*4882a593Smuzhiyun }
135*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_msg_clone);
136*4882a593Smuzhiyun
sk_msg_return_zero(struct sock * sk,struct sk_msg * msg,int bytes)137*4882a593Smuzhiyun void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
138*4882a593Smuzhiyun {
139*4882a593Smuzhiyun int i = msg->sg.start;
140*4882a593Smuzhiyun
141*4882a593Smuzhiyun do {
142*4882a593Smuzhiyun struct scatterlist *sge = sk_msg_elem(msg, i);
143*4882a593Smuzhiyun
144*4882a593Smuzhiyun if (bytes < sge->length) {
145*4882a593Smuzhiyun sge->length -= bytes;
146*4882a593Smuzhiyun sge->offset += bytes;
147*4882a593Smuzhiyun sk_mem_uncharge(sk, bytes);
148*4882a593Smuzhiyun break;
149*4882a593Smuzhiyun }
150*4882a593Smuzhiyun
151*4882a593Smuzhiyun sk_mem_uncharge(sk, sge->length);
152*4882a593Smuzhiyun bytes -= sge->length;
153*4882a593Smuzhiyun sge->length = 0;
154*4882a593Smuzhiyun sge->offset = 0;
155*4882a593Smuzhiyun sk_msg_iter_var_next(i);
156*4882a593Smuzhiyun } while (bytes && i != msg->sg.end);
157*4882a593Smuzhiyun msg->sg.start = i;
158*4882a593Smuzhiyun }
159*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_msg_return_zero);
160*4882a593Smuzhiyun
sk_msg_return(struct sock * sk,struct sk_msg * msg,int bytes)161*4882a593Smuzhiyun void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
162*4882a593Smuzhiyun {
163*4882a593Smuzhiyun int i = msg->sg.start;
164*4882a593Smuzhiyun
165*4882a593Smuzhiyun do {
166*4882a593Smuzhiyun struct scatterlist *sge = &msg->sg.data[i];
167*4882a593Smuzhiyun int uncharge = (bytes < sge->length) ? bytes : sge->length;
168*4882a593Smuzhiyun
169*4882a593Smuzhiyun sk_mem_uncharge(sk, uncharge);
170*4882a593Smuzhiyun bytes -= uncharge;
171*4882a593Smuzhiyun sk_msg_iter_var_next(i);
172*4882a593Smuzhiyun } while (i != msg->sg.end);
173*4882a593Smuzhiyun }
174*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_msg_return);
175*4882a593Smuzhiyun
sk_msg_free_elem(struct sock * sk,struct sk_msg * msg,u32 i,bool charge)176*4882a593Smuzhiyun static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
177*4882a593Smuzhiyun bool charge)
178*4882a593Smuzhiyun {
179*4882a593Smuzhiyun struct scatterlist *sge = sk_msg_elem(msg, i);
180*4882a593Smuzhiyun u32 len = sge->length;
181*4882a593Smuzhiyun
182*4882a593Smuzhiyun /* When the skb owns the memory we free it from consume_skb path. */
183*4882a593Smuzhiyun if (!msg->skb) {
184*4882a593Smuzhiyun if (charge)
185*4882a593Smuzhiyun sk_mem_uncharge(sk, len);
186*4882a593Smuzhiyun put_page(sg_page(sge));
187*4882a593Smuzhiyun }
188*4882a593Smuzhiyun memset(sge, 0, sizeof(*sge));
189*4882a593Smuzhiyun return len;
190*4882a593Smuzhiyun }
191*4882a593Smuzhiyun
__sk_msg_free(struct sock * sk,struct sk_msg * msg,u32 i,bool charge)192*4882a593Smuzhiyun static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i,
193*4882a593Smuzhiyun bool charge)
194*4882a593Smuzhiyun {
195*4882a593Smuzhiyun struct scatterlist *sge = sk_msg_elem(msg, i);
196*4882a593Smuzhiyun int freed = 0;
197*4882a593Smuzhiyun
198*4882a593Smuzhiyun while (msg->sg.size) {
199*4882a593Smuzhiyun msg->sg.size -= sge->length;
200*4882a593Smuzhiyun freed += sk_msg_free_elem(sk, msg, i, charge);
201*4882a593Smuzhiyun sk_msg_iter_var_next(i);
202*4882a593Smuzhiyun sk_msg_check_to_free(msg, i, msg->sg.size);
203*4882a593Smuzhiyun sge = sk_msg_elem(msg, i);
204*4882a593Smuzhiyun }
205*4882a593Smuzhiyun consume_skb(msg->skb);
206*4882a593Smuzhiyun sk_msg_init(msg);
207*4882a593Smuzhiyun return freed;
208*4882a593Smuzhiyun }
209*4882a593Smuzhiyun
sk_msg_free_nocharge(struct sock * sk,struct sk_msg * msg)210*4882a593Smuzhiyun int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg)
211*4882a593Smuzhiyun {
212*4882a593Smuzhiyun return __sk_msg_free(sk, msg, msg->sg.start, false);
213*4882a593Smuzhiyun }
214*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_msg_free_nocharge);
215*4882a593Smuzhiyun
sk_msg_free(struct sock * sk,struct sk_msg * msg)216*4882a593Smuzhiyun int sk_msg_free(struct sock *sk, struct sk_msg *msg)
217*4882a593Smuzhiyun {
218*4882a593Smuzhiyun return __sk_msg_free(sk, msg, msg->sg.start, true);
219*4882a593Smuzhiyun }
220*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_msg_free);
221*4882a593Smuzhiyun
__sk_msg_free_partial(struct sock * sk,struct sk_msg * msg,u32 bytes,bool charge)222*4882a593Smuzhiyun static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg,
223*4882a593Smuzhiyun u32 bytes, bool charge)
224*4882a593Smuzhiyun {
225*4882a593Smuzhiyun struct scatterlist *sge;
226*4882a593Smuzhiyun u32 i = msg->sg.start;
227*4882a593Smuzhiyun
228*4882a593Smuzhiyun while (bytes) {
229*4882a593Smuzhiyun sge = sk_msg_elem(msg, i);
230*4882a593Smuzhiyun if (!sge->length)
231*4882a593Smuzhiyun break;
232*4882a593Smuzhiyun if (bytes < sge->length) {
233*4882a593Smuzhiyun if (charge)
234*4882a593Smuzhiyun sk_mem_uncharge(sk, bytes);
235*4882a593Smuzhiyun sge->length -= bytes;
236*4882a593Smuzhiyun sge->offset += bytes;
237*4882a593Smuzhiyun msg->sg.size -= bytes;
238*4882a593Smuzhiyun break;
239*4882a593Smuzhiyun }
240*4882a593Smuzhiyun
241*4882a593Smuzhiyun msg->sg.size -= sge->length;
242*4882a593Smuzhiyun bytes -= sge->length;
243*4882a593Smuzhiyun sk_msg_free_elem(sk, msg, i, charge);
244*4882a593Smuzhiyun sk_msg_iter_var_next(i);
245*4882a593Smuzhiyun sk_msg_check_to_free(msg, i, bytes);
246*4882a593Smuzhiyun }
247*4882a593Smuzhiyun msg->sg.start = i;
248*4882a593Smuzhiyun }
249*4882a593Smuzhiyun
sk_msg_free_partial(struct sock * sk,struct sk_msg * msg,u32 bytes)250*4882a593Smuzhiyun void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes)
251*4882a593Smuzhiyun {
252*4882a593Smuzhiyun __sk_msg_free_partial(sk, msg, bytes, true);
253*4882a593Smuzhiyun }
254*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_msg_free_partial);
255*4882a593Smuzhiyun
sk_msg_free_partial_nocharge(struct sock * sk,struct sk_msg * msg,u32 bytes)256*4882a593Smuzhiyun void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
257*4882a593Smuzhiyun u32 bytes)
258*4882a593Smuzhiyun {
259*4882a593Smuzhiyun __sk_msg_free_partial(sk, msg, bytes, false);
260*4882a593Smuzhiyun }
261*4882a593Smuzhiyun
sk_msg_trim(struct sock * sk,struct sk_msg * msg,int len)262*4882a593Smuzhiyun void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
263*4882a593Smuzhiyun {
264*4882a593Smuzhiyun int trim = msg->sg.size - len;
265*4882a593Smuzhiyun u32 i = msg->sg.end;
266*4882a593Smuzhiyun
267*4882a593Smuzhiyun if (trim <= 0) {
268*4882a593Smuzhiyun WARN_ON(trim < 0);
269*4882a593Smuzhiyun return;
270*4882a593Smuzhiyun }
271*4882a593Smuzhiyun
272*4882a593Smuzhiyun sk_msg_iter_var_prev(i);
273*4882a593Smuzhiyun msg->sg.size = len;
274*4882a593Smuzhiyun while (msg->sg.data[i].length &&
275*4882a593Smuzhiyun trim >= msg->sg.data[i].length) {
276*4882a593Smuzhiyun trim -= msg->sg.data[i].length;
277*4882a593Smuzhiyun sk_msg_free_elem(sk, msg, i, true);
278*4882a593Smuzhiyun sk_msg_iter_var_prev(i);
279*4882a593Smuzhiyun if (!trim)
280*4882a593Smuzhiyun goto out;
281*4882a593Smuzhiyun }
282*4882a593Smuzhiyun
283*4882a593Smuzhiyun msg->sg.data[i].length -= trim;
284*4882a593Smuzhiyun sk_mem_uncharge(sk, trim);
285*4882a593Smuzhiyun /* Adjust copybreak if it falls into the trimmed part of last buf */
286*4882a593Smuzhiyun if (msg->sg.curr == i && msg->sg.copybreak > msg->sg.data[i].length)
287*4882a593Smuzhiyun msg->sg.copybreak = msg->sg.data[i].length;
288*4882a593Smuzhiyun out:
289*4882a593Smuzhiyun sk_msg_iter_var_next(i);
290*4882a593Smuzhiyun msg->sg.end = i;
291*4882a593Smuzhiyun
292*4882a593Smuzhiyun /* If we trim data a full sg elem before curr pointer update
293*4882a593Smuzhiyun * copybreak and current so that any future copy operations
294*4882a593Smuzhiyun * start at new copy location.
295*4882a593Smuzhiyun * However trimed data that has not yet been used in a copy op
296*4882a593Smuzhiyun * does not require an update.
297*4882a593Smuzhiyun */
298*4882a593Smuzhiyun if (!msg->sg.size) {
299*4882a593Smuzhiyun msg->sg.curr = msg->sg.start;
300*4882a593Smuzhiyun msg->sg.copybreak = 0;
301*4882a593Smuzhiyun } else if (sk_msg_iter_dist(msg->sg.start, msg->sg.curr) >=
302*4882a593Smuzhiyun sk_msg_iter_dist(msg->sg.start, msg->sg.end)) {
303*4882a593Smuzhiyun sk_msg_iter_var_prev(i);
304*4882a593Smuzhiyun msg->sg.curr = i;
305*4882a593Smuzhiyun msg->sg.copybreak = msg->sg.data[i].length;
306*4882a593Smuzhiyun }
307*4882a593Smuzhiyun }
308*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_msg_trim);
309*4882a593Smuzhiyun
sk_msg_zerocopy_from_iter(struct sock * sk,struct iov_iter * from,struct sk_msg * msg,u32 bytes)310*4882a593Smuzhiyun int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
311*4882a593Smuzhiyun struct sk_msg *msg, u32 bytes)
312*4882a593Smuzhiyun {
313*4882a593Smuzhiyun int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg);
314*4882a593Smuzhiyun const int to_max_pages = MAX_MSG_FRAGS;
315*4882a593Smuzhiyun struct page *pages[MAX_MSG_FRAGS];
316*4882a593Smuzhiyun ssize_t orig, copied, use, offset;
317*4882a593Smuzhiyun
318*4882a593Smuzhiyun orig = msg->sg.size;
319*4882a593Smuzhiyun while (bytes > 0) {
320*4882a593Smuzhiyun i = 0;
321*4882a593Smuzhiyun maxpages = to_max_pages - num_elems;
322*4882a593Smuzhiyun if (maxpages == 0) {
323*4882a593Smuzhiyun ret = -EFAULT;
324*4882a593Smuzhiyun goto out;
325*4882a593Smuzhiyun }
326*4882a593Smuzhiyun
327*4882a593Smuzhiyun copied = iov_iter_get_pages(from, pages, bytes, maxpages,
328*4882a593Smuzhiyun &offset);
329*4882a593Smuzhiyun if (copied <= 0) {
330*4882a593Smuzhiyun ret = -EFAULT;
331*4882a593Smuzhiyun goto out;
332*4882a593Smuzhiyun }
333*4882a593Smuzhiyun
334*4882a593Smuzhiyun iov_iter_advance(from, copied);
335*4882a593Smuzhiyun bytes -= copied;
336*4882a593Smuzhiyun msg->sg.size += copied;
337*4882a593Smuzhiyun
338*4882a593Smuzhiyun while (copied) {
339*4882a593Smuzhiyun use = min_t(int, copied, PAGE_SIZE - offset);
340*4882a593Smuzhiyun sg_set_page(&msg->sg.data[msg->sg.end],
341*4882a593Smuzhiyun pages[i], use, offset);
342*4882a593Smuzhiyun sg_unmark_end(&msg->sg.data[msg->sg.end]);
343*4882a593Smuzhiyun sk_mem_charge(sk, use);
344*4882a593Smuzhiyun
345*4882a593Smuzhiyun offset = 0;
346*4882a593Smuzhiyun copied -= use;
347*4882a593Smuzhiyun sk_msg_iter_next(msg, end);
348*4882a593Smuzhiyun num_elems++;
349*4882a593Smuzhiyun i++;
350*4882a593Smuzhiyun }
351*4882a593Smuzhiyun /* When zerocopy is mixed with sk_msg_*copy* operations we
352*4882a593Smuzhiyun * may have a copybreak set in this case clear and prefer
353*4882a593Smuzhiyun * zerocopy remainder when possible.
354*4882a593Smuzhiyun */
355*4882a593Smuzhiyun msg->sg.copybreak = 0;
356*4882a593Smuzhiyun msg->sg.curr = msg->sg.end;
357*4882a593Smuzhiyun }
358*4882a593Smuzhiyun out:
359*4882a593Smuzhiyun /* Revert iov_iter updates, msg will need to use 'trim' later if it
360*4882a593Smuzhiyun * also needs to be cleared.
361*4882a593Smuzhiyun */
362*4882a593Smuzhiyun if (ret)
363*4882a593Smuzhiyun iov_iter_revert(from, msg->sg.size - orig);
364*4882a593Smuzhiyun return ret;
365*4882a593Smuzhiyun }
366*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter);
367*4882a593Smuzhiyun
sk_msg_memcopy_from_iter(struct sock * sk,struct iov_iter * from,struct sk_msg * msg,u32 bytes)368*4882a593Smuzhiyun int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
369*4882a593Smuzhiyun struct sk_msg *msg, u32 bytes)
370*4882a593Smuzhiyun {
371*4882a593Smuzhiyun int ret = -ENOSPC, i = msg->sg.curr;
372*4882a593Smuzhiyun struct scatterlist *sge;
373*4882a593Smuzhiyun u32 copy, buf_size;
374*4882a593Smuzhiyun void *to;
375*4882a593Smuzhiyun
376*4882a593Smuzhiyun do {
377*4882a593Smuzhiyun sge = sk_msg_elem(msg, i);
378*4882a593Smuzhiyun /* This is possible if a trim operation shrunk the buffer */
379*4882a593Smuzhiyun if (msg->sg.copybreak >= sge->length) {
380*4882a593Smuzhiyun msg->sg.copybreak = 0;
381*4882a593Smuzhiyun sk_msg_iter_var_next(i);
382*4882a593Smuzhiyun if (i == msg->sg.end)
383*4882a593Smuzhiyun break;
384*4882a593Smuzhiyun sge = sk_msg_elem(msg, i);
385*4882a593Smuzhiyun }
386*4882a593Smuzhiyun
387*4882a593Smuzhiyun buf_size = sge->length - msg->sg.copybreak;
388*4882a593Smuzhiyun copy = (buf_size > bytes) ? bytes : buf_size;
389*4882a593Smuzhiyun to = sg_virt(sge) + msg->sg.copybreak;
390*4882a593Smuzhiyun msg->sg.copybreak += copy;
391*4882a593Smuzhiyun if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
392*4882a593Smuzhiyun ret = copy_from_iter_nocache(to, copy, from);
393*4882a593Smuzhiyun else
394*4882a593Smuzhiyun ret = copy_from_iter(to, copy, from);
395*4882a593Smuzhiyun if (ret != copy) {
396*4882a593Smuzhiyun ret = -EFAULT;
397*4882a593Smuzhiyun goto out;
398*4882a593Smuzhiyun }
399*4882a593Smuzhiyun bytes -= copy;
400*4882a593Smuzhiyun if (!bytes)
401*4882a593Smuzhiyun break;
402*4882a593Smuzhiyun msg->sg.copybreak = 0;
403*4882a593Smuzhiyun sk_msg_iter_var_next(i);
404*4882a593Smuzhiyun } while (i != msg->sg.end);
405*4882a593Smuzhiyun out:
406*4882a593Smuzhiyun msg->sg.curr = i;
407*4882a593Smuzhiyun return ret;
408*4882a593Smuzhiyun }
409*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
410*4882a593Smuzhiyun
sk_psock_create_ingress_msg(struct sock * sk,struct sk_buff * skb)411*4882a593Smuzhiyun static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk,
412*4882a593Smuzhiyun struct sk_buff *skb)
413*4882a593Smuzhiyun {
414*4882a593Smuzhiyun struct sk_msg *msg;
415*4882a593Smuzhiyun
416*4882a593Smuzhiyun if (atomic_read(&sk->sk_rmem_alloc) > sk->sk_rcvbuf)
417*4882a593Smuzhiyun return NULL;
418*4882a593Smuzhiyun
419*4882a593Smuzhiyun if (!sk_rmem_schedule(sk, skb, skb->truesize))
420*4882a593Smuzhiyun return NULL;
421*4882a593Smuzhiyun
422*4882a593Smuzhiyun msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
423*4882a593Smuzhiyun if (unlikely(!msg))
424*4882a593Smuzhiyun return NULL;
425*4882a593Smuzhiyun
426*4882a593Smuzhiyun sk_msg_init(msg);
427*4882a593Smuzhiyun return msg;
428*4882a593Smuzhiyun }
429*4882a593Smuzhiyun
sk_psock_skb_ingress_enqueue(struct sk_buff * skb,struct sk_psock * psock,struct sock * sk,struct sk_msg * msg)430*4882a593Smuzhiyun static int sk_psock_skb_ingress_enqueue(struct sk_buff *skb,
431*4882a593Smuzhiyun struct sk_psock *psock,
432*4882a593Smuzhiyun struct sock *sk,
433*4882a593Smuzhiyun struct sk_msg *msg)
434*4882a593Smuzhiyun {
435*4882a593Smuzhiyun int num_sge, copied;
436*4882a593Smuzhiyun
437*4882a593Smuzhiyun /* skb linearize may fail with ENOMEM, but lets simply try again
438*4882a593Smuzhiyun * later if this happens. Under memory pressure we don't want to
439*4882a593Smuzhiyun * drop the skb. We need to linearize the skb so that the mapping
440*4882a593Smuzhiyun * in skb_to_sgvec can not error.
441*4882a593Smuzhiyun */
442*4882a593Smuzhiyun if (skb_linearize(skb))
443*4882a593Smuzhiyun return -EAGAIN;
444*4882a593Smuzhiyun num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len);
445*4882a593Smuzhiyun if (unlikely(num_sge < 0))
446*4882a593Smuzhiyun return num_sge;
447*4882a593Smuzhiyun
448*4882a593Smuzhiyun copied = skb->len;
449*4882a593Smuzhiyun msg->sg.start = 0;
450*4882a593Smuzhiyun msg->sg.size = copied;
451*4882a593Smuzhiyun msg->sg.end = num_sge;
452*4882a593Smuzhiyun msg->skb = skb;
453*4882a593Smuzhiyun
454*4882a593Smuzhiyun sk_psock_queue_msg(psock, msg);
455*4882a593Smuzhiyun sk_psock_data_ready(sk, psock);
456*4882a593Smuzhiyun return copied;
457*4882a593Smuzhiyun }
458*4882a593Smuzhiyun
459*4882a593Smuzhiyun static int sk_psock_skb_ingress_self(struct sk_psock *psock, struct sk_buff *skb);
460*4882a593Smuzhiyun
sk_psock_skb_ingress(struct sk_psock * psock,struct sk_buff * skb)461*4882a593Smuzhiyun static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
462*4882a593Smuzhiyun {
463*4882a593Smuzhiyun struct sock *sk = psock->sk;
464*4882a593Smuzhiyun struct sk_msg *msg;
465*4882a593Smuzhiyun int err;
466*4882a593Smuzhiyun
467*4882a593Smuzhiyun /* If we are receiving on the same sock skb->sk is already assigned,
468*4882a593Smuzhiyun * skip memory accounting and owner transition seeing it already set
469*4882a593Smuzhiyun * correctly.
470*4882a593Smuzhiyun */
471*4882a593Smuzhiyun if (unlikely(skb->sk == sk))
472*4882a593Smuzhiyun return sk_psock_skb_ingress_self(psock, skb);
473*4882a593Smuzhiyun msg = sk_psock_create_ingress_msg(sk, skb);
474*4882a593Smuzhiyun if (!msg)
475*4882a593Smuzhiyun return -EAGAIN;
476*4882a593Smuzhiyun
477*4882a593Smuzhiyun /* This will transition ownership of the data from the socket where
478*4882a593Smuzhiyun * the BPF program was run initiating the redirect to the socket
479*4882a593Smuzhiyun * we will eventually receive this data on. The data will be released
480*4882a593Smuzhiyun * from skb_consume found in __tcp_bpf_recvmsg() after its been copied
481*4882a593Smuzhiyun * into user buffers.
482*4882a593Smuzhiyun */
483*4882a593Smuzhiyun skb_set_owner_r(skb, sk);
484*4882a593Smuzhiyun err = sk_psock_skb_ingress_enqueue(skb, psock, sk, msg);
485*4882a593Smuzhiyun if (err < 0)
486*4882a593Smuzhiyun kfree(msg);
487*4882a593Smuzhiyun return err;
488*4882a593Smuzhiyun }
489*4882a593Smuzhiyun
490*4882a593Smuzhiyun /* Puts an skb on the ingress queue of the socket already assigned to the
491*4882a593Smuzhiyun * skb. In this case we do not need to check memory limits or skb_set_owner_r
492*4882a593Smuzhiyun * because the skb is already accounted for here.
493*4882a593Smuzhiyun */
sk_psock_skb_ingress_self(struct sk_psock * psock,struct sk_buff * skb)494*4882a593Smuzhiyun static int sk_psock_skb_ingress_self(struct sk_psock *psock, struct sk_buff *skb)
495*4882a593Smuzhiyun {
496*4882a593Smuzhiyun struct sk_msg *msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
497*4882a593Smuzhiyun struct sock *sk = psock->sk;
498*4882a593Smuzhiyun int err;
499*4882a593Smuzhiyun
500*4882a593Smuzhiyun if (unlikely(!msg))
501*4882a593Smuzhiyun return -EAGAIN;
502*4882a593Smuzhiyun sk_msg_init(msg);
503*4882a593Smuzhiyun skb_set_owner_r(skb, sk);
504*4882a593Smuzhiyun err = sk_psock_skb_ingress_enqueue(skb, psock, sk, msg);
505*4882a593Smuzhiyun if (err < 0)
506*4882a593Smuzhiyun kfree(msg);
507*4882a593Smuzhiyun return err;
508*4882a593Smuzhiyun }
509*4882a593Smuzhiyun
sk_psock_handle_skb(struct sk_psock * psock,struct sk_buff * skb,u32 off,u32 len,bool ingress)510*4882a593Smuzhiyun static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
511*4882a593Smuzhiyun u32 off, u32 len, bool ingress)
512*4882a593Smuzhiyun {
513*4882a593Smuzhiyun if (!ingress) {
514*4882a593Smuzhiyun if (!sock_writeable(psock->sk))
515*4882a593Smuzhiyun return -EAGAIN;
516*4882a593Smuzhiyun return skb_send_sock_locked(psock->sk, skb, off, len);
517*4882a593Smuzhiyun }
518*4882a593Smuzhiyun return sk_psock_skb_ingress(psock, skb);
519*4882a593Smuzhiyun }
520*4882a593Smuzhiyun
sk_psock_backlog(struct work_struct * work)521*4882a593Smuzhiyun static void sk_psock_backlog(struct work_struct *work)
522*4882a593Smuzhiyun {
523*4882a593Smuzhiyun struct sk_psock *psock = container_of(work, struct sk_psock, work);
524*4882a593Smuzhiyun struct sk_psock_work_state *state = &psock->work_state;
525*4882a593Smuzhiyun struct sk_buff *skb;
526*4882a593Smuzhiyun bool ingress;
527*4882a593Smuzhiyun u32 len, off;
528*4882a593Smuzhiyun int ret;
529*4882a593Smuzhiyun
530*4882a593Smuzhiyun /* Lock sock to avoid losing sk_socket during loop. */
531*4882a593Smuzhiyun lock_sock(psock->sk);
532*4882a593Smuzhiyun if (state->skb) {
533*4882a593Smuzhiyun skb = state->skb;
534*4882a593Smuzhiyun len = state->len;
535*4882a593Smuzhiyun off = state->off;
536*4882a593Smuzhiyun state->skb = NULL;
537*4882a593Smuzhiyun goto start;
538*4882a593Smuzhiyun }
539*4882a593Smuzhiyun
540*4882a593Smuzhiyun while ((skb = skb_dequeue(&psock->ingress_skb))) {
541*4882a593Smuzhiyun len = skb->len;
542*4882a593Smuzhiyun off = 0;
543*4882a593Smuzhiyun start:
544*4882a593Smuzhiyun ingress = tcp_skb_bpf_ingress(skb);
545*4882a593Smuzhiyun do {
546*4882a593Smuzhiyun ret = -EIO;
547*4882a593Smuzhiyun if (likely(psock->sk->sk_socket))
548*4882a593Smuzhiyun ret = sk_psock_handle_skb(psock, skb, off,
549*4882a593Smuzhiyun len, ingress);
550*4882a593Smuzhiyun if (ret <= 0) {
551*4882a593Smuzhiyun if (ret == -EAGAIN) {
552*4882a593Smuzhiyun state->skb = skb;
553*4882a593Smuzhiyun state->len = len;
554*4882a593Smuzhiyun state->off = off;
555*4882a593Smuzhiyun goto end;
556*4882a593Smuzhiyun }
557*4882a593Smuzhiyun /* Hard errors break pipe and stop xmit. */
558*4882a593Smuzhiyun sk_psock_report_error(psock, ret ? -ret : EPIPE);
559*4882a593Smuzhiyun sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
560*4882a593Smuzhiyun kfree_skb(skb);
561*4882a593Smuzhiyun goto end;
562*4882a593Smuzhiyun }
563*4882a593Smuzhiyun off += ret;
564*4882a593Smuzhiyun len -= ret;
565*4882a593Smuzhiyun } while (len);
566*4882a593Smuzhiyun
567*4882a593Smuzhiyun if (!ingress)
568*4882a593Smuzhiyun kfree_skb(skb);
569*4882a593Smuzhiyun }
570*4882a593Smuzhiyun end:
571*4882a593Smuzhiyun release_sock(psock->sk);
572*4882a593Smuzhiyun }
573*4882a593Smuzhiyun
sk_psock_init(struct sock * sk,int node)574*4882a593Smuzhiyun struct sk_psock *sk_psock_init(struct sock *sk, int node)
575*4882a593Smuzhiyun {
576*4882a593Smuzhiyun struct sk_psock *psock;
577*4882a593Smuzhiyun struct proto *prot;
578*4882a593Smuzhiyun
579*4882a593Smuzhiyun write_lock_bh(&sk->sk_callback_lock);
580*4882a593Smuzhiyun
581*4882a593Smuzhiyun if (inet_csk_has_ulp(sk)) {
582*4882a593Smuzhiyun psock = ERR_PTR(-EINVAL);
583*4882a593Smuzhiyun goto out;
584*4882a593Smuzhiyun }
585*4882a593Smuzhiyun
586*4882a593Smuzhiyun if (sk->sk_user_data) {
587*4882a593Smuzhiyun psock = ERR_PTR(-EBUSY);
588*4882a593Smuzhiyun goto out;
589*4882a593Smuzhiyun }
590*4882a593Smuzhiyun
591*4882a593Smuzhiyun psock = kzalloc_node(sizeof(*psock), GFP_ATOMIC | __GFP_NOWARN, node);
592*4882a593Smuzhiyun if (!psock) {
593*4882a593Smuzhiyun psock = ERR_PTR(-ENOMEM);
594*4882a593Smuzhiyun goto out;
595*4882a593Smuzhiyun }
596*4882a593Smuzhiyun
597*4882a593Smuzhiyun prot = READ_ONCE(sk->sk_prot);
598*4882a593Smuzhiyun psock->sk = sk;
599*4882a593Smuzhiyun psock->eval = __SK_NONE;
600*4882a593Smuzhiyun psock->sk_proto = prot;
601*4882a593Smuzhiyun psock->saved_unhash = prot->unhash;
602*4882a593Smuzhiyun psock->saved_close = prot->close;
603*4882a593Smuzhiyun psock->saved_write_space = sk->sk_write_space;
604*4882a593Smuzhiyun
605*4882a593Smuzhiyun INIT_LIST_HEAD(&psock->link);
606*4882a593Smuzhiyun spin_lock_init(&psock->link_lock);
607*4882a593Smuzhiyun
608*4882a593Smuzhiyun INIT_WORK(&psock->work, sk_psock_backlog);
609*4882a593Smuzhiyun INIT_LIST_HEAD(&psock->ingress_msg);
610*4882a593Smuzhiyun skb_queue_head_init(&psock->ingress_skb);
611*4882a593Smuzhiyun
612*4882a593Smuzhiyun sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
613*4882a593Smuzhiyun refcount_set(&psock->refcnt, 1);
614*4882a593Smuzhiyun
615*4882a593Smuzhiyun __rcu_assign_sk_user_data_with_flags(sk, psock,
616*4882a593Smuzhiyun SK_USER_DATA_NOCOPY |
617*4882a593Smuzhiyun SK_USER_DATA_PSOCK);
618*4882a593Smuzhiyun sock_hold(sk);
619*4882a593Smuzhiyun
620*4882a593Smuzhiyun out:
621*4882a593Smuzhiyun write_unlock_bh(&sk->sk_callback_lock);
622*4882a593Smuzhiyun return psock;
623*4882a593Smuzhiyun }
624*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_psock_init);
625*4882a593Smuzhiyun
sk_psock_link_pop(struct sk_psock * psock)626*4882a593Smuzhiyun struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock)
627*4882a593Smuzhiyun {
628*4882a593Smuzhiyun struct sk_psock_link *link;
629*4882a593Smuzhiyun
630*4882a593Smuzhiyun spin_lock_bh(&psock->link_lock);
631*4882a593Smuzhiyun link = list_first_entry_or_null(&psock->link, struct sk_psock_link,
632*4882a593Smuzhiyun list);
633*4882a593Smuzhiyun if (link)
634*4882a593Smuzhiyun list_del(&link->list);
635*4882a593Smuzhiyun spin_unlock_bh(&psock->link_lock);
636*4882a593Smuzhiyun return link;
637*4882a593Smuzhiyun }
638*4882a593Smuzhiyun
__sk_psock_purge_ingress_msg(struct sk_psock * psock)639*4882a593Smuzhiyun void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
640*4882a593Smuzhiyun {
641*4882a593Smuzhiyun struct sk_msg *msg, *tmp;
642*4882a593Smuzhiyun
643*4882a593Smuzhiyun list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) {
644*4882a593Smuzhiyun list_del(&msg->list);
645*4882a593Smuzhiyun sk_msg_free(psock->sk, msg);
646*4882a593Smuzhiyun kfree(msg);
647*4882a593Smuzhiyun }
648*4882a593Smuzhiyun }
649*4882a593Smuzhiyun
sk_psock_zap_ingress(struct sk_psock * psock)650*4882a593Smuzhiyun static void sk_psock_zap_ingress(struct sk_psock *psock)
651*4882a593Smuzhiyun {
652*4882a593Smuzhiyun __skb_queue_purge(&psock->ingress_skb);
653*4882a593Smuzhiyun __sk_psock_purge_ingress_msg(psock);
654*4882a593Smuzhiyun }
655*4882a593Smuzhiyun
sk_psock_link_destroy(struct sk_psock * psock)656*4882a593Smuzhiyun static void sk_psock_link_destroy(struct sk_psock *psock)
657*4882a593Smuzhiyun {
658*4882a593Smuzhiyun struct sk_psock_link *link, *tmp;
659*4882a593Smuzhiyun
660*4882a593Smuzhiyun list_for_each_entry_safe(link, tmp, &psock->link, list) {
661*4882a593Smuzhiyun list_del(&link->list);
662*4882a593Smuzhiyun sk_psock_free_link(link);
663*4882a593Smuzhiyun }
664*4882a593Smuzhiyun }
665*4882a593Smuzhiyun
sk_psock_destroy_deferred(struct work_struct * gc)666*4882a593Smuzhiyun static void sk_psock_destroy_deferred(struct work_struct *gc)
667*4882a593Smuzhiyun {
668*4882a593Smuzhiyun struct sk_psock *psock = container_of(gc, struct sk_psock, gc);
669*4882a593Smuzhiyun
670*4882a593Smuzhiyun /* No sk_callback_lock since already detached. */
671*4882a593Smuzhiyun
672*4882a593Smuzhiyun /* Parser has been stopped */
673*4882a593Smuzhiyun if (psock->progs.skb_parser)
674*4882a593Smuzhiyun strp_done(&psock->parser.strp);
675*4882a593Smuzhiyun
676*4882a593Smuzhiyun cancel_work_sync(&psock->work);
677*4882a593Smuzhiyun
678*4882a593Smuzhiyun psock_progs_drop(&psock->progs);
679*4882a593Smuzhiyun
680*4882a593Smuzhiyun sk_psock_link_destroy(psock);
681*4882a593Smuzhiyun sk_psock_cork_free(psock);
682*4882a593Smuzhiyun sk_psock_zap_ingress(psock);
683*4882a593Smuzhiyun
684*4882a593Smuzhiyun if (psock->sk_redir)
685*4882a593Smuzhiyun sock_put(psock->sk_redir);
686*4882a593Smuzhiyun sock_put(psock->sk);
687*4882a593Smuzhiyun kfree(psock);
688*4882a593Smuzhiyun }
689*4882a593Smuzhiyun
sk_psock_destroy(struct rcu_head * rcu)690*4882a593Smuzhiyun static void sk_psock_destroy(struct rcu_head *rcu)
691*4882a593Smuzhiyun {
692*4882a593Smuzhiyun struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu);
693*4882a593Smuzhiyun
694*4882a593Smuzhiyun INIT_WORK(&psock->gc, sk_psock_destroy_deferred);
695*4882a593Smuzhiyun schedule_work(&psock->gc);
696*4882a593Smuzhiyun }
697*4882a593Smuzhiyun
sk_psock_drop(struct sock * sk,struct sk_psock * psock)698*4882a593Smuzhiyun void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
699*4882a593Smuzhiyun {
700*4882a593Smuzhiyun sk_psock_cork_free(psock);
701*4882a593Smuzhiyun sk_psock_zap_ingress(psock);
702*4882a593Smuzhiyun
703*4882a593Smuzhiyun write_lock_bh(&sk->sk_callback_lock);
704*4882a593Smuzhiyun sk_psock_restore_proto(sk, psock);
705*4882a593Smuzhiyun rcu_assign_sk_user_data(sk, NULL);
706*4882a593Smuzhiyun if (psock->progs.skb_parser)
707*4882a593Smuzhiyun sk_psock_stop_strp(sk, psock);
708*4882a593Smuzhiyun else if (psock->progs.skb_verdict)
709*4882a593Smuzhiyun sk_psock_stop_verdict(sk, psock);
710*4882a593Smuzhiyun write_unlock_bh(&sk->sk_callback_lock);
711*4882a593Smuzhiyun sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
712*4882a593Smuzhiyun
713*4882a593Smuzhiyun call_rcu(&psock->rcu, sk_psock_destroy);
714*4882a593Smuzhiyun }
715*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_psock_drop);
716*4882a593Smuzhiyun
sk_psock_map_verd(int verdict,bool redir)717*4882a593Smuzhiyun static int sk_psock_map_verd(int verdict, bool redir)
718*4882a593Smuzhiyun {
719*4882a593Smuzhiyun switch (verdict) {
720*4882a593Smuzhiyun case SK_PASS:
721*4882a593Smuzhiyun return redir ? __SK_REDIRECT : __SK_PASS;
722*4882a593Smuzhiyun case SK_DROP:
723*4882a593Smuzhiyun default:
724*4882a593Smuzhiyun break;
725*4882a593Smuzhiyun }
726*4882a593Smuzhiyun
727*4882a593Smuzhiyun return __SK_DROP;
728*4882a593Smuzhiyun }
729*4882a593Smuzhiyun
sk_psock_msg_verdict(struct sock * sk,struct sk_psock * psock,struct sk_msg * msg)730*4882a593Smuzhiyun int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
731*4882a593Smuzhiyun struct sk_msg *msg)
732*4882a593Smuzhiyun {
733*4882a593Smuzhiyun struct bpf_prog *prog;
734*4882a593Smuzhiyun int ret;
735*4882a593Smuzhiyun
736*4882a593Smuzhiyun rcu_read_lock();
737*4882a593Smuzhiyun prog = READ_ONCE(psock->progs.msg_parser);
738*4882a593Smuzhiyun if (unlikely(!prog)) {
739*4882a593Smuzhiyun ret = __SK_PASS;
740*4882a593Smuzhiyun goto out;
741*4882a593Smuzhiyun }
742*4882a593Smuzhiyun
743*4882a593Smuzhiyun sk_msg_compute_data_pointers(msg);
744*4882a593Smuzhiyun msg->sk = sk;
745*4882a593Smuzhiyun ret = bpf_prog_run_pin_on_cpu(prog, msg);
746*4882a593Smuzhiyun ret = sk_psock_map_verd(ret, msg->sk_redir);
747*4882a593Smuzhiyun psock->apply_bytes = msg->apply_bytes;
748*4882a593Smuzhiyun if (ret == __SK_REDIRECT) {
749*4882a593Smuzhiyun if (psock->sk_redir)
750*4882a593Smuzhiyun sock_put(psock->sk_redir);
751*4882a593Smuzhiyun psock->sk_redir = msg->sk_redir;
752*4882a593Smuzhiyun if (!psock->sk_redir) {
753*4882a593Smuzhiyun ret = __SK_DROP;
754*4882a593Smuzhiyun goto out;
755*4882a593Smuzhiyun }
756*4882a593Smuzhiyun sock_hold(psock->sk_redir);
757*4882a593Smuzhiyun }
758*4882a593Smuzhiyun out:
759*4882a593Smuzhiyun rcu_read_unlock();
760*4882a593Smuzhiyun return ret;
761*4882a593Smuzhiyun }
762*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
763*4882a593Smuzhiyun
sk_psock_bpf_run(struct sk_psock * psock,struct bpf_prog * prog,struct sk_buff * skb)764*4882a593Smuzhiyun static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
765*4882a593Smuzhiyun struct sk_buff *skb)
766*4882a593Smuzhiyun {
767*4882a593Smuzhiyun bpf_compute_data_end_sk_skb(skb);
768*4882a593Smuzhiyun return bpf_prog_run_pin_on_cpu(prog, skb);
769*4882a593Smuzhiyun }
770*4882a593Smuzhiyun
sk_psock_from_strp(struct strparser * strp)771*4882a593Smuzhiyun static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
772*4882a593Smuzhiyun {
773*4882a593Smuzhiyun struct sk_psock_parser *parser;
774*4882a593Smuzhiyun
775*4882a593Smuzhiyun parser = container_of(strp, struct sk_psock_parser, strp);
776*4882a593Smuzhiyun return container_of(parser, struct sk_psock, parser);
777*4882a593Smuzhiyun }
778*4882a593Smuzhiyun
sk_psock_skb_redirect(struct sk_buff * skb)779*4882a593Smuzhiyun static void sk_psock_skb_redirect(struct sk_buff *skb)
780*4882a593Smuzhiyun {
781*4882a593Smuzhiyun struct sk_psock *psock_other;
782*4882a593Smuzhiyun struct sock *sk_other;
783*4882a593Smuzhiyun
784*4882a593Smuzhiyun sk_other = tcp_skb_bpf_redirect_fetch(skb);
785*4882a593Smuzhiyun /* This error is a buggy BPF program, it returned a redirect
786*4882a593Smuzhiyun * return code, but then didn't set a redirect interface.
787*4882a593Smuzhiyun */
788*4882a593Smuzhiyun if (unlikely(!sk_other)) {
789*4882a593Smuzhiyun kfree_skb(skb);
790*4882a593Smuzhiyun return;
791*4882a593Smuzhiyun }
792*4882a593Smuzhiyun psock_other = sk_psock(sk_other);
793*4882a593Smuzhiyun /* This error indicates the socket is being torn down or had another
794*4882a593Smuzhiyun * error that caused the pipe to break. We can't send a packet on
795*4882a593Smuzhiyun * a socket that is in this state so we drop the skb.
796*4882a593Smuzhiyun */
797*4882a593Smuzhiyun if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
798*4882a593Smuzhiyun !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) {
799*4882a593Smuzhiyun kfree_skb(skb);
800*4882a593Smuzhiyun return;
801*4882a593Smuzhiyun }
802*4882a593Smuzhiyun
803*4882a593Smuzhiyun skb_queue_tail(&psock_other->ingress_skb, skb);
804*4882a593Smuzhiyun schedule_work(&psock_other->work);
805*4882a593Smuzhiyun }
806*4882a593Smuzhiyun
sk_psock_tls_verdict_apply(struct sk_buff * skb,struct sock * sk,int verdict)807*4882a593Smuzhiyun static void sk_psock_tls_verdict_apply(struct sk_buff *skb, struct sock *sk, int verdict)
808*4882a593Smuzhiyun {
809*4882a593Smuzhiyun switch (verdict) {
810*4882a593Smuzhiyun case __SK_REDIRECT:
811*4882a593Smuzhiyun sk_psock_skb_redirect(skb);
812*4882a593Smuzhiyun break;
813*4882a593Smuzhiyun case __SK_PASS:
814*4882a593Smuzhiyun case __SK_DROP:
815*4882a593Smuzhiyun default:
816*4882a593Smuzhiyun break;
817*4882a593Smuzhiyun }
818*4882a593Smuzhiyun }
819*4882a593Smuzhiyun
sk_psock_tls_strp_read(struct sk_psock * psock,struct sk_buff * skb)820*4882a593Smuzhiyun int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb)
821*4882a593Smuzhiyun {
822*4882a593Smuzhiyun struct bpf_prog *prog;
823*4882a593Smuzhiyun int ret = __SK_PASS;
824*4882a593Smuzhiyun
825*4882a593Smuzhiyun rcu_read_lock();
826*4882a593Smuzhiyun prog = READ_ONCE(psock->progs.skb_verdict);
827*4882a593Smuzhiyun if (likely(prog)) {
828*4882a593Smuzhiyun skb->sk = psock->sk;
829*4882a593Smuzhiyun tcp_skb_bpf_redirect_clear(skb);
830*4882a593Smuzhiyun ret = sk_psock_bpf_run(psock, prog, skb);
831*4882a593Smuzhiyun ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
832*4882a593Smuzhiyun skb->sk = NULL;
833*4882a593Smuzhiyun }
834*4882a593Smuzhiyun sk_psock_tls_verdict_apply(skb, psock->sk, ret);
835*4882a593Smuzhiyun rcu_read_unlock();
836*4882a593Smuzhiyun return ret;
837*4882a593Smuzhiyun }
838*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(sk_psock_tls_strp_read);
839*4882a593Smuzhiyun
sk_psock_verdict_apply(struct sk_psock * psock,struct sk_buff * skb,int verdict)840*4882a593Smuzhiyun static void sk_psock_verdict_apply(struct sk_psock *psock,
841*4882a593Smuzhiyun struct sk_buff *skb, int verdict)
842*4882a593Smuzhiyun {
843*4882a593Smuzhiyun struct tcp_skb_cb *tcp;
844*4882a593Smuzhiyun struct sock *sk_other;
845*4882a593Smuzhiyun int err = -EIO;
846*4882a593Smuzhiyun
847*4882a593Smuzhiyun switch (verdict) {
848*4882a593Smuzhiyun case __SK_PASS:
849*4882a593Smuzhiyun sk_other = psock->sk;
850*4882a593Smuzhiyun if (sock_flag(sk_other, SOCK_DEAD) ||
851*4882a593Smuzhiyun !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {
852*4882a593Smuzhiyun goto out_free;
853*4882a593Smuzhiyun }
854*4882a593Smuzhiyun
855*4882a593Smuzhiyun tcp = TCP_SKB_CB(skb);
856*4882a593Smuzhiyun tcp->bpf.flags |= BPF_F_INGRESS;
857*4882a593Smuzhiyun
858*4882a593Smuzhiyun /* If the queue is empty then we can submit directly
859*4882a593Smuzhiyun * into the msg queue. If its not empty we have to
860*4882a593Smuzhiyun * queue work otherwise we may get OOO data. Otherwise,
861*4882a593Smuzhiyun * if sk_psock_skb_ingress errors will be handled by
862*4882a593Smuzhiyun * retrying later from workqueue.
863*4882a593Smuzhiyun */
864*4882a593Smuzhiyun if (skb_queue_empty(&psock->ingress_skb)) {
865*4882a593Smuzhiyun err = sk_psock_skb_ingress_self(psock, skb);
866*4882a593Smuzhiyun }
867*4882a593Smuzhiyun if (err < 0) {
868*4882a593Smuzhiyun skb_queue_tail(&psock->ingress_skb, skb);
869*4882a593Smuzhiyun schedule_work(&psock->work);
870*4882a593Smuzhiyun }
871*4882a593Smuzhiyun break;
872*4882a593Smuzhiyun case __SK_REDIRECT:
873*4882a593Smuzhiyun sk_psock_skb_redirect(skb);
874*4882a593Smuzhiyun break;
875*4882a593Smuzhiyun case __SK_DROP:
876*4882a593Smuzhiyun default:
877*4882a593Smuzhiyun out_free:
878*4882a593Smuzhiyun kfree_skb(skb);
879*4882a593Smuzhiyun }
880*4882a593Smuzhiyun }
881*4882a593Smuzhiyun
sk_psock_strp_read(struct strparser * strp,struct sk_buff * skb)882*4882a593Smuzhiyun static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
883*4882a593Smuzhiyun {
884*4882a593Smuzhiyun struct sk_psock *psock;
885*4882a593Smuzhiyun struct bpf_prog *prog;
886*4882a593Smuzhiyun int ret = __SK_DROP;
887*4882a593Smuzhiyun struct sock *sk;
888*4882a593Smuzhiyun
889*4882a593Smuzhiyun rcu_read_lock();
890*4882a593Smuzhiyun sk = strp->sk;
891*4882a593Smuzhiyun psock = sk_psock(sk);
892*4882a593Smuzhiyun if (unlikely(!psock)) {
893*4882a593Smuzhiyun kfree_skb(skb);
894*4882a593Smuzhiyun goto out;
895*4882a593Smuzhiyun }
896*4882a593Smuzhiyun prog = READ_ONCE(psock->progs.skb_verdict);
897*4882a593Smuzhiyun if (likely(prog)) {
898*4882a593Smuzhiyun skb->sk = sk;
899*4882a593Smuzhiyun tcp_skb_bpf_redirect_clear(skb);
900*4882a593Smuzhiyun ret = sk_psock_bpf_run(psock, prog, skb);
901*4882a593Smuzhiyun ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
902*4882a593Smuzhiyun skb->sk = NULL;
903*4882a593Smuzhiyun }
904*4882a593Smuzhiyun sk_psock_verdict_apply(psock, skb, ret);
905*4882a593Smuzhiyun out:
906*4882a593Smuzhiyun rcu_read_unlock();
907*4882a593Smuzhiyun }
908*4882a593Smuzhiyun
sk_psock_strp_read_done(struct strparser * strp,int err)909*4882a593Smuzhiyun static int sk_psock_strp_read_done(struct strparser *strp, int err)
910*4882a593Smuzhiyun {
911*4882a593Smuzhiyun return err;
912*4882a593Smuzhiyun }
913*4882a593Smuzhiyun
sk_psock_strp_parse(struct strparser * strp,struct sk_buff * skb)914*4882a593Smuzhiyun static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
915*4882a593Smuzhiyun {
916*4882a593Smuzhiyun struct sk_psock *psock = sk_psock_from_strp(strp);
917*4882a593Smuzhiyun struct bpf_prog *prog;
918*4882a593Smuzhiyun int ret = skb->len;
919*4882a593Smuzhiyun
920*4882a593Smuzhiyun rcu_read_lock();
921*4882a593Smuzhiyun prog = READ_ONCE(psock->progs.skb_parser);
922*4882a593Smuzhiyun if (likely(prog)) {
923*4882a593Smuzhiyun skb->sk = psock->sk;
924*4882a593Smuzhiyun ret = sk_psock_bpf_run(psock, prog, skb);
925*4882a593Smuzhiyun skb->sk = NULL;
926*4882a593Smuzhiyun }
927*4882a593Smuzhiyun rcu_read_unlock();
928*4882a593Smuzhiyun return ret;
929*4882a593Smuzhiyun }
930*4882a593Smuzhiyun
931*4882a593Smuzhiyun /* Called with socket lock held. */
sk_psock_strp_data_ready(struct sock * sk)932*4882a593Smuzhiyun static void sk_psock_strp_data_ready(struct sock *sk)
933*4882a593Smuzhiyun {
934*4882a593Smuzhiyun struct sk_psock *psock;
935*4882a593Smuzhiyun
936*4882a593Smuzhiyun rcu_read_lock();
937*4882a593Smuzhiyun psock = sk_psock(sk);
938*4882a593Smuzhiyun if (likely(psock)) {
939*4882a593Smuzhiyun if (tls_sw_has_ctx_rx(sk)) {
940*4882a593Smuzhiyun psock->parser.saved_data_ready(sk);
941*4882a593Smuzhiyun } else {
942*4882a593Smuzhiyun write_lock_bh(&sk->sk_callback_lock);
943*4882a593Smuzhiyun strp_data_ready(&psock->parser.strp);
944*4882a593Smuzhiyun write_unlock_bh(&sk->sk_callback_lock);
945*4882a593Smuzhiyun }
946*4882a593Smuzhiyun }
947*4882a593Smuzhiyun rcu_read_unlock();
948*4882a593Smuzhiyun }
949*4882a593Smuzhiyun
sk_psock_verdict_recv(read_descriptor_t * desc,struct sk_buff * skb,unsigned int offset,size_t orig_len)950*4882a593Smuzhiyun static int sk_psock_verdict_recv(read_descriptor_t *desc, struct sk_buff *skb,
951*4882a593Smuzhiyun unsigned int offset, size_t orig_len)
952*4882a593Smuzhiyun {
953*4882a593Smuzhiyun struct sock *sk = (struct sock *)desc->arg.data;
954*4882a593Smuzhiyun struct sk_psock *psock;
955*4882a593Smuzhiyun struct bpf_prog *prog;
956*4882a593Smuzhiyun int ret = __SK_DROP;
957*4882a593Smuzhiyun int len = orig_len;
958*4882a593Smuzhiyun
959*4882a593Smuzhiyun /* clone here so sk_eat_skb() in tcp_read_sock does not drop our data */
960*4882a593Smuzhiyun skb = skb_clone(skb, GFP_ATOMIC);
961*4882a593Smuzhiyun if (!skb) {
962*4882a593Smuzhiyun desc->error = -ENOMEM;
963*4882a593Smuzhiyun return 0;
964*4882a593Smuzhiyun }
965*4882a593Smuzhiyun
966*4882a593Smuzhiyun rcu_read_lock();
967*4882a593Smuzhiyun psock = sk_psock(sk);
968*4882a593Smuzhiyun if (unlikely(!psock)) {
969*4882a593Smuzhiyun len = 0;
970*4882a593Smuzhiyun kfree_skb(skb);
971*4882a593Smuzhiyun goto out;
972*4882a593Smuzhiyun }
973*4882a593Smuzhiyun prog = READ_ONCE(psock->progs.skb_verdict);
974*4882a593Smuzhiyun if (likely(prog)) {
975*4882a593Smuzhiyun skb->sk = sk;
976*4882a593Smuzhiyun tcp_skb_bpf_redirect_clear(skb);
977*4882a593Smuzhiyun ret = sk_psock_bpf_run(psock, prog, skb);
978*4882a593Smuzhiyun ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
979*4882a593Smuzhiyun skb->sk = NULL;
980*4882a593Smuzhiyun }
981*4882a593Smuzhiyun sk_psock_verdict_apply(psock, skb, ret);
982*4882a593Smuzhiyun out:
983*4882a593Smuzhiyun rcu_read_unlock();
984*4882a593Smuzhiyun return len;
985*4882a593Smuzhiyun }
986*4882a593Smuzhiyun
sk_psock_verdict_data_ready(struct sock * sk)987*4882a593Smuzhiyun static void sk_psock_verdict_data_ready(struct sock *sk)
988*4882a593Smuzhiyun {
989*4882a593Smuzhiyun struct socket *sock = sk->sk_socket;
990*4882a593Smuzhiyun read_descriptor_t desc;
991*4882a593Smuzhiyun
992*4882a593Smuzhiyun if (unlikely(!sock || !sock->ops || !sock->ops->read_sock))
993*4882a593Smuzhiyun return;
994*4882a593Smuzhiyun
995*4882a593Smuzhiyun desc.arg.data = sk;
996*4882a593Smuzhiyun desc.error = 0;
997*4882a593Smuzhiyun desc.count = 1;
998*4882a593Smuzhiyun
999*4882a593Smuzhiyun sock->ops->read_sock(sk, &desc, sk_psock_verdict_recv);
1000*4882a593Smuzhiyun }
1001*4882a593Smuzhiyun
sk_psock_write_space(struct sock * sk)1002*4882a593Smuzhiyun static void sk_psock_write_space(struct sock *sk)
1003*4882a593Smuzhiyun {
1004*4882a593Smuzhiyun struct sk_psock *psock;
1005*4882a593Smuzhiyun void (*write_space)(struct sock *sk) = NULL;
1006*4882a593Smuzhiyun
1007*4882a593Smuzhiyun rcu_read_lock();
1008*4882a593Smuzhiyun psock = sk_psock(sk);
1009*4882a593Smuzhiyun if (likely(psock)) {
1010*4882a593Smuzhiyun if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))
1011*4882a593Smuzhiyun schedule_work(&psock->work);
1012*4882a593Smuzhiyun write_space = psock->saved_write_space;
1013*4882a593Smuzhiyun }
1014*4882a593Smuzhiyun rcu_read_unlock();
1015*4882a593Smuzhiyun if (write_space)
1016*4882a593Smuzhiyun write_space(sk);
1017*4882a593Smuzhiyun }
1018*4882a593Smuzhiyun
sk_psock_init_strp(struct sock * sk,struct sk_psock * psock)1019*4882a593Smuzhiyun int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
1020*4882a593Smuzhiyun {
1021*4882a593Smuzhiyun static const struct strp_callbacks cb = {
1022*4882a593Smuzhiyun .rcv_msg = sk_psock_strp_read,
1023*4882a593Smuzhiyun .read_sock_done = sk_psock_strp_read_done,
1024*4882a593Smuzhiyun .parse_msg = sk_psock_strp_parse,
1025*4882a593Smuzhiyun };
1026*4882a593Smuzhiyun
1027*4882a593Smuzhiyun psock->parser.enabled = false;
1028*4882a593Smuzhiyun return strp_init(&psock->parser.strp, sk, &cb);
1029*4882a593Smuzhiyun }
1030*4882a593Smuzhiyun
sk_psock_start_verdict(struct sock * sk,struct sk_psock * psock)1031*4882a593Smuzhiyun void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock)
1032*4882a593Smuzhiyun {
1033*4882a593Smuzhiyun struct sk_psock_parser *parser = &psock->parser;
1034*4882a593Smuzhiyun
1035*4882a593Smuzhiyun if (parser->enabled)
1036*4882a593Smuzhiyun return;
1037*4882a593Smuzhiyun
1038*4882a593Smuzhiyun parser->saved_data_ready = sk->sk_data_ready;
1039*4882a593Smuzhiyun sk->sk_data_ready = sk_psock_verdict_data_ready;
1040*4882a593Smuzhiyun sk->sk_write_space = sk_psock_write_space;
1041*4882a593Smuzhiyun parser->enabled = true;
1042*4882a593Smuzhiyun }
1043*4882a593Smuzhiyun
sk_psock_start_strp(struct sock * sk,struct sk_psock * psock)1044*4882a593Smuzhiyun void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
1045*4882a593Smuzhiyun {
1046*4882a593Smuzhiyun struct sk_psock_parser *parser = &psock->parser;
1047*4882a593Smuzhiyun
1048*4882a593Smuzhiyun if (parser->enabled)
1049*4882a593Smuzhiyun return;
1050*4882a593Smuzhiyun
1051*4882a593Smuzhiyun parser->saved_data_ready = sk->sk_data_ready;
1052*4882a593Smuzhiyun sk->sk_data_ready = sk_psock_strp_data_ready;
1053*4882a593Smuzhiyun sk->sk_write_space = sk_psock_write_space;
1054*4882a593Smuzhiyun parser->enabled = true;
1055*4882a593Smuzhiyun }
1056*4882a593Smuzhiyun
sk_psock_stop_strp(struct sock * sk,struct sk_psock * psock)1057*4882a593Smuzhiyun void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
1058*4882a593Smuzhiyun {
1059*4882a593Smuzhiyun struct sk_psock_parser *parser = &psock->parser;
1060*4882a593Smuzhiyun
1061*4882a593Smuzhiyun if (!parser->enabled)
1062*4882a593Smuzhiyun return;
1063*4882a593Smuzhiyun
1064*4882a593Smuzhiyun sk->sk_data_ready = parser->saved_data_ready;
1065*4882a593Smuzhiyun parser->saved_data_ready = NULL;
1066*4882a593Smuzhiyun strp_stop(&parser->strp);
1067*4882a593Smuzhiyun parser->enabled = false;
1068*4882a593Smuzhiyun }
1069*4882a593Smuzhiyun
sk_psock_stop_verdict(struct sock * sk,struct sk_psock * psock)1070*4882a593Smuzhiyun void sk_psock_stop_verdict(struct sock *sk, struct sk_psock *psock)
1071*4882a593Smuzhiyun {
1072*4882a593Smuzhiyun struct sk_psock_parser *parser = &psock->parser;
1073*4882a593Smuzhiyun
1074*4882a593Smuzhiyun if (!parser->enabled)
1075*4882a593Smuzhiyun return;
1076*4882a593Smuzhiyun
1077*4882a593Smuzhiyun sk->sk_data_ready = parser->saved_data_ready;
1078*4882a593Smuzhiyun parser->saved_data_ready = NULL;
1079*4882a593Smuzhiyun parser->enabled = false;
1080*4882a593Smuzhiyun }
1081