xref: /OK3568_Linux_fs/kernel/net/tls/tls_device.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun /* Copyright (c) 2018, Mellanox Technologies All rights reserved.
2*4882a593Smuzhiyun  *
3*4882a593Smuzhiyun  * This software is available to you under a choice of one of two
4*4882a593Smuzhiyun  * licenses.  You may choose to be licensed under the terms of the GNU
5*4882a593Smuzhiyun  * General Public License (GPL) Version 2, available from the file
6*4882a593Smuzhiyun  * COPYING in the main directory of this source tree, or the
7*4882a593Smuzhiyun  * OpenIB.org BSD license below:
8*4882a593Smuzhiyun  *
9*4882a593Smuzhiyun  *     Redistribution and use in source and binary forms, with or
10*4882a593Smuzhiyun  *     without modification, are permitted provided that the following
11*4882a593Smuzhiyun  *     conditions are met:
12*4882a593Smuzhiyun  *
13*4882a593Smuzhiyun  *      - Redistributions of source code must retain the above
14*4882a593Smuzhiyun  *        copyright notice, this list of conditions and the following
15*4882a593Smuzhiyun  *        disclaimer.
16*4882a593Smuzhiyun  *
17*4882a593Smuzhiyun  *      - Redistributions in binary form must reproduce the above
18*4882a593Smuzhiyun  *        copyright notice, this list of conditions and the following
19*4882a593Smuzhiyun  *        disclaimer in the documentation and/or other materials
20*4882a593Smuzhiyun  *        provided with the distribution.
21*4882a593Smuzhiyun  *
22*4882a593Smuzhiyun  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23*4882a593Smuzhiyun  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24*4882a593Smuzhiyun  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25*4882a593Smuzhiyun  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
26*4882a593Smuzhiyun  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
27*4882a593Smuzhiyun  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
28*4882a593Smuzhiyun  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29*4882a593Smuzhiyun  * SOFTWARE.
30*4882a593Smuzhiyun  */
31*4882a593Smuzhiyun 
32*4882a593Smuzhiyun #include <crypto/aead.h>
33*4882a593Smuzhiyun #include <linux/highmem.h>
34*4882a593Smuzhiyun #include <linux/module.h>
35*4882a593Smuzhiyun #include <linux/netdevice.h>
36*4882a593Smuzhiyun #include <net/dst.h>
37*4882a593Smuzhiyun #include <net/inet_connection_sock.h>
38*4882a593Smuzhiyun #include <net/tcp.h>
39*4882a593Smuzhiyun #include <net/tls.h>
40*4882a593Smuzhiyun 
41*4882a593Smuzhiyun #include "trace.h"
42*4882a593Smuzhiyun 
43*4882a593Smuzhiyun /* device_offload_lock is used to synchronize tls_dev_add
44*4882a593Smuzhiyun  * against NETDEV_DOWN notifications.
45*4882a593Smuzhiyun  */
46*4882a593Smuzhiyun static DECLARE_RWSEM(device_offload_lock);
47*4882a593Smuzhiyun 
48*4882a593Smuzhiyun static void tls_device_gc_task(struct work_struct *work);
49*4882a593Smuzhiyun 
50*4882a593Smuzhiyun static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
51*4882a593Smuzhiyun static LIST_HEAD(tls_device_gc_list);
52*4882a593Smuzhiyun static LIST_HEAD(tls_device_list);
53*4882a593Smuzhiyun static LIST_HEAD(tls_device_down_list);
54*4882a593Smuzhiyun static DEFINE_SPINLOCK(tls_device_lock);
55*4882a593Smuzhiyun 
tls_device_free_ctx(struct tls_context * ctx)56*4882a593Smuzhiyun static void tls_device_free_ctx(struct tls_context *ctx)
57*4882a593Smuzhiyun {
58*4882a593Smuzhiyun 	if (ctx->tx_conf == TLS_HW) {
59*4882a593Smuzhiyun 		kfree(tls_offload_ctx_tx(ctx));
60*4882a593Smuzhiyun 		kfree(ctx->tx.rec_seq);
61*4882a593Smuzhiyun 		kfree(ctx->tx.iv);
62*4882a593Smuzhiyun 	}
63*4882a593Smuzhiyun 
64*4882a593Smuzhiyun 	if (ctx->rx_conf == TLS_HW)
65*4882a593Smuzhiyun 		kfree(tls_offload_ctx_rx(ctx));
66*4882a593Smuzhiyun 
67*4882a593Smuzhiyun 	tls_ctx_free(NULL, ctx);
68*4882a593Smuzhiyun }
69*4882a593Smuzhiyun 
tls_device_gc_task(struct work_struct * work)70*4882a593Smuzhiyun static void tls_device_gc_task(struct work_struct *work)
71*4882a593Smuzhiyun {
72*4882a593Smuzhiyun 	struct tls_context *ctx, *tmp;
73*4882a593Smuzhiyun 	unsigned long flags;
74*4882a593Smuzhiyun 	LIST_HEAD(gc_list);
75*4882a593Smuzhiyun 
76*4882a593Smuzhiyun 	spin_lock_irqsave(&tls_device_lock, flags);
77*4882a593Smuzhiyun 	list_splice_init(&tls_device_gc_list, &gc_list);
78*4882a593Smuzhiyun 	spin_unlock_irqrestore(&tls_device_lock, flags);
79*4882a593Smuzhiyun 
80*4882a593Smuzhiyun 	list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
81*4882a593Smuzhiyun 		struct net_device *netdev = ctx->netdev;
82*4882a593Smuzhiyun 
83*4882a593Smuzhiyun 		if (netdev && ctx->tx_conf == TLS_HW) {
84*4882a593Smuzhiyun 			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
85*4882a593Smuzhiyun 							TLS_OFFLOAD_CTX_DIR_TX);
86*4882a593Smuzhiyun 			dev_put(netdev);
87*4882a593Smuzhiyun 			ctx->netdev = NULL;
88*4882a593Smuzhiyun 		}
89*4882a593Smuzhiyun 
90*4882a593Smuzhiyun 		list_del(&ctx->list);
91*4882a593Smuzhiyun 		tls_device_free_ctx(ctx);
92*4882a593Smuzhiyun 	}
93*4882a593Smuzhiyun }
94*4882a593Smuzhiyun 
tls_device_queue_ctx_destruction(struct tls_context * ctx)95*4882a593Smuzhiyun static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
96*4882a593Smuzhiyun {
97*4882a593Smuzhiyun 	unsigned long flags;
98*4882a593Smuzhiyun 
99*4882a593Smuzhiyun 	spin_lock_irqsave(&tls_device_lock, flags);
100*4882a593Smuzhiyun 	if (unlikely(!refcount_dec_and_test(&ctx->refcount)))
101*4882a593Smuzhiyun 		goto unlock;
102*4882a593Smuzhiyun 
103*4882a593Smuzhiyun 	list_move_tail(&ctx->list, &tls_device_gc_list);
104*4882a593Smuzhiyun 
105*4882a593Smuzhiyun 	/* schedule_work inside the spinlock
106*4882a593Smuzhiyun 	 * to make sure tls_device_down waits for that work.
107*4882a593Smuzhiyun 	 */
108*4882a593Smuzhiyun 	schedule_work(&tls_device_gc_work);
109*4882a593Smuzhiyun unlock:
110*4882a593Smuzhiyun 	spin_unlock_irqrestore(&tls_device_lock, flags);
111*4882a593Smuzhiyun }
112*4882a593Smuzhiyun 
113*4882a593Smuzhiyun /* We assume that the socket is already connected */
get_netdev_for_sock(struct sock * sk)114*4882a593Smuzhiyun static struct net_device *get_netdev_for_sock(struct sock *sk)
115*4882a593Smuzhiyun {
116*4882a593Smuzhiyun 	struct dst_entry *dst = sk_dst_get(sk);
117*4882a593Smuzhiyun 	struct net_device *netdev = NULL;
118*4882a593Smuzhiyun 
119*4882a593Smuzhiyun 	if (likely(dst)) {
120*4882a593Smuzhiyun 		netdev = dst->dev;
121*4882a593Smuzhiyun 		dev_hold(netdev);
122*4882a593Smuzhiyun 	}
123*4882a593Smuzhiyun 
124*4882a593Smuzhiyun 	dst_release(dst);
125*4882a593Smuzhiyun 
126*4882a593Smuzhiyun 	return netdev;
127*4882a593Smuzhiyun }
128*4882a593Smuzhiyun 
destroy_record(struct tls_record_info * record)129*4882a593Smuzhiyun static void destroy_record(struct tls_record_info *record)
130*4882a593Smuzhiyun {
131*4882a593Smuzhiyun 	int i;
132*4882a593Smuzhiyun 
133*4882a593Smuzhiyun 	for (i = 0; i < record->num_frags; i++)
134*4882a593Smuzhiyun 		__skb_frag_unref(&record->frags[i]);
135*4882a593Smuzhiyun 	kfree(record);
136*4882a593Smuzhiyun }
137*4882a593Smuzhiyun 
delete_all_records(struct tls_offload_context_tx * offload_ctx)138*4882a593Smuzhiyun static void delete_all_records(struct tls_offload_context_tx *offload_ctx)
139*4882a593Smuzhiyun {
140*4882a593Smuzhiyun 	struct tls_record_info *info, *temp;
141*4882a593Smuzhiyun 
142*4882a593Smuzhiyun 	list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
143*4882a593Smuzhiyun 		list_del(&info->list);
144*4882a593Smuzhiyun 		destroy_record(info);
145*4882a593Smuzhiyun 	}
146*4882a593Smuzhiyun 
147*4882a593Smuzhiyun 	offload_ctx->retransmit_hint = NULL;
148*4882a593Smuzhiyun }
149*4882a593Smuzhiyun 
tls_icsk_clean_acked(struct sock * sk,u32 acked_seq)150*4882a593Smuzhiyun static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
151*4882a593Smuzhiyun {
152*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
153*4882a593Smuzhiyun 	struct tls_record_info *info, *temp;
154*4882a593Smuzhiyun 	struct tls_offload_context_tx *ctx;
155*4882a593Smuzhiyun 	u64 deleted_records = 0;
156*4882a593Smuzhiyun 	unsigned long flags;
157*4882a593Smuzhiyun 
158*4882a593Smuzhiyun 	if (!tls_ctx)
159*4882a593Smuzhiyun 		return;
160*4882a593Smuzhiyun 
161*4882a593Smuzhiyun 	ctx = tls_offload_ctx_tx(tls_ctx);
162*4882a593Smuzhiyun 
163*4882a593Smuzhiyun 	spin_lock_irqsave(&ctx->lock, flags);
164*4882a593Smuzhiyun 	info = ctx->retransmit_hint;
165*4882a593Smuzhiyun 	if (info && !before(acked_seq, info->end_seq))
166*4882a593Smuzhiyun 		ctx->retransmit_hint = NULL;
167*4882a593Smuzhiyun 
168*4882a593Smuzhiyun 	list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
169*4882a593Smuzhiyun 		if (before(acked_seq, info->end_seq))
170*4882a593Smuzhiyun 			break;
171*4882a593Smuzhiyun 		list_del(&info->list);
172*4882a593Smuzhiyun 
173*4882a593Smuzhiyun 		destroy_record(info);
174*4882a593Smuzhiyun 		deleted_records++;
175*4882a593Smuzhiyun 	}
176*4882a593Smuzhiyun 
177*4882a593Smuzhiyun 	ctx->unacked_record_sn += deleted_records;
178*4882a593Smuzhiyun 	spin_unlock_irqrestore(&ctx->lock, flags);
179*4882a593Smuzhiyun }
180*4882a593Smuzhiyun 
181*4882a593Smuzhiyun /* At this point, there should be no references on this
182*4882a593Smuzhiyun  * socket and no in-flight SKBs associated with this
183*4882a593Smuzhiyun  * socket, so it is safe to free all the resources.
184*4882a593Smuzhiyun  */
tls_device_sk_destruct(struct sock * sk)185*4882a593Smuzhiyun void tls_device_sk_destruct(struct sock *sk)
186*4882a593Smuzhiyun {
187*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
188*4882a593Smuzhiyun 	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
189*4882a593Smuzhiyun 
190*4882a593Smuzhiyun 	tls_ctx->sk_destruct(sk);
191*4882a593Smuzhiyun 
192*4882a593Smuzhiyun 	if (tls_ctx->tx_conf == TLS_HW) {
193*4882a593Smuzhiyun 		if (ctx->open_record)
194*4882a593Smuzhiyun 			destroy_record(ctx->open_record);
195*4882a593Smuzhiyun 		delete_all_records(ctx);
196*4882a593Smuzhiyun 		crypto_free_aead(ctx->aead_send);
197*4882a593Smuzhiyun 		clean_acked_data_disable(inet_csk(sk));
198*4882a593Smuzhiyun 	}
199*4882a593Smuzhiyun 
200*4882a593Smuzhiyun 	tls_device_queue_ctx_destruction(tls_ctx);
201*4882a593Smuzhiyun }
202*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(tls_device_sk_destruct);
203*4882a593Smuzhiyun 
tls_device_free_resources_tx(struct sock * sk)204*4882a593Smuzhiyun void tls_device_free_resources_tx(struct sock *sk)
205*4882a593Smuzhiyun {
206*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
207*4882a593Smuzhiyun 
208*4882a593Smuzhiyun 	tls_free_partial_record(sk, tls_ctx);
209*4882a593Smuzhiyun }
210*4882a593Smuzhiyun 
tls_offload_tx_resync_request(struct sock * sk,u32 got_seq,u32 exp_seq)211*4882a593Smuzhiyun void tls_offload_tx_resync_request(struct sock *sk, u32 got_seq, u32 exp_seq)
212*4882a593Smuzhiyun {
213*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
214*4882a593Smuzhiyun 
215*4882a593Smuzhiyun 	trace_tls_device_tx_resync_req(sk, got_seq, exp_seq);
216*4882a593Smuzhiyun 	WARN_ON(test_and_set_bit(TLS_TX_SYNC_SCHED, &tls_ctx->flags));
217*4882a593Smuzhiyun }
218*4882a593Smuzhiyun EXPORT_SYMBOL_GPL(tls_offload_tx_resync_request);
219*4882a593Smuzhiyun 
tls_device_resync_tx(struct sock * sk,struct tls_context * tls_ctx,u32 seq)220*4882a593Smuzhiyun static void tls_device_resync_tx(struct sock *sk, struct tls_context *tls_ctx,
221*4882a593Smuzhiyun 				 u32 seq)
222*4882a593Smuzhiyun {
223*4882a593Smuzhiyun 	struct net_device *netdev;
224*4882a593Smuzhiyun 	struct sk_buff *skb;
225*4882a593Smuzhiyun 	int err = 0;
226*4882a593Smuzhiyun 	u8 *rcd_sn;
227*4882a593Smuzhiyun 
228*4882a593Smuzhiyun 	skb = tcp_write_queue_tail(sk);
229*4882a593Smuzhiyun 	if (skb)
230*4882a593Smuzhiyun 		TCP_SKB_CB(skb)->eor = 1;
231*4882a593Smuzhiyun 
232*4882a593Smuzhiyun 	rcd_sn = tls_ctx->tx.rec_seq;
233*4882a593Smuzhiyun 
234*4882a593Smuzhiyun 	trace_tls_device_tx_resync_send(sk, seq, rcd_sn);
235*4882a593Smuzhiyun 	down_read(&device_offload_lock);
236*4882a593Smuzhiyun 	netdev = tls_ctx->netdev;
237*4882a593Smuzhiyun 	if (netdev)
238*4882a593Smuzhiyun 		err = netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq,
239*4882a593Smuzhiyun 							 rcd_sn,
240*4882a593Smuzhiyun 							 TLS_OFFLOAD_CTX_DIR_TX);
241*4882a593Smuzhiyun 	up_read(&device_offload_lock);
242*4882a593Smuzhiyun 	if (err)
243*4882a593Smuzhiyun 		return;
244*4882a593Smuzhiyun 
245*4882a593Smuzhiyun 	clear_bit_unlock(TLS_TX_SYNC_SCHED, &tls_ctx->flags);
246*4882a593Smuzhiyun }
247*4882a593Smuzhiyun 
tls_append_frag(struct tls_record_info * record,struct page_frag * pfrag,int size)248*4882a593Smuzhiyun static void tls_append_frag(struct tls_record_info *record,
249*4882a593Smuzhiyun 			    struct page_frag *pfrag,
250*4882a593Smuzhiyun 			    int size)
251*4882a593Smuzhiyun {
252*4882a593Smuzhiyun 	skb_frag_t *frag;
253*4882a593Smuzhiyun 
254*4882a593Smuzhiyun 	frag = &record->frags[record->num_frags - 1];
255*4882a593Smuzhiyun 	if (skb_frag_page(frag) == pfrag->page &&
256*4882a593Smuzhiyun 	    skb_frag_off(frag) + skb_frag_size(frag) == pfrag->offset) {
257*4882a593Smuzhiyun 		skb_frag_size_add(frag, size);
258*4882a593Smuzhiyun 	} else {
259*4882a593Smuzhiyun 		++frag;
260*4882a593Smuzhiyun 		__skb_frag_set_page(frag, pfrag->page);
261*4882a593Smuzhiyun 		skb_frag_off_set(frag, pfrag->offset);
262*4882a593Smuzhiyun 		skb_frag_size_set(frag, size);
263*4882a593Smuzhiyun 		++record->num_frags;
264*4882a593Smuzhiyun 		get_page(pfrag->page);
265*4882a593Smuzhiyun 	}
266*4882a593Smuzhiyun 
267*4882a593Smuzhiyun 	pfrag->offset += size;
268*4882a593Smuzhiyun 	record->len += size;
269*4882a593Smuzhiyun }
270*4882a593Smuzhiyun 
tls_push_record(struct sock * sk,struct tls_context * ctx,struct tls_offload_context_tx * offload_ctx,struct tls_record_info * record,int flags)271*4882a593Smuzhiyun static int tls_push_record(struct sock *sk,
272*4882a593Smuzhiyun 			   struct tls_context *ctx,
273*4882a593Smuzhiyun 			   struct tls_offload_context_tx *offload_ctx,
274*4882a593Smuzhiyun 			   struct tls_record_info *record,
275*4882a593Smuzhiyun 			   int flags)
276*4882a593Smuzhiyun {
277*4882a593Smuzhiyun 	struct tls_prot_info *prot = &ctx->prot_info;
278*4882a593Smuzhiyun 	struct tcp_sock *tp = tcp_sk(sk);
279*4882a593Smuzhiyun 	skb_frag_t *frag;
280*4882a593Smuzhiyun 	int i;
281*4882a593Smuzhiyun 
282*4882a593Smuzhiyun 	record->end_seq = tp->write_seq + record->len;
283*4882a593Smuzhiyun 	list_add_tail_rcu(&record->list, &offload_ctx->records_list);
284*4882a593Smuzhiyun 	offload_ctx->open_record = NULL;
285*4882a593Smuzhiyun 
286*4882a593Smuzhiyun 	if (test_bit(TLS_TX_SYNC_SCHED, &ctx->flags))
287*4882a593Smuzhiyun 		tls_device_resync_tx(sk, ctx, tp->write_seq);
288*4882a593Smuzhiyun 
289*4882a593Smuzhiyun 	tls_advance_record_sn(sk, prot, &ctx->tx);
290*4882a593Smuzhiyun 
291*4882a593Smuzhiyun 	for (i = 0; i < record->num_frags; i++) {
292*4882a593Smuzhiyun 		frag = &record->frags[i];
293*4882a593Smuzhiyun 		sg_unmark_end(&offload_ctx->sg_tx_data[i]);
294*4882a593Smuzhiyun 		sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
295*4882a593Smuzhiyun 			    skb_frag_size(frag), skb_frag_off(frag));
296*4882a593Smuzhiyun 		sk_mem_charge(sk, skb_frag_size(frag));
297*4882a593Smuzhiyun 		get_page(skb_frag_page(frag));
298*4882a593Smuzhiyun 	}
299*4882a593Smuzhiyun 	sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
300*4882a593Smuzhiyun 
301*4882a593Smuzhiyun 	/* all ready, send */
302*4882a593Smuzhiyun 	return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
303*4882a593Smuzhiyun }
304*4882a593Smuzhiyun 
tls_device_record_close(struct sock * sk,struct tls_context * ctx,struct tls_record_info * record,struct page_frag * pfrag,unsigned char record_type)305*4882a593Smuzhiyun static int tls_device_record_close(struct sock *sk,
306*4882a593Smuzhiyun 				   struct tls_context *ctx,
307*4882a593Smuzhiyun 				   struct tls_record_info *record,
308*4882a593Smuzhiyun 				   struct page_frag *pfrag,
309*4882a593Smuzhiyun 				   unsigned char record_type)
310*4882a593Smuzhiyun {
311*4882a593Smuzhiyun 	struct tls_prot_info *prot = &ctx->prot_info;
312*4882a593Smuzhiyun 	int ret;
313*4882a593Smuzhiyun 
314*4882a593Smuzhiyun 	/* append tag
315*4882a593Smuzhiyun 	 * device will fill in the tag, we just need to append a placeholder
316*4882a593Smuzhiyun 	 * use socket memory to improve coalescing (re-using a single buffer
317*4882a593Smuzhiyun 	 * increases frag count)
318*4882a593Smuzhiyun 	 * if we can't allocate memory now, steal some back from data
319*4882a593Smuzhiyun 	 */
320*4882a593Smuzhiyun 	if (likely(skb_page_frag_refill(prot->tag_size, pfrag,
321*4882a593Smuzhiyun 					sk->sk_allocation))) {
322*4882a593Smuzhiyun 		ret = 0;
323*4882a593Smuzhiyun 		tls_append_frag(record, pfrag, prot->tag_size);
324*4882a593Smuzhiyun 	} else {
325*4882a593Smuzhiyun 		ret = prot->tag_size;
326*4882a593Smuzhiyun 		if (record->len <= prot->overhead_size)
327*4882a593Smuzhiyun 			return -ENOMEM;
328*4882a593Smuzhiyun 	}
329*4882a593Smuzhiyun 
330*4882a593Smuzhiyun 	/* fill prepend */
331*4882a593Smuzhiyun 	tls_fill_prepend(ctx, skb_frag_address(&record->frags[0]),
332*4882a593Smuzhiyun 			 record->len - prot->overhead_size,
333*4882a593Smuzhiyun 			 record_type, prot->version);
334*4882a593Smuzhiyun 	return ret;
335*4882a593Smuzhiyun }
336*4882a593Smuzhiyun 
tls_create_new_record(struct tls_offload_context_tx * offload_ctx,struct page_frag * pfrag,size_t prepend_size)337*4882a593Smuzhiyun static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
338*4882a593Smuzhiyun 				 struct page_frag *pfrag,
339*4882a593Smuzhiyun 				 size_t prepend_size)
340*4882a593Smuzhiyun {
341*4882a593Smuzhiyun 	struct tls_record_info *record;
342*4882a593Smuzhiyun 	skb_frag_t *frag;
343*4882a593Smuzhiyun 
344*4882a593Smuzhiyun 	record = kmalloc(sizeof(*record), GFP_KERNEL);
345*4882a593Smuzhiyun 	if (!record)
346*4882a593Smuzhiyun 		return -ENOMEM;
347*4882a593Smuzhiyun 
348*4882a593Smuzhiyun 	frag = &record->frags[0];
349*4882a593Smuzhiyun 	__skb_frag_set_page(frag, pfrag->page);
350*4882a593Smuzhiyun 	skb_frag_off_set(frag, pfrag->offset);
351*4882a593Smuzhiyun 	skb_frag_size_set(frag, prepend_size);
352*4882a593Smuzhiyun 
353*4882a593Smuzhiyun 	get_page(pfrag->page);
354*4882a593Smuzhiyun 	pfrag->offset += prepend_size;
355*4882a593Smuzhiyun 
356*4882a593Smuzhiyun 	record->num_frags = 1;
357*4882a593Smuzhiyun 	record->len = prepend_size;
358*4882a593Smuzhiyun 	offload_ctx->open_record = record;
359*4882a593Smuzhiyun 	return 0;
360*4882a593Smuzhiyun }
361*4882a593Smuzhiyun 
tls_do_allocation(struct sock * sk,struct tls_offload_context_tx * offload_ctx,struct page_frag * pfrag,size_t prepend_size)362*4882a593Smuzhiyun static int tls_do_allocation(struct sock *sk,
363*4882a593Smuzhiyun 			     struct tls_offload_context_tx *offload_ctx,
364*4882a593Smuzhiyun 			     struct page_frag *pfrag,
365*4882a593Smuzhiyun 			     size_t prepend_size)
366*4882a593Smuzhiyun {
367*4882a593Smuzhiyun 	int ret;
368*4882a593Smuzhiyun 
369*4882a593Smuzhiyun 	if (!offload_ctx->open_record) {
370*4882a593Smuzhiyun 		if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
371*4882a593Smuzhiyun 						   sk->sk_allocation))) {
372*4882a593Smuzhiyun 			READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
373*4882a593Smuzhiyun 			sk_stream_moderate_sndbuf(sk);
374*4882a593Smuzhiyun 			return -ENOMEM;
375*4882a593Smuzhiyun 		}
376*4882a593Smuzhiyun 
377*4882a593Smuzhiyun 		ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
378*4882a593Smuzhiyun 		if (ret)
379*4882a593Smuzhiyun 			return ret;
380*4882a593Smuzhiyun 
381*4882a593Smuzhiyun 		if (pfrag->size > pfrag->offset)
382*4882a593Smuzhiyun 			return 0;
383*4882a593Smuzhiyun 	}
384*4882a593Smuzhiyun 
385*4882a593Smuzhiyun 	if (!sk_page_frag_refill(sk, pfrag))
386*4882a593Smuzhiyun 		return -ENOMEM;
387*4882a593Smuzhiyun 
388*4882a593Smuzhiyun 	return 0;
389*4882a593Smuzhiyun }
390*4882a593Smuzhiyun 
tls_device_copy_data(void * addr,size_t bytes,struct iov_iter * i)391*4882a593Smuzhiyun static int tls_device_copy_data(void *addr, size_t bytes, struct iov_iter *i)
392*4882a593Smuzhiyun {
393*4882a593Smuzhiyun 	size_t pre_copy, nocache;
394*4882a593Smuzhiyun 
395*4882a593Smuzhiyun 	pre_copy = ~((unsigned long)addr - 1) & (SMP_CACHE_BYTES - 1);
396*4882a593Smuzhiyun 	if (pre_copy) {
397*4882a593Smuzhiyun 		pre_copy = min(pre_copy, bytes);
398*4882a593Smuzhiyun 		if (copy_from_iter(addr, pre_copy, i) != pre_copy)
399*4882a593Smuzhiyun 			return -EFAULT;
400*4882a593Smuzhiyun 		bytes -= pre_copy;
401*4882a593Smuzhiyun 		addr += pre_copy;
402*4882a593Smuzhiyun 	}
403*4882a593Smuzhiyun 
404*4882a593Smuzhiyun 	nocache = round_down(bytes, SMP_CACHE_BYTES);
405*4882a593Smuzhiyun 	if (copy_from_iter_nocache(addr, nocache, i) != nocache)
406*4882a593Smuzhiyun 		return -EFAULT;
407*4882a593Smuzhiyun 	bytes -= nocache;
408*4882a593Smuzhiyun 	addr += nocache;
409*4882a593Smuzhiyun 
410*4882a593Smuzhiyun 	if (bytes && copy_from_iter(addr, bytes, i) != bytes)
411*4882a593Smuzhiyun 		return -EFAULT;
412*4882a593Smuzhiyun 
413*4882a593Smuzhiyun 	return 0;
414*4882a593Smuzhiyun }
415*4882a593Smuzhiyun 
tls_push_data(struct sock * sk,struct iov_iter * msg_iter,size_t size,int flags,unsigned char record_type)416*4882a593Smuzhiyun static int tls_push_data(struct sock *sk,
417*4882a593Smuzhiyun 			 struct iov_iter *msg_iter,
418*4882a593Smuzhiyun 			 size_t size, int flags,
419*4882a593Smuzhiyun 			 unsigned char record_type)
420*4882a593Smuzhiyun {
421*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
422*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
423*4882a593Smuzhiyun 	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
424*4882a593Smuzhiyun 	struct tls_record_info *record = ctx->open_record;
425*4882a593Smuzhiyun 	int tls_push_record_flags;
426*4882a593Smuzhiyun 	struct page_frag *pfrag;
427*4882a593Smuzhiyun 	size_t orig_size = size;
428*4882a593Smuzhiyun 	u32 max_open_record_len;
429*4882a593Smuzhiyun 	bool more = false;
430*4882a593Smuzhiyun 	bool done = false;
431*4882a593Smuzhiyun 	int copy, rc = 0;
432*4882a593Smuzhiyun 	long timeo;
433*4882a593Smuzhiyun 
434*4882a593Smuzhiyun 	if (flags &
435*4882a593Smuzhiyun 	    ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
436*4882a593Smuzhiyun 		return -EOPNOTSUPP;
437*4882a593Smuzhiyun 
438*4882a593Smuzhiyun 	if (unlikely(sk->sk_err))
439*4882a593Smuzhiyun 		return -sk->sk_err;
440*4882a593Smuzhiyun 
441*4882a593Smuzhiyun 	flags |= MSG_SENDPAGE_DECRYPTED;
442*4882a593Smuzhiyun 	tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
443*4882a593Smuzhiyun 
444*4882a593Smuzhiyun 	timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
445*4882a593Smuzhiyun 	if (tls_is_partially_sent_record(tls_ctx)) {
446*4882a593Smuzhiyun 		rc = tls_push_partial_record(sk, tls_ctx, flags);
447*4882a593Smuzhiyun 		if (rc < 0)
448*4882a593Smuzhiyun 			return rc;
449*4882a593Smuzhiyun 	}
450*4882a593Smuzhiyun 
451*4882a593Smuzhiyun 	pfrag = sk_page_frag(sk);
452*4882a593Smuzhiyun 
453*4882a593Smuzhiyun 	/* TLS_HEADER_SIZE is not counted as part of the TLS record, and
454*4882a593Smuzhiyun 	 * we need to leave room for an authentication tag.
455*4882a593Smuzhiyun 	 */
456*4882a593Smuzhiyun 	max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
457*4882a593Smuzhiyun 			      prot->prepend_size;
458*4882a593Smuzhiyun 	do {
459*4882a593Smuzhiyun 		rc = tls_do_allocation(sk, ctx, pfrag, prot->prepend_size);
460*4882a593Smuzhiyun 		if (unlikely(rc)) {
461*4882a593Smuzhiyun 			rc = sk_stream_wait_memory(sk, &timeo);
462*4882a593Smuzhiyun 			if (!rc)
463*4882a593Smuzhiyun 				continue;
464*4882a593Smuzhiyun 
465*4882a593Smuzhiyun 			record = ctx->open_record;
466*4882a593Smuzhiyun 			if (!record)
467*4882a593Smuzhiyun 				break;
468*4882a593Smuzhiyun handle_error:
469*4882a593Smuzhiyun 			if (record_type != TLS_RECORD_TYPE_DATA) {
470*4882a593Smuzhiyun 				/* avoid sending partial
471*4882a593Smuzhiyun 				 * record with type !=
472*4882a593Smuzhiyun 				 * application_data
473*4882a593Smuzhiyun 				 */
474*4882a593Smuzhiyun 				size = orig_size;
475*4882a593Smuzhiyun 				destroy_record(record);
476*4882a593Smuzhiyun 				ctx->open_record = NULL;
477*4882a593Smuzhiyun 			} else if (record->len > prot->prepend_size) {
478*4882a593Smuzhiyun 				goto last_record;
479*4882a593Smuzhiyun 			}
480*4882a593Smuzhiyun 
481*4882a593Smuzhiyun 			break;
482*4882a593Smuzhiyun 		}
483*4882a593Smuzhiyun 
484*4882a593Smuzhiyun 		record = ctx->open_record;
485*4882a593Smuzhiyun 		copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
486*4882a593Smuzhiyun 		copy = min_t(size_t, copy, (max_open_record_len - record->len));
487*4882a593Smuzhiyun 
488*4882a593Smuzhiyun 		if (copy) {
489*4882a593Smuzhiyun 			rc = tls_device_copy_data(page_address(pfrag->page) +
490*4882a593Smuzhiyun 						  pfrag->offset, copy, msg_iter);
491*4882a593Smuzhiyun 			if (rc)
492*4882a593Smuzhiyun 				goto handle_error;
493*4882a593Smuzhiyun 			tls_append_frag(record, pfrag, copy);
494*4882a593Smuzhiyun 		}
495*4882a593Smuzhiyun 
496*4882a593Smuzhiyun 		size -= copy;
497*4882a593Smuzhiyun 		if (!size) {
498*4882a593Smuzhiyun last_record:
499*4882a593Smuzhiyun 			tls_push_record_flags = flags;
500*4882a593Smuzhiyun 			if (flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE)) {
501*4882a593Smuzhiyun 				more = true;
502*4882a593Smuzhiyun 				break;
503*4882a593Smuzhiyun 			}
504*4882a593Smuzhiyun 
505*4882a593Smuzhiyun 			done = true;
506*4882a593Smuzhiyun 		}
507*4882a593Smuzhiyun 
508*4882a593Smuzhiyun 		if (done || record->len >= max_open_record_len ||
509*4882a593Smuzhiyun 		    (record->num_frags >= MAX_SKB_FRAGS - 1)) {
510*4882a593Smuzhiyun 			rc = tls_device_record_close(sk, tls_ctx, record,
511*4882a593Smuzhiyun 						     pfrag, record_type);
512*4882a593Smuzhiyun 			if (rc) {
513*4882a593Smuzhiyun 				if (rc > 0) {
514*4882a593Smuzhiyun 					size += rc;
515*4882a593Smuzhiyun 				} else {
516*4882a593Smuzhiyun 					size = orig_size;
517*4882a593Smuzhiyun 					destroy_record(record);
518*4882a593Smuzhiyun 					ctx->open_record = NULL;
519*4882a593Smuzhiyun 					break;
520*4882a593Smuzhiyun 				}
521*4882a593Smuzhiyun 			}
522*4882a593Smuzhiyun 
523*4882a593Smuzhiyun 			rc = tls_push_record(sk,
524*4882a593Smuzhiyun 					     tls_ctx,
525*4882a593Smuzhiyun 					     ctx,
526*4882a593Smuzhiyun 					     record,
527*4882a593Smuzhiyun 					     tls_push_record_flags);
528*4882a593Smuzhiyun 			if (rc < 0)
529*4882a593Smuzhiyun 				break;
530*4882a593Smuzhiyun 		}
531*4882a593Smuzhiyun 	} while (!done);
532*4882a593Smuzhiyun 
533*4882a593Smuzhiyun 	tls_ctx->pending_open_record_frags = more;
534*4882a593Smuzhiyun 
535*4882a593Smuzhiyun 	if (orig_size - size > 0)
536*4882a593Smuzhiyun 		rc = orig_size - size;
537*4882a593Smuzhiyun 
538*4882a593Smuzhiyun 	return rc;
539*4882a593Smuzhiyun }
540*4882a593Smuzhiyun 
tls_device_sendmsg(struct sock * sk,struct msghdr * msg,size_t size)541*4882a593Smuzhiyun int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
542*4882a593Smuzhiyun {
543*4882a593Smuzhiyun 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
544*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
545*4882a593Smuzhiyun 	int rc;
546*4882a593Smuzhiyun 
547*4882a593Smuzhiyun 	mutex_lock(&tls_ctx->tx_lock);
548*4882a593Smuzhiyun 	lock_sock(sk);
549*4882a593Smuzhiyun 
550*4882a593Smuzhiyun 	if (unlikely(msg->msg_controllen)) {
551*4882a593Smuzhiyun 		rc = tls_proccess_cmsg(sk, msg, &record_type);
552*4882a593Smuzhiyun 		if (rc)
553*4882a593Smuzhiyun 			goto out;
554*4882a593Smuzhiyun 	}
555*4882a593Smuzhiyun 
556*4882a593Smuzhiyun 	rc = tls_push_data(sk, &msg->msg_iter, size,
557*4882a593Smuzhiyun 			   msg->msg_flags, record_type);
558*4882a593Smuzhiyun 
559*4882a593Smuzhiyun out:
560*4882a593Smuzhiyun 	release_sock(sk);
561*4882a593Smuzhiyun 	mutex_unlock(&tls_ctx->tx_lock);
562*4882a593Smuzhiyun 	return rc;
563*4882a593Smuzhiyun }
564*4882a593Smuzhiyun 
tls_device_sendpage(struct sock * sk,struct page * page,int offset,size_t size,int flags)565*4882a593Smuzhiyun int tls_device_sendpage(struct sock *sk, struct page *page,
566*4882a593Smuzhiyun 			int offset, size_t size, int flags)
567*4882a593Smuzhiyun {
568*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
569*4882a593Smuzhiyun 	struct iov_iter	msg_iter;
570*4882a593Smuzhiyun 	char *kaddr;
571*4882a593Smuzhiyun 	struct kvec iov;
572*4882a593Smuzhiyun 	int rc;
573*4882a593Smuzhiyun 
574*4882a593Smuzhiyun 	if (flags & MSG_SENDPAGE_NOTLAST)
575*4882a593Smuzhiyun 		flags |= MSG_MORE;
576*4882a593Smuzhiyun 
577*4882a593Smuzhiyun 	mutex_lock(&tls_ctx->tx_lock);
578*4882a593Smuzhiyun 	lock_sock(sk);
579*4882a593Smuzhiyun 
580*4882a593Smuzhiyun 	if (flags & MSG_OOB) {
581*4882a593Smuzhiyun 		rc = -EOPNOTSUPP;
582*4882a593Smuzhiyun 		goto out;
583*4882a593Smuzhiyun 	}
584*4882a593Smuzhiyun 
585*4882a593Smuzhiyun 	kaddr = kmap(page);
586*4882a593Smuzhiyun 	iov.iov_base = kaddr + offset;
587*4882a593Smuzhiyun 	iov.iov_len = size;
588*4882a593Smuzhiyun 	iov_iter_kvec(&msg_iter, WRITE, &iov, 1, size);
589*4882a593Smuzhiyun 	rc = tls_push_data(sk, &msg_iter, size,
590*4882a593Smuzhiyun 			   flags, TLS_RECORD_TYPE_DATA);
591*4882a593Smuzhiyun 	kunmap(page);
592*4882a593Smuzhiyun 
593*4882a593Smuzhiyun out:
594*4882a593Smuzhiyun 	release_sock(sk);
595*4882a593Smuzhiyun 	mutex_unlock(&tls_ctx->tx_lock);
596*4882a593Smuzhiyun 	return rc;
597*4882a593Smuzhiyun }
598*4882a593Smuzhiyun 
tls_get_record(struct tls_offload_context_tx * context,u32 seq,u64 * p_record_sn)599*4882a593Smuzhiyun struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
600*4882a593Smuzhiyun 				       u32 seq, u64 *p_record_sn)
601*4882a593Smuzhiyun {
602*4882a593Smuzhiyun 	u64 record_sn = context->hint_record_sn;
603*4882a593Smuzhiyun 	struct tls_record_info *info, *last;
604*4882a593Smuzhiyun 
605*4882a593Smuzhiyun 	info = context->retransmit_hint;
606*4882a593Smuzhiyun 	if (!info ||
607*4882a593Smuzhiyun 	    before(seq, info->end_seq - info->len)) {
608*4882a593Smuzhiyun 		/* if retransmit_hint is irrelevant start
609*4882a593Smuzhiyun 		 * from the beggining of the list
610*4882a593Smuzhiyun 		 */
611*4882a593Smuzhiyun 		info = list_first_entry_or_null(&context->records_list,
612*4882a593Smuzhiyun 						struct tls_record_info, list);
613*4882a593Smuzhiyun 		if (!info)
614*4882a593Smuzhiyun 			return NULL;
615*4882a593Smuzhiyun 		/* send the start_marker record if seq number is before the
616*4882a593Smuzhiyun 		 * tls offload start marker sequence number. This record is
617*4882a593Smuzhiyun 		 * required to handle TCP packets which are before TLS offload
618*4882a593Smuzhiyun 		 * started.
619*4882a593Smuzhiyun 		 *  And if it's not start marker, look if this seq number
620*4882a593Smuzhiyun 		 * belongs to the list.
621*4882a593Smuzhiyun 		 */
622*4882a593Smuzhiyun 		if (likely(!tls_record_is_start_marker(info))) {
623*4882a593Smuzhiyun 			/* we have the first record, get the last record to see
624*4882a593Smuzhiyun 			 * if this seq number belongs to the list.
625*4882a593Smuzhiyun 			 */
626*4882a593Smuzhiyun 			last = list_last_entry(&context->records_list,
627*4882a593Smuzhiyun 					       struct tls_record_info, list);
628*4882a593Smuzhiyun 
629*4882a593Smuzhiyun 			if (!between(seq, tls_record_start_seq(info),
630*4882a593Smuzhiyun 				     last->end_seq))
631*4882a593Smuzhiyun 				return NULL;
632*4882a593Smuzhiyun 		}
633*4882a593Smuzhiyun 		record_sn = context->unacked_record_sn;
634*4882a593Smuzhiyun 	}
635*4882a593Smuzhiyun 
636*4882a593Smuzhiyun 	/* We just need the _rcu for the READ_ONCE() */
637*4882a593Smuzhiyun 	rcu_read_lock();
638*4882a593Smuzhiyun 	list_for_each_entry_from_rcu(info, &context->records_list, list) {
639*4882a593Smuzhiyun 		if (before(seq, info->end_seq)) {
640*4882a593Smuzhiyun 			if (!context->retransmit_hint ||
641*4882a593Smuzhiyun 			    after(info->end_seq,
642*4882a593Smuzhiyun 				  context->retransmit_hint->end_seq)) {
643*4882a593Smuzhiyun 				context->hint_record_sn = record_sn;
644*4882a593Smuzhiyun 				context->retransmit_hint = info;
645*4882a593Smuzhiyun 			}
646*4882a593Smuzhiyun 			*p_record_sn = record_sn;
647*4882a593Smuzhiyun 			goto exit_rcu_unlock;
648*4882a593Smuzhiyun 		}
649*4882a593Smuzhiyun 		record_sn++;
650*4882a593Smuzhiyun 	}
651*4882a593Smuzhiyun 	info = NULL;
652*4882a593Smuzhiyun 
653*4882a593Smuzhiyun exit_rcu_unlock:
654*4882a593Smuzhiyun 	rcu_read_unlock();
655*4882a593Smuzhiyun 	return info;
656*4882a593Smuzhiyun }
657*4882a593Smuzhiyun EXPORT_SYMBOL(tls_get_record);
658*4882a593Smuzhiyun 
tls_device_push_pending_record(struct sock * sk,int flags)659*4882a593Smuzhiyun static int tls_device_push_pending_record(struct sock *sk, int flags)
660*4882a593Smuzhiyun {
661*4882a593Smuzhiyun 	struct iov_iter	msg_iter;
662*4882a593Smuzhiyun 
663*4882a593Smuzhiyun 	iov_iter_kvec(&msg_iter, WRITE, NULL, 0, 0);
664*4882a593Smuzhiyun 	return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
665*4882a593Smuzhiyun }
666*4882a593Smuzhiyun 
tls_device_write_space(struct sock * sk,struct tls_context * ctx)667*4882a593Smuzhiyun void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
668*4882a593Smuzhiyun {
669*4882a593Smuzhiyun 	if (tls_is_partially_sent_record(ctx)) {
670*4882a593Smuzhiyun 		gfp_t sk_allocation = sk->sk_allocation;
671*4882a593Smuzhiyun 
672*4882a593Smuzhiyun 		WARN_ON_ONCE(sk->sk_write_pending);
673*4882a593Smuzhiyun 
674*4882a593Smuzhiyun 		sk->sk_allocation = GFP_ATOMIC;
675*4882a593Smuzhiyun 		tls_push_partial_record(sk, ctx,
676*4882a593Smuzhiyun 					MSG_DONTWAIT | MSG_NOSIGNAL |
677*4882a593Smuzhiyun 					MSG_SENDPAGE_DECRYPTED);
678*4882a593Smuzhiyun 		sk->sk_allocation = sk_allocation;
679*4882a593Smuzhiyun 	}
680*4882a593Smuzhiyun }
681*4882a593Smuzhiyun 
tls_device_resync_rx(struct tls_context * tls_ctx,struct sock * sk,u32 seq,u8 * rcd_sn)682*4882a593Smuzhiyun static void tls_device_resync_rx(struct tls_context *tls_ctx,
683*4882a593Smuzhiyun 				 struct sock *sk, u32 seq, u8 *rcd_sn)
684*4882a593Smuzhiyun {
685*4882a593Smuzhiyun 	struct tls_offload_context_rx *rx_ctx = tls_offload_ctx_rx(tls_ctx);
686*4882a593Smuzhiyun 	struct net_device *netdev;
687*4882a593Smuzhiyun 
688*4882a593Smuzhiyun 	trace_tls_device_rx_resync_send(sk, seq, rcd_sn, rx_ctx->resync_type);
689*4882a593Smuzhiyun 	rcu_read_lock();
690*4882a593Smuzhiyun 	netdev = READ_ONCE(tls_ctx->netdev);
691*4882a593Smuzhiyun 	if (netdev)
692*4882a593Smuzhiyun 		netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn,
693*4882a593Smuzhiyun 						   TLS_OFFLOAD_CTX_DIR_RX);
694*4882a593Smuzhiyun 	rcu_read_unlock();
695*4882a593Smuzhiyun 	TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICERESYNC);
696*4882a593Smuzhiyun }
697*4882a593Smuzhiyun 
698*4882a593Smuzhiyun static bool
tls_device_rx_resync_async(struct tls_offload_resync_async * resync_async,s64 resync_req,u32 * seq,u16 * rcd_delta)699*4882a593Smuzhiyun tls_device_rx_resync_async(struct tls_offload_resync_async *resync_async,
700*4882a593Smuzhiyun 			   s64 resync_req, u32 *seq, u16 *rcd_delta)
701*4882a593Smuzhiyun {
702*4882a593Smuzhiyun 	u32 is_async = resync_req & RESYNC_REQ_ASYNC;
703*4882a593Smuzhiyun 	u32 req_seq = resync_req >> 32;
704*4882a593Smuzhiyun 	u32 req_end = req_seq + ((resync_req >> 16) & 0xffff);
705*4882a593Smuzhiyun 	u16 i;
706*4882a593Smuzhiyun 
707*4882a593Smuzhiyun 	*rcd_delta = 0;
708*4882a593Smuzhiyun 
709*4882a593Smuzhiyun 	if (is_async) {
710*4882a593Smuzhiyun 		/* shouldn't get to wraparound:
711*4882a593Smuzhiyun 		 * too long in async stage, something bad happened
712*4882a593Smuzhiyun 		 */
713*4882a593Smuzhiyun 		if (WARN_ON_ONCE(resync_async->rcd_delta == USHRT_MAX))
714*4882a593Smuzhiyun 			return false;
715*4882a593Smuzhiyun 
716*4882a593Smuzhiyun 		/* asynchronous stage: log all headers seq such that
717*4882a593Smuzhiyun 		 * req_seq <= seq <= end_seq, and wait for real resync request
718*4882a593Smuzhiyun 		 */
719*4882a593Smuzhiyun 		if (before(*seq, req_seq))
720*4882a593Smuzhiyun 			return false;
721*4882a593Smuzhiyun 		if (!after(*seq, req_end) &&
722*4882a593Smuzhiyun 		    resync_async->loglen < TLS_DEVICE_RESYNC_ASYNC_LOGMAX)
723*4882a593Smuzhiyun 			resync_async->log[resync_async->loglen++] = *seq;
724*4882a593Smuzhiyun 
725*4882a593Smuzhiyun 		resync_async->rcd_delta++;
726*4882a593Smuzhiyun 
727*4882a593Smuzhiyun 		return false;
728*4882a593Smuzhiyun 	}
729*4882a593Smuzhiyun 
730*4882a593Smuzhiyun 	/* synchronous stage: check against the logged entries and
731*4882a593Smuzhiyun 	 * proceed to check the next entries if no match was found
732*4882a593Smuzhiyun 	 */
733*4882a593Smuzhiyun 	for (i = 0; i < resync_async->loglen; i++)
734*4882a593Smuzhiyun 		if (req_seq == resync_async->log[i] &&
735*4882a593Smuzhiyun 		    atomic64_try_cmpxchg(&resync_async->req, &resync_req, 0)) {
736*4882a593Smuzhiyun 			*rcd_delta = resync_async->rcd_delta - i;
737*4882a593Smuzhiyun 			*seq = req_seq;
738*4882a593Smuzhiyun 			resync_async->loglen = 0;
739*4882a593Smuzhiyun 			resync_async->rcd_delta = 0;
740*4882a593Smuzhiyun 			return true;
741*4882a593Smuzhiyun 		}
742*4882a593Smuzhiyun 
743*4882a593Smuzhiyun 	resync_async->loglen = 0;
744*4882a593Smuzhiyun 	resync_async->rcd_delta = 0;
745*4882a593Smuzhiyun 
746*4882a593Smuzhiyun 	if (req_seq == *seq &&
747*4882a593Smuzhiyun 	    atomic64_try_cmpxchg(&resync_async->req,
748*4882a593Smuzhiyun 				 &resync_req, 0))
749*4882a593Smuzhiyun 		return true;
750*4882a593Smuzhiyun 
751*4882a593Smuzhiyun 	return false;
752*4882a593Smuzhiyun }
753*4882a593Smuzhiyun 
tls_device_rx_resync_new_rec(struct sock * sk,u32 rcd_len,u32 seq)754*4882a593Smuzhiyun void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq)
755*4882a593Smuzhiyun {
756*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
757*4882a593Smuzhiyun 	struct tls_offload_context_rx *rx_ctx;
758*4882a593Smuzhiyun 	u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
759*4882a593Smuzhiyun 	u32 sock_data, is_req_pending;
760*4882a593Smuzhiyun 	struct tls_prot_info *prot;
761*4882a593Smuzhiyun 	s64 resync_req;
762*4882a593Smuzhiyun 	u16 rcd_delta;
763*4882a593Smuzhiyun 	u32 req_seq;
764*4882a593Smuzhiyun 
765*4882a593Smuzhiyun 	if (tls_ctx->rx_conf != TLS_HW)
766*4882a593Smuzhiyun 		return;
767*4882a593Smuzhiyun 	if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags)))
768*4882a593Smuzhiyun 		return;
769*4882a593Smuzhiyun 
770*4882a593Smuzhiyun 	prot = &tls_ctx->prot_info;
771*4882a593Smuzhiyun 	rx_ctx = tls_offload_ctx_rx(tls_ctx);
772*4882a593Smuzhiyun 	memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
773*4882a593Smuzhiyun 
774*4882a593Smuzhiyun 	switch (rx_ctx->resync_type) {
775*4882a593Smuzhiyun 	case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ:
776*4882a593Smuzhiyun 		resync_req = atomic64_read(&rx_ctx->resync_req);
777*4882a593Smuzhiyun 		req_seq = resync_req >> 32;
778*4882a593Smuzhiyun 		seq += TLS_HEADER_SIZE - 1;
779*4882a593Smuzhiyun 		is_req_pending = resync_req;
780*4882a593Smuzhiyun 
781*4882a593Smuzhiyun 		if (likely(!is_req_pending) || req_seq != seq ||
782*4882a593Smuzhiyun 		    !atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
783*4882a593Smuzhiyun 			return;
784*4882a593Smuzhiyun 		break;
785*4882a593Smuzhiyun 	case TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT:
786*4882a593Smuzhiyun 		if (likely(!rx_ctx->resync_nh_do_now))
787*4882a593Smuzhiyun 			return;
788*4882a593Smuzhiyun 
789*4882a593Smuzhiyun 		/* head of next rec is already in, note that the sock_inq will
790*4882a593Smuzhiyun 		 * include the currently parsed message when called from parser
791*4882a593Smuzhiyun 		 */
792*4882a593Smuzhiyun 		sock_data = tcp_inq(sk);
793*4882a593Smuzhiyun 		if (sock_data > rcd_len) {
794*4882a593Smuzhiyun 			trace_tls_device_rx_resync_nh_delay(sk, sock_data,
795*4882a593Smuzhiyun 							    rcd_len);
796*4882a593Smuzhiyun 			return;
797*4882a593Smuzhiyun 		}
798*4882a593Smuzhiyun 
799*4882a593Smuzhiyun 		rx_ctx->resync_nh_do_now = 0;
800*4882a593Smuzhiyun 		seq += rcd_len;
801*4882a593Smuzhiyun 		tls_bigint_increment(rcd_sn, prot->rec_seq_size);
802*4882a593Smuzhiyun 		break;
803*4882a593Smuzhiyun 	case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ_ASYNC:
804*4882a593Smuzhiyun 		resync_req = atomic64_read(&rx_ctx->resync_async->req);
805*4882a593Smuzhiyun 		is_req_pending = resync_req;
806*4882a593Smuzhiyun 		if (likely(!is_req_pending))
807*4882a593Smuzhiyun 			return;
808*4882a593Smuzhiyun 
809*4882a593Smuzhiyun 		if (!tls_device_rx_resync_async(rx_ctx->resync_async,
810*4882a593Smuzhiyun 						resync_req, &seq, &rcd_delta))
811*4882a593Smuzhiyun 			return;
812*4882a593Smuzhiyun 		tls_bigint_subtract(rcd_sn, rcd_delta);
813*4882a593Smuzhiyun 		break;
814*4882a593Smuzhiyun 	}
815*4882a593Smuzhiyun 
816*4882a593Smuzhiyun 	tls_device_resync_rx(tls_ctx, sk, seq, rcd_sn);
817*4882a593Smuzhiyun }
818*4882a593Smuzhiyun 
tls_device_core_ctrl_rx_resync(struct tls_context * tls_ctx,struct tls_offload_context_rx * ctx,struct sock * sk,struct sk_buff * skb)819*4882a593Smuzhiyun static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
820*4882a593Smuzhiyun 					   struct tls_offload_context_rx *ctx,
821*4882a593Smuzhiyun 					   struct sock *sk, struct sk_buff *skb)
822*4882a593Smuzhiyun {
823*4882a593Smuzhiyun 	struct strp_msg *rxm;
824*4882a593Smuzhiyun 
825*4882a593Smuzhiyun 	/* device will request resyncs by itself based on stream scan */
826*4882a593Smuzhiyun 	if (ctx->resync_type != TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT)
827*4882a593Smuzhiyun 		return;
828*4882a593Smuzhiyun 	/* already scheduled */
829*4882a593Smuzhiyun 	if (ctx->resync_nh_do_now)
830*4882a593Smuzhiyun 		return;
831*4882a593Smuzhiyun 	/* seen decrypted fragments since last fully-failed record */
832*4882a593Smuzhiyun 	if (ctx->resync_nh_reset) {
833*4882a593Smuzhiyun 		ctx->resync_nh_reset = 0;
834*4882a593Smuzhiyun 		ctx->resync_nh.decrypted_failed = 1;
835*4882a593Smuzhiyun 		ctx->resync_nh.decrypted_tgt = TLS_DEVICE_RESYNC_NH_START_IVAL;
836*4882a593Smuzhiyun 		return;
837*4882a593Smuzhiyun 	}
838*4882a593Smuzhiyun 
839*4882a593Smuzhiyun 	if (++ctx->resync_nh.decrypted_failed <= ctx->resync_nh.decrypted_tgt)
840*4882a593Smuzhiyun 		return;
841*4882a593Smuzhiyun 
842*4882a593Smuzhiyun 	/* doing resync, bump the next target in case it fails */
843*4882a593Smuzhiyun 	if (ctx->resync_nh.decrypted_tgt < TLS_DEVICE_RESYNC_NH_MAX_IVAL)
844*4882a593Smuzhiyun 		ctx->resync_nh.decrypted_tgt *= 2;
845*4882a593Smuzhiyun 	else
846*4882a593Smuzhiyun 		ctx->resync_nh.decrypted_tgt += TLS_DEVICE_RESYNC_NH_MAX_IVAL;
847*4882a593Smuzhiyun 
848*4882a593Smuzhiyun 	rxm = strp_msg(skb);
849*4882a593Smuzhiyun 
850*4882a593Smuzhiyun 	/* head of next rec is already in, parser will sync for us */
851*4882a593Smuzhiyun 	if (tcp_inq(sk) > rxm->full_len) {
852*4882a593Smuzhiyun 		trace_tls_device_rx_resync_nh_schedule(sk);
853*4882a593Smuzhiyun 		ctx->resync_nh_do_now = 1;
854*4882a593Smuzhiyun 	} else {
855*4882a593Smuzhiyun 		struct tls_prot_info *prot = &tls_ctx->prot_info;
856*4882a593Smuzhiyun 		u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
857*4882a593Smuzhiyun 
858*4882a593Smuzhiyun 		memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
859*4882a593Smuzhiyun 		tls_bigint_increment(rcd_sn, prot->rec_seq_size);
860*4882a593Smuzhiyun 
861*4882a593Smuzhiyun 		tls_device_resync_rx(tls_ctx, sk, tcp_sk(sk)->copied_seq,
862*4882a593Smuzhiyun 				     rcd_sn);
863*4882a593Smuzhiyun 	}
864*4882a593Smuzhiyun }
865*4882a593Smuzhiyun 
tls_device_reencrypt(struct sock * sk,struct sk_buff * skb)866*4882a593Smuzhiyun static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
867*4882a593Smuzhiyun {
868*4882a593Smuzhiyun 	struct strp_msg *rxm = strp_msg(skb);
869*4882a593Smuzhiyun 	int err = 0, offset = rxm->offset, copy, nsg, data_len, pos;
870*4882a593Smuzhiyun 	struct sk_buff *skb_iter, *unused;
871*4882a593Smuzhiyun 	struct scatterlist sg[1];
872*4882a593Smuzhiyun 	char *orig_buf, *buf;
873*4882a593Smuzhiyun 
874*4882a593Smuzhiyun 	orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
875*4882a593Smuzhiyun 			   TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
876*4882a593Smuzhiyun 	if (!orig_buf)
877*4882a593Smuzhiyun 		return -ENOMEM;
878*4882a593Smuzhiyun 	buf = orig_buf;
879*4882a593Smuzhiyun 
880*4882a593Smuzhiyun 	nsg = skb_cow_data(skb, 0, &unused);
881*4882a593Smuzhiyun 	if (unlikely(nsg < 0)) {
882*4882a593Smuzhiyun 		err = nsg;
883*4882a593Smuzhiyun 		goto free_buf;
884*4882a593Smuzhiyun 	}
885*4882a593Smuzhiyun 
886*4882a593Smuzhiyun 	sg_init_table(sg, 1);
887*4882a593Smuzhiyun 	sg_set_buf(&sg[0], buf,
888*4882a593Smuzhiyun 		   rxm->full_len + TLS_HEADER_SIZE +
889*4882a593Smuzhiyun 		   TLS_CIPHER_AES_GCM_128_IV_SIZE);
890*4882a593Smuzhiyun 	err = skb_copy_bits(skb, offset, buf,
891*4882a593Smuzhiyun 			    TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
892*4882a593Smuzhiyun 	if (err)
893*4882a593Smuzhiyun 		goto free_buf;
894*4882a593Smuzhiyun 
895*4882a593Smuzhiyun 	/* We are interested only in the decrypted data not the auth */
896*4882a593Smuzhiyun 	err = decrypt_skb(sk, skb, sg);
897*4882a593Smuzhiyun 	if (err != -EBADMSG)
898*4882a593Smuzhiyun 		goto free_buf;
899*4882a593Smuzhiyun 	else
900*4882a593Smuzhiyun 		err = 0;
901*4882a593Smuzhiyun 
902*4882a593Smuzhiyun 	data_len = rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE;
903*4882a593Smuzhiyun 
904*4882a593Smuzhiyun 	if (skb_pagelen(skb) > offset) {
905*4882a593Smuzhiyun 		copy = min_t(int, skb_pagelen(skb) - offset, data_len);
906*4882a593Smuzhiyun 
907*4882a593Smuzhiyun 		if (skb->decrypted) {
908*4882a593Smuzhiyun 			err = skb_store_bits(skb, offset, buf, copy);
909*4882a593Smuzhiyun 			if (err)
910*4882a593Smuzhiyun 				goto free_buf;
911*4882a593Smuzhiyun 		}
912*4882a593Smuzhiyun 
913*4882a593Smuzhiyun 		offset += copy;
914*4882a593Smuzhiyun 		buf += copy;
915*4882a593Smuzhiyun 	}
916*4882a593Smuzhiyun 
917*4882a593Smuzhiyun 	pos = skb_pagelen(skb);
918*4882a593Smuzhiyun 	skb_walk_frags(skb, skb_iter) {
919*4882a593Smuzhiyun 		int frag_pos;
920*4882a593Smuzhiyun 
921*4882a593Smuzhiyun 		/* Practically all frags must belong to msg if reencrypt
922*4882a593Smuzhiyun 		 * is needed with current strparser and coalescing logic,
923*4882a593Smuzhiyun 		 * but strparser may "get optimized", so let's be safe.
924*4882a593Smuzhiyun 		 */
925*4882a593Smuzhiyun 		if (pos + skb_iter->len <= offset)
926*4882a593Smuzhiyun 			goto done_with_frag;
927*4882a593Smuzhiyun 		if (pos >= data_len + rxm->offset)
928*4882a593Smuzhiyun 			break;
929*4882a593Smuzhiyun 
930*4882a593Smuzhiyun 		frag_pos = offset - pos;
931*4882a593Smuzhiyun 		copy = min_t(int, skb_iter->len - frag_pos,
932*4882a593Smuzhiyun 			     data_len + rxm->offset - offset);
933*4882a593Smuzhiyun 
934*4882a593Smuzhiyun 		if (skb_iter->decrypted) {
935*4882a593Smuzhiyun 			err = skb_store_bits(skb_iter, frag_pos, buf, copy);
936*4882a593Smuzhiyun 			if (err)
937*4882a593Smuzhiyun 				goto free_buf;
938*4882a593Smuzhiyun 		}
939*4882a593Smuzhiyun 
940*4882a593Smuzhiyun 		offset += copy;
941*4882a593Smuzhiyun 		buf += copy;
942*4882a593Smuzhiyun done_with_frag:
943*4882a593Smuzhiyun 		pos += skb_iter->len;
944*4882a593Smuzhiyun 	}
945*4882a593Smuzhiyun 
946*4882a593Smuzhiyun free_buf:
947*4882a593Smuzhiyun 	kfree(orig_buf);
948*4882a593Smuzhiyun 	return err;
949*4882a593Smuzhiyun }
950*4882a593Smuzhiyun 
tls_device_decrypted(struct sock * sk,struct tls_context * tls_ctx,struct sk_buff * skb,struct strp_msg * rxm)951*4882a593Smuzhiyun int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
952*4882a593Smuzhiyun 			 struct sk_buff *skb, struct strp_msg *rxm)
953*4882a593Smuzhiyun {
954*4882a593Smuzhiyun 	struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
955*4882a593Smuzhiyun 	int is_decrypted = skb->decrypted;
956*4882a593Smuzhiyun 	int is_encrypted = !is_decrypted;
957*4882a593Smuzhiyun 	struct sk_buff *skb_iter;
958*4882a593Smuzhiyun 
959*4882a593Smuzhiyun 	/* Check if all the data is decrypted already */
960*4882a593Smuzhiyun 	skb_walk_frags(skb, skb_iter) {
961*4882a593Smuzhiyun 		is_decrypted &= skb_iter->decrypted;
962*4882a593Smuzhiyun 		is_encrypted &= !skb_iter->decrypted;
963*4882a593Smuzhiyun 	}
964*4882a593Smuzhiyun 
965*4882a593Smuzhiyun 	trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len,
966*4882a593Smuzhiyun 				   tls_ctx->rx.rec_seq, rxm->full_len,
967*4882a593Smuzhiyun 				   is_encrypted, is_decrypted);
968*4882a593Smuzhiyun 
969*4882a593Smuzhiyun 	ctx->sw.decrypted |= is_decrypted;
970*4882a593Smuzhiyun 
971*4882a593Smuzhiyun 	if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) {
972*4882a593Smuzhiyun 		if (likely(is_encrypted || is_decrypted))
973*4882a593Smuzhiyun 			return 0;
974*4882a593Smuzhiyun 
975*4882a593Smuzhiyun 		/* After tls_device_down disables the offload, the next SKB will
976*4882a593Smuzhiyun 		 * likely have initial fragments decrypted, and final ones not
977*4882a593Smuzhiyun 		 * decrypted. We need to reencrypt that single SKB.
978*4882a593Smuzhiyun 		 */
979*4882a593Smuzhiyun 		return tls_device_reencrypt(sk, skb);
980*4882a593Smuzhiyun 	}
981*4882a593Smuzhiyun 
982*4882a593Smuzhiyun 	/* Return immediately if the record is either entirely plaintext or
983*4882a593Smuzhiyun 	 * entirely ciphertext. Otherwise handle reencrypt partially decrypted
984*4882a593Smuzhiyun 	 * record.
985*4882a593Smuzhiyun 	 */
986*4882a593Smuzhiyun 	if (is_decrypted) {
987*4882a593Smuzhiyun 		ctx->resync_nh_reset = 1;
988*4882a593Smuzhiyun 		return 0;
989*4882a593Smuzhiyun 	}
990*4882a593Smuzhiyun 	if (is_encrypted) {
991*4882a593Smuzhiyun 		tls_device_core_ctrl_rx_resync(tls_ctx, ctx, sk, skb);
992*4882a593Smuzhiyun 		return 0;
993*4882a593Smuzhiyun 	}
994*4882a593Smuzhiyun 
995*4882a593Smuzhiyun 	ctx->resync_nh_reset = 1;
996*4882a593Smuzhiyun 	return tls_device_reencrypt(sk, skb);
997*4882a593Smuzhiyun }
998*4882a593Smuzhiyun 
tls_device_attach(struct tls_context * ctx,struct sock * sk,struct net_device * netdev)999*4882a593Smuzhiyun static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
1000*4882a593Smuzhiyun 			      struct net_device *netdev)
1001*4882a593Smuzhiyun {
1002*4882a593Smuzhiyun 	if (sk->sk_destruct != tls_device_sk_destruct) {
1003*4882a593Smuzhiyun 		refcount_set(&ctx->refcount, 1);
1004*4882a593Smuzhiyun 		dev_hold(netdev);
1005*4882a593Smuzhiyun 		ctx->netdev = netdev;
1006*4882a593Smuzhiyun 		spin_lock_irq(&tls_device_lock);
1007*4882a593Smuzhiyun 		list_add_tail(&ctx->list, &tls_device_list);
1008*4882a593Smuzhiyun 		spin_unlock_irq(&tls_device_lock);
1009*4882a593Smuzhiyun 
1010*4882a593Smuzhiyun 		ctx->sk_destruct = sk->sk_destruct;
1011*4882a593Smuzhiyun 		smp_store_release(&sk->sk_destruct, tls_device_sk_destruct);
1012*4882a593Smuzhiyun 	}
1013*4882a593Smuzhiyun }
1014*4882a593Smuzhiyun 
tls_set_device_offload(struct sock * sk,struct tls_context * ctx)1015*4882a593Smuzhiyun int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
1016*4882a593Smuzhiyun {
1017*4882a593Smuzhiyun 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
1018*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1019*4882a593Smuzhiyun 	struct tls_prot_info *prot = &tls_ctx->prot_info;
1020*4882a593Smuzhiyun 	struct tls_record_info *start_marker_record;
1021*4882a593Smuzhiyun 	struct tls_offload_context_tx *offload_ctx;
1022*4882a593Smuzhiyun 	struct tls_crypto_info *crypto_info;
1023*4882a593Smuzhiyun 	struct net_device *netdev;
1024*4882a593Smuzhiyun 	char *iv, *rec_seq;
1025*4882a593Smuzhiyun 	struct sk_buff *skb;
1026*4882a593Smuzhiyun 	__be64 rcd_sn;
1027*4882a593Smuzhiyun 	int rc;
1028*4882a593Smuzhiyun 
1029*4882a593Smuzhiyun 	if (!ctx)
1030*4882a593Smuzhiyun 		return -EINVAL;
1031*4882a593Smuzhiyun 
1032*4882a593Smuzhiyun 	if (ctx->priv_ctx_tx)
1033*4882a593Smuzhiyun 		return -EEXIST;
1034*4882a593Smuzhiyun 
1035*4882a593Smuzhiyun 	start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
1036*4882a593Smuzhiyun 	if (!start_marker_record)
1037*4882a593Smuzhiyun 		return -ENOMEM;
1038*4882a593Smuzhiyun 
1039*4882a593Smuzhiyun 	offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
1040*4882a593Smuzhiyun 	if (!offload_ctx) {
1041*4882a593Smuzhiyun 		rc = -ENOMEM;
1042*4882a593Smuzhiyun 		goto free_marker_record;
1043*4882a593Smuzhiyun 	}
1044*4882a593Smuzhiyun 
1045*4882a593Smuzhiyun 	crypto_info = &ctx->crypto_send.info;
1046*4882a593Smuzhiyun 	if (crypto_info->version != TLS_1_2_VERSION) {
1047*4882a593Smuzhiyun 		rc = -EOPNOTSUPP;
1048*4882a593Smuzhiyun 		goto free_offload_ctx;
1049*4882a593Smuzhiyun 	}
1050*4882a593Smuzhiyun 
1051*4882a593Smuzhiyun 	switch (crypto_info->cipher_type) {
1052*4882a593Smuzhiyun 	case TLS_CIPHER_AES_GCM_128:
1053*4882a593Smuzhiyun 		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1054*4882a593Smuzhiyun 		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
1055*4882a593Smuzhiyun 		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1056*4882a593Smuzhiyun 		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
1057*4882a593Smuzhiyun 		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
1058*4882a593Smuzhiyun 		rec_seq =
1059*4882a593Smuzhiyun 		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
1060*4882a593Smuzhiyun 		break;
1061*4882a593Smuzhiyun 	default:
1062*4882a593Smuzhiyun 		rc = -EINVAL;
1063*4882a593Smuzhiyun 		goto free_offload_ctx;
1064*4882a593Smuzhiyun 	}
1065*4882a593Smuzhiyun 
1066*4882a593Smuzhiyun 	/* Sanity-check the rec_seq_size for stack allocations */
1067*4882a593Smuzhiyun 	if (rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
1068*4882a593Smuzhiyun 		rc = -EINVAL;
1069*4882a593Smuzhiyun 		goto free_offload_ctx;
1070*4882a593Smuzhiyun 	}
1071*4882a593Smuzhiyun 
1072*4882a593Smuzhiyun 	prot->version = crypto_info->version;
1073*4882a593Smuzhiyun 	prot->cipher_type = crypto_info->cipher_type;
1074*4882a593Smuzhiyun 	prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
1075*4882a593Smuzhiyun 	prot->tag_size = tag_size;
1076*4882a593Smuzhiyun 	prot->overhead_size = prot->prepend_size + prot->tag_size;
1077*4882a593Smuzhiyun 	prot->iv_size = iv_size;
1078*4882a593Smuzhiyun 	ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1079*4882a593Smuzhiyun 			     GFP_KERNEL);
1080*4882a593Smuzhiyun 	if (!ctx->tx.iv) {
1081*4882a593Smuzhiyun 		rc = -ENOMEM;
1082*4882a593Smuzhiyun 		goto free_offload_ctx;
1083*4882a593Smuzhiyun 	}
1084*4882a593Smuzhiyun 
1085*4882a593Smuzhiyun 	memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1086*4882a593Smuzhiyun 
1087*4882a593Smuzhiyun 	prot->rec_seq_size = rec_seq_size;
1088*4882a593Smuzhiyun 	ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
1089*4882a593Smuzhiyun 	if (!ctx->tx.rec_seq) {
1090*4882a593Smuzhiyun 		rc = -ENOMEM;
1091*4882a593Smuzhiyun 		goto free_iv;
1092*4882a593Smuzhiyun 	}
1093*4882a593Smuzhiyun 
1094*4882a593Smuzhiyun 	rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
1095*4882a593Smuzhiyun 	if (rc)
1096*4882a593Smuzhiyun 		goto free_rec_seq;
1097*4882a593Smuzhiyun 
1098*4882a593Smuzhiyun 	/* start at rec_seq - 1 to account for the start marker record */
1099*4882a593Smuzhiyun 	memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
1100*4882a593Smuzhiyun 	offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
1101*4882a593Smuzhiyun 
1102*4882a593Smuzhiyun 	start_marker_record->end_seq = tcp_sk(sk)->write_seq;
1103*4882a593Smuzhiyun 	start_marker_record->len = 0;
1104*4882a593Smuzhiyun 	start_marker_record->num_frags = 0;
1105*4882a593Smuzhiyun 
1106*4882a593Smuzhiyun 	INIT_LIST_HEAD(&offload_ctx->records_list);
1107*4882a593Smuzhiyun 	list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
1108*4882a593Smuzhiyun 	spin_lock_init(&offload_ctx->lock);
1109*4882a593Smuzhiyun 	sg_init_table(offload_ctx->sg_tx_data,
1110*4882a593Smuzhiyun 		      ARRAY_SIZE(offload_ctx->sg_tx_data));
1111*4882a593Smuzhiyun 
1112*4882a593Smuzhiyun 	clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
1113*4882a593Smuzhiyun 	ctx->push_pending_record = tls_device_push_pending_record;
1114*4882a593Smuzhiyun 
1115*4882a593Smuzhiyun 	/* TLS offload is greatly simplified if we don't send
1116*4882a593Smuzhiyun 	 * SKBs where only part of the payload needs to be encrypted.
1117*4882a593Smuzhiyun 	 * So mark the last skb in the write queue as end of record.
1118*4882a593Smuzhiyun 	 */
1119*4882a593Smuzhiyun 	skb = tcp_write_queue_tail(sk);
1120*4882a593Smuzhiyun 	if (skb)
1121*4882a593Smuzhiyun 		TCP_SKB_CB(skb)->eor = 1;
1122*4882a593Smuzhiyun 
1123*4882a593Smuzhiyun 	netdev = get_netdev_for_sock(sk);
1124*4882a593Smuzhiyun 	if (!netdev) {
1125*4882a593Smuzhiyun 		pr_err_ratelimited("%s: netdev not found\n", __func__);
1126*4882a593Smuzhiyun 		rc = -EINVAL;
1127*4882a593Smuzhiyun 		goto disable_cad;
1128*4882a593Smuzhiyun 	}
1129*4882a593Smuzhiyun 
1130*4882a593Smuzhiyun 	if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
1131*4882a593Smuzhiyun 		rc = -EOPNOTSUPP;
1132*4882a593Smuzhiyun 		goto release_netdev;
1133*4882a593Smuzhiyun 	}
1134*4882a593Smuzhiyun 
1135*4882a593Smuzhiyun 	/* Avoid offloading if the device is down
1136*4882a593Smuzhiyun 	 * We don't want to offload new flows after
1137*4882a593Smuzhiyun 	 * the NETDEV_DOWN event
1138*4882a593Smuzhiyun 	 *
1139*4882a593Smuzhiyun 	 * device_offload_lock is taken in tls_devices's NETDEV_DOWN
1140*4882a593Smuzhiyun 	 * handler thus protecting from the device going down before
1141*4882a593Smuzhiyun 	 * ctx was added to tls_device_list.
1142*4882a593Smuzhiyun 	 */
1143*4882a593Smuzhiyun 	down_read(&device_offload_lock);
1144*4882a593Smuzhiyun 	if (!(netdev->flags & IFF_UP)) {
1145*4882a593Smuzhiyun 		rc = -EINVAL;
1146*4882a593Smuzhiyun 		goto release_lock;
1147*4882a593Smuzhiyun 	}
1148*4882a593Smuzhiyun 
1149*4882a593Smuzhiyun 	ctx->priv_ctx_tx = offload_ctx;
1150*4882a593Smuzhiyun 	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
1151*4882a593Smuzhiyun 					     &ctx->crypto_send.info,
1152*4882a593Smuzhiyun 					     tcp_sk(sk)->write_seq);
1153*4882a593Smuzhiyun 	trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_TX,
1154*4882a593Smuzhiyun 				     tcp_sk(sk)->write_seq, rec_seq, rc);
1155*4882a593Smuzhiyun 	if (rc)
1156*4882a593Smuzhiyun 		goto release_lock;
1157*4882a593Smuzhiyun 
1158*4882a593Smuzhiyun 	tls_device_attach(ctx, sk, netdev);
1159*4882a593Smuzhiyun 	up_read(&device_offload_lock);
1160*4882a593Smuzhiyun 
1161*4882a593Smuzhiyun 	/* following this assignment tls_is_sk_tx_device_offloaded
1162*4882a593Smuzhiyun 	 * will return true and the context might be accessed
1163*4882a593Smuzhiyun 	 * by the netdev's xmit function.
1164*4882a593Smuzhiyun 	 */
1165*4882a593Smuzhiyun 	smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
1166*4882a593Smuzhiyun 	dev_put(netdev);
1167*4882a593Smuzhiyun 
1168*4882a593Smuzhiyun 	return 0;
1169*4882a593Smuzhiyun 
1170*4882a593Smuzhiyun release_lock:
1171*4882a593Smuzhiyun 	up_read(&device_offload_lock);
1172*4882a593Smuzhiyun release_netdev:
1173*4882a593Smuzhiyun 	dev_put(netdev);
1174*4882a593Smuzhiyun disable_cad:
1175*4882a593Smuzhiyun 	clean_acked_data_disable(inet_csk(sk));
1176*4882a593Smuzhiyun 	crypto_free_aead(offload_ctx->aead_send);
1177*4882a593Smuzhiyun free_rec_seq:
1178*4882a593Smuzhiyun 	kfree(ctx->tx.rec_seq);
1179*4882a593Smuzhiyun free_iv:
1180*4882a593Smuzhiyun 	kfree(ctx->tx.iv);
1181*4882a593Smuzhiyun free_offload_ctx:
1182*4882a593Smuzhiyun 	kfree(offload_ctx);
1183*4882a593Smuzhiyun 	ctx->priv_ctx_tx = NULL;
1184*4882a593Smuzhiyun free_marker_record:
1185*4882a593Smuzhiyun 	kfree(start_marker_record);
1186*4882a593Smuzhiyun 	return rc;
1187*4882a593Smuzhiyun }
1188*4882a593Smuzhiyun 
tls_set_device_offload_rx(struct sock * sk,struct tls_context * ctx)1189*4882a593Smuzhiyun int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
1190*4882a593Smuzhiyun {
1191*4882a593Smuzhiyun 	struct tls12_crypto_info_aes_gcm_128 *info;
1192*4882a593Smuzhiyun 	struct tls_offload_context_rx *context;
1193*4882a593Smuzhiyun 	struct net_device *netdev;
1194*4882a593Smuzhiyun 	int rc = 0;
1195*4882a593Smuzhiyun 
1196*4882a593Smuzhiyun 	if (ctx->crypto_recv.info.version != TLS_1_2_VERSION)
1197*4882a593Smuzhiyun 		return -EOPNOTSUPP;
1198*4882a593Smuzhiyun 
1199*4882a593Smuzhiyun 	netdev = get_netdev_for_sock(sk);
1200*4882a593Smuzhiyun 	if (!netdev) {
1201*4882a593Smuzhiyun 		pr_err_ratelimited("%s: netdev not found\n", __func__);
1202*4882a593Smuzhiyun 		return -EINVAL;
1203*4882a593Smuzhiyun 	}
1204*4882a593Smuzhiyun 
1205*4882a593Smuzhiyun 	if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
1206*4882a593Smuzhiyun 		rc = -EOPNOTSUPP;
1207*4882a593Smuzhiyun 		goto release_netdev;
1208*4882a593Smuzhiyun 	}
1209*4882a593Smuzhiyun 
1210*4882a593Smuzhiyun 	/* Avoid offloading if the device is down
1211*4882a593Smuzhiyun 	 * We don't want to offload new flows after
1212*4882a593Smuzhiyun 	 * the NETDEV_DOWN event
1213*4882a593Smuzhiyun 	 *
1214*4882a593Smuzhiyun 	 * device_offload_lock is taken in tls_devices's NETDEV_DOWN
1215*4882a593Smuzhiyun 	 * handler thus protecting from the device going down before
1216*4882a593Smuzhiyun 	 * ctx was added to tls_device_list.
1217*4882a593Smuzhiyun 	 */
1218*4882a593Smuzhiyun 	down_read(&device_offload_lock);
1219*4882a593Smuzhiyun 	if (!(netdev->flags & IFF_UP)) {
1220*4882a593Smuzhiyun 		rc = -EINVAL;
1221*4882a593Smuzhiyun 		goto release_lock;
1222*4882a593Smuzhiyun 	}
1223*4882a593Smuzhiyun 
1224*4882a593Smuzhiyun 	context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
1225*4882a593Smuzhiyun 	if (!context) {
1226*4882a593Smuzhiyun 		rc = -ENOMEM;
1227*4882a593Smuzhiyun 		goto release_lock;
1228*4882a593Smuzhiyun 	}
1229*4882a593Smuzhiyun 	context->resync_nh_reset = 1;
1230*4882a593Smuzhiyun 
1231*4882a593Smuzhiyun 	ctx->priv_ctx_rx = context;
1232*4882a593Smuzhiyun 	rc = tls_set_sw_offload(sk, ctx, 0);
1233*4882a593Smuzhiyun 	if (rc)
1234*4882a593Smuzhiyun 		goto release_ctx;
1235*4882a593Smuzhiyun 
1236*4882a593Smuzhiyun 	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
1237*4882a593Smuzhiyun 					     &ctx->crypto_recv.info,
1238*4882a593Smuzhiyun 					     tcp_sk(sk)->copied_seq);
1239*4882a593Smuzhiyun 	info = (void *)&ctx->crypto_recv.info;
1240*4882a593Smuzhiyun 	trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_RX,
1241*4882a593Smuzhiyun 				     tcp_sk(sk)->copied_seq, info->rec_seq, rc);
1242*4882a593Smuzhiyun 	if (rc)
1243*4882a593Smuzhiyun 		goto free_sw_resources;
1244*4882a593Smuzhiyun 
1245*4882a593Smuzhiyun 	tls_device_attach(ctx, sk, netdev);
1246*4882a593Smuzhiyun 	up_read(&device_offload_lock);
1247*4882a593Smuzhiyun 
1248*4882a593Smuzhiyun 	dev_put(netdev);
1249*4882a593Smuzhiyun 
1250*4882a593Smuzhiyun 	return 0;
1251*4882a593Smuzhiyun 
1252*4882a593Smuzhiyun free_sw_resources:
1253*4882a593Smuzhiyun 	up_read(&device_offload_lock);
1254*4882a593Smuzhiyun 	tls_sw_free_resources_rx(sk);
1255*4882a593Smuzhiyun 	down_read(&device_offload_lock);
1256*4882a593Smuzhiyun release_ctx:
1257*4882a593Smuzhiyun 	ctx->priv_ctx_rx = NULL;
1258*4882a593Smuzhiyun release_lock:
1259*4882a593Smuzhiyun 	up_read(&device_offload_lock);
1260*4882a593Smuzhiyun release_netdev:
1261*4882a593Smuzhiyun 	dev_put(netdev);
1262*4882a593Smuzhiyun 	return rc;
1263*4882a593Smuzhiyun }
1264*4882a593Smuzhiyun 
tls_device_offload_cleanup_rx(struct sock * sk)1265*4882a593Smuzhiyun void tls_device_offload_cleanup_rx(struct sock *sk)
1266*4882a593Smuzhiyun {
1267*4882a593Smuzhiyun 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1268*4882a593Smuzhiyun 	struct net_device *netdev;
1269*4882a593Smuzhiyun 
1270*4882a593Smuzhiyun 	down_read(&device_offload_lock);
1271*4882a593Smuzhiyun 	netdev = tls_ctx->netdev;
1272*4882a593Smuzhiyun 	if (!netdev)
1273*4882a593Smuzhiyun 		goto out;
1274*4882a593Smuzhiyun 
1275*4882a593Smuzhiyun 	netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
1276*4882a593Smuzhiyun 					TLS_OFFLOAD_CTX_DIR_RX);
1277*4882a593Smuzhiyun 
1278*4882a593Smuzhiyun 	if (tls_ctx->tx_conf != TLS_HW) {
1279*4882a593Smuzhiyun 		dev_put(netdev);
1280*4882a593Smuzhiyun 		tls_ctx->netdev = NULL;
1281*4882a593Smuzhiyun 	} else {
1282*4882a593Smuzhiyun 		set_bit(TLS_RX_DEV_CLOSED, &tls_ctx->flags);
1283*4882a593Smuzhiyun 	}
1284*4882a593Smuzhiyun out:
1285*4882a593Smuzhiyun 	up_read(&device_offload_lock);
1286*4882a593Smuzhiyun 	tls_sw_release_resources_rx(sk);
1287*4882a593Smuzhiyun }
1288*4882a593Smuzhiyun 
tls_device_down(struct net_device * netdev)1289*4882a593Smuzhiyun static int tls_device_down(struct net_device *netdev)
1290*4882a593Smuzhiyun {
1291*4882a593Smuzhiyun 	struct tls_context *ctx, *tmp;
1292*4882a593Smuzhiyun 	unsigned long flags;
1293*4882a593Smuzhiyun 	LIST_HEAD(list);
1294*4882a593Smuzhiyun 
1295*4882a593Smuzhiyun 	/* Request a write lock to block new offload attempts */
1296*4882a593Smuzhiyun 	down_write(&device_offload_lock);
1297*4882a593Smuzhiyun 
1298*4882a593Smuzhiyun 	spin_lock_irqsave(&tls_device_lock, flags);
1299*4882a593Smuzhiyun 	list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
1300*4882a593Smuzhiyun 		if (ctx->netdev != netdev ||
1301*4882a593Smuzhiyun 		    !refcount_inc_not_zero(&ctx->refcount))
1302*4882a593Smuzhiyun 			continue;
1303*4882a593Smuzhiyun 
1304*4882a593Smuzhiyun 		list_move(&ctx->list, &list);
1305*4882a593Smuzhiyun 	}
1306*4882a593Smuzhiyun 	spin_unlock_irqrestore(&tls_device_lock, flags);
1307*4882a593Smuzhiyun 
1308*4882a593Smuzhiyun 	list_for_each_entry_safe(ctx, tmp, &list, list)	{
1309*4882a593Smuzhiyun 		/* Stop offloaded TX and switch to the fallback.
1310*4882a593Smuzhiyun 		 * tls_is_sk_tx_device_offloaded will return false.
1311*4882a593Smuzhiyun 		 */
1312*4882a593Smuzhiyun 		WRITE_ONCE(ctx->sk->sk_validate_xmit_skb, tls_validate_xmit_skb_sw);
1313*4882a593Smuzhiyun 
1314*4882a593Smuzhiyun 		/* Stop the RX and TX resync.
1315*4882a593Smuzhiyun 		 * tls_dev_resync must not be called after tls_dev_del.
1316*4882a593Smuzhiyun 		 */
1317*4882a593Smuzhiyun 		WRITE_ONCE(ctx->netdev, NULL);
1318*4882a593Smuzhiyun 
1319*4882a593Smuzhiyun 		/* Start skipping the RX resync logic completely. */
1320*4882a593Smuzhiyun 		set_bit(TLS_RX_DEV_DEGRADED, &ctx->flags);
1321*4882a593Smuzhiyun 
1322*4882a593Smuzhiyun 		/* Sync with inflight packets. After this point:
1323*4882a593Smuzhiyun 		 * TX: no non-encrypted packets will be passed to the driver.
1324*4882a593Smuzhiyun 		 * RX: resync requests from the driver will be ignored.
1325*4882a593Smuzhiyun 		 */
1326*4882a593Smuzhiyun 		synchronize_net();
1327*4882a593Smuzhiyun 
1328*4882a593Smuzhiyun 		/* Release the offload context on the driver side. */
1329*4882a593Smuzhiyun 		if (ctx->tx_conf == TLS_HW)
1330*4882a593Smuzhiyun 			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
1331*4882a593Smuzhiyun 							TLS_OFFLOAD_CTX_DIR_TX);
1332*4882a593Smuzhiyun 		if (ctx->rx_conf == TLS_HW &&
1333*4882a593Smuzhiyun 		    !test_bit(TLS_RX_DEV_CLOSED, &ctx->flags))
1334*4882a593Smuzhiyun 			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
1335*4882a593Smuzhiyun 							TLS_OFFLOAD_CTX_DIR_RX);
1336*4882a593Smuzhiyun 
1337*4882a593Smuzhiyun 		dev_put(netdev);
1338*4882a593Smuzhiyun 
1339*4882a593Smuzhiyun 		/* Move the context to a separate list for two reasons:
1340*4882a593Smuzhiyun 		 * 1. When the context is deallocated, list_del is called.
1341*4882a593Smuzhiyun 		 * 2. It's no longer an offloaded context, so we don't want to
1342*4882a593Smuzhiyun 		 *    run offload-specific code on this context.
1343*4882a593Smuzhiyun 		 */
1344*4882a593Smuzhiyun 		spin_lock_irqsave(&tls_device_lock, flags);
1345*4882a593Smuzhiyun 		list_move_tail(&ctx->list, &tls_device_down_list);
1346*4882a593Smuzhiyun 		spin_unlock_irqrestore(&tls_device_lock, flags);
1347*4882a593Smuzhiyun 
1348*4882a593Smuzhiyun 		/* Device contexts for RX and TX will be freed in on sk_destruct
1349*4882a593Smuzhiyun 		 * by tls_device_free_ctx. rx_conf and tx_conf stay in TLS_HW.
1350*4882a593Smuzhiyun 		 * Now release the ref taken above.
1351*4882a593Smuzhiyun 		 */
1352*4882a593Smuzhiyun 		if (refcount_dec_and_test(&ctx->refcount)) {
1353*4882a593Smuzhiyun 			/* sk_destruct ran after tls_device_down took a ref, and
1354*4882a593Smuzhiyun 			 * it returned early. Complete the destruction here.
1355*4882a593Smuzhiyun 			 */
1356*4882a593Smuzhiyun 			list_del(&ctx->list);
1357*4882a593Smuzhiyun 			tls_device_free_ctx(ctx);
1358*4882a593Smuzhiyun 		}
1359*4882a593Smuzhiyun 	}
1360*4882a593Smuzhiyun 
1361*4882a593Smuzhiyun 	up_write(&device_offload_lock);
1362*4882a593Smuzhiyun 
1363*4882a593Smuzhiyun 	flush_work(&tls_device_gc_work);
1364*4882a593Smuzhiyun 
1365*4882a593Smuzhiyun 	return NOTIFY_DONE;
1366*4882a593Smuzhiyun }
1367*4882a593Smuzhiyun 
tls_dev_event(struct notifier_block * this,unsigned long event,void * ptr)1368*4882a593Smuzhiyun static int tls_dev_event(struct notifier_block *this, unsigned long event,
1369*4882a593Smuzhiyun 			 void *ptr)
1370*4882a593Smuzhiyun {
1371*4882a593Smuzhiyun 	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
1372*4882a593Smuzhiyun 
1373*4882a593Smuzhiyun 	if (!dev->tlsdev_ops &&
1374*4882a593Smuzhiyun 	    !(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
1375*4882a593Smuzhiyun 		return NOTIFY_DONE;
1376*4882a593Smuzhiyun 
1377*4882a593Smuzhiyun 	switch (event) {
1378*4882a593Smuzhiyun 	case NETDEV_REGISTER:
1379*4882a593Smuzhiyun 	case NETDEV_FEAT_CHANGE:
1380*4882a593Smuzhiyun 		if ((dev->features & NETIF_F_HW_TLS_RX) &&
1381*4882a593Smuzhiyun 		    !dev->tlsdev_ops->tls_dev_resync)
1382*4882a593Smuzhiyun 			return NOTIFY_BAD;
1383*4882a593Smuzhiyun 
1384*4882a593Smuzhiyun 		if  (dev->tlsdev_ops &&
1385*4882a593Smuzhiyun 		     dev->tlsdev_ops->tls_dev_add &&
1386*4882a593Smuzhiyun 		     dev->tlsdev_ops->tls_dev_del)
1387*4882a593Smuzhiyun 			return NOTIFY_DONE;
1388*4882a593Smuzhiyun 		else
1389*4882a593Smuzhiyun 			return NOTIFY_BAD;
1390*4882a593Smuzhiyun 	case NETDEV_DOWN:
1391*4882a593Smuzhiyun 		return tls_device_down(dev);
1392*4882a593Smuzhiyun 	}
1393*4882a593Smuzhiyun 	return NOTIFY_DONE;
1394*4882a593Smuzhiyun }
1395*4882a593Smuzhiyun 
1396*4882a593Smuzhiyun static struct notifier_block tls_dev_notifier = {
1397*4882a593Smuzhiyun 	.notifier_call	= tls_dev_event,
1398*4882a593Smuzhiyun };
1399*4882a593Smuzhiyun 
tls_device_init(void)1400*4882a593Smuzhiyun int __init tls_device_init(void)
1401*4882a593Smuzhiyun {
1402*4882a593Smuzhiyun 	return register_netdevice_notifier(&tls_dev_notifier);
1403*4882a593Smuzhiyun }
1404*4882a593Smuzhiyun 
tls_device_cleanup(void)1405*4882a593Smuzhiyun void __exit tls_device_cleanup(void)
1406*4882a593Smuzhiyun {
1407*4882a593Smuzhiyun 	unregister_netdevice_notifier(&tls_dev_notifier);
1408*4882a593Smuzhiyun 	flush_work(&tls_device_gc_work);
1409*4882a593Smuzhiyun 	clean_acked_data_flush();
1410*4882a593Smuzhiyun }
1411