xref: /OK3568_Linux_fs/kernel/net/tls/tls_sw.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun /*
2*4882a593Smuzhiyun  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3*4882a593Smuzhiyun  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4*4882a593Smuzhiyun  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5*4882a593Smuzhiyun  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6*4882a593Smuzhiyun  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7*4882a593Smuzhiyun  * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
8*4882a593Smuzhiyun  *
9*4882a593Smuzhiyun  * This software is available to you under a choice of one of two
10*4882a593Smuzhiyun  * licenses.  You may choose to be licensed under the terms of the GNU
11*4882a593Smuzhiyun  * General Public License (GPL) Version 2, available from the file
12*4882a593Smuzhiyun  * COPYING in the main directory of this source tree, or the
13*4882a593Smuzhiyun  * OpenIB.org BSD license below:
14*4882a593Smuzhiyun  *
15*4882a593Smuzhiyun  *     Redistribution and use in source and binary forms, with or
16*4882a593Smuzhiyun  *     without modification, are permitted provided that the following
17*4882a593Smuzhiyun  *     conditions are met:
18*4882a593Smuzhiyun  *
19*4882a593Smuzhiyun  *      - Redistributions of source code must retain the above
20*4882a593Smuzhiyun  *        copyright notice, this list of conditions and the following
21*4882a593Smuzhiyun  *        disclaimer.
22*4882a593Smuzhiyun  *
23*4882a593Smuzhiyun  *      - Redistributions in binary form must reproduce the above
24*4882a593Smuzhiyun  *        copyright notice, this list of conditions and the following
25*4882a593Smuzhiyun  *        disclaimer in the documentation and/or other materials
26*4882a593Smuzhiyun  *        provided with the distribution.
27*4882a593Smuzhiyun  *
28*4882a593Smuzhiyun  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
29*4882a593Smuzhiyun  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
30*4882a593Smuzhiyun  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
31*4882a593Smuzhiyun  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
32*4882a593Smuzhiyun  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
33*4882a593Smuzhiyun  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
34*4882a593Smuzhiyun  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35*4882a593Smuzhiyun  * SOFTWARE.
36*4882a593Smuzhiyun  */
37*4882a593Smuzhiyun 
38*4882a593Smuzhiyun #include <linux/bug.h>
39*4882a593Smuzhiyun #include <linux/sched/signal.h>
40*4882a593Smuzhiyun #include <linux/module.h>
41*4882a593Smuzhiyun #include <linux/splice.h>
42*4882a593Smuzhiyun #include <crypto/aead.h>
43*4882a593Smuzhiyun 
44*4882a593Smuzhiyun #include <net/strparser.h>
45*4882a593Smuzhiyun #include <net/tls.h>
46*4882a593Smuzhiyun 
tls_err_abort(struct sock * sk,int err)47*4882a593Smuzhiyun noinline void tls_err_abort(struct sock *sk, int err)
48*4882a593Smuzhiyun {
49*4882a593Smuzhiyun 	WARN_ON_ONCE(err >= 0);
50*4882a593Smuzhiyun 	/* sk->sk_err should contain a positive error code. */
51*4882a593Smuzhiyun 	sk->sk_err = -err;
52*4882a593Smuzhiyun 	sk->sk_error_report(sk);
53*4882a593Smuzhiyun }
54*4882a593Smuzhiyun 
__skb_nsg(struct sk_buff * skb,int offset,int len,unsigned int recursion_level)55*4882a593Smuzhiyun static int __skb_nsg(struct sk_buff *skb, int offset, int len,
56*4882a593Smuzhiyun                      unsigned int recursion_level)
57*4882a593Smuzhiyun {
58*4882a593Smuzhiyun         int start = skb_headlen(skb);
59*4882a593Smuzhiyun         int i, chunk = start - offset;
60*4882a593Smuzhiyun         struct sk_buff *frag_iter;
61*4882a593Smuzhiyun         int elt = 0;
62*4882a593Smuzhiyun 
63*4882a593Smuzhiyun         if (unlikely(recursion_level >= 24))
64*4882a593Smuzhiyun                 return -EMSGSIZE;
65*4882a593Smuzhiyun 
66*4882a593Smuzhiyun         if (chunk > 0) {
67*4882a593Smuzhiyun                 if (chunk > len)
68*4882a593Smuzhiyun                         chunk = len;
69*4882a593Smuzhiyun                 elt++;
70*4882a593Smuzhiyun                 len -= chunk;
71*4882a593Smuzhiyun                 if (len == 0)
72*4882a593Smuzhiyun                         return elt;
73*4882a593Smuzhiyun                 offset += chunk;
74*4882a593Smuzhiyun         }
75*4882a593Smuzhiyun 
76*4882a593Smuzhiyun         for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
77*4882a593Smuzhiyun                 int end;
78*4882a593Smuzhiyun 
79*4882a593Smuzhiyun                 WARN_ON(start > offset + len);
80*4882a593Smuzhiyun 
81*4882a593Smuzhiyun                 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
82*4882a593Smuzhiyun                 chunk = end - offset;
83*4882a593Smuzhiyun                 if (chunk > 0) {
84*4882a593Smuzhiyun                         if (chunk > len)
85*4882a593Smuzhiyun                                 chunk = len;
86*4882a593Smuzhiyun                         elt++;
87*4882a593Smuzhiyun                         len -= chunk;
88*4882a593Smuzhiyun                         if (len == 0)
89*4882a593Smuzhiyun                                 return elt;
90*4882a593Smuzhiyun                         offset += chunk;
91*4882a593Smuzhiyun                 }
92*4882a593Smuzhiyun                 start = end;
93*4882a593Smuzhiyun         }
94*4882a593Smuzhiyun 
95*4882a593Smuzhiyun         if (unlikely(skb_has_frag_list(skb))) {
96*4882a593Smuzhiyun                 skb_walk_frags(skb, frag_iter) {
97*4882a593Smuzhiyun                         int end, ret;
98*4882a593Smuzhiyun 
99*4882a593Smuzhiyun                         WARN_ON(start > offset + len);
100*4882a593Smuzhiyun 
101*4882a593Smuzhiyun                         end = start + frag_iter->len;
102*4882a593Smuzhiyun                         chunk = end - offset;
103*4882a593Smuzhiyun                         if (chunk > 0) {
104*4882a593Smuzhiyun                                 if (chunk > len)
105*4882a593Smuzhiyun                                         chunk = len;
106*4882a593Smuzhiyun                                 ret = __skb_nsg(frag_iter, offset - start, chunk,
107*4882a593Smuzhiyun                                                 recursion_level + 1);
108*4882a593Smuzhiyun                                 if (unlikely(ret < 0))
109*4882a593Smuzhiyun                                         return ret;
110*4882a593Smuzhiyun                                 elt += ret;
111*4882a593Smuzhiyun                                 len -= chunk;
112*4882a593Smuzhiyun                                 if (len == 0)
113*4882a593Smuzhiyun                                         return elt;
114*4882a593Smuzhiyun                                 offset += chunk;
115*4882a593Smuzhiyun                         }
116*4882a593Smuzhiyun                         start = end;
117*4882a593Smuzhiyun                 }
118*4882a593Smuzhiyun         }
119*4882a593Smuzhiyun         BUG_ON(len);
120*4882a593Smuzhiyun         return elt;
121*4882a593Smuzhiyun }
122*4882a593Smuzhiyun 
123*4882a593Smuzhiyun /* Return the number of scatterlist elements required to completely map the
124*4882a593Smuzhiyun  * skb, or -EMSGSIZE if the recursion depth is exceeded.
125*4882a593Smuzhiyun  */
skb_nsg(struct sk_buff * skb,int offset,int len)126*4882a593Smuzhiyun static int skb_nsg(struct sk_buff *skb, int offset, int len)
127*4882a593Smuzhiyun {
128*4882a593Smuzhiyun         return __skb_nsg(skb, offset, len, 0);
129*4882a593Smuzhiyun }
130*4882a593Smuzhiyun 
padding_length(struct tls_sw_context_rx * ctx,struct tls_prot_info * prot,struct sk_buff * skb)131*4882a593Smuzhiyun static int padding_length(struct tls_sw_context_rx *ctx,
132*4882a593Smuzhiyun 			  struct tls_prot_info *prot, struct sk_buff *skb)
133*4882a593Smuzhiyun {
134*4882a593Smuzhiyun 	struct strp_msg *rxm = strp_msg(skb);
135*4882a593Smuzhiyun 	int sub = 0;
136*4882a593Smuzhiyun 
137*4882a593Smuzhiyun 	/* Determine zero-padding length */
138*4882a593Smuzhiyun 	if (prot->version == TLS_1_3_VERSION) {
139*4882a593Smuzhiyun 		char content_type = 0;
140*4882a593Smuzhiyun 		int err;
141*4882a593Smuzhiyun 		int back = 17;
142*4882a593Smuzhiyun 
143*4882a593Smuzhiyun 		while (content_type == 0) {
144*4882a593Smuzhiyun 			if (back > rxm->full_len - prot->prepend_size)
145*4882a593Smuzhiyun 				return -EBADMSG;
146*4882a593Smuzhiyun 			err = skb_copy_bits(skb,
147*4882a593Smuzhiyun 					    rxm->offset + rxm->full_len - back,
148*4882a593Smuzhiyun 					    &content_type, 1);
149*4882a593Smuzhiyun 			if (err)
150*4882a593Smuzhiyun 				return err;
151*4882a593Smuzhiyun 			if (content_type)
152*4882a593Smuzhiyun 				break;
153*4882a593Smuzhiyun 			sub++;
154*4882a593Smuzhiyun 			back++;
155*4882a593Smuzhiyun 		}
156*4882a593Smuzhiyun 		ctx->control = content_type;
157*4882a593Smuzhiyun 	}
158*4882a593Smuzhiyun 	return sub;
159*4882a593Smuzhiyun }
160*4882a593Smuzhiyun 
tls_decrypt_done(struct crypto_async_request * req,int err)161*4882a593Smuzhiyun static void tls_decrypt_done(struct crypto_async_request *req, int err)
162*4882a593Smuzhiyun {
163*4882a593Smuzhiyun 	struct aead_request *aead_req = (struct aead_request *)req;
164*4882a593Smuzhiyun 	struct scatterlist *sgout = aead_req->dst;
165*4882a593Smuzhiyun 	struct scatterlist *sgin = aead_req->src;
166*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx;
167*4882a593Smuzhiyun 	struct tls_context *tls_ctx;
168*4882a593Smuzhiyun 	struct tls_prot_info *prot;
169*4882a593Smuzhiyun 	struct scatterlist *sg;
170*4882a593Smuzhiyun 	struct sk_buff *skb;
171*4882a593Smuzhiyun 	unsigned int pages;
172*4882a593Smuzhiyun 	int pending;
173*4882a593Smuzhiyun 
174*4882a593Smuzhiyun 	skb = (struct sk_buff *)req->data;
175*4882a593Smuzhiyun 	tls_ctx = tls_get_ctx(skb->sk);
176*4882a593Smuzhiyun 	ctx = tls_sw_ctx_rx(tls_ctx);
177*4882a593Smuzhiyun 	prot = &tls_ctx->prot_info;
178*4882a593Smuzhiyun 
179*4882a593Smuzhiyun 	/* Propagate if there was an err */
180*4882a593Smuzhiyun 	if (err) {
181*4882a593Smuzhiyun 		if (err == -EBADMSG)
182*4882a593Smuzhiyun 			TLS_INC_STATS(sock_net(skb->sk),
183*4882a593Smuzhiyun 				      LINUX_MIB_TLSDECRYPTERROR);
184*4882a593Smuzhiyun 		ctx->async_wait.err = err;
185*4882a593Smuzhiyun 		tls_err_abort(skb->sk, err);
186*4882a593Smuzhiyun 	} else {
187*4882a593Smuzhiyun 		struct strp_msg *rxm = strp_msg(skb);
188*4882a593Smuzhiyun 		int pad;
189*4882a593Smuzhiyun 
190*4882a593Smuzhiyun 		pad = padding_length(ctx, prot, skb);
191*4882a593Smuzhiyun 		if (pad < 0) {
192*4882a593Smuzhiyun 			ctx->async_wait.err = pad;
193*4882a593Smuzhiyun 			tls_err_abort(skb->sk, pad);
194*4882a593Smuzhiyun 		} else {
195*4882a593Smuzhiyun 			rxm->full_len -= pad;
196*4882a593Smuzhiyun 			rxm->offset += prot->prepend_size;
197*4882a593Smuzhiyun 			rxm->full_len -= prot->overhead_size;
198*4882a593Smuzhiyun 		}
199*4882a593Smuzhiyun 	}
200*4882a593Smuzhiyun 
201*4882a593Smuzhiyun 	/* After using skb->sk to propagate sk through crypto async callback
202*4882a593Smuzhiyun 	 * we need to NULL it again.
203*4882a593Smuzhiyun 	 */
204*4882a593Smuzhiyun 	skb->sk = NULL;
205*4882a593Smuzhiyun 
206*4882a593Smuzhiyun 
207*4882a593Smuzhiyun 	/* Free the destination pages if skb was not decrypted inplace */
208*4882a593Smuzhiyun 	if (sgout != sgin) {
209*4882a593Smuzhiyun 		/* Skip the first S/G entry as it points to AAD */
210*4882a593Smuzhiyun 		for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
211*4882a593Smuzhiyun 			if (!sg)
212*4882a593Smuzhiyun 				break;
213*4882a593Smuzhiyun 			put_page(sg_page(sg));
214*4882a593Smuzhiyun 		}
215*4882a593Smuzhiyun 	}
216*4882a593Smuzhiyun 
217*4882a593Smuzhiyun 	kfree(aead_req);
218*4882a593Smuzhiyun 
219*4882a593Smuzhiyun 	spin_lock_bh(&ctx->decrypt_compl_lock);
220*4882a593Smuzhiyun 	pending = atomic_dec_return(&ctx->decrypt_pending);
221*4882a593Smuzhiyun 
222*4882a593Smuzhiyun 	if (!pending && ctx->async_notify)
223*4882a593Smuzhiyun 		complete(&ctx->async_wait.completion);
224*4882a593Smuzhiyun 	spin_unlock_bh(&ctx->decrypt_compl_lock);
225*4882a593Smuzhiyun }
226*4882a593Smuzhiyun 
tls_do_decryption(struct sock * sk,struct sk_buff * skb,struct scatterlist * sgin,struct scatterlist * sgout,char * iv_recv,size_t data_len,struct aead_request * aead_req,bool async)227*4882a593Smuzhiyun static int tls_do_decryption(struct sock *sk,
228*4882a593Smuzhiyun 			     struct sk_buff *skb,
229*4882a593Smuzhiyun 			     struct scatterlist *sgin,
230*4882a593Smuzhiyun 			     struct scatterlist *sgout,
231*4882a593Smuzhiyun 			     char *iv_recv,
232*4882a593Smuzhiyun 			     size_t data_len,
233*4882a593Smuzhiyun 			     struct aead_request *aead_req,
234*4882a593Smuzhiyun 			     bool async)
235*4882a593Smuzhiyun {
236*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
237*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
238*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
239*4882a593Smuzhiyun 	int ret;
240*4882a593Smuzhiyun 
241*4882a593Smuzhiyun 	aead_request_set_tfm(aead_req, ctx->aead_recv);
242*4882a593Smuzhiyun 	aead_request_set_ad(aead_req, prot->aad_size);
243*4882a593Smuzhiyun 	aead_request_set_crypt(aead_req, sgin, sgout,
244*4882a593Smuzhiyun 			       data_len + prot->tag_size,
245*4882a593Smuzhiyun 			       (u8 *)iv_recv);
246*4882a593Smuzhiyun 
247*4882a593Smuzhiyun 	if (async) {
248*4882a593Smuzhiyun 		/* Using skb->sk to push sk through to crypto async callback
249*4882a593Smuzhiyun 		 * handler. This allows propagating errors up to the socket
250*4882a593Smuzhiyun 		 * if needed. It _must_ be cleared in the async handler
251*4882a593Smuzhiyun 		 * before consume_skb is called. We _know_ skb->sk is NULL
252*4882a593Smuzhiyun 		 * because it is a clone from strparser.
253*4882a593Smuzhiyun 		 */
254*4882a593Smuzhiyun 		skb->sk = sk;
255*4882a593Smuzhiyun 		aead_request_set_callback(aead_req,
256*4882a593Smuzhiyun 					  CRYPTO_TFM_REQ_MAY_BACKLOG,
257*4882a593Smuzhiyun 					  tls_decrypt_done, skb);
258*4882a593Smuzhiyun 		atomic_inc(&ctx->decrypt_pending);
259*4882a593Smuzhiyun 	} else {
260*4882a593Smuzhiyun 		aead_request_set_callback(aead_req,
261*4882a593Smuzhiyun 					  CRYPTO_TFM_REQ_MAY_BACKLOG,
262*4882a593Smuzhiyun 					  crypto_req_done, &ctx->async_wait);
263*4882a593Smuzhiyun 	}
264*4882a593Smuzhiyun 
265*4882a593Smuzhiyun 	ret = crypto_aead_decrypt(aead_req);
266*4882a593Smuzhiyun 	if (ret == -EINPROGRESS) {
267*4882a593Smuzhiyun 		if (async)
268*4882a593Smuzhiyun 			return ret;
269*4882a593Smuzhiyun 
270*4882a593Smuzhiyun 		ret = crypto_wait_req(ret, &ctx->async_wait);
271*4882a593Smuzhiyun 	}
272*4882a593Smuzhiyun 
273*4882a593Smuzhiyun 	if (async)
274*4882a593Smuzhiyun 		atomic_dec(&ctx->decrypt_pending);
275*4882a593Smuzhiyun 
276*4882a593Smuzhiyun 	return ret;
277*4882a593Smuzhiyun }
278*4882a593Smuzhiyun 
tls_trim_both_msgs(struct sock * sk,int target_size)279*4882a593Smuzhiyun static void tls_trim_both_msgs(struct sock *sk, int target_size)
280*4882a593Smuzhiyun {
281*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
282*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
283*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
284*4882a593Smuzhiyun 	struct tls_rec *rec = ctx->open_rec;
285*4882a593Smuzhiyun 
286*4882a593Smuzhiyun 	sk_msg_trim(sk, &rec->msg_plaintext, target_size);
287*4882a593Smuzhiyun 	if (target_size > 0)
288*4882a593Smuzhiyun 		target_size += prot->overhead_size;
289*4882a593Smuzhiyun 	sk_msg_trim(sk, &rec->msg_encrypted, target_size);
290*4882a593Smuzhiyun }
291*4882a593Smuzhiyun 
tls_alloc_encrypted_msg(struct sock * sk,int len)292*4882a593Smuzhiyun static int tls_alloc_encrypted_msg(struct sock *sk, int len)
293*4882a593Smuzhiyun {
294*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
295*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
296*4882a593Smuzhiyun 	struct tls_rec *rec = ctx->open_rec;
297*4882a593Smuzhiyun 	struct sk_msg *msg_en = &rec->msg_encrypted;
298*4882a593Smuzhiyun 
299*4882a593Smuzhiyun 	return sk_msg_alloc(sk, msg_en, len, 0);
300*4882a593Smuzhiyun }
301*4882a593Smuzhiyun 
tls_clone_plaintext_msg(struct sock * sk,int required)302*4882a593Smuzhiyun static int tls_clone_plaintext_msg(struct sock *sk, int required)
303*4882a593Smuzhiyun {
304*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
305*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
306*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
307*4882a593Smuzhiyun 	struct tls_rec *rec = ctx->open_rec;
308*4882a593Smuzhiyun 	struct sk_msg *msg_pl = &rec->msg_plaintext;
309*4882a593Smuzhiyun 	struct sk_msg *msg_en = &rec->msg_encrypted;
310*4882a593Smuzhiyun 	int skip, len;
311*4882a593Smuzhiyun 
312*4882a593Smuzhiyun 	/* We add page references worth len bytes from encrypted sg
313*4882a593Smuzhiyun 	 * at the end of plaintext sg. It is guaranteed that msg_en
314*4882a593Smuzhiyun 	 * has enough required room (ensured by caller).
315*4882a593Smuzhiyun 	 */
316*4882a593Smuzhiyun 	len = required - msg_pl->sg.size;
317*4882a593Smuzhiyun 
318*4882a593Smuzhiyun 	/* Skip initial bytes in msg_en's data to be able to use
319*4882a593Smuzhiyun 	 * same offset of both plain and encrypted data.
320*4882a593Smuzhiyun 	 */
321*4882a593Smuzhiyun 	skip = prot->prepend_size + msg_pl->sg.size;
322*4882a593Smuzhiyun 
323*4882a593Smuzhiyun 	return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
324*4882a593Smuzhiyun }
325*4882a593Smuzhiyun 
tls_get_rec(struct sock * sk)326*4882a593Smuzhiyun static struct tls_rec *tls_get_rec(struct sock *sk)
327*4882a593Smuzhiyun {
328*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
329*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
330*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
331*4882a593Smuzhiyun 	struct sk_msg *msg_pl, *msg_en;
332*4882a593Smuzhiyun 	struct tls_rec *rec;
333*4882a593Smuzhiyun 	int mem_size;
334*4882a593Smuzhiyun 
335*4882a593Smuzhiyun 	mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
336*4882a593Smuzhiyun 
337*4882a593Smuzhiyun 	rec = kzalloc(mem_size, sk->sk_allocation);
338*4882a593Smuzhiyun 	if (!rec)
339*4882a593Smuzhiyun 		return NULL;
340*4882a593Smuzhiyun 
341*4882a593Smuzhiyun 	msg_pl = &rec->msg_plaintext;
342*4882a593Smuzhiyun 	msg_en = &rec->msg_encrypted;
343*4882a593Smuzhiyun 
344*4882a593Smuzhiyun 	sk_msg_init(msg_pl);
345*4882a593Smuzhiyun 	sk_msg_init(msg_en);
346*4882a593Smuzhiyun 
347*4882a593Smuzhiyun 	sg_init_table(rec->sg_aead_in, 2);
348*4882a593Smuzhiyun 	sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
349*4882a593Smuzhiyun 	sg_unmark_end(&rec->sg_aead_in[1]);
350*4882a593Smuzhiyun 
351*4882a593Smuzhiyun 	sg_init_table(rec->sg_aead_out, 2);
352*4882a593Smuzhiyun 	sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
353*4882a593Smuzhiyun 	sg_unmark_end(&rec->sg_aead_out[1]);
354*4882a593Smuzhiyun 
355*4882a593Smuzhiyun 	return rec;
356*4882a593Smuzhiyun }
357*4882a593Smuzhiyun 
tls_free_rec(struct sock * sk,struct tls_rec * rec)358*4882a593Smuzhiyun static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
359*4882a593Smuzhiyun {
360*4882a593Smuzhiyun 	sk_msg_free(sk, &rec->msg_encrypted);
361*4882a593Smuzhiyun 	sk_msg_free(sk, &rec->msg_plaintext);
362*4882a593Smuzhiyun 	kfree(rec);
363*4882a593Smuzhiyun }
364*4882a593Smuzhiyun 
tls_free_open_rec(struct sock * sk)365*4882a593Smuzhiyun static void tls_free_open_rec(struct sock *sk)
366*4882a593Smuzhiyun {
367*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
368*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
369*4882a593Smuzhiyun 	struct tls_rec *rec = ctx->open_rec;
370*4882a593Smuzhiyun 
371*4882a593Smuzhiyun 	if (rec) {
372*4882a593Smuzhiyun 		tls_free_rec(sk, rec);
373*4882a593Smuzhiyun 		ctx->open_rec = NULL;
374*4882a593Smuzhiyun 	}
375*4882a593Smuzhiyun }
376*4882a593Smuzhiyun 
tls_tx_records(struct sock * sk,int flags)377*4882a593Smuzhiyun int tls_tx_records(struct sock *sk, int flags)
378*4882a593Smuzhiyun {
379*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
380*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
381*4882a593Smuzhiyun 	struct tls_rec *rec, *tmp;
382*4882a593Smuzhiyun 	struct sk_msg *msg_en;
383*4882a593Smuzhiyun 	int tx_flags, rc = 0;
384*4882a593Smuzhiyun 
385*4882a593Smuzhiyun 	if (tls_is_partially_sent_record(tls_ctx)) {
386*4882a593Smuzhiyun 		rec = list_first_entry(&ctx->tx_list,
387*4882a593Smuzhiyun 				       struct tls_rec, list);
388*4882a593Smuzhiyun 
389*4882a593Smuzhiyun 		if (flags == -1)
390*4882a593Smuzhiyun 			tx_flags = rec->tx_flags;
391*4882a593Smuzhiyun 		else
392*4882a593Smuzhiyun 			tx_flags = flags;
393*4882a593Smuzhiyun 
394*4882a593Smuzhiyun 		rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
395*4882a593Smuzhiyun 		if (rc)
396*4882a593Smuzhiyun 			goto tx_err;
397*4882a593Smuzhiyun 
398*4882a593Smuzhiyun 		/* Full record has been transmitted.
399*4882a593Smuzhiyun 		 * Remove the head of tx_list
400*4882a593Smuzhiyun 		 */
401*4882a593Smuzhiyun 		list_del(&rec->list);
402*4882a593Smuzhiyun 		sk_msg_free(sk, &rec->msg_plaintext);
403*4882a593Smuzhiyun 		kfree(rec);
404*4882a593Smuzhiyun 	}
405*4882a593Smuzhiyun 
406*4882a593Smuzhiyun 	/* Tx all ready records */
407*4882a593Smuzhiyun 	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
408*4882a593Smuzhiyun 		if (READ_ONCE(rec->tx_ready)) {
409*4882a593Smuzhiyun 			if (flags == -1)
410*4882a593Smuzhiyun 				tx_flags = rec->tx_flags;
411*4882a593Smuzhiyun 			else
412*4882a593Smuzhiyun 				tx_flags = flags;
413*4882a593Smuzhiyun 
414*4882a593Smuzhiyun 			msg_en = &rec->msg_encrypted;
415*4882a593Smuzhiyun 			rc = tls_push_sg(sk, tls_ctx,
416*4882a593Smuzhiyun 					 &msg_en->sg.data[msg_en->sg.curr],
417*4882a593Smuzhiyun 					 0, tx_flags);
418*4882a593Smuzhiyun 			if (rc)
419*4882a593Smuzhiyun 				goto tx_err;
420*4882a593Smuzhiyun 
421*4882a593Smuzhiyun 			list_del(&rec->list);
422*4882a593Smuzhiyun 			sk_msg_free(sk, &rec->msg_plaintext);
423*4882a593Smuzhiyun 			kfree(rec);
424*4882a593Smuzhiyun 		} else {
425*4882a593Smuzhiyun 			break;
426*4882a593Smuzhiyun 		}
427*4882a593Smuzhiyun 	}
428*4882a593Smuzhiyun 
429*4882a593Smuzhiyun tx_err:
430*4882a593Smuzhiyun 	if (rc < 0 && rc != -EAGAIN)
431*4882a593Smuzhiyun 		tls_err_abort(sk, -EBADMSG);
432*4882a593Smuzhiyun 
433*4882a593Smuzhiyun 	return rc;
434*4882a593Smuzhiyun }
435*4882a593Smuzhiyun 
tls_encrypt_done(struct crypto_async_request * req,int err)436*4882a593Smuzhiyun static void tls_encrypt_done(struct crypto_async_request *req, int err)
437*4882a593Smuzhiyun {
438*4882a593Smuzhiyun 	struct aead_request *aead_req = (struct aead_request *)req;
439*4882a593Smuzhiyun 	struct sock *sk = req->data;
440*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
441*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
442*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
443*4882a593Smuzhiyun 	struct scatterlist *sge;
444*4882a593Smuzhiyun 	struct sk_msg *msg_en;
445*4882a593Smuzhiyun 	struct tls_rec *rec;
446*4882a593Smuzhiyun 	bool ready = false;
447*4882a593Smuzhiyun 	int pending;
448*4882a593Smuzhiyun 
449*4882a593Smuzhiyun 	rec = container_of(aead_req, struct tls_rec, aead_req);
450*4882a593Smuzhiyun 	msg_en = &rec->msg_encrypted;
451*4882a593Smuzhiyun 
452*4882a593Smuzhiyun 	sge = sk_msg_elem(msg_en, msg_en->sg.curr);
453*4882a593Smuzhiyun 	sge->offset -= prot->prepend_size;
454*4882a593Smuzhiyun 	sge->length += prot->prepend_size;
455*4882a593Smuzhiyun 
456*4882a593Smuzhiyun 	/* Check if error is previously set on socket */
457*4882a593Smuzhiyun 	if (err || sk->sk_err) {
458*4882a593Smuzhiyun 		rec = NULL;
459*4882a593Smuzhiyun 
460*4882a593Smuzhiyun 		/* If err is already set on socket, return the same code */
461*4882a593Smuzhiyun 		if (sk->sk_err) {
462*4882a593Smuzhiyun 			ctx->async_wait.err = -sk->sk_err;
463*4882a593Smuzhiyun 		} else {
464*4882a593Smuzhiyun 			ctx->async_wait.err = err;
465*4882a593Smuzhiyun 			tls_err_abort(sk, err);
466*4882a593Smuzhiyun 		}
467*4882a593Smuzhiyun 	}
468*4882a593Smuzhiyun 
469*4882a593Smuzhiyun 	if (rec) {
470*4882a593Smuzhiyun 		struct tls_rec *first_rec;
471*4882a593Smuzhiyun 
472*4882a593Smuzhiyun 		/* Mark the record as ready for transmission */
473*4882a593Smuzhiyun 		smp_store_mb(rec->tx_ready, true);
474*4882a593Smuzhiyun 
475*4882a593Smuzhiyun 		/* If received record is at head of tx_list, schedule tx */
476*4882a593Smuzhiyun 		first_rec = list_first_entry(&ctx->tx_list,
477*4882a593Smuzhiyun 					     struct tls_rec, list);
478*4882a593Smuzhiyun 		if (rec == first_rec)
479*4882a593Smuzhiyun 			ready = true;
480*4882a593Smuzhiyun 	}
481*4882a593Smuzhiyun 
482*4882a593Smuzhiyun 	spin_lock_bh(&ctx->encrypt_compl_lock);
483*4882a593Smuzhiyun 	pending = atomic_dec_return(&ctx->encrypt_pending);
484*4882a593Smuzhiyun 
485*4882a593Smuzhiyun 	if (!pending && ctx->async_notify)
486*4882a593Smuzhiyun 		complete(&ctx->async_wait.completion);
487*4882a593Smuzhiyun 	spin_unlock_bh(&ctx->encrypt_compl_lock);
488*4882a593Smuzhiyun 
489*4882a593Smuzhiyun 	if (!ready)
490*4882a593Smuzhiyun 		return;
491*4882a593Smuzhiyun 
492*4882a593Smuzhiyun 	/* Schedule the transmission */
493*4882a593Smuzhiyun 	if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
494*4882a593Smuzhiyun 		schedule_delayed_work(&ctx->tx_work.work, 1);
495*4882a593Smuzhiyun }
496*4882a593Smuzhiyun 
tls_do_encryption(struct sock * sk,struct tls_context * tls_ctx,struct tls_sw_context_tx * ctx,struct aead_request * aead_req,size_t data_len,u32 start)497*4882a593Smuzhiyun static int tls_do_encryption(struct sock *sk,
498*4882a593Smuzhiyun 			     struct tls_context *tls_ctx,
499*4882a593Smuzhiyun 			     struct tls_sw_context_tx *ctx,
500*4882a593Smuzhiyun 			     struct aead_request *aead_req,
501*4882a593Smuzhiyun 			     size_t data_len, u32 start)
502*4882a593Smuzhiyun {
503*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
504*4882a593Smuzhiyun 	struct tls_rec *rec = ctx->open_rec;
505*4882a593Smuzhiyun 	struct sk_msg *msg_en = &rec->msg_encrypted;
506*4882a593Smuzhiyun 	struct scatterlist *sge = sk_msg_elem(msg_en, start);
507*4882a593Smuzhiyun 	int rc, iv_offset = 0;
508*4882a593Smuzhiyun 
509*4882a593Smuzhiyun 	/* For CCM based ciphers, first byte of IV is a constant */
510*4882a593Smuzhiyun 	if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
511*4882a593Smuzhiyun 		rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
512*4882a593Smuzhiyun 		iv_offset = 1;
513*4882a593Smuzhiyun 	}
514*4882a593Smuzhiyun 
515*4882a593Smuzhiyun 	memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
516*4882a593Smuzhiyun 	       prot->iv_size + prot->salt_size);
517*4882a593Smuzhiyun 
518*4882a593Smuzhiyun 	xor_iv_with_seq(prot->version, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq);
519*4882a593Smuzhiyun 
520*4882a593Smuzhiyun 	sge->offset += prot->prepend_size;
521*4882a593Smuzhiyun 	sge->length -= prot->prepend_size;
522*4882a593Smuzhiyun 
523*4882a593Smuzhiyun 	msg_en->sg.curr = start;
524*4882a593Smuzhiyun 
525*4882a593Smuzhiyun 	aead_request_set_tfm(aead_req, ctx->aead_send);
526*4882a593Smuzhiyun 	aead_request_set_ad(aead_req, prot->aad_size);
527*4882a593Smuzhiyun 	aead_request_set_crypt(aead_req, rec->sg_aead_in,
528*4882a593Smuzhiyun 			       rec->sg_aead_out,
529*4882a593Smuzhiyun 			       data_len, rec->iv_data);
530*4882a593Smuzhiyun 
531*4882a593Smuzhiyun 	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
532*4882a593Smuzhiyun 				  tls_encrypt_done, sk);
533*4882a593Smuzhiyun 
534*4882a593Smuzhiyun 	/* Add the record in tx_list */
535*4882a593Smuzhiyun 	list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
536*4882a593Smuzhiyun 	atomic_inc(&ctx->encrypt_pending);
537*4882a593Smuzhiyun 
538*4882a593Smuzhiyun 	rc = crypto_aead_encrypt(aead_req);
539*4882a593Smuzhiyun 	if (!rc || rc != -EINPROGRESS) {
540*4882a593Smuzhiyun 		atomic_dec(&ctx->encrypt_pending);
541*4882a593Smuzhiyun 		sge->offset -= prot->prepend_size;
542*4882a593Smuzhiyun 		sge->length += prot->prepend_size;
543*4882a593Smuzhiyun 	}
544*4882a593Smuzhiyun 
545*4882a593Smuzhiyun 	if (!rc) {
546*4882a593Smuzhiyun 		WRITE_ONCE(rec->tx_ready, true);
547*4882a593Smuzhiyun 	} else if (rc != -EINPROGRESS) {
548*4882a593Smuzhiyun 		list_del(&rec->list);
549*4882a593Smuzhiyun 		return rc;
550*4882a593Smuzhiyun 	}
551*4882a593Smuzhiyun 
552*4882a593Smuzhiyun 	/* Unhook the record from context if encryption is not failure */
553*4882a593Smuzhiyun 	ctx->open_rec = NULL;
554*4882a593Smuzhiyun 	tls_advance_record_sn(sk, prot, &tls_ctx->tx);
555*4882a593Smuzhiyun 	return rc;
556*4882a593Smuzhiyun }
557*4882a593Smuzhiyun 
tls_split_open_record(struct sock * sk,struct tls_rec * from,struct tls_rec ** to,struct sk_msg * msg_opl,struct sk_msg * msg_oen,u32 split_point,u32 tx_overhead_size,u32 * orig_end)558*4882a593Smuzhiyun static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
559*4882a593Smuzhiyun 				 struct tls_rec **to, struct sk_msg *msg_opl,
560*4882a593Smuzhiyun 				 struct sk_msg *msg_oen, u32 split_point,
561*4882a593Smuzhiyun 				 u32 tx_overhead_size, u32 *orig_end)
562*4882a593Smuzhiyun {
563*4882a593Smuzhiyun 	u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
564*4882a593Smuzhiyun 	struct scatterlist *sge, *osge, *nsge;
565*4882a593Smuzhiyun 	u32 orig_size = msg_opl->sg.size;
566*4882a593Smuzhiyun 	struct scatterlist tmp = { };
567*4882a593Smuzhiyun 	struct sk_msg *msg_npl;
568*4882a593Smuzhiyun 	struct tls_rec *new;
569*4882a593Smuzhiyun 	int ret;
570*4882a593Smuzhiyun 
571*4882a593Smuzhiyun 	new = tls_get_rec(sk);
572*4882a593Smuzhiyun 	if (!new)
573*4882a593Smuzhiyun 		return -ENOMEM;
574*4882a593Smuzhiyun 	ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
575*4882a593Smuzhiyun 			   tx_overhead_size, 0);
576*4882a593Smuzhiyun 	if (ret < 0) {
577*4882a593Smuzhiyun 		tls_free_rec(sk, new);
578*4882a593Smuzhiyun 		return ret;
579*4882a593Smuzhiyun 	}
580*4882a593Smuzhiyun 
581*4882a593Smuzhiyun 	*orig_end = msg_opl->sg.end;
582*4882a593Smuzhiyun 	i = msg_opl->sg.start;
583*4882a593Smuzhiyun 	sge = sk_msg_elem(msg_opl, i);
584*4882a593Smuzhiyun 	while (apply && sge->length) {
585*4882a593Smuzhiyun 		if (sge->length > apply) {
586*4882a593Smuzhiyun 			u32 len = sge->length - apply;
587*4882a593Smuzhiyun 
588*4882a593Smuzhiyun 			get_page(sg_page(sge));
589*4882a593Smuzhiyun 			sg_set_page(&tmp, sg_page(sge), len,
590*4882a593Smuzhiyun 				    sge->offset + apply);
591*4882a593Smuzhiyun 			sge->length = apply;
592*4882a593Smuzhiyun 			bytes += apply;
593*4882a593Smuzhiyun 			apply = 0;
594*4882a593Smuzhiyun 		} else {
595*4882a593Smuzhiyun 			apply -= sge->length;
596*4882a593Smuzhiyun 			bytes += sge->length;
597*4882a593Smuzhiyun 		}
598*4882a593Smuzhiyun 
599*4882a593Smuzhiyun 		sk_msg_iter_var_next(i);
600*4882a593Smuzhiyun 		if (i == msg_opl->sg.end)
601*4882a593Smuzhiyun 			break;
602*4882a593Smuzhiyun 		sge = sk_msg_elem(msg_opl, i);
603*4882a593Smuzhiyun 	}
604*4882a593Smuzhiyun 
605*4882a593Smuzhiyun 	msg_opl->sg.end = i;
606*4882a593Smuzhiyun 	msg_opl->sg.curr = i;
607*4882a593Smuzhiyun 	msg_opl->sg.copybreak = 0;
608*4882a593Smuzhiyun 	msg_opl->apply_bytes = 0;
609*4882a593Smuzhiyun 	msg_opl->sg.size = bytes;
610*4882a593Smuzhiyun 
611*4882a593Smuzhiyun 	msg_npl = &new->msg_plaintext;
612*4882a593Smuzhiyun 	msg_npl->apply_bytes = apply;
613*4882a593Smuzhiyun 	msg_npl->sg.size = orig_size - bytes;
614*4882a593Smuzhiyun 
615*4882a593Smuzhiyun 	j = msg_npl->sg.start;
616*4882a593Smuzhiyun 	nsge = sk_msg_elem(msg_npl, j);
617*4882a593Smuzhiyun 	if (tmp.length) {
618*4882a593Smuzhiyun 		memcpy(nsge, &tmp, sizeof(*nsge));
619*4882a593Smuzhiyun 		sk_msg_iter_var_next(j);
620*4882a593Smuzhiyun 		nsge = sk_msg_elem(msg_npl, j);
621*4882a593Smuzhiyun 	}
622*4882a593Smuzhiyun 
623*4882a593Smuzhiyun 	osge = sk_msg_elem(msg_opl, i);
624*4882a593Smuzhiyun 	while (osge->length) {
625*4882a593Smuzhiyun 		memcpy(nsge, osge, sizeof(*nsge));
626*4882a593Smuzhiyun 		sg_unmark_end(nsge);
627*4882a593Smuzhiyun 		sk_msg_iter_var_next(i);
628*4882a593Smuzhiyun 		sk_msg_iter_var_next(j);
629*4882a593Smuzhiyun 		if (i == *orig_end)
630*4882a593Smuzhiyun 			break;
631*4882a593Smuzhiyun 		osge = sk_msg_elem(msg_opl, i);
632*4882a593Smuzhiyun 		nsge = sk_msg_elem(msg_npl, j);
633*4882a593Smuzhiyun 	}
634*4882a593Smuzhiyun 
635*4882a593Smuzhiyun 	msg_npl->sg.end = j;
636*4882a593Smuzhiyun 	msg_npl->sg.curr = j;
637*4882a593Smuzhiyun 	msg_npl->sg.copybreak = 0;
638*4882a593Smuzhiyun 
639*4882a593Smuzhiyun 	*to = new;
640*4882a593Smuzhiyun 	return 0;
641*4882a593Smuzhiyun }
642*4882a593Smuzhiyun 
tls_merge_open_record(struct sock * sk,struct tls_rec * to,struct tls_rec * from,u32 orig_end)643*4882a593Smuzhiyun static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
644*4882a593Smuzhiyun 				  struct tls_rec *from, u32 orig_end)
645*4882a593Smuzhiyun {
646*4882a593Smuzhiyun 	struct sk_msg *msg_npl = &from->msg_plaintext;
647*4882a593Smuzhiyun 	struct sk_msg *msg_opl = &to->msg_plaintext;
648*4882a593Smuzhiyun 	struct scatterlist *osge, *nsge;
649*4882a593Smuzhiyun 	u32 i, j;
650*4882a593Smuzhiyun 
651*4882a593Smuzhiyun 	i = msg_opl->sg.end;
652*4882a593Smuzhiyun 	sk_msg_iter_var_prev(i);
653*4882a593Smuzhiyun 	j = msg_npl->sg.start;
654*4882a593Smuzhiyun 
655*4882a593Smuzhiyun 	osge = sk_msg_elem(msg_opl, i);
656*4882a593Smuzhiyun 	nsge = sk_msg_elem(msg_npl, j);
657*4882a593Smuzhiyun 
658*4882a593Smuzhiyun 	if (sg_page(osge) == sg_page(nsge) &&
659*4882a593Smuzhiyun 	    osge->offset + osge->length == nsge->offset) {
660*4882a593Smuzhiyun 		osge->length += nsge->length;
661*4882a593Smuzhiyun 		put_page(sg_page(nsge));
662*4882a593Smuzhiyun 	}
663*4882a593Smuzhiyun 
664*4882a593Smuzhiyun 	msg_opl->sg.end = orig_end;
665*4882a593Smuzhiyun 	msg_opl->sg.curr = orig_end;
666*4882a593Smuzhiyun 	msg_opl->sg.copybreak = 0;
667*4882a593Smuzhiyun 	msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
668*4882a593Smuzhiyun 	msg_opl->sg.size += msg_npl->sg.size;
669*4882a593Smuzhiyun 
670*4882a593Smuzhiyun 	sk_msg_free(sk, &to->msg_encrypted);
671*4882a593Smuzhiyun 	sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
672*4882a593Smuzhiyun 
673*4882a593Smuzhiyun 	kfree(from);
674*4882a593Smuzhiyun }
675*4882a593Smuzhiyun 
tls_push_record(struct sock * sk,int flags,unsigned char record_type)676*4882a593Smuzhiyun static int tls_push_record(struct sock *sk, int flags,
677*4882a593Smuzhiyun 			   unsigned char record_type)
678*4882a593Smuzhiyun {
679*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
680*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
681*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
682*4882a593Smuzhiyun 	struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
683*4882a593Smuzhiyun 	u32 i, split_point, orig_end;
684*4882a593Smuzhiyun 	struct sk_msg *msg_pl, *msg_en;
685*4882a593Smuzhiyun 	struct aead_request *req;
686*4882a593Smuzhiyun 	bool split;
687*4882a593Smuzhiyun 	int rc;
688*4882a593Smuzhiyun 
689*4882a593Smuzhiyun 	if (!rec)
690*4882a593Smuzhiyun 		return 0;
691*4882a593Smuzhiyun 
692*4882a593Smuzhiyun 	msg_pl = &rec->msg_plaintext;
693*4882a593Smuzhiyun 	msg_en = &rec->msg_encrypted;
694*4882a593Smuzhiyun 
695*4882a593Smuzhiyun 	split_point = msg_pl->apply_bytes;
696*4882a593Smuzhiyun 	split = split_point && split_point < msg_pl->sg.size;
697*4882a593Smuzhiyun 	if (unlikely((!split &&
698*4882a593Smuzhiyun 		      msg_pl->sg.size +
699*4882a593Smuzhiyun 		      prot->overhead_size > msg_en->sg.size) ||
700*4882a593Smuzhiyun 		     (split &&
701*4882a593Smuzhiyun 		      split_point +
702*4882a593Smuzhiyun 		      prot->overhead_size > msg_en->sg.size))) {
703*4882a593Smuzhiyun 		split = true;
704*4882a593Smuzhiyun 		split_point = msg_en->sg.size;
705*4882a593Smuzhiyun 	}
706*4882a593Smuzhiyun 	if (split) {
707*4882a593Smuzhiyun 		rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
708*4882a593Smuzhiyun 					   split_point, prot->overhead_size,
709*4882a593Smuzhiyun 					   &orig_end);
710*4882a593Smuzhiyun 		if (rc < 0)
711*4882a593Smuzhiyun 			return rc;
712*4882a593Smuzhiyun 		/* This can happen if above tls_split_open_record allocates
713*4882a593Smuzhiyun 		 * a single large encryption buffer instead of two smaller
714*4882a593Smuzhiyun 		 * ones. In this case adjust pointers and continue without
715*4882a593Smuzhiyun 		 * split.
716*4882a593Smuzhiyun 		 */
717*4882a593Smuzhiyun 		if (!msg_pl->sg.size) {
718*4882a593Smuzhiyun 			tls_merge_open_record(sk, rec, tmp, orig_end);
719*4882a593Smuzhiyun 			msg_pl = &rec->msg_plaintext;
720*4882a593Smuzhiyun 			msg_en = &rec->msg_encrypted;
721*4882a593Smuzhiyun 			split = false;
722*4882a593Smuzhiyun 		}
723*4882a593Smuzhiyun 		sk_msg_trim(sk, msg_en, msg_pl->sg.size +
724*4882a593Smuzhiyun 			    prot->overhead_size);
725*4882a593Smuzhiyun 	}
726*4882a593Smuzhiyun 
727*4882a593Smuzhiyun 	rec->tx_flags = flags;
728*4882a593Smuzhiyun 	req = &rec->aead_req;
729*4882a593Smuzhiyun 
730*4882a593Smuzhiyun 	i = msg_pl->sg.end;
731*4882a593Smuzhiyun 	sk_msg_iter_var_prev(i);
732*4882a593Smuzhiyun 
733*4882a593Smuzhiyun 	rec->content_type = record_type;
734*4882a593Smuzhiyun 	if (prot->version == TLS_1_3_VERSION) {
735*4882a593Smuzhiyun 		/* Add content type to end of message.  No padding added */
736*4882a593Smuzhiyun 		sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
737*4882a593Smuzhiyun 		sg_mark_end(&rec->sg_content_type);
738*4882a593Smuzhiyun 		sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
739*4882a593Smuzhiyun 			 &rec->sg_content_type);
740*4882a593Smuzhiyun 	} else {
741*4882a593Smuzhiyun 		sg_mark_end(sk_msg_elem(msg_pl, i));
742*4882a593Smuzhiyun 	}
743*4882a593Smuzhiyun 
744*4882a593Smuzhiyun 	if (msg_pl->sg.end < msg_pl->sg.start) {
745*4882a593Smuzhiyun 		sg_chain(&msg_pl->sg.data[msg_pl->sg.start],
746*4882a593Smuzhiyun 			 MAX_SKB_FRAGS - msg_pl->sg.start + 1,
747*4882a593Smuzhiyun 			 msg_pl->sg.data);
748*4882a593Smuzhiyun 	}
749*4882a593Smuzhiyun 
750*4882a593Smuzhiyun 	i = msg_pl->sg.start;
751*4882a593Smuzhiyun 	sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
752*4882a593Smuzhiyun 
753*4882a593Smuzhiyun 	i = msg_en->sg.end;
754*4882a593Smuzhiyun 	sk_msg_iter_var_prev(i);
755*4882a593Smuzhiyun 	sg_mark_end(sk_msg_elem(msg_en, i));
756*4882a593Smuzhiyun 
757*4882a593Smuzhiyun 	i = msg_en->sg.start;
758*4882a593Smuzhiyun 	sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
759*4882a593Smuzhiyun 
760*4882a593Smuzhiyun 	tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
761*4882a593Smuzhiyun 		     tls_ctx->tx.rec_seq, prot->rec_seq_size,
762*4882a593Smuzhiyun 		     record_type, prot->version);
763*4882a593Smuzhiyun 
764*4882a593Smuzhiyun 	tls_fill_prepend(tls_ctx,
765*4882a593Smuzhiyun 			 page_address(sg_page(&msg_en->sg.data[i])) +
766*4882a593Smuzhiyun 			 msg_en->sg.data[i].offset,
767*4882a593Smuzhiyun 			 msg_pl->sg.size + prot->tail_size,
768*4882a593Smuzhiyun 			 record_type, prot->version);
769*4882a593Smuzhiyun 
770*4882a593Smuzhiyun 	tls_ctx->pending_open_record_frags = false;
771*4882a593Smuzhiyun 
772*4882a593Smuzhiyun 	rc = tls_do_encryption(sk, tls_ctx, ctx, req,
773*4882a593Smuzhiyun 			       msg_pl->sg.size + prot->tail_size, i);
774*4882a593Smuzhiyun 	if (rc < 0) {
775*4882a593Smuzhiyun 		if (rc != -EINPROGRESS) {
776*4882a593Smuzhiyun 			tls_err_abort(sk, -EBADMSG);
777*4882a593Smuzhiyun 			if (split) {
778*4882a593Smuzhiyun 				tls_ctx->pending_open_record_frags = true;
779*4882a593Smuzhiyun 				tls_merge_open_record(sk, rec, tmp, orig_end);
780*4882a593Smuzhiyun 			}
781*4882a593Smuzhiyun 		}
782*4882a593Smuzhiyun 		ctx->async_capable = 1;
783*4882a593Smuzhiyun 		return rc;
784*4882a593Smuzhiyun 	} else if (split) {
785*4882a593Smuzhiyun 		msg_pl = &tmp->msg_plaintext;
786*4882a593Smuzhiyun 		msg_en = &tmp->msg_encrypted;
787*4882a593Smuzhiyun 		sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
788*4882a593Smuzhiyun 		tls_ctx->pending_open_record_frags = true;
789*4882a593Smuzhiyun 		ctx->open_rec = tmp;
790*4882a593Smuzhiyun 	}
791*4882a593Smuzhiyun 
792*4882a593Smuzhiyun 	return tls_tx_records(sk, flags);
793*4882a593Smuzhiyun }
794*4882a593Smuzhiyun 
bpf_exec_tx_verdict(struct sk_msg * msg,struct sock * sk,bool full_record,u8 record_type,ssize_t * copied,int flags)795*4882a593Smuzhiyun static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
796*4882a593Smuzhiyun 			       bool full_record, u8 record_type,
797*4882a593Smuzhiyun 			       ssize_t *copied, int flags)
798*4882a593Smuzhiyun {
799*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
800*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
801*4882a593Smuzhiyun 	struct sk_msg msg_redir = { };
802*4882a593Smuzhiyun 	struct sk_psock *psock;
803*4882a593Smuzhiyun 	struct sock *sk_redir;
804*4882a593Smuzhiyun 	struct tls_rec *rec;
805*4882a593Smuzhiyun 	bool enospc, policy;
806*4882a593Smuzhiyun 	int err = 0, send;
807*4882a593Smuzhiyun 	u32 delta = 0;
808*4882a593Smuzhiyun 
809*4882a593Smuzhiyun 	policy = !(flags & MSG_SENDPAGE_NOPOLICY);
810*4882a593Smuzhiyun 	psock = sk_psock_get(sk);
811*4882a593Smuzhiyun 	if (!psock || !policy) {
812*4882a593Smuzhiyun 		err = tls_push_record(sk, flags, record_type);
813*4882a593Smuzhiyun 		if (err && sk->sk_err == EBADMSG) {
814*4882a593Smuzhiyun 			*copied -= sk_msg_free(sk, msg);
815*4882a593Smuzhiyun 			tls_free_open_rec(sk);
816*4882a593Smuzhiyun 			err = -sk->sk_err;
817*4882a593Smuzhiyun 		}
818*4882a593Smuzhiyun 		if (psock)
819*4882a593Smuzhiyun 			sk_psock_put(sk, psock);
820*4882a593Smuzhiyun 		return err;
821*4882a593Smuzhiyun 	}
822*4882a593Smuzhiyun more_data:
823*4882a593Smuzhiyun 	enospc = sk_msg_full(msg);
824*4882a593Smuzhiyun 	if (psock->eval == __SK_NONE) {
825*4882a593Smuzhiyun 		delta = msg->sg.size;
826*4882a593Smuzhiyun 		psock->eval = sk_psock_msg_verdict(sk, psock, msg);
827*4882a593Smuzhiyun 		delta -= msg->sg.size;
828*4882a593Smuzhiyun 	}
829*4882a593Smuzhiyun 	if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
830*4882a593Smuzhiyun 	    !enospc && !full_record) {
831*4882a593Smuzhiyun 		err = -ENOSPC;
832*4882a593Smuzhiyun 		goto out_err;
833*4882a593Smuzhiyun 	}
834*4882a593Smuzhiyun 	msg->cork_bytes = 0;
835*4882a593Smuzhiyun 	send = msg->sg.size;
836*4882a593Smuzhiyun 	if (msg->apply_bytes && msg->apply_bytes < send)
837*4882a593Smuzhiyun 		send = msg->apply_bytes;
838*4882a593Smuzhiyun 
839*4882a593Smuzhiyun 	switch (psock->eval) {
840*4882a593Smuzhiyun 	case __SK_PASS:
841*4882a593Smuzhiyun 		err = tls_push_record(sk, flags, record_type);
842*4882a593Smuzhiyun 		if (err && sk->sk_err == EBADMSG) {
843*4882a593Smuzhiyun 			*copied -= sk_msg_free(sk, msg);
844*4882a593Smuzhiyun 			tls_free_open_rec(sk);
845*4882a593Smuzhiyun 			err = -sk->sk_err;
846*4882a593Smuzhiyun 			goto out_err;
847*4882a593Smuzhiyun 		}
848*4882a593Smuzhiyun 		break;
849*4882a593Smuzhiyun 	case __SK_REDIRECT:
850*4882a593Smuzhiyun 		sk_redir = psock->sk_redir;
851*4882a593Smuzhiyun 		memcpy(&msg_redir, msg, sizeof(*msg));
852*4882a593Smuzhiyun 		if (msg->apply_bytes < send)
853*4882a593Smuzhiyun 			msg->apply_bytes = 0;
854*4882a593Smuzhiyun 		else
855*4882a593Smuzhiyun 			msg->apply_bytes -= send;
856*4882a593Smuzhiyun 		sk_msg_return_zero(sk, msg, send);
857*4882a593Smuzhiyun 		msg->sg.size -= send;
858*4882a593Smuzhiyun 		release_sock(sk);
859*4882a593Smuzhiyun 		err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
860*4882a593Smuzhiyun 		lock_sock(sk);
861*4882a593Smuzhiyun 		if (err < 0) {
862*4882a593Smuzhiyun 			*copied -= sk_msg_free_nocharge(sk, &msg_redir);
863*4882a593Smuzhiyun 			msg->sg.size = 0;
864*4882a593Smuzhiyun 		}
865*4882a593Smuzhiyun 		if (msg->sg.size == 0)
866*4882a593Smuzhiyun 			tls_free_open_rec(sk);
867*4882a593Smuzhiyun 		break;
868*4882a593Smuzhiyun 	case __SK_DROP:
869*4882a593Smuzhiyun 	default:
870*4882a593Smuzhiyun 		sk_msg_free_partial(sk, msg, send);
871*4882a593Smuzhiyun 		if (msg->apply_bytes < send)
872*4882a593Smuzhiyun 			msg->apply_bytes = 0;
873*4882a593Smuzhiyun 		else
874*4882a593Smuzhiyun 			msg->apply_bytes -= send;
875*4882a593Smuzhiyun 		if (msg->sg.size == 0)
876*4882a593Smuzhiyun 			tls_free_open_rec(sk);
877*4882a593Smuzhiyun 		*copied -= (send + delta);
878*4882a593Smuzhiyun 		err = -EACCES;
879*4882a593Smuzhiyun 	}
880*4882a593Smuzhiyun 
881*4882a593Smuzhiyun 	if (likely(!err)) {
882*4882a593Smuzhiyun 		bool reset_eval = !ctx->open_rec;
883*4882a593Smuzhiyun 
884*4882a593Smuzhiyun 		rec = ctx->open_rec;
885*4882a593Smuzhiyun 		if (rec) {
886*4882a593Smuzhiyun 			msg = &rec->msg_plaintext;
887*4882a593Smuzhiyun 			if (!msg->apply_bytes)
888*4882a593Smuzhiyun 				reset_eval = true;
889*4882a593Smuzhiyun 		}
890*4882a593Smuzhiyun 		if (reset_eval) {
891*4882a593Smuzhiyun 			psock->eval = __SK_NONE;
892*4882a593Smuzhiyun 			if (psock->sk_redir) {
893*4882a593Smuzhiyun 				sock_put(psock->sk_redir);
894*4882a593Smuzhiyun 				psock->sk_redir = NULL;
895*4882a593Smuzhiyun 			}
896*4882a593Smuzhiyun 		}
897*4882a593Smuzhiyun 		if (rec)
898*4882a593Smuzhiyun 			goto more_data;
899*4882a593Smuzhiyun 	}
900*4882a593Smuzhiyun  out_err:
901*4882a593Smuzhiyun 	sk_psock_put(sk, psock);
902*4882a593Smuzhiyun 	return err;
903*4882a593Smuzhiyun }
904*4882a593Smuzhiyun 
tls_sw_push_pending_record(struct sock * sk,int flags)905*4882a593Smuzhiyun static int tls_sw_push_pending_record(struct sock *sk, int flags)
906*4882a593Smuzhiyun {
907*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
908*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
909*4882a593Smuzhiyun 	struct tls_rec *rec = ctx->open_rec;
910*4882a593Smuzhiyun 	struct sk_msg *msg_pl;
911*4882a593Smuzhiyun 	size_t copied;
912*4882a593Smuzhiyun 
913*4882a593Smuzhiyun 	if (!rec)
914*4882a593Smuzhiyun 		return 0;
915*4882a593Smuzhiyun 
916*4882a593Smuzhiyun 	msg_pl = &rec->msg_plaintext;
917*4882a593Smuzhiyun 	copied = msg_pl->sg.size;
918*4882a593Smuzhiyun 	if (!copied)
919*4882a593Smuzhiyun 		return 0;
920*4882a593Smuzhiyun 
921*4882a593Smuzhiyun 	return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
922*4882a593Smuzhiyun 				   &copied, flags);
923*4882a593Smuzhiyun }
924*4882a593Smuzhiyun 
tls_sw_sendmsg(struct sock * sk,struct msghdr * msg,size_t size)925*4882a593Smuzhiyun int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
926*4882a593Smuzhiyun {
927*4882a593Smuzhiyun 	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
928*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
929*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
930*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
931*4882a593Smuzhiyun 	bool async_capable = ctx->async_capable;
932*4882a593Smuzhiyun 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
933*4882a593Smuzhiyun 	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
934*4882a593Smuzhiyun 	bool eor = !(msg->msg_flags & MSG_MORE);
935*4882a593Smuzhiyun 	size_t try_to_copy;
936*4882a593Smuzhiyun 	ssize_t copied = 0;
937*4882a593Smuzhiyun 	struct sk_msg *msg_pl, *msg_en;
938*4882a593Smuzhiyun 	struct tls_rec *rec;
939*4882a593Smuzhiyun 	int required_size;
940*4882a593Smuzhiyun 	int num_async = 0;
941*4882a593Smuzhiyun 	bool full_record;
942*4882a593Smuzhiyun 	int record_room;
943*4882a593Smuzhiyun 	int num_zc = 0;
944*4882a593Smuzhiyun 	int orig_size;
945*4882a593Smuzhiyun 	int ret = 0;
946*4882a593Smuzhiyun 	int pending;
947*4882a593Smuzhiyun 
948*4882a593Smuzhiyun 	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
949*4882a593Smuzhiyun 			       MSG_CMSG_COMPAT))
950*4882a593Smuzhiyun 		return -EOPNOTSUPP;
951*4882a593Smuzhiyun 
952*4882a593Smuzhiyun 	mutex_lock(&tls_ctx->tx_lock);
953*4882a593Smuzhiyun 	lock_sock(sk);
954*4882a593Smuzhiyun 
955*4882a593Smuzhiyun 	if (unlikely(msg->msg_controllen)) {
956*4882a593Smuzhiyun 		ret = tls_proccess_cmsg(sk, msg, &record_type);
957*4882a593Smuzhiyun 		if (ret) {
958*4882a593Smuzhiyun 			if (ret == -EINPROGRESS)
959*4882a593Smuzhiyun 				num_async++;
960*4882a593Smuzhiyun 			else if (ret != -EAGAIN)
961*4882a593Smuzhiyun 				goto send_end;
962*4882a593Smuzhiyun 		}
963*4882a593Smuzhiyun 	}
964*4882a593Smuzhiyun 
965*4882a593Smuzhiyun 	while (msg_data_left(msg)) {
966*4882a593Smuzhiyun 		if (sk->sk_err) {
967*4882a593Smuzhiyun 			ret = -sk->sk_err;
968*4882a593Smuzhiyun 			goto send_end;
969*4882a593Smuzhiyun 		}
970*4882a593Smuzhiyun 
971*4882a593Smuzhiyun 		if (ctx->open_rec)
972*4882a593Smuzhiyun 			rec = ctx->open_rec;
973*4882a593Smuzhiyun 		else
974*4882a593Smuzhiyun 			rec = ctx->open_rec = tls_get_rec(sk);
975*4882a593Smuzhiyun 		if (!rec) {
976*4882a593Smuzhiyun 			ret = -ENOMEM;
977*4882a593Smuzhiyun 			goto send_end;
978*4882a593Smuzhiyun 		}
979*4882a593Smuzhiyun 
980*4882a593Smuzhiyun 		msg_pl = &rec->msg_plaintext;
981*4882a593Smuzhiyun 		msg_en = &rec->msg_encrypted;
982*4882a593Smuzhiyun 
983*4882a593Smuzhiyun 		orig_size = msg_pl->sg.size;
984*4882a593Smuzhiyun 		full_record = false;
985*4882a593Smuzhiyun 		try_to_copy = msg_data_left(msg);
986*4882a593Smuzhiyun 		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
987*4882a593Smuzhiyun 		if (try_to_copy >= record_room) {
988*4882a593Smuzhiyun 			try_to_copy = record_room;
989*4882a593Smuzhiyun 			full_record = true;
990*4882a593Smuzhiyun 		}
991*4882a593Smuzhiyun 
992*4882a593Smuzhiyun 		required_size = msg_pl->sg.size + try_to_copy +
993*4882a593Smuzhiyun 				prot->overhead_size;
994*4882a593Smuzhiyun 
995*4882a593Smuzhiyun 		if (!sk_stream_memory_free(sk))
996*4882a593Smuzhiyun 			goto wait_for_sndbuf;
997*4882a593Smuzhiyun 
998*4882a593Smuzhiyun alloc_encrypted:
999*4882a593Smuzhiyun 		ret = tls_alloc_encrypted_msg(sk, required_size);
1000*4882a593Smuzhiyun 		if (ret) {
1001*4882a593Smuzhiyun 			if (ret != -ENOSPC)
1002*4882a593Smuzhiyun 				goto wait_for_memory;
1003*4882a593Smuzhiyun 
1004*4882a593Smuzhiyun 			/* Adjust try_to_copy according to the amount that was
1005*4882a593Smuzhiyun 			 * actually allocated. The difference is due
1006*4882a593Smuzhiyun 			 * to max sg elements limit
1007*4882a593Smuzhiyun 			 */
1008*4882a593Smuzhiyun 			try_to_copy -= required_size - msg_en->sg.size;
1009*4882a593Smuzhiyun 			full_record = true;
1010*4882a593Smuzhiyun 		}
1011*4882a593Smuzhiyun 
1012*4882a593Smuzhiyun 		if (!is_kvec && (full_record || eor) && !async_capable) {
1013*4882a593Smuzhiyun 			u32 first = msg_pl->sg.end;
1014*4882a593Smuzhiyun 
1015*4882a593Smuzhiyun 			ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
1016*4882a593Smuzhiyun 							msg_pl, try_to_copy);
1017*4882a593Smuzhiyun 			if (ret)
1018*4882a593Smuzhiyun 				goto fallback_to_reg_send;
1019*4882a593Smuzhiyun 
1020*4882a593Smuzhiyun 			num_zc++;
1021*4882a593Smuzhiyun 			copied += try_to_copy;
1022*4882a593Smuzhiyun 
1023*4882a593Smuzhiyun 			sk_msg_sg_copy_set(msg_pl, first);
1024*4882a593Smuzhiyun 			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1025*4882a593Smuzhiyun 						  record_type, &copied,
1026*4882a593Smuzhiyun 						  msg->msg_flags);
1027*4882a593Smuzhiyun 			if (ret) {
1028*4882a593Smuzhiyun 				if (ret == -EINPROGRESS)
1029*4882a593Smuzhiyun 					num_async++;
1030*4882a593Smuzhiyun 				else if (ret == -ENOMEM)
1031*4882a593Smuzhiyun 					goto wait_for_memory;
1032*4882a593Smuzhiyun 				else if (ctx->open_rec && ret == -ENOSPC)
1033*4882a593Smuzhiyun 					goto rollback_iter;
1034*4882a593Smuzhiyun 				else if (ret != -EAGAIN)
1035*4882a593Smuzhiyun 					goto send_end;
1036*4882a593Smuzhiyun 			}
1037*4882a593Smuzhiyun 			continue;
1038*4882a593Smuzhiyun rollback_iter:
1039*4882a593Smuzhiyun 			copied -= try_to_copy;
1040*4882a593Smuzhiyun 			sk_msg_sg_copy_clear(msg_pl, first);
1041*4882a593Smuzhiyun 			iov_iter_revert(&msg->msg_iter,
1042*4882a593Smuzhiyun 					msg_pl->sg.size - orig_size);
1043*4882a593Smuzhiyun fallback_to_reg_send:
1044*4882a593Smuzhiyun 			sk_msg_trim(sk, msg_pl, orig_size);
1045*4882a593Smuzhiyun 		}
1046*4882a593Smuzhiyun 
1047*4882a593Smuzhiyun 		required_size = msg_pl->sg.size + try_to_copy;
1048*4882a593Smuzhiyun 
1049*4882a593Smuzhiyun 		ret = tls_clone_plaintext_msg(sk, required_size);
1050*4882a593Smuzhiyun 		if (ret) {
1051*4882a593Smuzhiyun 			if (ret != -ENOSPC)
1052*4882a593Smuzhiyun 				goto send_end;
1053*4882a593Smuzhiyun 
1054*4882a593Smuzhiyun 			/* Adjust try_to_copy according to the amount that was
1055*4882a593Smuzhiyun 			 * actually allocated. The difference is due
1056*4882a593Smuzhiyun 			 * to max sg elements limit
1057*4882a593Smuzhiyun 			 */
1058*4882a593Smuzhiyun 			try_to_copy -= required_size - msg_pl->sg.size;
1059*4882a593Smuzhiyun 			full_record = true;
1060*4882a593Smuzhiyun 			sk_msg_trim(sk, msg_en,
1061*4882a593Smuzhiyun 				    msg_pl->sg.size + prot->overhead_size);
1062*4882a593Smuzhiyun 		}
1063*4882a593Smuzhiyun 
1064*4882a593Smuzhiyun 		if (try_to_copy) {
1065*4882a593Smuzhiyun 			ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
1066*4882a593Smuzhiyun 						       msg_pl, try_to_copy);
1067*4882a593Smuzhiyun 			if (ret < 0)
1068*4882a593Smuzhiyun 				goto trim_sgl;
1069*4882a593Smuzhiyun 		}
1070*4882a593Smuzhiyun 
1071*4882a593Smuzhiyun 		/* Open records defined only if successfully copied, otherwise
1072*4882a593Smuzhiyun 		 * we would trim the sg but not reset the open record frags.
1073*4882a593Smuzhiyun 		 */
1074*4882a593Smuzhiyun 		tls_ctx->pending_open_record_frags = true;
1075*4882a593Smuzhiyun 		copied += try_to_copy;
1076*4882a593Smuzhiyun 		if (full_record || eor) {
1077*4882a593Smuzhiyun 			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1078*4882a593Smuzhiyun 						  record_type, &copied,
1079*4882a593Smuzhiyun 						  msg->msg_flags);
1080*4882a593Smuzhiyun 			if (ret) {
1081*4882a593Smuzhiyun 				if (ret == -EINPROGRESS)
1082*4882a593Smuzhiyun 					num_async++;
1083*4882a593Smuzhiyun 				else if (ret == -ENOMEM)
1084*4882a593Smuzhiyun 					goto wait_for_memory;
1085*4882a593Smuzhiyun 				else if (ret != -EAGAIN) {
1086*4882a593Smuzhiyun 					if (ret == -ENOSPC)
1087*4882a593Smuzhiyun 						ret = 0;
1088*4882a593Smuzhiyun 					goto send_end;
1089*4882a593Smuzhiyun 				}
1090*4882a593Smuzhiyun 			}
1091*4882a593Smuzhiyun 		}
1092*4882a593Smuzhiyun 
1093*4882a593Smuzhiyun 		continue;
1094*4882a593Smuzhiyun 
1095*4882a593Smuzhiyun wait_for_sndbuf:
1096*4882a593Smuzhiyun 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1097*4882a593Smuzhiyun wait_for_memory:
1098*4882a593Smuzhiyun 		ret = sk_stream_wait_memory(sk, &timeo);
1099*4882a593Smuzhiyun 		if (ret) {
1100*4882a593Smuzhiyun trim_sgl:
1101*4882a593Smuzhiyun 			if (ctx->open_rec)
1102*4882a593Smuzhiyun 				tls_trim_both_msgs(sk, orig_size);
1103*4882a593Smuzhiyun 			goto send_end;
1104*4882a593Smuzhiyun 		}
1105*4882a593Smuzhiyun 
1106*4882a593Smuzhiyun 		if (ctx->open_rec && msg_en->sg.size < required_size)
1107*4882a593Smuzhiyun 			goto alloc_encrypted;
1108*4882a593Smuzhiyun 	}
1109*4882a593Smuzhiyun 
1110*4882a593Smuzhiyun 	if (!num_async) {
1111*4882a593Smuzhiyun 		goto send_end;
1112*4882a593Smuzhiyun 	} else if (num_zc) {
1113*4882a593Smuzhiyun 		/* Wait for pending encryptions to get completed */
1114*4882a593Smuzhiyun 		spin_lock_bh(&ctx->encrypt_compl_lock);
1115*4882a593Smuzhiyun 		ctx->async_notify = true;
1116*4882a593Smuzhiyun 
1117*4882a593Smuzhiyun 		pending = atomic_read(&ctx->encrypt_pending);
1118*4882a593Smuzhiyun 		spin_unlock_bh(&ctx->encrypt_compl_lock);
1119*4882a593Smuzhiyun 		if (pending)
1120*4882a593Smuzhiyun 			crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1121*4882a593Smuzhiyun 		else
1122*4882a593Smuzhiyun 			reinit_completion(&ctx->async_wait.completion);
1123*4882a593Smuzhiyun 
1124*4882a593Smuzhiyun 		/* There can be no concurrent accesses, since we have no
1125*4882a593Smuzhiyun 		 * pending encrypt operations
1126*4882a593Smuzhiyun 		 */
1127*4882a593Smuzhiyun 		WRITE_ONCE(ctx->async_notify, false);
1128*4882a593Smuzhiyun 
1129*4882a593Smuzhiyun 		if (ctx->async_wait.err) {
1130*4882a593Smuzhiyun 			ret = ctx->async_wait.err;
1131*4882a593Smuzhiyun 			copied = 0;
1132*4882a593Smuzhiyun 		}
1133*4882a593Smuzhiyun 	}
1134*4882a593Smuzhiyun 
1135*4882a593Smuzhiyun 	/* Transmit if any encryptions have completed */
1136*4882a593Smuzhiyun 	if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1137*4882a593Smuzhiyun 		cancel_delayed_work(&ctx->tx_work.work);
1138*4882a593Smuzhiyun 		tls_tx_records(sk, msg->msg_flags);
1139*4882a593Smuzhiyun 	}
1140*4882a593Smuzhiyun 
1141*4882a593Smuzhiyun send_end:
1142*4882a593Smuzhiyun 	ret = sk_stream_error(sk, msg->msg_flags, ret);
1143*4882a593Smuzhiyun 
1144*4882a593Smuzhiyun 	release_sock(sk);
1145*4882a593Smuzhiyun 	mutex_unlock(&tls_ctx->tx_lock);
1146*4882a593Smuzhiyun 	return copied > 0 ? copied : ret;
1147*4882a593Smuzhiyun }
1148*4882a593Smuzhiyun 
tls_sw_do_sendpage(struct sock * sk,struct page * page,int offset,size_t size,int flags)1149*4882a593Smuzhiyun static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
1150*4882a593Smuzhiyun 			      int offset, size_t size, int flags)
1151*4882a593Smuzhiyun {
1152*4882a593Smuzhiyun 	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
1153*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1154*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1155*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
1156*4882a593Smuzhiyun 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
1157*4882a593Smuzhiyun 	struct sk_msg *msg_pl;
1158*4882a593Smuzhiyun 	struct tls_rec *rec;
1159*4882a593Smuzhiyun 	int num_async = 0;
1160*4882a593Smuzhiyun 	ssize_t copied = 0;
1161*4882a593Smuzhiyun 	bool full_record;
1162*4882a593Smuzhiyun 	int record_room;
1163*4882a593Smuzhiyun 	int ret = 0;
1164*4882a593Smuzhiyun 	bool eor;
1165*4882a593Smuzhiyun 
1166*4882a593Smuzhiyun 	eor = !(flags & MSG_SENDPAGE_NOTLAST);
1167*4882a593Smuzhiyun 	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
1168*4882a593Smuzhiyun 
1169*4882a593Smuzhiyun 	/* Call the sk_stream functions to manage the sndbuf mem. */
1170*4882a593Smuzhiyun 	while (size > 0) {
1171*4882a593Smuzhiyun 		size_t copy, required_size;
1172*4882a593Smuzhiyun 
1173*4882a593Smuzhiyun 		if (sk->sk_err) {
1174*4882a593Smuzhiyun 			ret = -sk->sk_err;
1175*4882a593Smuzhiyun 			goto sendpage_end;
1176*4882a593Smuzhiyun 		}
1177*4882a593Smuzhiyun 
1178*4882a593Smuzhiyun 		if (ctx->open_rec)
1179*4882a593Smuzhiyun 			rec = ctx->open_rec;
1180*4882a593Smuzhiyun 		else
1181*4882a593Smuzhiyun 			rec = ctx->open_rec = tls_get_rec(sk);
1182*4882a593Smuzhiyun 		if (!rec) {
1183*4882a593Smuzhiyun 			ret = -ENOMEM;
1184*4882a593Smuzhiyun 			goto sendpage_end;
1185*4882a593Smuzhiyun 		}
1186*4882a593Smuzhiyun 
1187*4882a593Smuzhiyun 		msg_pl = &rec->msg_plaintext;
1188*4882a593Smuzhiyun 
1189*4882a593Smuzhiyun 		full_record = false;
1190*4882a593Smuzhiyun 		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
1191*4882a593Smuzhiyun 		copy = size;
1192*4882a593Smuzhiyun 		if (copy >= record_room) {
1193*4882a593Smuzhiyun 			copy = record_room;
1194*4882a593Smuzhiyun 			full_record = true;
1195*4882a593Smuzhiyun 		}
1196*4882a593Smuzhiyun 
1197*4882a593Smuzhiyun 		required_size = msg_pl->sg.size + copy + prot->overhead_size;
1198*4882a593Smuzhiyun 
1199*4882a593Smuzhiyun 		if (!sk_stream_memory_free(sk))
1200*4882a593Smuzhiyun 			goto wait_for_sndbuf;
1201*4882a593Smuzhiyun alloc_payload:
1202*4882a593Smuzhiyun 		ret = tls_alloc_encrypted_msg(sk, required_size);
1203*4882a593Smuzhiyun 		if (ret) {
1204*4882a593Smuzhiyun 			if (ret != -ENOSPC)
1205*4882a593Smuzhiyun 				goto wait_for_memory;
1206*4882a593Smuzhiyun 
1207*4882a593Smuzhiyun 			/* Adjust copy according to the amount that was
1208*4882a593Smuzhiyun 			 * actually allocated. The difference is due
1209*4882a593Smuzhiyun 			 * to max sg elements limit
1210*4882a593Smuzhiyun 			 */
1211*4882a593Smuzhiyun 			copy -= required_size - msg_pl->sg.size;
1212*4882a593Smuzhiyun 			full_record = true;
1213*4882a593Smuzhiyun 		}
1214*4882a593Smuzhiyun 
1215*4882a593Smuzhiyun 		sk_msg_page_add(msg_pl, page, copy, offset);
1216*4882a593Smuzhiyun 		sk_mem_charge(sk, copy);
1217*4882a593Smuzhiyun 
1218*4882a593Smuzhiyun 		offset += copy;
1219*4882a593Smuzhiyun 		size -= copy;
1220*4882a593Smuzhiyun 		copied += copy;
1221*4882a593Smuzhiyun 
1222*4882a593Smuzhiyun 		tls_ctx->pending_open_record_frags = true;
1223*4882a593Smuzhiyun 		if (full_record || eor || sk_msg_full(msg_pl)) {
1224*4882a593Smuzhiyun 			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1225*4882a593Smuzhiyun 						  record_type, &copied, flags);
1226*4882a593Smuzhiyun 			if (ret) {
1227*4882a593Smuzhiyun 				if (ret == -EINPROGRESS)
1228*4882a593Smuzhiyun 					num_async++;
1229*4882a593Smuzhiyun 				else if (ret == -ENOMEM)
1230*4882a593Smuzhiyun 					goto wait_for_memory;
1231*4882a593Smuzhiyun 				else if (ret != -EAGAIN) {
1232*4882a593Smuzhiyun 					if (ret == -ENOSPC)
1233*4882a593Smuzhiyun 						ret = 0;
1234*4882a593Smuzhiyun 					goto sendpage_end;
1235*4882a593Smuzhiyun 				}
1236*4882a593Smuzhiyun 			}
1237*4882a593Smuzhiyun 		}
1238*4882a593Smuzhiyun 		continue;
1239*4882a593Smuzhiyun wait_for_sndbuf:
1240*4882a593Smuzhiyun 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1241*4882a593Smuzhiyun wait_for_memory:
1242*4882a593Smuzhiyun 		ret = sk_stream_wait_memory(sk, &timeo);
1243*4882a593Smuzhiyun 		if (ret) {
1244*4882a593Smuzhiyun 			if (ctx->open_rec)
1245*4882a593Smuzhiyun 				tls_trim_both_msgs(sk, msg_pl->sg.size);
1246*4882a593Smuzhiyun 			goto sendpage_end;
1247*4882a593Smuzhiyun 		}
1248*4882a593Smuzhiyun 
1249*4882a593Smuzhiyun 		if (ctx->open_rec)
1250*4882a593Smuzhiyun 			goto alloc_payload;
1251*4882a593Smuzhiyun 	}
1252*4882a593Smuzhiyun 
1253*4882a593Smuzhiyun 	if (num_async) {
1254*4882a593Smuzhiyun 		/* Transmit if any encryptions have completed */
1255*4882a593Smuzhiyun 		if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1256*4882a593Smuzhiyun 			cancel_delayed_work(&ctx->tx_work.work);
1257*4882a593Smuzhiyun 			tls_tx_records(sk, flags);
1258*4882a593Smuzhiyun 		}
1259*4882a593Smuzhiyun 	}
1260*4882a593Smuzhiyun sendpage_end:
1261*4882a593Smuzhiyun 	ret = sk_stream_error(sk, flags, ret);
1262*4882a593Smuzhiyun 	return copied > 0 ? copied : ret;
1263*4882a593Smuzhiyun }
1264*4882a593Smuzhiyun 
tls_sw_sendpage_locked(struct sock * sk,struct page * page,int offset,size_t size,int flags)1265*4882a593Smuzhiyun int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
1266*4882a593Smuzhiyun 			   int offset, size_t size, int flags)
1267*4882a593Smuzhiyun {
1268*4882a593Smuzhiyun 	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1269*4882a593Smuzhiyun 		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
1270*4882a593Smuzhiyun 		      MSG_NO_SHARED_FRAGS))
1271*4882a593Smuzhiyun 		return -EOPNOTSUPP;
1272*4882a593Smuzhiyun 
1273*4882a593Smuzhiyun 	return tls_sw_do_sendpage(sk, page, offset, size, flags);
1274*4882a593Smuzhiyun }
1275*4882a593Smuzhiyun 
tls_sw_sendpage(struct sock * sk,struct page * page,int offset,size_t size,int flags)1276*4882a593Smuzhiyun int tls_sw_sendpage(struct sock *sk, struct page *page,
1277*4882a593Smuzhiyun 		    int offset, size_t size, int flags)
1278*4882a593Smuzhiyun {
1279*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1280*4882a593Smuzhiyun 	int ret;
1281*4882a593Smuzhiyun 
1282*4882a593Smuzhiyun 	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1283*4882a593Smuzhiyun 		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
1284*4882a593Smuzhiyun 		return -EOPNOTSUPP;
1285*4882a593Smuzhiyun 
1286*4882a593Smuzhiyun 	mutex_lock(&tls_ctx->tx_lock);
1287*4882a593Smuzhiyun 	lock_sock(sk);
1288*4882a593Smuzhiyun 	ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
1289*4882a593Smuzhiyun 	release_sock(sk);
1290*4882a593Smuzhiyun 	mutex_unlock(&tls_ctx->tx_lock);
1291*4882a593Smuzhiyun 	return ret;
1292*4882a593Smuzhiyun }
1293*4882a593Smuzhiyun 
tls_wait_data(struct sock * sk,struct sk_psock * psock,bool nonblock,long timeo,int * err)1294*4882a593Smuzhiyun static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
1295*4882a593Smuzhiyun 				     bool nonblock, long timeo, int *err)
1296*4882a593Smuzhiyun {
1297*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1298*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1299*4882a593Smuzhiyun 	struct sk_buff *skb;
1300*4882a593Smuzhiyun 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
1301*4882a593Smuzhiyun 
1302*4882a593Smuzhiyun 	while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
1303*4882a593Smuzhiyun 		if (sk->sk_err) {
1304*4882a593Smuzhiyun 			*err = sock_error(sk);
1305*4882a593Smuzhiyun 			return NULL;
1306*4882a593Smuzhiyun 		}
1307*4882a593Smuzhiyun 
1308*4882a593Smuzhiyun 		if (!skb_queue_empty(&sk->sk_receive_queue)) {
1309*4882a593Smuzhiyun 			__strp_unpause(&ctx->strp);
1310*4882a593Smuzhiyun 			if (ctx->recv_pkt)
1311*4882a593Smuzhiyun 				return ctx->recv_pkt;
1312*4882a593Smuzhiyun 		}
1313*4882a593Smuzhiyun 
1314*4882a593Smuzhiyun 		if (sk->sk_shutdown & RCV_SHUTDOWN)
1315*4882a593Smuzhiyun 			return NULL;
1316*4882a593Smuzhiyun 
1317*4882a593Smuzhiyun 		if (sock_flag(sk, SOCK_DONE))
1318*4882a593Smuzhiyun 			return NULL;
1319*4882a593Smuzhiyun 
1320*4882a593Smuzhiyun 		if (nonblock || !timeo) {
1321*4882a593Smuzhiyun 			*err = -EAGAIN;
1322*4882a593Smuzhiyun 			return NULL;
1323*4882a593Smuzhiyun 		}
1324*4882a593Smuzhiyun 
1325*4882a593Smuzhiyun 		add_wait_queue(sk_sleep(sk), &wait);
1326*4882a593Smuzhiyun 		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1327*4882a593Smuzhiyun 		sk_wait_event(sk, &timeo,
1328*4882a593Smuzhiyun 			      ctx->recv_pkt != skb ||
1329*4882a593Smuzhiyun 			      !sk_psock_queue_empty(psock),
1330*4882a593Smuzhiyun 			      &wait);
1331*4882a593Smuzhiyun 		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1332*4882a593Smuzhiyun 		remove_wait_queue(sk_sleep(sk), &wait);
1333*4882a593Smuzhiyun 
1334*4882a593Smuzhiyun 		/* Handle signals */
1335*4882a593Smuzhiyun 		if (signal_pending(current)) {
1336*4882a593Smuzhiyun 			*err = sock_intr_errno(timeo);
1337*4882a593Smuzhiyun 			return NULL;
1338*4882a593Smuzhiyun 		}
1339*4882a593Smuzhiyun 	}
1340*4882a593Smuzhiyun 
1341*4882a593Smuzhiyun 	return skb;
1342*4882a593Smuzhiyun }
1343*4882a593Smuzhiyun 
tls_setup_from_iter(struct sock * sk,struct iov_iter * from,int length,int * pages_used,unsigned int * size_used,struct scatterlist * to,int to_max_pages)1344*4882a593Smuzhiyun static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
1345*4882a593Smuzhiyun 			       int length, int *pages_used,
1346*4882a593Smuzhiyun 			       unsigned int *size_used,
1347*4882a593Smuzhiyun 			       struct scatterlist *to,
1348*4882a593Smuzhiyun 			       int to_max_pages)
1349*4882a593Smuzhiyun {
1350*4882a593Smuzhiyun 	int rc = 0, i = 0, num_elem = *pages_used, maxpages;
1351*4882a593Smuzhiyun 	struct page *pages[MAX_SKB_FRAGS];
1352*4882a593Smuzhiyun 	unsigned int size = *size_used;
1353*4882a593Smuzhiyun 	ssize_t copied, use;
1354*4882a593Smuzhiyun 	size_t offset;
1355*4882a593Smuzhiyun 
1356*4882a593Smuzhiyun 	while (length > 0) {
1357*4882a593Smuzhiyun 		i = 0;
1358*4882a593Smuzhiyun 		maxpages = to_max_pages - num_elem;
1359*4882a593Smuzhiyun 		if (maxpages == 0) {
1360*4882a593Smuzhiyun 			rc = -EFAULT;
1361*4882a593Smuzhiyun 			goto out;
1362*4882a593Smuzhiyun 		}
1363*4882a593Smuzhiyun 		copied = iov_iter_get_pages(from, pages,
1364*4882a593Smuzhiyun 					    length,
1365*4882a593Smuzhiyun 					    maxpages, &offset);
1366*4882a593Smuzhiyun 		if (copied <= 0) {
1367*4882a593Smuzhiyun 			rc = -EFAULT;
1368*4882a593Smuzhiyun 			goto out;
1369*4882a593Smuzhiyun 		}
1370*4882a593Smuzhiyun 
1371*4882a593Smuzhiyun 		iov_iter_advance(from, copied);
1372*4882a593Smuzhiyun 
1373*4882a593Smuzhiyun 		length -= copied;
1374*4882a593Smuzhiyun 		size += copied;
1375*4882a593Smuzhiyun 		while (copied) {
1376*4882a593Smuzhiyun 			use = min_t(int, copied, PAGE_SIZE - offset);
1377*4882a593Smuzhiyun 
1378*4882a593Smuzhiyun 			sg_set_page(&to[num_elem],
1379*4882a593Smuzhiyun 				    pages[i], use, offset);
1380*4882a593Smuzhiyun 			sg_unmark_end(&to[num_elem]);
1381*4882a593Smuzhiyun 			/* We do not uncharge memory from this API */
1382*4882a593Smuzhiyun 
1383*4882a593Smuzhiyun 			offset = 0;
1384*4882a593Smuzhiyun 			copied -= use;
1385*4882a593Smuzhiyun 
1386*4882a593Smuzhiyun 			i++;
1387*4882a593Smuzhiyun 			num_elem++;
1388*4882a593Smuzhiyun 		}
1389*4882a593Smuzhiyun 	}
1390*4882a593Smuzhiyun 	/* Mark the end in the last sg entry if newly added */
1391*4882a593Smuzhiyun 	if (num_elem > *pages_used)
1392*4882a593Smuzhiyun 		sg_mark_end(&to[num_elem - 1]);
1393*4882a593Smuzhiyun out:
1394*4882a593Smuzhiyun 	if (rc)
1395*4882a593Smuzhiyun 		iov_iter_revert(from, size - *size_used);
1396*4882a593Smuzhiyun 	*size_used = size;
1397*4882a593Smuzhiyun 	*pages_used = num_elem;
1398*4882a593Smuzhiyun 
1399*4882a593Smuzhiyun 	return rc;
1400*4882a593Smuzhiyun }
1401*4882a593Smuzhiyun 
1402*4882a593Smuzhiyun /* This function decrypts the input skb into either out_iov or in out_sg
1403*4882a593Smuzhiyun  * or in skb buffers itself. The input parameter 'zc' indicates if
1404*4882a593Smuzhiyun  * zero-copy mode needs to be tried or not. With zero-copy mode, either
1405*4882a593Smuzhiyun  * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
1406*4882a593Smuzhiyun  * NULL, then the decryption happens inside skb buffers itself, i.e.
1407*4882a593Smuzhiyun  * zero-copy gets disabled and 'zc' is updated.
1408*4882a593Smuzhiyun  */
1409*4882a593Smuzhiyun 
decrypt_internal(struct sock * sk,struct sk_buff * skb,struct iov_iter * out_iov,struct scatterlist * out_sg,int * chunk,bool * zc,bool async)1410*4882a593Smuzhiyun static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
1411*4882a593Smuzhiyun 			    struct iov_iter *out_iov,
1412*4882a593Smuzhiyun 			    struct scatterlist *out_sg,
1413*4882a593Smuzhiyun 			    int *chunk, bool *zc, bool async)
1414*4882a593Smuzhiyun {
1415*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1416*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1417*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
1418*4882a593Smuzhiyun 	struct strp_msg *rxm = strp_msg(skb);
1419*4882a593Smuzhiyun 	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
1420*4882a593Smuzhiyun 	struct aead_request *aead_req;
1421*4882a593Smuzhiyun 	struct sk_buff *unused;
1422*4882a593Smuzhiyun 	u8 *aad, *iv, *mem = NULL;
1423*4882a593Smuzhiyun 	struct scatterlist *sgin = NULL;
1424*4882a593Smuzhiyun 	struct scatterlist *sgout = NULL;
1425*4882a593Smuzhiyun 	const int data_len = rxm->full_len - prot->overhead_size +
1426*4882a593Smuzhiyun 			     prot->tail_size;
1427*4882a593Smuzhiyun 	int iv_offset = 0;
1428*4882a593Smuzhiyun 
1429*4882a593Smuzhiyun 	if (*zc && (out_iov || out_sg)) {
1430*4882a593Smuzhiyun 		if (out_iov)
1431*4882a593Smuzhiyun 			n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
1432*4882a593Smuzhiyun 		else
1433*4882a593Smuzhiyun 			n_sgout = sg_nents(out_sg);
1434*4882a593Smuzhiyun 		n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
1435*4882a593Smuzhiyun 				 rxm->full_len - prot->prepend_size);
1436*4882a593Smuzhiyun 	} else {
1437*4882a593Smuzhiyun 		n_sgout = 0;
1438*4882a593Smuzhiyun 		*zc = false;
1439*4882a593Smuzhiyun 		n_sgin = skb_cow_data(skb, 0, &unused);
1440*4882a593Smuzhiyun 	}
1441*4882a593Smuzhiyun 
1442*4882a593Smuzhiyun 	if (n_sgin < 1)
1443*4882a593Smuzhiyun 		return -EBADMSG;
1444*4882a593Smuzhiyun 
1445*4882a593Smuzhiyun 	/* Increment to accommodate AAD */
1446*4882a593Smuzhiyun 	n_sgin = n_sgin + 1;
1447*4882a593Smuzhiyun 
1448*4882a593Smuzhiyun 	nsg = n_sgin + n_sgout;
1449*4882a593Smuzhiyun 
1450*4882a593Smuzhiyun 	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
1451*4882a593Smuzhiyun 	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
1452*4882a593Smuzhiyun 	mem_size = mem_size + prot->aad_size;
1453*4882a593Smuzhiyun 	mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
1454*4882a593Smuzhiyun 
1455*4882a593Smuzhiyun 	/* Allocate a single block of memory which contains
1456*4882a593Smuzhiyun 	 * aead_req || sgin[] || sgout[] || aad || iv.
1457*4882a593Smuzhiyun 	 * This order achieves correct alignment for aead_req, sgin, sgout.
1458*4882a593Smuzhiyun 	 */
1459*4882a593Smuzhiyun 	mem = kmalloc(mem_size, sk->sk_allocation);
1460*4882a593Smuzhiyun 	if (!mem)
1461*4882a593Smuzhiyun 		return -ENOMEM;
1462*4882a593Smuzhiyun 
1463*4882a593Smuzhiyun 	/* Segment the allocated memory */
1464*4882a593Smuzhiyun 	aead_req = (struct aead_request *)mem;
1465*4882a593Smuzhiyun 	sgin = (struct scatterlist *)(mem + aead_size);
1466*4882a593Smuzhiyun 	sgout = sgin + n_sgin;
1467*4882a593Smuzhiyun 	aad = (u8 *)(sgout + n_sgout);
1468*4882a593Smuzhiyun 	iv = aad + prot->aad_size;
1469*4882a593Smuzhiyun 
1470*4882a593Smuzhiyun 	/* For CCM based ciphers, first byte of nonce+iv is always '2' */
1471*4882a593Smuzhiyun 	if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
1472*4882a593Smuzhiyun 		iv[0] = 2;
1473*4882a593Smuzhiyun 		iv_offset = 1;
1474*4882a593Smuzhiyun 	}
1475*4882a593Smuzhiyun 
1476*4882a593Smuzhiyun 	/* Prepare IV */
1477*4882a593Smuzhiyun 	err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
1478*4882a593Smuzhiyun 			    iv + iv_offset + prot->salt_size,
1479*4882a593Smuzhiyun 			    prot->iv_size);
1480*4882a593Smuzhiyun 	if (err < 0) {
1481*4882a593Smuzhiyun 		kfree(mem);
1482*4882a593Smuzhiyun 		return err;
1483*4882a593Smuzhiyun 	}
1484*4882a593Smuzhiyun 	if (prot->version == TLS_1_3_VERSION)
1485*4882a593Smuzhiyun 		memcpy(iv + iv_offset, tls_ctx->rx.iv,
1486*4882a593Smuzhiyun 		       prot->iv_size + prot->salt_size);
1487*4882a593Smuzhiyun 	else
1488*4882a593Smuzhiyun 		memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
1489*4882a593Smuzhiyun 
1490*4882a593Smuzhiyun 	xor_iv_with_seq(prot->version, iv + iv_offset, tls_ctx->rx.rec_seq);
1491*4882a593Smuzhiyun 
1492*4882a593Smuzhiyun 	/* Prepare AAD */
1493*4882a593Smuzhiyun 	tls_make_aad(aad, rxm->full_len - prot->overhead_size +
1494*4882a593Smuzhiyun 		     prot->tail_size,
1495*4882a593Smuzhiyun 		     tls_ctx->rx.rec_seq, prot->rec_seq_size,
1496*4882a593Smuzhiyun 		     ctx->control, prot->version);
1497*4882a593Smuzhiyun 
1498*4882a593Smuzhiyun 	/* Prepare sgin */
1499*4882a593Smuzhiyun 	sg_init_table(sgin, n_sgin);
1500*4882a593Smuzhiyun 	sg_set_buf(&sgin[0], aad, prot->aad_size);
1501*4882a593Smuzhiyun 	err = skb_to_sgvec(skb, &sgin[1],
1502*4882a593Smuzhiyun 			   rxm->offset + prot->prepend_size,
1503*4882a593Smuzhiyun 			   rxm->full_len - prot->prepend_size);
1504*4882a593Smuzhiyun 	if (err < 0) {
1505*4882a593Smuzhiyun 		kfree(mem);
1506*4882a593Smuzhiyun 		return err;
1507*4882a593Smuzhiyun 	}
1508*4882a593Smuzhiyun 
1509*4882a593Smuzhiyun 	if (n_sgout) {
1510*4882a593Smuzhiyun 		if (out_iov) {
1511*4882a593Smuzhiyun 			sg_init_table(sgout, n_sgout);
1512*4882a593Smuzhiyun 			sg_set_buf(&sgout[0], aad, prot->aad_size);
1513*4882a593Smuzhiyun 
1514*4882a593Smuzhiyun 			*chunk = 0;
1515*4882a593Smuzhiyun 			err = tls_setup_from_iter(sk, out_iov, data_len,
1516*4882a593Smuzhiyun 						  &pages, chunk, &sgout[1],
1517*4882a593Smuzhiyun 						  (n_sgout - 1));
1518*4882a593Smuzhiyun 			if (err < 0)
1519*4882a593Smuzhiyun 				goto fallback_to_reg_recv;
1520*4882a593Smuzhiyun 		} else if (out_sg) {
1521*4882a593Smuzhiyun 			memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
1522*4882a593Smuzhiyun 		} else {
1523*4882a593Smuzhiyun 			goto fallback_to_reg_recv;
1524*4882a593Smuzhiyun 		}
1525*4882a593Smuzhiyun 	} else {
1526*4882a593Smuzhiyun fallback_to_reg_recv:
1527*4882a593Smuzhiyun 		sgout = sgin;
1528*4882a593Smuzhiyun 		pages = 0;
1529*4882a593Smuzhiyun 		*chunk = data_len;
1530*4882a593Smuzhiyun 		*zc = false;
1531*4882a593Smuzhiyun 	}
1532*4882a593Smuzhiyun 
1533*4882a593Smuzhiyun 	/* Prepare and submit AEAD request */
1534*4882a593Smuzhiyun 	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
1535*4882a593Smuzhiyun 				data_len, aead_req, async);
1536*4882a593Smuzhiyun 	if (err == -EINPROGRESS)
1537*4882a593Smuzhiyun 		return err;
1538*4882a593Smuzhiyun 
1539*4882a593Smuzhiyun 	/* Release the pages in case iov was mapped to pages */
1540*4882a593Smuzhiyun 	for (; pages > 0; pages--)
1541*4882a593Smuzhiyun 		put_page(sg_page(&sgout[pages]));
1542*4882a593Smuzhiyun 
1543*4882a593Smuzhiyun 	kfree(mem);
1544*4882a593Smuzhiyun 	return err;
1545*4882a593Smuzhiyun }
1546*4882a593Smuzhiyun 
decrypt_skb_update(struct sock * sk,struct sk_buff * skb,struct iov_iter * dest,int * chunk,bool * zc,bool async)1547*4882a593Smuzhiyun static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
1548*4882a593Smuzhiyun 			      struct iov_iter *dest, int *chunk, bool *zc,
1549*4882a593Smuzhiyun 			      bool async)
1550*4882a593Smuzhiyun {
1551*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1552*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1553*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
1554*4882a593Smuzhiyun 	struct strp_msg *rxm = strp_msg(skb);
1555*4882a593Smuzhiyun 	int pad, err = 0;
1556*4882a593Smuzhiyun 
1557*4882a593Smuzhiyun 	if (!ctx->decrypted) {
1558*4882a593Smuzhiyun 		if (tls_ctx->rx_conf == TLS_HW) {
1559*4882a593Smuzhiyun 			err = tls_device_decrypted(sk, tls_ctx, skb, rxm);
1560*4882a593Smuzhiyun 			if (err < 0)
1561*4882a593Smuzhiyun 				return err;
1562*4882a593Smuzhiyun 		}
1563*4882a593Smuzhiyun 
1564*4882a593Smuzhiyun 		/* Still not decrypted after tls_device */
1565*4882a593Smuzhiyun 		if (!ctx->decrypted) {
1566*4882a593Smuzhiyun 			err = decrypt_internal(sk, skb, dest, NULL, chunk, zc,
1567*4882a593Smuzhiyun 					       async);
1568*4882a593Smuzhiyun 			if (err < 0) {
1569*4882a593Smuzhiyun 				if (err == -EINPROGRESS)
1570*4882a593Smuzhiyun 					tls_advance_record_sn(sk, prot,
1571*4882a593Smuzhiyun 							      &tls_ctx->rx);
1572*4882a593Smuzhiyun 				else if (err == -EBADMSG)
1573*4882a593Smuzhiyun 					TLS_INC_STATS(sock_net(sk),
1574*4882a593Smuzhiyun 						      LINUX_MIB_TLSDECRYPTERROR);
1575*4882a593Smuzhiyun 				return err;
1576*4882a593Smuzhiyun 			}
1577*4882a593Smuzhiyun 		} else {
1578*4882a593Smuzhiyun 			*zc = false;
1579*4882a593Smuzhiyun 		}
1580*4882a593Smuzhiyun 
1581*4882a593Smuzhiyun 		pad = padding_length(ctx, prot, skb);
1582*4882a593Smuzhiyun 		if (pad < 0)
1583*4882a593Smuzhiyun 			return pad;
1584*4882a593Smuzhiyun 
1585*4882a593Smuzhiyun 		rxm->full_len -= pad;
1586*4882a593Smuzhiyun 		rxm->offset += prot->prepend_size;
1587*4882a593Smuzhiyun 		rxm->full_len -= prot->overhead_size;
1588*4882a593Smuzhiyun 		tls_advance_record_sn(sk, prot, &tls_ctx->rx);
1589*4882a593Smuzhiyun 		ctx->decrypted = 1;
1590*4882a593Smuzhiyun 		ctx->saved_data_ready(sk);
1591*4882a593Smuzhiyun 	} else {
1592*4882a593Smuzhiyun 		*zc = false;
1593*4882a593Smuzhiyun 	}
1594*4882a593Smuzhiyun 
1595*4882a593Smuzhiyun 	return err;
1596*4882a593Smuzhiyun }
1597*4882a593Smuzhiyun 
decrypt_skb(struct sock * sk,struct sk_buff * skb,struct scatterlist * sgout)1598*4882a593Smuzhiyun int decrypt_skb(struct sock *sk, struct sk_buff *skb,
1599*4882a593Smuzhiyun 		struct scatterlist *sgout)
1600*4882a593Smuzhiyun {
1601*4882a593Smuzhiyun 	bool zc = true;
1602*4882a593Smuzhiyun 	int chunk;
1603*4882a593Smuzhiyun 
1604*4882a593Smuzhiyun 	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
1605*4882a593Smuzhiyun }
1606*4882a593Smuzhiyun 
tls_sw_advance_skb(struct sock * sk,struct sk_buff * skb,unsigned int len)1607*4882a593Smuzhiyun static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
1608*4882a593Smuzhiyun 			       unsigned int len)
1609*4882a593Smuzhiyun {
1610*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1611*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1612*4882a593Smuzhiyun 
1613*4882a593Smuzhiyun 	if (skb) {
1614*4882a593Smuzhiyun 		struct strp_msg *rxm = strp_msg(skb);
1615*4882a593Smuzhiyun 
1616*4882a593Smuzhiyun 		if (len < rxm->full_len) {
1617*4882a593Smuzhiyun 			rxm->offset += len;
1618*4882a593Smuzhiyun 			rxm->full_len -= len;
1619*4882a593Smuzhiyun 			return false;
1620*4882a593Smuzhiyun 		}
1621*4882a593Smuzhiyun 		consume_skb(skb);
1622*4882a593Smuzhiyun 	}
1623*4882a593Smuzhiyun 
1624*4882a593Smuzhiyun 	/* Finished with message */
1625*4882a593Smuzhiyun 	ctx->recv_pkt = NULL;
1626*4882a593Smuzhiyun 	__strp_unpause(&ctx->strp);
1627*4882a593Smuzhiyun 
1628*4882a593Smuzhiyun 	return true;
1629*4882a593Smuzhiyun }
1630*4882a593Smuzhiyun 
1631*4882a593Smuzhiyun /* This function traverses the rx_list in tls receive context to copies the
1632*4882a593Smuzhiyun  * decrypted records into the buffer provided by caller zero copy is not
1633*4882a593Smuzhiyun  * true. Further, the records are removed from the rx_list if it is not a peek
1634*4882a593Smuzhiyun  * case and the record has been consumed completely.
1635*4882a593Smuzhiyun  */
process_rx_list(struct tls_sw_context_rx * ctx,struct msghdr * msg,u8 * control,bool * cmsg,size_t skip,size_t len,bool zc,bool is_peek)1636*4882a593Smuzhiyun static int process_rx_list(struct tls_sw_context_rx *ctx,
1637*4882a593Smuzhiyun 			   struct msghdr *msg,
1638*4882a593Smuzhiyun 			   u8 *control,
1639*4882a593Smuzhiyun 			   bool *cmsg,
1640*4882a593Smuzhiyun 			   size_t skip,
1641*4882a593Smuzhiyun 			   size_t len,
1642*4882a593Smuzhiyun 			   bool zc,
1643*4882a593Smuzhiyun 			   bool is_peek)
1644*4882a593Smuzhiyun {
1645*4882a593Smuzhiyun 	struct sk_buff *skb = skb_peek(&ctx->rx_list);
1646*4882a593Smuzhiyun 	u8 ctrl = *control;
1647*4882a593Smuzhiyun 	u8 msgc = *cmsg;
1648*4882a593Smuzhiyun 	struct tls_msg *tlm;
1649*4882a593Smuzhiyun 	ssize_t copied = 0;
1650*4882a593Smuzhiyun 
1651*4882a593Smuzhiyun 	/* Set the record type in 'control' if caller didn't pass it */
1652*4882a593Smuzhiyun 	if (!ctrl && skb) {
1653*4882a593Smuzhiyun 		tlm = tls_msg(skb);
1654*4882a593Smuzhiyun 		ctrl = tlm->control;
1655*4882a593Smuzhiyun 	}
1656*4882a593Smuzhiyun 
1657*4882a593Smuzhiyun 	while (skip && skb) {
1658*4882a593Smuzhiyun 		struct strp_msg *rxm = strp_msg(skb);
1659*4882a593Smuzhiyun 		tlm = tls_msg(skb);
1660*4882a593Smuzhiyun 
1661*4882a593Smuzhiyun 		/* Cannot process a record of different type */
1662*4882a593Smuzhiyun 		if (ctrl != tlm->control)
1663*4882a593Smuzhiyun 			return 0;
1664*4882a593Smuzhiyun 
1665*4882a593Smuzhiyun 		if (skip < rxm->full_len)
1666*4882a593Smuzhiyun 			break;
1667*4882a593Smuzhiyun 
1668*4882a593Smuzhiyun 		skip = skip - rxm->full_len;
1669*4882a593Smuzhiyun 		skb = skb_peek_next(skb, &ctx->rx_list);
1670*4882a593Smuzhiyun 	}
1671*4882a593Smuzhiyun 
1672*4882a593Smuzhiyun 	while (len && skb) {
1673*4882a593Smuzhiyun 		struct sk_buff *next_skb;
1674*4882a593Smuzhiyun 		struct strp_msg *rxm = strp_msg(skb);
1675*4882a593Smuzhiyun 		int chunk = min_t(unsigned int, rxm->full_len - skip, len);
1676*4882a593Smuzhiyun 
1677*4882a593Smuzhiyun 		tlm = tls_msg(skb);
1678*4882a593Smuzhiyun 
1679*4882a593Smuzhiyun 		/* Cannot process a record of different type */
1680*4882a593Smuzhiyun 		if (ctrl != tlm->control)
1681*4882a593Smuzhiyun 			return 0;
1682*4882a593Smuzhiyun 
1683*4882a593Smuzhiyun 		/* Set record type if not already done. For a non-data record,
1684*4882a593Smuzhiyun 		 * do not proceed if record type could not be copied.
1685*4882a593Smuzhiyun 		 */
1686*4882a593Smuzhiyun 		if (!msgc) {
1687*4882a593Smuzhiyun 			int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1688*4882a593Smuzhiyun 					    sizeof(ctrl), &ctrl);
1689*4882a593Smuzhiyun 			msgc = true;
1690*4882a593Smuzhiyun 			if (ctrl != TLS_RECORD_TYPE_DATA) {
1691*4882a593Smuzhiyun 				if (cerr || msg->msg_flags & MSG_CTRUNC)
1692*4882a593Smuzhiyun 					return -EIO;
1693*4882a593Smuzhiyun 
1694*4882a593Smuzhiyun 				*cmsg = msgc;
1695*4882a593Smuzhiyun 			}
1696*4882a593Smuzhiyun 		}
1697*4882a593Smuzhiyun 
1698*4882a593Smuzhiyun 		if (!zc || (rxm->full_len - skip) > len) {
1699*4882a593Smuzhiyun 			int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
1700*4882a593Smuzhiyun 						    msg, chunk);
1701*4882a593Smuzhiyun 			if (err < 0)
1702*4882a593Smuzhiyun 				return err;
1703*4882a593Smuzhiyun 		}
1704*4882a593Smuzhiyun 
1705*4882a593Smuzhiyun 		len = len - chunk;
1706*4882a593Smuzhiyun 		copied = copied + chunk;
1707*4882a593Smuzhiyun 
1708*4882a593Smuzhiyun 		/* Consume the data from record if it is non-peek case*/
1709*4882a593Smuzhiyun 		if (!is_peek) {
1710*4882a593Smuzhiyun 			rxm->offset = rxm->offset + chunk;
1711*4882a593Smuzhiyun 			rxm->full_len = rxm->full_len - chunk;
1712*4882a593Smuzhiyun 
1713*4882a593Smuzhiyun 			/* Return if there is unconsumed data in the record */
1714*4882a593Smuzhiyun 			if (rxm->full_len - skip)
1715*4882a593Smuzhiyun 				break;
1716*4882a593Smuzhiyun 		}
1717*4882a593Smuzhiyun 
1718*4882a593Smuzhiyun 		/* The remaining skip-bytes must lie in 1st record in rx_list.
1719*4882a593Smuzhiyun 		 * So from the 2nd record, 'skip' should be 0.
1720*4882a593Smuzhiyun 		 */
1721*4882a593Smuzhiyun 		skip = 0;
1722*4882a593Smuzhiyun 
1723*4882a593Smuzhiyun 		if (msg)
1724*4882a593Smuzhiyun 			msg->msg_flags |= MSG_EOR;
1725*4882a593Smuzhiyun 
1726*4882a593Smuzhiyun 		next_skb = skb_peek_next(skb, &ctx->rx_list);
1727*4882a593Smuzhiyun 
1728*4882a593Smuzhiyun 		if (!is_peek) {
1729*4882a593Smuzhiyun 			skb_unlink(skb, &ctx->rx_list);
1730*4882a593Smuzhiyun 			consume_skb(skb);
1731*4882a593Smuzhiyun 		}
1732*4882a593Smuzhiyun 
1733*4882a593Smuzhiyun 		skb = next_skb;
1734*4882a593Smuzhiyun 	}
1735*4882a593Smuzhiyun 
1736*4882a593Smuzhiyun 	*control = ctrl;
1737*4882a593Smuzhiyun 	return copied;
1738*4882a593Smuzhiyun }
1739*4882a593Smuzhiyun 
tls_sw_recvmsg(struct sock * sk,struct msghdr * msg,size_t len,int nonblock,int flags,int * addr_len)1740*4882a593Smuzhiyun int tls_sw_recvmsg(struct sock *sk,
1741*4882a593Smuzhiyun 		   struct msghdr *msg,
1742*4882a593Smuzhiyun 		   size_t len,
1743*4882a593Smuzhiyun 		   int nonblock,
1744*4882a593Smuzhiyun 		   int flags,
1745*4882a593Smuzhiyun 		   int *addr_len)
1746*4882a593Smuzhiyun {
1747*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1748*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1749*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
1750*4882a593Smuzhiyun 	struct sk_psock *psock;
1751*4882a593Smuzhiyun 	unsigned char control = 0;
1752*4882a593Smuzhiyun 	ssize_t decrypted = 0;
1753*4882a593Smuzhiyun 	struct strp_msg *rxm;
1754*4882a593Smuzhiyun 	struct tls_msg *tlm;
1755*4882a593Smuzhiyun 	struct sk_buff *skb;
1756*4882a593Smuzhiyun 	ssize_t copied = 0;
1757*4882a593Smuzhiyun 	bool cmsg = false;
1758*4882a593Smuzhiyun 	int target, err = 0;
1759*4882a593Smuzhiyun 	long timeo;
1760*4882a593Smuzhiyun 	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
1761*4882a593Smuzhiyun 	bool is_peek = flags & MSG_PEEK;
1762*4882a593Smuzhiyun 	bool bpf_strp_enabled;
1763*4882a593Smuzhiyun 	int num_async = 0;
1764*4882a593Smuzhiyun 	int pending;
1765*4882a593Smuzhiyun 
1766*4882a593Smuzhiyun 	flags |= nonblock;
1767*4882a593Smuzhiyun 
1768*4882a593Smuzhiyun 	if (unlikely(flags & MSG_ERRQUEUE))
1769*4882a593Smuzhiyun 		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
1770*4882a593Smuzhiyun 
1771*4882a593Smuzhiyun 	psock = sk_psock_get(sk);
1772*4882a593Smuzhiyun 	lock_sock(sk);
1773*4882a593Smuzhiyun 	bpf_strp_enabled = sk_psock_strp_enabled(psock);
1774*4882a593Smuzhiyun 
1775*4882a593Smuzhiyun 	/* Process pending decrypted records. It must be non-zero-copy */
1776*4882a593Smuzhiyun 	err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false,
1777*4882a593Smuzhiyun 			      is_peek);
1778*4882a593Smuzhiyun 	if (err < 0) {
1779*4882a593Smuzhiyun 		tls_err_abort(sk, err);
1780*4882a593Smuzhiyun 		goto end;
1781*4882a593Smuzhiyun 	} else {
1782*4882a593Smuzhiyun 		copied = err;
1783*4882a593Smuzhiyun 	}
1784*4882a593Smuzhiyun 
1785*4882a593Smuzhiyun 	if (len <= copied)
1786*4882a593Smuzhiyun 		goto recv_end;
1787*4882a593Smuzhiyun 
1788*4882a593Smuzhiyun 	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1789*4882a593Smuzhiyun 	len = len - copied;
1790*4882a593Smuzhiyun 	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1791*4882a593Smuzhiyun 
1792*4882a593Smuzhiyun 	while (len && (decrypted + copied < target || ctx->recv_pkt)) {
1793*4882a593Smuzhiyun 		bool retain_skb = false;
1794*4882a593Smuzhiyun 		bool zc = false;
1795*4882a593Smuzhiyun 		int to_decrypt;
1796*4882a593Smuzhiyun 		int chunk = 0;
1797*4882a593Smuzhiyun 		bool async_capable;
1798*4882a593Smuzhiyun 		bool async = false;
1799*4882a593Smuzhiyun 
1800*4882a593Smuzhiyun 		skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
1801*4882a593Smuzhiyun 		if (!skb) {
1802*4882a593Smuzhiyun 			if (psock) {
1803*4882a593Smuzhiyun 				int ret = __tcp_bpf_recvmsg(sk, psock,
1804*4882a593Smuzhiyun 							    msg, len, flags);
1805*4882a593Smuzhiyun 
1806*4882a593Smuzhiyun 				if (ret > 0) {
1807*4882a593Smuzhiyun 					decrypted += ret;
1808*4882a593Smuzhiyun 					len -= ret;
1809*4882a593Smuzhiyun 					continue;
1810*4882a593Smuzhiyun 				}
1811*4882a593Smuzhiyun 			}
1812*4882a593Smuzhiyun 			goto recv_end;
1813*4882a593Smuzhiyun 		} else {
1814*4882a593Smuzhiyun 			tlm = tls_msg(skb);
1815*4882a593Smuzhiyun 			if (prot->version == TLS_1_3_VERSION)
1816*4882a593Smuzhiyun 				tlm->control = 0;
1817*4882a593Smuzhiyun 			else
1818*4882a593Smuzhiyun 				tlm->control = ctx->control;
1819*4882a593Smuzhiyun 		}
1820*4882a593Smuzhiyun 
1821*4882a593Smuzhiyun 		rxm = strp_msg(skb);
1822*4882a593Smuzhiyun 
1823*4882a593Smuzhiyun 		to_decrypt = rxm->full_len - prot->overhead_size;
1824*4882a593Smuzhiyun 
1825*4882a593Smuzhiyun 		if (to_decrypt <= len && !is_kvec && !is_peek &&
1826*4882a593Smuzhiyun 		    ctx->control == TLS_RECORD_TYPE_DATA &&
1827*4882a593Smuzhiyun 		    prot->version != TLS_1_3_VERSION &&
1828*4882a593Smuzhiyun 		    !bpf_strp_enabled)
1829*4882a593Smuzhiyun 			zc = true;
1830*4882a593Smuzhiyun 
1831*4882a593Smuzhiyun 		/* Do not use async mode if record is non-data */
1832*4882a593Smuzhiyun 		if (ctx->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
1833*4882a593Smuzhiyun 			async_capable = ctx->async_capable;
1834*4882a593Smuzhiyun 		else
1835*4882a593Smuzhiyun 			async_capable = false;
1836*4882a593Smuzhiyun 
1837*4882a593Smuzhiyun 		err = decrypt_skb_update(sk, skb, &msg->msg_iter,
1838*4882a593Smuzhiyun 					 &chunk, &zc, async_capable);
1839*4882a593Smuzhiyun 		if (err < 0 && err != -EINPROGRESS) {
1840*4882a593Smuzhiyun 			tls_err_abort(sk, -EBADMSG);
1841*4882a593Smuzhiyun 			goto recv_end;
1842*4882a593Smuzhiyun 		}
1843*4882a593Smuzhiyun 
1844*4882a593Smuzhiyun 		if (err == -EINPROGRESS) {
1845*4882a593Smuzhiyun 			async = true;
1846*4882a593Smuzhiyun 			num_async++;
1847*4882a593Smuzhiyun 		} else if (prot->version == TLS_1_3_VERSION) {
1848*4882a593Smuzhiyun 			tlm->control = ctx->control;
1849*4882a593Smuzhiyun 		}
1850*4882a593Smuzhiyun 
1851*4882a593Smuzhiyun 		/* If the type of records being processed is not known yet,
1852*4882a593Smuzhiyun 		 * set it to record type just dequeued. If it is already known,
1853*4882a593Smuzhiyun 		 * but does not match the record type just dequeued, go to end.
1854*4882a593Smuzhiyun 		 * We always get record type here since for tls1.2, record type
1855*4882a593Smuzhiyun 		 * is known just after record is dequeued from stream parser.
1856*4882a593Smuzhiyun 		 * For tls1.3, we disable async.
1857*4882a593Smuzhiyun 		 */
1858*4882a593Smuzhiyun 
1859*4882a593Smuzhiyun 		if (!control)
1860*4882a593Smuzhiyun 			control = tlm->control;
1861*4882a593Smuzhiyun 		else if (control != tlm->control)
1862*4882a593Smuzhiyun 			goto recv_end;
1863*4882a593Smuzhiyun 
1864*4882a593Smuzhiyun 		if (!cmsg) {
1865*4882a593Smuzhiyun 			int cerr;
1866*4882a593Smuzhiyun 
1867*4882a593Smuzhiyun 			cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1868*4882a593Smuzhiyun 					sizeof(control), &control);
1869*4882a593Smuzhiyun 			cmsg = true;
1870*4882a593Smuzhiyun 			if (control != TLS_RECORD_TYPE_DATA) {
1871*4882a593Smuzhiyun 				if (cerr || msg->msg_flags & MSG_CTRUNC) {
1872*4882a593Smuzhiyun 					err = -EIO;
1873*4882a593Smuzhiyun 					goto recv_end;
1874*4882a593Smuzhiyun 				}
1875*4882a593Smuzhiyun 			}
1876*4882a593Smuzhiyun 		}
1877*4882a593Smuzhiyun 
1878*4882a593Smuzhiyun 		if (async)
1879*4882a593Smuzhiyun 			goto pick_next_record;
1880*4882a593Smuzhiyun 
1881*4882a593Smuzhiyun 		if (!zc) {
1882*4882a593Smuzhiyun 			if (bpf_strp_enabled) {
1883*4882a593Smuzhiyun 				err = sk_psock_tls_strp_read(psock, skb);
1884*4882a593Smuzhiyun 				if (err != __SK_PASS) {
1885*4882a593Smuzhiyun 					rxm->offset = rxm->offset + rxm->full_len;
1886*4882a593Smuzhiyun 					rxm->full_len = 0;
1887*4882a593Smuzhiyun 					if (err == __SK_DROP)
1888*4882a593Smuzhiyun 						consume_skb(skb);
1889*4882a593Smuzhiyun 					ctx->recv_pkt = NULL;
1890*4882a593Smuzhiyun 					__strp_unpause(&ctx->strp);
1891*4882a593Smuzhiyun 					continue;
1892*4882a593Smuzhiyun 				}
1893*4882a593Smuzhiyun 			}
1894*4882a593Smuzhiyun 
1895*4882a593Smuzhiyun 			if (rxm->full_len > len) {
1896*4882a593Smuzhiyun 				retain_skb = true;
1897*4882a593Smuzhiyun 				chunk = len;
1898*4882a593Smuzhiyun 			} else {
1899*4882a593Smuzhiyun 				chunk = rxm->full_len;
1900*4882a593Smuzhiyun 			}
1901*4882a593Smuzhiyun 
1902*4882a593Smuzhiyun 			err = skb_copy_datagram_msg(skb, rxm->offset,
1903*4882a593Smuzhiyun 						    msg, chunk);
1904*4882a593Smuzhiyun 			if (err < 0)
1905*4882a593Smuzhiyun 				goto recv_end;
1906*4882a593Smuzhiyun 
1907*4882a593Smuzhiyun 			if (!is_peek) {
1908*4882a593Smuzhiyun 				rxm->offset = rxm->offset + chunk;
1909*4882a593Smuzhiyun 				rxm->full_len = rxm->full_len - chunk;
1910*4882a593Smuzhiyun 			}
1911*4882a593Smuzhiyun 		}
1912*4882a593Smuzhiyun 
1913*4882a593Smuzhiyun pick_next_record:
1914*4882a593Smuzhiyun 		if (chunk > len)
1915*4882a593Smuzhiyun 			chunk = len;
1916*4882a593Smuzhiyun 
1917*4882a593Smuzhiyun 		decrypted += chunk;
1918*4882a593Smuzhiyun 		len -= chunk;
1919*4882a593Smuzhiyun 
1920*4882a593Smuzhiyun 		/* For async or peek case, queue the current skb */
1921*4882a593Smuzhiyun 		if (async || is_peek || retain_skb) {
1922*4882a593Smuzhiyun 			skb_queue_tail(&ctx->rx_list, skb);
1923*4882a593Smuzhiyun 			skb = NULL;
1924*4882a593Smuzhiyun 		}
1925*4882a593Smuzhiyun 
1926*4882a593Smuzhiyun 		if (tls_sw_advance_skb(sk, skb, chunk)) {
1927*4882a593Smuzhiyun 			/* Return full control message to
1928*4882a593Smuzhiyun 			 * userspace before trying to parse
1929*4882a593Smuzhiyun 			 * another message type
1930*4882a593Smuzhiyun 			 */
1931*4882a593Smuzhiyun 			msg->msg_flags |= MSG_EOR;
1932*4882a593Smuzhiyun 			if (control != TLS_RECORD_TYPE_DATA)
1933*4882a593Smuzhiyun 				goto recv_end;
1934*4882a593Smuzhiyun 		} else {
1935*4882a593Smuzhiyun 			break;
1936*4882a593Smuzhiyun 		}
1937*4882a593Smuzhiyun 	}
1938*4882a593Smuzhiyun 
1939*4882a593Smuzhiyun recv_end:
1940*4882a593Smuzhiyun 	if (num_async) {
1941*4882a593Smuzhiyun 		/* Wait for all previously submitted records to be decrypted */
1942*4882a593Smuzhiyun 		spin_lock_bh(&ctx->decrypt_compl_lock);
1943*4882a593Smuzhiyun 		ctx->async_notify = true;
1944*4882a593Smuzhiyun 		pending = atomic_read(&ctx->decrypt_pending);
1945*4882a593Smuzhiyun 		spin_unlock_bh(&ctx->decrypt_compl_lock);
1946*4882a593Smuzhiyun 		if (pending) {
1947*4882a593Smuzhiyun 			err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1948*4882a593Smuzhiyun 			if (err) {
1949*4882a593Smuzhiyun 				/* one of async decrypt failed */
1950*4882a593Smuzhiyun 				tls_err_abort(sk, err);
1951*4882a593Smuzhiyun 				copied = 0;
1952*4882a593Smuzhiyun 				decrypted = 0;
1953*4882a593Smuzhiyun 				goto end;
1954*4882a593Smuzhiyun 			}
1955*4882a593Smuzhiyun 		} else {
1956*4882a593Smuzhiyun 			reinit_completion(&ctx->async_wait.completion);
1957*4882a593Smuzhiyun 		}
1958*4882a593Smuzhiyun 
1959*4882a593Smuzhiyun 		/* There can be no concurrent accesses, since we have no
1960*4882a593Smuzhiyun 		 * pending decrypt operations
1961*4882a593Smuzhiyun 		 */
1962*4882a593Smuzhiyun 		WRITE_ONCE(ctx->async_notify, false);
1963*4882a593Smuzhiyun 
1964*4882a593Smuzhiyun 		/* Drain records from the rx_list & copy if required */
1965*4882a593Smuzhiyun 		if (is_peek || is_kvec)
1966*4882a593Smuzhiyun 			err = process_rx_list(ctx, msg, &control, &cmsg, copied,
1967*4882a593Smuzhiyun 					      decrypted, false, is_peek);
1968*4882a593Smuzhiyun 		else
1969*4882a593Smuzhiyun 			err = process_rx_list(ctx, msg, &control, &cmsg, 0,
1970*4882a593Smuzhiyun 					      decrypted, true, is_peek);
1971*4882a593Smuzhiyun 		if (err < 0) {
1972*4882a593Smuzhiyun 			tls_err_abort(sk, err);
1973*4882a593Smuzhiyun 			copied = 0;
1974*4882a593Smuzhiyun 			goto end;
1975*4882a593Smuzhiyun 		}
1976*4882a593Smuzhiyun 	}
1977*4882a593Smuzhiyun 
1978*4882a593Smuzhiyun 	copied += decrypted;
1979*4882a593Smuzhiyun 
1980*4882a593Smuzhiyun end:
1981*4882a593Smuzhiyun 	release_sock(sk);
1982*4882a593Smuzhiyun 	if (psock)
1983*4882a593Smuzhiyun 		sk_psock_put(sk, psock);
1984*4882a593Smuzhiyun 	return copied ? : err;
1985*4882a593Smuzhiyun }
1986*4882a593Smuzhiyun 
tls_sw_splice_read(struct socket * sock,loff_t * ppos,struct pipe_inode_info * pipe,size_t len,unsigned int flags)1987*4882a593Smuzhiyun ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
1988*4882a593Smuzhiyun 			   struct pipe_inode_info *pipe,
1989*4882a593Smuzhiyun 			   size_t len, unsigned int flags)
1990*4882a593Smuzhiyun {
1991*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
1992*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1993*4882a593Smuzhiyun 	struct strp_msg *rxm = NULL;
1994*4882a593Smuzhiyun 	struct sock *sk = sock->sk;
1995*4882a593Smuzhiyun 	struct sk_buff *skb;
1996*4882a593Smuzhiyun 	ssize_t copied = 0;
1997*4882a593Smuzhiyun 	int err = 0;
1998*4882a593Smuzhiyun 	long timeo;
1999*4882a593Smuzhiyun 	int chunk;
2000*4882a593Smuzhiyun 	bool zc = false;
2001*4882a593Smuzhiyun 
2002*4882a593Smuzhiyun 	lock_sock(sk);
2003*4882a593Smuzhiyun 
2004*4882a593Smuzhiyun 	timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);
2005*4882a593Smuzhiyun 
2006*4882a593Smuzhiyun 	skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, &err);
2007*4882a593Smuzhiyun 	if (!skb)
2008*4882a593Smuzhiyun 		goto splice_read_end;
2009*4882a593Smuzhiyun 
2010*4882a593Smuzhiyun 	err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
2011*4882a593Smuzhiyun 	if (err < 0) {
2012*4882a593Smuzhiyun 		tls_err_abort(sk, -EBADMSG);
2013*4882a593Smuzhiyun 		goto splice_read_end;
2014*4882a593Smuzhiyun 	}
2015*4882a593Smuzhiyun 
2016*4882a593Smuzhiyun 	/* splice does not support reading control messages */
2017*4882a593Smuzhiyun 	if (ctx->control != TLS_RECORD_TYPE_DATA) {
2018*4882a593Smuzhiyun 		err = -EINVAL;
2019*4882a593Smuzhiyun 		goto splice_read_end;
2020*4882a593Smuzhiyun 	}
2021*4882a593Smuzhiyun 
2022*4882a593Smuzhiyun 	rxm = strp_msg(skb);
2023*4882a593Smuzhiyun 
2024*4882a593Smuzhiyun 	chunk = min_t(unsigned int, rxm->full_len, len);
2025*4882a593Smuzhiyun 	copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
2026*4882a593Smuzhiyun 	if (copied < 0)
2027*4882a593Smuzhiyun 		goto splice_read_end;
2028*4882a593Smuzhiyun 
2029*4882a593Smuzhiyun 	if (likely(!(flags & MSG_PEEK)))
2030*4882a593Smuzhiyun 		tls_sw_advance_skb(sk, skb, copied);
2031*4882a593Smuzhiyun 
2032*4882a593Smuzhiyun splice_read_end:
2033*4882a593Smuzhiyun 	release_sock(sk);
2034*4882a593Smuzhiyun 	return copied ? : err;
2035*4882a593Smuzhiyun }
2036*4882a593Smuzhiyun 
tls_sw_stream_read(const struct sock * sk)2037*4882a593Smuzhiyun bool tls_sw_stream_read(const struct sock *sk)
2038*4882a593Smuzhiyun {
2039*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
2040*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2041*4882a593Smuzhiyun 	bool ingress_empty = true;
2042*4882a593Smuzhiyun 	struct sk_psock *psock;
2043*4882a593Smuzhiyun 
2044*4882a593Smuzhiyun 	rcu_read_lock();
2045*4882a593Smuzhiyun 	psock = sk_psock(sk);
2046*4882a593Smuzhiyun 	if (psock)
2047*4882a593Smuzhiyun 		ingress_empty = list_empty(&psock->ingress_msg);
2048*4882a593Smuzhiyun 	rcu_read_unlock();
2049*4882a593Smuzhiyun 
2050*4882a593Smuzhiyun 	return !ingress_empty || ctx->recv_pkt ||
2051*4882a593Smuzhiyun 		!skb_queue_empty(&ctx->rx_list);
2052*4882a593Smuzhiyun }
2053*4882a593Smuzhiyun 
tls_read_size(struct strparser * strp,struct sk_buff * skb)2054*4882a593Smuzhiyun static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
2055*4882a593Smuzhiyun {
2056*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
2057*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2058*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
2059*4882a593Smuzhiyun 	char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
2060*4882a593Smuzhiyun 	struct strp_msg *rxm = strp_msg(skb);
2061*4882a593Smuzhiyun 	size_t cipher_overhead;
2062*4882a593Smuzhiyun 	size_t data_len = 0;
2063*4882a593Smuzhiyun 	int ret;
2064*4882a593Smuzhiyun 
2065*4882a593Smuzhiyun 	/* Verify that we have a full TLS header, or wait for more data */
2066*4882a593Smuzhiyun 	if (rxm->offset + prot->prepend_size > skb->len)
2067*4882a593Smuzhiyun 		return 0;
2068*4882a593Smuzhiyun 
2069*4882a593Smuzhiyun 	/* Sanity-check size of on-stack buffer. */
2070*4882a593Smuzhiyun 	if (WARN_ON(prot->prepend_size > sizeof(header))) {
2071*4882a593Smuzhiyun 		ret = -EINVAL;
2072*4882a593Smuzhiyun 		goto read_failure;
2073*4882a593Smuzhiyun 	}
2074*4882a593Smuzhiyun 
2075*4882a593Smuzhiyun 	/* Linearize header to local buffer */
2076*4882a593Smuzhiyun 	ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
2077*4882a593Smuzhiyun 
2078*4882a593Smuzhiyun 	if (ret < 0)
2079*4882a593Smuzhiyun 		goto read_failure;
2080*4882a593Smuzhiyun 
2081*4882a593Smuzhiyun 	ctx->control = header[0];
2082*4882a593Smuzhiyun 
2083*4882a593Smuzhiyun 	data_len = ((header[4] & 0xFF) | (header[3] << 8));
2084*4882a593Smuzhiyun 
2085*4882a593Smuzhiyun 	cipher_overhead = prot->tag_size;
2086*4882a593Smuzhiyun 	if (prot->version != TLS_1_3_VERSION)
2087*4882a593Smuzhiyun 		cipher_overhead += prot->iv_size;
2088*4882a593Smuzhiyun 
2089*4882a593Smuzhiyun 	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
2090*4882a593Smuzhiyun 	    prot->tail_size) {
2091*4882a593Smuzhiyun 		ret = -EMSGSIZE;
2092*4882a593Smuzhiyun 		goto read_failure;
2093*4882a593Smuzhiyun 	}
2094*4882a593Smuzhiyun 	if (data_len < cipher_overhead) {
2095*4882a593Smuzhiyun 		ret = -EBADMSG;
2096*4882a593Smuzhiyun 		goto read_failure;
2097*4882a593Smuzhiyun 	}
2098*4882a593Smuzhiyun 
2099*4882a593Smuzhiyun 	/* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */
2100*4882a593Smuzhiyun 	if (header[1] != TLS_1_2_VERSION_MINOR ||
2101*4882a593Smuzhiyun 	    header[2] != TLS_1_2_VERSION_MAJOR) {
2102*4882a593Smuzhiyun 		ret = -EINVAL;
2103*4882a593Smuzhiyun 		goto read_failure;
2104*4882a593Smuzhiyun 	}
2105*4882a593Smuzhiyun 
2106*4882a593Smuzhiyun 	tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
2107*4882a593Smuzhiyun 				     TCP_SKB_CB(skb)->seq + rxm->offset);
2108*4882a593Smuzhiyun 	return data_len + TLS_HEADER_SIZE;
2109*4882a593Smuzhiyun 
2110*4882a593Smuzhiyun read_failure:
2111*4882a593Smuzhiyun 	tls_err_abort(strp->sk, ret);
2112*4882a593Smuzhiyun 
2113*4882a593Smuzhiyun 	return ret;
2114*4882a593Smuzhiyun }
2115*4882a593Smuzhiyun 
tls_queue(struct strparser * strp,struct sk_buff * skb)2116*4882a593Smuzhiyun static void tls_queue(struct strparser *strp, struct sk_buff *skb)
2117*4882a593Smuzhiyun {
2118*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
2119*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2120*4882a593Smuzhiyun 
2121*4882a593Smuzhiyun 	ctx->decrypted = 0;
2122*4882a593Smuzhiyun 
2123*4882a593Smuzhiyun 	ctx->recv_pkt = skb;
2124*4882a593Smuzhiyun 	strp_pause(strp);
2125*4882a593Smuzhiyun 
2126*4882a593Smuzhiyun 	ctx->saved_data_ready(strp->sk);
2127*4882a593Smuzhiyun }
2128*4882a593Smuzhiyun 
tls_data_ready(struct sock * sk)2129*4882a593Smuzhiyun static void tls_data_ready(struct sock *sk)
2130*4882a593Smuzhiyun {
2131*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
2132*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2133*4882a593Smuzhiyun 	struct sk_psock *psock;
2134*4882a593Smuzhiyun 
2135*4882a593Smuzhiyun 	strp_data_ready(&ctx->strp);
2136*4882a593Smuzhiyun 
2137*4882a593Smuzhiyun 	psock = sk_psock_get(sk);
2138*4882a593Smuzhiyun 	if (psock) {
2139*4882a593Smuzhiyun 		if (!list_empty(&psock->ingress_msg))
2140*4882a593Smuzhiyun 			ctx->saved_data_ready(sk);
2141*4882a593Smuzhiyun 		sk_psock_put(sk, psock);
2142*4882a593Smuzhiyun 	}
2143*4882a593Smuzhiyun }
2144*4882a593Smuzhiyun 
tls_sw_cancel_work_tx(struct tls_context * tls_ctx)2145*4882a593Smuzhiyun void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
2146*4882a593Smuzhiyun {
2147*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2148*4882a593Smuzhiyun 
2149*4882a593Smuzhiyun 	set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
2150*4882a593Smuzhiyun 	set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
2151*4882a593Smuzhiyun 	cancel_delayed_work_sync(&ctx->tx_work.work);
2152*4882a593Smuzhiyun }
2153*4882a593Smuzhiyun 
tls_sw_release_resources_tx(struct sock * sk)2154*4882a593Smuzhiyun void tls_sw_release_resources_tx(struct sock *sk)
2155*4882a593Smuzhiyun {
2156*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
2157*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2158*4882a593Smuzhiyun 	struct tls_rec *rec, *tmp;
2159*4882a593Smuzhiyun 	int pending;
2160*4882a593Smuzhiyun 
2161*4882a593Smuzhiyun 	/* Wait for any pending async encryptions to complete */
2162*4882a593Smuzhiyun 	spin_lock_bh(&ctx->encrypt_compl_lock);
2163*4882a593Smuzhiyun 	ctx->async_notify = true;
2164*4882a593Smuzhiyun 	pending = atomic_read(&ctx->encrypt_pending);
2165*4882a593Smuzhiyun 	spin_unlock_bh(&ctx->encrypt_compl_lock);
2166*4882a593Smuzhiyun 
2167*4882a593Smuzhiyun 	if (pending)
2168*4882a593Smuzhiyun 		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2169*4882a593Smuzhiyun 
2170*4882a593Smuzhiyun 	tls_tx_records(sk, -1);
2171*4882a593Smuzhiyun 
2172*4882a593Smuzhiyun 	/* Free up un-sent records in tx_list. First, free
2173*4882a593Smuzhiyun 	 * the partially sent record if any at head of tx_list.
2174*4882a593Smuzhiyun 	 */
2175*4882a593Smuzhiyun 	if (tls_ctx->partially_sent_record) {
2176*4882a593Smuzhiyun 		tls_free_partial_record(sk, tls_ctx);
2177*4882a593Smuzhiyun 		rec = list_first_entry(&ctx->tx_list,
2178*4882a593Smuzhiyun 				       struct tls_rec, list);
2179*4882a593Smuzhiyun 		list_del(&rec->list);
2180*4882a593Smuzhiyun 		sk_msg_free(sk, &rec->msg_plaintext);
2181*4882a593Smuzhiyun 		kfree(rec);
2182*4882a593Smuzhiyun 	}
2183*4882a593Smuzhiyun 
2184*4882a593Smuzhiyun 	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
2185*4882a593Smuzhiyun 		list_del(&rec->list);
2186*4882a593Smuzhiyun 		sk_msg_free(sk, &rec->msg_encrypted);
2187*4882a593Smuzhiyun 		sk_msg_free(sk, &rec->msg_plaintext);
2188*4882a593Smuzhiyun 		kfree(rec);
2189*4882a593Smuzhiyun 	}
2190*4882a593Smuzhiyun 
2191*4882a593Smuzhiyun 	crypto_free_aead(ctx->aead_send);
2192*4882a593Smuzhiyun 	tls_free_open_rec(sk);
2193*4882a593Smuzhiyun }
2194*4882a593Smuzhiyun 
tls_sw_free_ctx_tx(struct tls_context * tls_ctx)2195*4882a593Smuzhiyun void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
2196*4882a593Smuzhiyun {
2197*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2198*4882a593Smuzhiyun 
2199*4882a593Smuzhiyun 	kfree(ctx);
2200*4882a593Smuzhiyun }
2201*4882a593Smuzhiyun 
tls_sw_release_resources_rx(struct sock * sk)2202*4882a593Smuzhiyun void tls_sw_release_resources_rx(struct sock *sk)
2203*4882a593Smuzhiyun {
2204*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
2205*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2206*4882a593Smuzhiyun 
2207*4882a593Smuzhiyun 	kfree(tls_ctx->rx.rec_seq);
2208*4882a593Smuzhiyun 	kfree(tls_ctx->rx.iv);
2209*4882a593Smuzhiyun 
2210*4882a593Smuzhiyun 	if (ctx->aead_recv) {
2211*4882a593Smuzhiyun 		kfree_skb(ctx->recv_pkt);
2212*4882a593Smuzhiyun 		ctx->recv_pkt = NULL;
2213*4882a593Smuzhiyun 		skb_queue_purge(&ctx->rx_list);
2214*4882a593Smuzhiyun 		crypto_free_aead(ctx->aead_recv);
2215*4882a593Smuzhiyun 		strp_stop(&ctx->strp);
2216*4882a593Smuzhiyun 		/* If tls_sw_strparser_arm() was not called (cleanup paths)
2217*4882a593Smuzhiyun 		 * we still want to strp_stop(), but sk->sk_data_ready was
2218*4882a593Smuzhiyun 		 * never swapped.
2219*4882a593Smuzhiyun 		 */
2220*4882a593Smuzhiyun 		if (ctx->saved_data_ready) {
2221*4882a593Smuzhiyun 			write_lock_bh(&sk->sk_callback_lock);
2222*4882a593Smuzhiyun 			sk->sk_data_ready = ctx->saved_data_ready;
2223*4882a593Smuzhiyun 			write_unlock_bh(&sk->sk_callback_lock);
2224*4882a593Smuzhiyun 		}
2225*4882a593Smuzhiyun 	}
2226*4882a593Smuzhiyun }
2227*4882a593Smuzhiyun 
tls_sw_strparser_done(struct tls_context * tls_ctx)2228*4882a593Smuzhiyun void tls_sw_strparser_done(struct tls_context *tls_ctx)
2229*4882a593Smuzhiyun {
2230*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2231*4882a593Smuzhiyun 
2232*4882a593Smuzhiyun 	strp_done(&ctx->strp);
2233*4882a593Smuzhiyun }
2234*4882a593Smuzhiyun 
tls_sw_free_ctx_rx(struct tls_context * tls_ctx)2235*4882a593Smuzhiyun void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
2236*4882a593Smuzhiyun {
2237*4882a593Smuzhiyun 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2238*4882a593Smuzhiyun 
2239*4882a593Smuzhiyun 	kfree(ctx);
2240*4882a593Smuzhiyun }
2241*4882a593Smuzhiyun 
tls_sw_free_resources_rx(struct sock * sk)2242*4882a593Smuzhiyun void tls_sw_free_resources_rx(struct sock *sk)
2243*4882a593Smuzhiyun {
2244*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
2245*4882a593Smuzhiyun 
2246*4882a593Smuzhiyun 	tls_sw_release_resources_rx(sk);
2247*4882a593Smuzhiyun 	tls_sw_free_ctx_rx(tls_ctx);
2248*4882a593Smuzhiyun }
2249*4882a593Smuzhiyun 
2250*4882a593Smuzhiyun /* The work handler to transmitt the encrypted records in tx_list */
tx_work_handler(struct work_struct * work)2251*4882a593Smuzhiyun static void tx_work_handler(struct work_struct *work)
2252*4882a593Smuzhiyun {
2253*4882a593Smuzhiyun 	struct delayed_work *delayed_work = to_delayed_work(work);
2254*4882a593Smuzhiyun 	struct tx_work *tx_work = container_of(delayed_work,
2255*4882a593Smuzhiyun 					       struct tx_work, work);
2256*4882a593Smuzhiyun 	struct sock *sk = tx_work->sk;
2257*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
2258*4882a593Smuzhiyun 	struct tls_sw_context_tx *ctx;
2259*4882a593Smuzhiyun 
2260*4882a593Smuzhiyun 	if (unlikely(!tls_ctx))
2261*4882a593Smuzhiyun 		return;
2262*4882a593Smuzhiyun 
2263*4882a593Smuzhiyun 	ctx = tls_sw_ctx_tx(tls_ctx);
2264*4882a593Smuzhiyun 	if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
2265*4882a593Smuzhiyun 		return;
2266*4882a593Smuzhiyun 
2267*4882a593Smuzhiyun 	if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
2268*4882a593Smuzhiyun 		return;
2269*4882a593Smuzhiyun 	mutex_lock(&tls_ctx->tx_lock);
2270*4882a593Smuzhiyun 	lock_sock(sk);
2271*4882a593Smuzhiyun 	tls_tx_records(sk, -1);
2272*4882a593Smuzhiyun 	release_sock(sk);
2273*4882a593Smuzhiyun 	mutex_unlock(&tls_ctx->tx_lock);
2274*4882a593Smuzhiyun }
2275*4882a593Smuzhiyun 
tls_sw_write_space(struct sock * sk,struct tls_context * ctx)2276*4882a593Smuzhiyun void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
2277*4882a593Smuzhiyun {
2278*4882a593Smuzhiyun 	struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
2279*4882a593Smuzhiyun 
2280*4882a593Smuzhiyun 	/* Schedule the transmission if tx list is ready */
2281*4882a593Smuzhiyun 	if (is_tx_ready(tx_ctx) &&
2282*4882a593Smuzhiyun 	    !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
2283*4882a593Smuzhiyun 		schedule_delayed_work(&tx_ctx->tx_work.work, 0);
2284*4882a593Smuzhiyun }
2285*4882a593Smuzhiyun 
tls_sw_strparser_arm(struct sock * sk,struct tls_context * tls_ctx)2286*4882a593Smuzhiyun void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
2287*4882a593Smuzhiyun {
2288*4882a593Smuzhiyun 	struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2289*4882a593Smuzhiyun 
2290*4882a593Smuzhiyun 	write_lock_bh(&sk->sk_callback_lock);
2291*4882a593Smuzhiyun 	rx_ctx->saved_data_ready = sk->sk_data_ready;
2292*4882a593Smuzhiyun 	sk->sk_data_ready = tls_data_ready;
2293*4882a593Smuzhiyun 	write_unlock_bh(&sk->sk_callback_lock);
2294*4882a593Smuzhiyun 
2295*4882a593Smuzhiyun 	strp_check_rcv(&rx_ctx->strp);
2296*4882a593Smuzhiyun }
2297*4882a593Smuzhiyun 
tls_set_sw_offload(struct sock * sk,struct tls_context * ctx,int tx)2298*4882a593Smuzhiyun int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
2299*4882a593Smuzhiyun {
2300*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
2301*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
2302*4882a593Smuzhiyun 	struct tls_crypto_info *crypto_info;
2303*4882a593Smuzhiyun 	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
2304*4882a593Smuzhiyun 	struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
2305*4882a593Smuzhiyun 	struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
2306*4882a593Smuzhiyun 	struct tls_sw_context_tx *sw_ctx_tx = NULL;
2307*4882a593Smuzhiyun 	struct tls_sw_context_rx *sw_ctx_rx = NULL;
2308*4882a593Smuzhiyun 	struct cipher_context *cctx;
2309*4882a593Smuzhiyun 	struct crypto_aead **aead;
2310*4882a593Smuzhiyun 	struct strp_callbacks cb;
2311*4882a593Smuzhiyun 	u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
2312*4882a593Smuzhiyun 	struct crypto_tfm *tfm;
2313*4882a593Smuzhiyun 	char *iv, *rec_seq, *key, *salt, *cipher_name;
2314*4882a593Smuzhiyun 	size_t keysize;
2315*4882a593Smuzhiyun 	int rc = 0;
2316*4882a593Smuzhiyun 
2317*4882a593Smuzhiyun 	if (!ctx) {
2318*4882a593Smuzhiyun 		rc = -EINVAL;
2319*4882a593Smuzhiyun 		goto out;
2320*4882a593Smuzhiyun 	}
2321*4882a593Smuzhiyun 
2322*4882a593Smuzhiyun 	if (tx) {
2323*4882a593Smuzhiyun 		if (!ctx->priv_ctx_tx) {
2324*4882a593Smuzhiyun 			sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
2325*4882a593Smuzhiyun 			if (!sw_ctx_tx) {
2326*4882a593Smuzhiyun 				rc = -ENOMEM;
2327*4882a593Smuzhiyun 				goto out;
2328*4882a593Smuzhiyun 			}
2329*4882a593Smuzhiyun 			ctx->priv_ctx_tx = sw_ctx_tx;
2330*4882a593Smuzhiyun 		} else {
2331*4882a593Smuzhiyun 			sw_ctx_tx =
2332*4882a593Smuzhiyun 				(struct tls_sw_context_tx *)ctx->priv_ctx_tx;
2333*4882a593Smuzhiyun 		}
2334*4882a593Smuzhiyun 	} else {
2335*4882a593Smuzhiyun 		if (!ctx->priv_ctx_rx) {
2336*4882a593Smuzhiyun 			sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
2337*4882a593Smuzhiyun 			if (!sw_ctx_rx) {
2338*4882a593Smuzhiyun 				rc = -ENOMEM;
2339*4882a593Smuzhiyun 				goto out;
2340*4882a593Smuzhiyun 			}
2341*4882a593Smuzhiyun 			ctx->priv_ctx_rx = sw_ctx_rx;
2342*4882a593Smuzhiyun 		} else {
2343*4882a593Smuzhiyun 			sw_ctx_rx =
2344*4882a593Smuzhiyun 				(struct tls_sw_context_rx *)ctx->priv_ctx_rx;
2345*4882a593Smuzhiyun 		}
2346*4882a593Smuzhiyun 	}
2347*4882a593Smuzhiyun 
2348*4882a593Smuzhiyun 	if (tx) {
2349*4882a593Smuzhiyun 		crypto_init_wait(&sw_ctx_tx->async_wait);
2350*4882a593Smuzhiyun 		spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
2351*4882a593Smuzhiyun 		crypto_info = &ctx->crypto_send.info;
2352*4882a593Smuzhiyun 		cctx = &ctx->tx;
2353*4882a593Smuzhiyun 		aead = &sw_ctx_tx->aead_send;
2354*4882a593Smuzhiyun 		INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
2355*4882a593Smuzhiyun 		INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
2356*4882a593Smuzhiyun 		sw_ctx_tx->tx_work.sk = sk;
2357*4882a593Smuzhiyun 	} else {
2358*4882a593Smuzhiyun 		crypto_init_wait(&sw_ctx_rx->async_wait);
2359*4882a593Smuzhiyun 		spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
2360*4882a593Smuzhiyun 		crypto_info = &ctx->crypto_recv.info;
2361*4882a593Smuzhiyun 		cctx = &ctx->rx;
2362*4882a593Smuzhiyun 		skb_queue_head_init(&sw_ctx_rx->rx_list);
2363*4882a593Smuzhiyun 		aead = &sw_ctx_rx->aead_recv;
2364*4882a593Smuzhiyun 	}
2365*4882a593Smuzhiyun 
2366*4882a593Smuzhiyun 	switch (crypto_info->cipher_type) {
2367*4882a593Smuzhiyun 	case TLS_CIPHER_AES_GCM_128: {
2368*4882a593Smuzhiyun 		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2369*4882a593Smuzhiyun 		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
2370*4882a593Smuzhiyun 		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2371*4882a593Smuzhiyun 		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
2372*4882a593Smuzhiyun 		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
2373*4882a593Smuzhiyun 		rec_seq =
2374*4882a593Smuzhiyun 		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
2375*4882a593Smuzhiyun 		gcm_128_info =
2376*4882a593Smuzhiyun 			(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
2377*4882a593Smuzhiyun 		keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
2378*4882a593Smuzhiyun 		key = gcm_128_info->key;
2379*4882a593Smuzhiyun 		salt = gcm_128_info->salt;
2380*4882a593Smuzhiyun 		salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
2381*4882a593Smuzhiyun 		cipher_name = "gcm(aes)";
2382*4882a593Smuzhiyun 		break;
2383*4882a593Smuzhiyun 	}
2384*4882a593Smuzhiyun 	case TLS_CIPHER_AES_GCM_256: {
2385*4882a593Smuzhiyun 		nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2386*4882a593Smuzhiyun 		tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
2387*4882a593Smuzhiyun 		iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2388*4882a593Smuzhiyun 		iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv;
2389*4882a593Smuzhiyun 		rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
2390*4882a593Smuzhiyun 		rec_seq =
2391*4882a593Smuzhiyun 		 ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq;
2392*4882a593Smuzhiyun 		gcm_256_info =
2393*4882a593Smuzhiyun 			(struct tls12_crypto_info_aes_gcm_256 *)crypto_info;
2394*4882a593Smuzhiyun 		keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
2395*4882a593Smuzhiyun 		key = gcm_256_info->key;
2396*4882a593Smuzhiyun 		salt = gcm_256_info->salt;
2397*4882a593Smuzhiyun 		salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
2398*4882a593Smuzhiyun 		cipher_name = "gcm(aes)";
2399*4882a593Smuzhiyun 		break;
2400*4882a593Smuzhiyun 	}
2401*4882a593Smuzhiyun 	case TLS_CIPHER_AES_CCM_128: {
2402*4882a593Smuzhiyun 		nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2403*4882a593Smuzhiyun 		tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
2404*4882a593Smuzhiyun 		iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2405*4882a593Smuzhiyun 		iv = ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->iv;
2406*4882a593Smuzhiyun 		rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
2407*4882a593Smuzhiyun 		rec_seq =
2408*4882a593Smuzhiyun 		((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->rec_seq;
2409*4882a593Smuzhiyun 		ccm_128_info =
2410*4882a593Smuzhiyun 		(struct tls12_crypto_info_aes_ccm_128 *)crypto_info;
2411*4882a593Smuzhiyun 		keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
2412*4882a593Smuzhiyun 		key = ccm_128_info->key;
2413*4882a593Smuzhiyun 		salt = ccm_128_info->salt;
2414*4882a593Smuzhiyun 		salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
2415*4882a593Smuzhiyun 		cipher_name = "ccm(aes)";
2416*4882a593Smuzhiyun 		break;
2417*4882a593Smuzhiyun 	}
2418*4882a593Smuzhiyun 	default:
2419*4882a593Smuzhiyun 		rc = -EINVAL;
2420*4882a593Smuzhiyun 		goto free_priv;
2421*4882a593Smuzhiyun 	}
2422*4882a593Smuzhiyun 
2423*4882a593Smuzhiyun 	/* Sanity-check the sizes for stack allocations. */
2424*4882a593Smuzhiyun 	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
2425*4882a593Smuzhiyun 	    rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
2426*4882a593Smuzhiyun 		rc = -EINVAL;
2427*4882a593Smuzhiyun 		goto free_priv;
2428*4882a593Smuzhiyun 	}
2429*4882a593Smuzhiyun 
2430*4882a593Smuzhiyun 	if (crypto_info->version == TLS_1_3_VERSION) {
2431*4882a593Smuzhiyun 		nonce_size = 0;
2432*4882a593Smuzhiyun 		prot->aad_size = TLS_HEADER_SIZE;
2433*4882a593Smuzhiyun 		prot->tail_size = 1;
2434*4882a593Smuzhiyun 	} else {
2435*4882a593Smuzhiyun 		prot->aad_size = TLS_AAD_SPACE_SIZE;
2436*4882a593Smuzhiyun 		prot->tail_size = 0;
2437*4882a593Smuzhiyun 	}
2438*4882a593Smuzhiyun 
2439*4882a593Smuzhiyun 	prot->version = crypto_info->version;
2440*4882a593Smuzhiyun 	prot->cipher_type = crypto_info->cipher_type;
2441*4882a593Smuzhiyun 	prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
2442*4882a593Smuzhiyun 	prot->tag_size = tag_size;
2443*4882a593Smuzhiyun 	prot->overhead_size = prot->prepend_size +
2444*4882a593Smuzhiyun 			      prot->tag_size + prot->tail_size;
2445*4882a593Smuzhiyun 	prot->iv_size = iv_size;
2446*4882a593Smuzhiyun 	prot->salt_size = salt_size;
2447*4882a593Smuzhiyun 	cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
2448*4882a593Smuzhiyun 	if (!cctx->iv) {
2449*4882a593Smuzhiyun 		rc = -ENOMEM;
2450*4882a593Smuzhiyun 		goto free_priv;
2451*4882a593Smuzhiyun 	}
2452*4882a593Smuzhiyun 	/* Note: 128 & 256 bit salt are the same size */
2453*4882a593Smuzhiyun 	prot->rec_seq_size = rec_seq_size;
2454*4882a593Smuzhiyun 	memcpy(cctx->iv, salt, salt_size);
2455*4882a593Smuzhiyun 	memcpy(cctx->iv + salt_size, iv, iv_size);
2456*4882a593Smuzhiyun 	cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
2457*4882a593Smuzhiyun 	if (!cctx->rec_seq) {
2458*4882a593Smuzhiyun 		rc = -ENOMEM;
2459*4882a593Smuzhiyun 		goto free_iv;
2460*4882a593Smuzhiyun 	}
2461*4882a593Smuzhiyun 
2462*4882a593Smuzhiyun 	if (!*aead) {
2463*4882a593Smuzhiyun 		*aead = crypto_alloc_aead(cipher_name, 0, 0);
2464*4882a593Smuzhiyun 		if (IS_ERR(*aead)) {
2465*4882a593Smuzhiyun 			rc = PTR_ERR(*aead);
2466*4882a593Smuzhiyun 			*aead = NULL;
2467*4882a593Smuzhiyun 			goto free_rec_seq;
2468*4882a593Smuzhiyun 		}
2469*4882a593Smuzhiyun 	}
2470*4882a593Smuzhiyun 
2471*4882a593Smuzhiyun 	ctx->push_pending_record = tls_sw_push_pending_record;
2472*4882a593Smuzhiyun 
2473*4882a593Smuzhiyun 	rc = crypto_aead_setkey(*aead, key, keysize);
2474*4882a593Smuzhiyun 
2475*4882a593Smuzhiyun 	if (rc)
2476*4882a593Smuzhiyun 		goto free_aead;
2477*4882a593Smuzhiyun 
2478*4882a593Smuzhiyun 	rc = crypto_aead_setauthsize(*aead, prot->tag_size);
2479*4882a593Smuzhiyun 	if (rc)
2480*4882a593Smuzhiyun 		goto free_aead;
2481*4882a593Smuzhiyun 
2482*4882a593Smuzhiyun 	if (sw_ctx_rx) {
2483*4882a593Smuzhiyun 		tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
2484*4882a593Smuzhiyun 
2485*4882a593Smuzhiyun 		if (crypto_info->version == TLS_1_3_VERSION)
2486*4882a593Smuzhiyun 			sw_ctx_rx->async_capable = 0;
2487*4882a593Smuzhiyun 		else
2488*4882a593Smuzhiyun 			sw_ctx_rx->async_capable =
2489*4882a593Smuzhiyun 				!!(tfm->__crt_alg->cra_flags &
2490*4882a593Smuzhiyun 				   CRYPTO_ALG_ASYNC);
2491*4882a593Smuzhiyun 
2492*4882a593Smuzhiyun 		/* Set up strparser */
2493*4882a593Smuzhiyun 		memset(&cb, 0, sizeof(cb));
2494*4882a593Smuzhiyun 		cb.rcv_msg = tls_queue;
2495*4882a593Smuzhiyun 		cb.parse_msg = tls_read_size;
2496*4882a593Smuzhiyun 
2497*4882a593Smuzhiyun 		strp_init(&sw_ctx_rx->strp, sk, &cb);
2498*4882a593Smuzhiyun 	}
2499*4882a593Smuzhiyun 
2500*4882a593Smuzhiyun 	goto out;
2501*4882a593Smuzhiyun 
2502*4882a593Smuzhiyun free_aead:
2503*4882a593Smuzhiyun 	crypto_free_aead(*aead);
2504*4882a593Smuzhiyun 	*aead = NULL;
2505*4882a593Smuzhiyun free_rec_seq:
2506*4882a593Smuzhiyun 	kfree(cctx->rec_seq);
2507*4882a593Smuzhiyun 	cctx->rec_seq = NULL;
2508*4882a593Smuzhiyun free_iv:
2509*4882a593Smuzhiyun 	kfree(cctx->iv);
2510*4882a593Smuzhiyun 	cctx->iv = NULL;
2511*4882a593Smuzhiyun free_priv:
2512*4882a593Smuzhiyun 	if (tx) {
2513*4882a593Smuzhiyun 		kfree(ctx->priv_ctx_tx);
2514*4882a593Smuzhiyun 		ctx->priv_ctx_tx = NULL;
2515*4882a593Smuzhiyun 	} else {
2516*4882a593Smuzhiyun 		kfree(ctx->priv_ctx_rx);
2517*4882a593Smuzhiyun 		ctx->priv_ctx_rx = NULL;
2518*4882a593Smuzhiyun 	}
2519*4882a593Smuzhiyun out:
2520*4882a593Smuzhiyun 	return rc;
2521*4882a593Smuzhiyun }
2522