xref: /OK3568_Linux_fs/kernel/net/tls/tls_main.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun /*
2*4882a593Smuzhiyun  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3*4882a593Smuzhiyun  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4*4882a593Smuzhiyun  *
5*4882a593Smuzhiyun  * This software is available to you under a choice of one of two
6*4882a593Smuzhiyun  * licenses.  You may choose to be licensed under the terms of the GNU
7*4882a593Smuzhiyun  * General Public License (GPL) Version 2, available from the file
8*4882a593Smuzhiyun  * COPYING in the main directory of this source tree, or the
9*4882a593Smuzhiyun  * OpenIB.org BSD license below:
10*4882a593Smuzhiyun  *
11*4882a593Smuzhiyun  *     Redistribution and use in source and binary forms, with or
12*4882a593Smuzhiyun  *     without modification, are permitted provided that the following
13*4882a593Smuzhiyun  *     conditions are met:
14*4882a593Smuzhiyun  *
15*4882a593Smuzhiyun  *      - Redistributions of source code must retain the above
16*4882a593Smuzhiyun  *        copyright notice, this list of conditions and the following
17*4882a593Smuzhiyun  *        disclaimer.
18*4882a593Smuzhiyun  *
19*4882a593Smuzhiyun  *      - Redistributions in binary form must reproduce the above
20*4882a593Smuzhiyun  *        copyright notice, this list of conditions and the following
21*4882a593Smuzhiyun  *        disclaimer in the documentation and/or other materials
22*4882a593Smuzhiyun  *        provided with the distribution.
23*4882a593Smuzhiyun  *
24*4882a593Smuzhiyun  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25*4882a593Smuzhiyun  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26*4882a593Smuzhiyun  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27*4882a593Smuzhiyun  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28*4882a593Smuzhiyun  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29*4882a593Smuzhiyun  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30*4882a593Smuzhiyun  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31*4882a593Smuzhiyun  * SOFTWARE.
32*4882a593Smuzhiyun  */
33*4882a593Smuzhiyun 
34*4882a593Smuzhiyun #include <linux/module.h>
35*4882a593Smuzhiyun 
36*4882a593Smuzhiyun #include <net/tcp.h>
37*4882a593Smuzhiyun #include <net/inet_common.h>
38*4882a593Smuzhiyun #include <linux/highmem.h>
39*4882a593Smuzhiyun #include <linux/netdevice.h>
40*4882a593Smuzhiyun #include <linux/sched/signal.h>
41*4882a593Smuzhiyun #include <linux/inetdevice.h>
42*4882a593Smuzhiyun #include <linux/inet_diag.h>
43*4882a593Smuzhiyun 
44*4882a593Smuzhiyun #include <net/snmp.h>
45*4882a593Smuzhiyun #include <net/tls.h>
46*4882a593Smuzhiyun #include <net/tls_toe.h>
47*4882a593Smuzhiyun 
48*4882a593Smuzhiyun MODULE_AUTHOR("Mellanox Technologies");
49*4882a593Smuzhiyun MODULE_DESCRIPTION("Transport Layer Security Support");
50*4882a593Smuzhiyun MODULE_LICENSE("Dual BSD/GPL");
51*4882a593Smuzhiyun MODULE_ALIAS_TCP_ULP("tls");
52*4882a593Smuzhiyun 
53*4882a593Smuzhiyun enum {
54*4882a593Smuzhiyun 	TLSV4,
55*4882a593Smuzhiyun 	TLSV6,
56*4882a593Smuzhiyun 	TLS_NUM_PROTS,
57*4882a593Smuzhiyun };
58*4882a593Smuzhiyun 
59*4882a593Smuzhiyun static const struct proto *saved_tcpv6_prot;
60*4882a593Smuzhiyun static DEFINE_MUTEX(tcpv6_prot_mutex);
61*4882a593Smuzhiyun static const struct proto *saved_tcpv4_prot;
62*4882a593Smuzhiyun static DEFINE_MUTEX(tcpv4_prot_mutex);
63*4882a593Smuzhiyun static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
64*4882a593Smuzhiyun static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
65*4882a593Smuzhiyun static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
66*4882a593Smuzhiyun 			 const struct proto *base);
67*4882a593Smuzhiyun 
update_sk_prot(struct sock * sk,struct tls_context * ctx)68*4882a593Smuzhiyun void update_sk_prot(struct sock *sk, struct tls_context *ctx)
69*4882a593Smuzhiyun {
70*4882a593Smuzhiyun 	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
71*4882a593Smuzhiyun 
72*4882a593Smuzhiyun 	WRITE_ONCE(sk->sk_prot,
73*4882a593Smuzhiyun 		   &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
74*4882a593Smuzhiyun 	WRITE_ONCE(sk->sk_socket->ops,
75*4882a593Smuzhiyun 		   &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
76*4882a593Smuzhiyun }
77*4882a593Smuzhiyun 
wait_on_pending_writer(struct sock * sk,long * timeo)78*4882a593Smuzhiyun int wait_on_pending_writer(struct sock *sk, long *timeo)
79*4882a593Smuzhiyun {
80*4882a593Smuzhiyun 	int rc = 0;
81*4882a593Smuzhiyun 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
82*4882a593Smuzhiyun 
83*4882a593Smuzhiyun 	add_wait_queue(sk_sleep(sk), &wait);
84*4882a593Smuzhiyun 	while (1) {
85*4882a593Smuzhiyun 		if (!*timeo) {
86*4882a593Smuzhiyun 			rc = -EAGAIN;
87*4882a593Smuzhiyun 			break;
88*4882a593Smuzhiyun 		}
89*4882a593Smuzhiyun 
90*4882a593Smuzhiyun 		if (signal_pending(current)) {
91*4882a593Smuzhiyun 			rc = sock_intr_errno(*timeo);
92*4882a593Smuzhiyun 			break;
93*4882a593Smuzhiyun 		}
94*4882a593Smuzhiyun 
95*4882a593Smuzhiyun 		if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
96*4882a593Smuzhiyun 			break;
97*4882a593Smuzhiyun 	}
98*4882a593Smuzhiyun 	remove_wait_queue(sk_sleep(sk), &wait);
99*4882a593Smuzhiyun 	return rc;
100*4882a593Smuzhiyun }
101*4882a593Smuzhiyun 
tls_push_sg(struct sock * sk,struct tls_context * ctx,struct scatterlist * sg,u16 first_offset,int flags)102*4882a593Smuzhiyun int tls_push_sg(struct sock *sk,
103*4882a593Smuzhiyun 		struct tls_context *ctx,
104*4882a593Smuzhiyun 		struct scatterlist *sg,
105*4882a593Smuzhiyun 		u16 first_offset,
106*4882a593Smuzhiyun 		int flags)
107*4882a593Smuzhiyun {
108*4882a593Smuzhiyun 	int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
109*4882a593Smuzhiyun 	int ret = 0;
110*4882a593Smuzhiyun 	struct page *p;
111*4882a593Smuzhiyun 	size_t size;
112*4882a593Smuzhiyun 	int offset = first_offset;
113*4882a593Smuzhiyun 
114*4882a593Smuzhiyun 	size = sg->length - offset;
115*4882a593Smuzhiyun 	offset += sg->offset;
116*4882a593Smuzhiyun 
117*4882a593Smuzhiyun 	ctx->in_tcp_sendpages = true;
118*4882a593Smuzhiyun 	while (1) {
119*4882a593Smuzhiyun 		if (sg_is_last(sg))
120*4882a593Smuzhiyun 			sendpage_flags = flags;
121*4882a593Smuzhiyun 
122*4882a593Smuzhiyun 		/* is sending application-limited? */
123*4882a593Smuzhiyun 		tcp_rate_check_app_limited(sk);
124*4882a593Smuzhiyun 		p = sg_page(sg);
125*4882a593Smuzhiyun retry:
126*4882a593Smuzhiyun 		ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
127*4882a593Smuzhiyun 
128*4882a593Smuzhiyun 		if (ret != size) {
129*4882a593Smuzhiyun 			if (ret > 0) {
130*4882a593Smuzhiyun 				offset += ret;
131*4882a593Smuzhiyun 				size -= ret;
132*4882a593Smuzhiyun 				goto retry;
133*4882a593Smuzhiyun 			}
134*4882a593Smuzhiyun 
135*4882a593Smuzhiyun 			offset -= sg->offset;
136*4882a593Smuzhiyun 			ctx->partially_sent_offset = offset;
137*4882a593Smuzhiyun 			ctx->partially_sent_record = (void *)sg;
138*4882a593Smuzhiyun 			ctx->in_tcp_sendpages = false;
139*4882a593Smuzhiyun 			return ret;
140*4882a593Smuzhiyun 		}
141*4882a593Smuzhiyun 
142*4882a593Smuzhiyun 		put_page(p);
143*4882a593Smuzhiyun 		sk_mem_uncharge(sk, sg->length);
144*4882a593Smuzhiyun 		sg = sg_next(sg);
145*4882a593Smuzhiyun 		if (!sg)
146*4882a593Smuzhiyun 			break;
147*4882a593Smuzhiyun 
148*4882a593Smuzhiyun 		offset = sg->offset;
149*4882a593Smuzhiyun 		size = sg->length;
150*4882a593Smuzhiyun 	}
151*4882a593Smuzhiyun 
152*4882a593Smuzhiyun 	ctx->in_tcp_sendpages = false;
153*4882a593Smuzhiyun 
154*4882a593Smuzhiyun 	return 0;
155*4882a593Smuzhiyun }
156*4882a593Smuzhiyun 
tls_handle_open_record(struct sock * sk,int flags)157*4882a593Smuzhiyun static int tls_handle_open_record(struct sock *sk, int flags)
158*4882a593Smuzhiyun {
159*4882a593Smuzhiyun 	struct tls_context *ctx = tls_get_ctx(sk);
160*4882a593Smuzhiyun 
161*4882a593Smuzhiyun 	if (tls_is_pending_open_record(ctx))
162*4882a593Smuzhiyun 		return ctx->push_pending_record(sk, flags);
163*4882a593Smuzhiyun 
164*4882a593Smuzhiyun 	return 0;
165*4882a593Smuzhiyun }
166*4882a593Smuzhiyun 
tls_proccess_cmsg(struct sock * sk,struct msghdr * msg,unsigned char * record_type)167*4882a593Smuzhiyun int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
168*4882a593Smuzhiyun 		      unsigned char *record_type)
169*4882a593Smuzhiyun {
170*4882a593Smuzhiyun 	struct cmsghdr *cmsg;
171*4882a593Smuzhiyun 	int rc = -EINVAL;
172*4882a593Smuzhiyun 
173*4882a593Smuzhiyun 	for_each_cmsghdr(cmsg, msg) {
174*4882a593Smuzhiyun 		if (!CMSG_OK(msg, cmsg))
175*4882a593Smuzhiyun 			return -EINVAL;
176*4882a593Smuzhiyun 		if (cmsg->cmsg_level != SOL_TLS)
177*4882a593Smuzhiyun 			continue;
178*4882a593Smuzhiyun 
179*4882a593Smuzhiyun 		switch (cmsg->cmsg_type) {
180*4882a593Smuzhiyun 		case TLS_SET_RECORD_TYPE:
181*4882a593Smuzhiyun 			if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
182*4882a593Smuzhiyun 				return -EINVAL;
183*4882a593Smuzhiyun 
184*4882a593Smuzhiyun 			if (msg->msg_flags & MSG_MORE)
185*4882a593Smuzhiyun 				return -EINVAL;
186*4882a593Smuzhiyun 
187*4882a593Smuzhiyun 			rc = tls_handle_open_record(sk, msg->msg_flags);
188*4882a593Smuzhiyun 			if (rc)
189*4882a593Smuzhiyun 				return rc;
190*4882a593Smuzhiyun 
191*4882a593Smuzhiyun 			*record_type = *(unsigned char *)CMSG_DATA(cmsg);
192*4882a593Smuzhiyun 			rc = 0;
193*4882a593Smuzhiyun 			break;
194*4882a593Smuzhiyun 		default:
195*4882a593Smuzhiyun 			return -EINVAL;
196*4882a593Smuzhiyun 		}
197*4882a593Smuzhiyun 	}
198*4882a593Smuzhiyun 
199*4882a593Smuzhiyun 	return rc;
200*4882a593Smuzhiyun }
201*4882a593Smuzhiyun 
tls_push_partial_record(struct sock * sk,struct tls_context * ctx,int flags)202*4882a593Smuzhiyun int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
203*4882a593Smuzhiyun 			    int flags)
204*4882a593Smuzhiyun {
205*4882a593Smuzhiyun 	struct scatterlist *sg;
206*4882a593Smuzhiyun 	u16 offset;
207*4882a593Smuzhiyun 
208*4882a593Smuzhiyun 	sg = ctx->partially_sent_record;
209*4882a593Smuzhiyun 	offset = ctx->partially_sent_offset;
210*4882a593Smuzhiyun 
211*4882a593Smuzhiyun 	ctx->partially_sent_record = NULL;
212*4882a593Smuzhiyun 	return tls_push_sg(sk, ctx, sg, offset, flags);
213*4882a593Smuzhiyun }
214*4882a593Smuzhiyun 
tls_free_partial_record(struct sock * sk,struct tls_context * ctx)215*4882a593Smuzhiyun void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
216*4882a593Smuzhiyun {
217*4882a593Smuzhiyun 	struct scatterlist *sg;
218*4882a593Smuzhiyun 
219*4882a593Smuzhiyun 	for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
220*4882a593Smuzhiyun 		put_page(sg_page(sg));
221*4882a593Smuzhiyun 		sk_mem_uncharge(sk, sg->length);
222*4882a593Smuzhiyun 	}
223*4882a593Smuzhiyun 	ctx->partially_sent_record = NULL;
224*4882a593Smuzhiyun }
225*4882a593Smuzhiyun 
tls_write_space(struct sock * sk)226*4882a593Smuzhiyun static void tls_write_space(struct sock *sk)
227*4882a593Smuzhiyun {
228*4882a593Smuzhiyun 	struct tls_context *ctx = tls_get_ctx(sk);
229*4882a593Smuzhiyun 
230*4882a593Smuzhiyun 	/* If in_tcp_sendpages call lower protocol write space handler
231*4882a593Smuzhiyun 	 * to ensure we wake up any waiting operations there. For example
232*4882a593Smuzhiyun 	 * if do_tcp_sendpages where to call sk_wait_event.
233*4882a593Smuzhiyun 	 */
234*4882a593Smuzhiyun 	if (ctx->in_tcp_sendpages) {
235*4882a593Smuzhiyun 		ctx->sk_write_space(sk);
236*4882a593Smuzhiyun 		return;
237*4882a593Smuzhiyun 	}
238*4882a593Smuzhiyun 
239*4882a593Smuzhiyun #ifdef CONFIG_TLS_DEVICE
240*4882a593Smuzhiyun 	if (ctx->tx_conf == TLS_HW)
241*4882a593Smuzhiyun 		tls_device_write_space(sk, ctx);
242*4882a593Smuzhiyun 	else
243*4882a593Smuzhiyun #endif
244*4882a593Smuzhiyun 		tls_sw_write_space(sk, ctx);
245*4882a593Smuzhiyun 
246*4882a593Smuzhiyun 	ctx->sk_write_space(sk);
247*4882a593Smuzhiyun }
248*4882a593Smuzhiyun 
249*4882a593Smuzhiyun /**
250*4882a593Smuzhiyun  * tls_ctx_free() - free TLS ULP context
251*4882a593Smuzhiyun  * @sk:  socket to with @ctx is attached
252*4882a593Smuzhiyun  * @ctx: TLS context structure
253*4882a593Smuzhiyun  *
254*4882a593Smuzhiyun  * Free TLS context. If @sk is %NULL caller guarantees that the socket
255*4882a593Smuzhiyun  * to which @ctx was attached has no outstanding references.
256*4882a593Smuzhiyun  */
tls_ctx_free(struct sock * sk,struct tls_context * ctx)257*4882a593Smuzhiyun void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
258*4882a593Smuzhiyun {
259*4882a593Smuzhiyun 	if (!ctx)
260*4882a593Smuzhiyun 		return;
261*4882a593Smuzhiyun 
262*4882a593Smuzhiyun 	memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
263*4882a593Smuzhiyun 	memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
264*4882a593Smuzhiyun 	mutex_destroy(&ctx->tx_lock);
265*4882a593Smuzhiyun 
266*4882a593Smuzhiyun 	if (sk)
267*4882a593Smuzhiyun 		kfree_rcu(ctx, rcu);
268*4882a593Smuzhiyun 	else
269*4882a593Smuzhiyun 		kfree(ctx);
270*4882a593Smuzhiyun }
271*4882a593Smuzhiyun 
tls_sk_proto_cleanup(struct sock * sk,struct tls_context * ctx,long timeo)272*4882a593Smuzhiyun static void tls_sk_proto_cleanup(struct sock *sk,
273*4882a593Smuzhiyun 				 struct tls_context *ctx, long timeo)
274*4882a593Smuzhiyun {
275*4882a593Smuzhiyun 	if (unlikely(sk->sk_write_pending) &&
276*4882a593Smuzhiyun 	    !wait_on_pending_writer(sk, &timeo))
277*4882a593Smuzhiyun 		tls_handle_open_record(sk, 0);
278*4882a593Smuzhiyun 
279*4882a593Smuzhiyun 	/* We need these for tls_sw_fallback handling of other packets */
280*4882a593Smuzhiyun 	if (ctx->tx_conf == TLS_SW) {
281*4882a593Smuzhiyun 		kfree(ctx->tx.rec_seq);
282*4882a593Smuzhiyun 		kfree(ctx->tx.iv);
283*4882a593Smuzhiyun 		tls_sw_release_resources_tx(sk);
284*4882a593Smuzhiyun 		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
285*4882a593Smuzhiyun 	} else if (ctx->tx_conf == TLS_HW) {
286*4882a593Smuzhiyun 		tls_device_free_resources_tx(sk);
287*4882a593Smuzhiyun 		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
288*4882a593Smuzhiyun 	}
289*4882a593Smuzhiyun 
290*4882a593Smuzhiyun 	if (ctx->rx_conf == TLS_SW) {
291*4882a593Smuzhiyun 		tls_sw_release_resources_rx(sk);
292*4882a593Smuzhiyun 		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
293*4882a593Smuzhiyun 	} else if (ctx->rx_conf == TLS_HW) {
294*4882a593Smuzhiyun 		tls_device_offload_cleanup_rx(sk);
295*4882a593Smuzhiyun 		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
296*4882a593Smuzhiyun 	}
297*4882a593Smuzhiyun }
298*4882a593Smuzhiyun 
tls_sk_proto_close(struct sock * sk,long timeout)299*4882a593Smuzhiyun static void tls_sk_proto_close(struct sock *sk, long timeout)
300*4882a593Smuzhiyun {
301*4882a593Smuzhiyun 	struct inet_connection_sock *icsk = inet_csk(sk);
302*4882a593Smuzhiyun 	struct tls_context *ctx = tls_get_ctx(sk);
303*4882a593Smuzhiyun 	long timeo = sock_sndtimeo(sk, 0);
304*4882a593Smuzhiyun 	bool free_ctx;
305*4882a593Smuzhiyun 
306*4882a593Smuzhiyun 	if (ctx->tx_conf == TLS_SW)
307*4882a593Smuzhiyun 		tls_sw_cancel_work_tx(ctx);
308*4882a593Smuzhiyun 
309*4882a593Smuzhiyun 	lock_sock(sk);
310*4882a593Smuzhiyun 	free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
311*4882a593Smuzhiyun 
312*4882a593Smuzhiyun 	if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
313*4882a593Smuzhiyun 		tls_sk_proto_cleanup(sk, ctx, timeo);
314*4882a593Smuzhiyun 
315*4882a593Smuzhiyun 	write_lock_bh(&sk->sk_callback_lock);
316*4882a593Smuzhiyun 	if (free_ctx)
317*4882a593Smuzhiyun 		rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
318*4882a593Smuzhiyun 	WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
319*4882a593Smuzhiyun 	if (sk->sk_write_space == tls_write_space)
320*4882a593Smuzhiyun 		sk->sk_write_space = ctx->sk_write_space;
321*4882a593Smuzhiyun 	write_unlock_bh(&sk->sk_callback_lock);
322*4882a593Smuzhiyun 	release_sock(sk);
323*4882a593Smuzhiyun 	if (ctx->tx_conf == TLS_SW)
324*4882a593Smuzhiyun 		tls_sw_free_ctx_tx(ctx);
325*4882a593Smuzhiyun 	if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
326*4882a593Smuzhiyun 		tls_sw_strparser_done(ctx);
327*4882a593Smuzhiyun 	if (ctx->rx_conf == TLS_SW)
328*4882a593Smuzhiyun 		tls_sw_free_ctx_rx(ctx);
329*4882a593Smuzhiyun 	ctx->sk_proto->close(sk, timeout);
330*4882a593Smuzhiyun 
331*4882a593Smuzhiyun 	if (free_ctx)
332*4882a593Smuzhiyun 		tls_ctx_free(sk, ctx);
333*4882a593Smuzhiyun }
334*4882a593Smuzhiyun 
do_tls_getsockopt_conf(struct sock * sk,char __user * optval,int __user * optlen,int tx)335*4882a593Smuzhiyun static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
336*4882a593Smuzhiyun 				  int __user *optlen, int tx)
337*4882a593Smuzhiyun {
338*4882a593Smuzhiyun 	int rc = 0;
339*4882a593Smuzhiyun 	struct tls_context *ctx = tls_get_ctx(sk);
340*4882a593Smuzhiyun 	struct tls_crypto_info *crypto_info;
341*4882a593Smuzhiyun 	struct cipher_context *cctx;
342*4882a593Smuzhiyun 	int len;
343*4882a593Smuzhiyun 
344*4882a593Smuzhiyun 	if (get_user(len, optlen))
345*4882a593Smuzhiyun 		return -EFAULT;
346*4882a593Smuzhiyun 
347*4882a593Smuzhiyun 	if (!optval || (len < sizeof(*crypto_info))) {
348*4882a593Smuzhiyun 		rc = -EINVAL;
349*4882a593Smuzhiyun 		goto out;
350*4882a593Smuzhiyun 	}
351*4882a593Smuzhiyun 
352*4882a593Smuzhiyun 	if (!ctx) {
353*4882a593Smuzhiyun 		rc = -EBUSY;
354*4882a593Smuzhiyun 		goto out;
355*4882a593Smuzhiyun 	}
356*4882a593Smuzhiyun 
357*4882a593Smuzhiyun 	/* get user crypto info */
358*4882a593Smuzhiyun 	if (tx) {
359*4882a593Smuzhiyun 		crypto_info = &ctx->crypto_send.info;
360*4882a593Smuzhiyun 		cctx = &ctx->tx;
361*4882a593Smuzhiyun 	} else {
362*4882a593Smuzhiyun 		crypto_info = &ctx->crypto_recv.info;
363*4882a593Smuzhiyun 		cctx = &ctx->rx;
364*4882a593Smuzhiyun 	}
365*4882a593Smuzhiyun 
366*4882a593Smuzhiyun 	if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
367*4882a593Smuzhiyun 		rc = -EBUSY;
368*4882a593Smuzhiyun 		goto out;
369*4882a593Smuzhiyun 	}
370*4882a593Smuzhiyun 
371*4882a593Smuzhiyun 	if (len == sizeof(*crypto_info)) {
372*4882a593Smuzhiyun 		if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
373*4882a593Smuzhiyun 			rc = -EFAULT;
374*4882a593Smuzhiyun 		goto out;
375*4882a593Smuzhiyun 	}
376*4882a593Smuzhiyun 
377*4882a593Smuzhiyun 	switch (crypto_info->cipher_type) {
378*4882a593Smuzhiyun 	case TLS_CIPHER_AES_GCM_128: {
379*4882a593Smuzhiyun 		struct tls12_crypto_info_aes_gcm_128 *
380*4882a593Smuzhiyun 		  crypto_info_aes_gcm_128 =
381*4882a593Smuzhiyun 		  container_of(crypto_info,
382*4882a593Smuzhiyun 			       struct tls12_crypto_info_aes_gcm_128,
383*4882a593Smuzhiyun 			       info);
384*4882a593Smuzhiyun 
385*4882a593Smuzhiyun 		if (len != sizeof(*crypto_info_aes_gcm_128)) {
386*4882a593Smuzhiyun 			rc = -EINVAL;
387*4882a593Smuzhiyun 			goto out;
388*4882a593Smuzhiyun 		}
389*4882a593Smuzhiyun 		lock_sock(sk);
390*4882a593Smuzhiyun 		memcpy(crypto_info_aes_gcm_128->iv,
391*4882a593Smuzhiyun 		       cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
392*4882a593Smuzhiyun 		       TLS_CIPHER_AES_GCM_128_IV_SIZE);
393*4882a593Smuzhiyun 		memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq,
394*4882a593Smuzhiyun 		       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
395*4882a593Smuzhiyun 		release_sock(sk);
396*4882a593Smuzhiyun 		if (copy_to_user(optval,
397*4882a593Smuzhiyun 				 crypto_info_aes_gcm_128,
398*4882a593Smuzhiyun 				 sizeof(*crypto_info_aes_gcm_128)))
399*4882a593Smuzhiyun 			rc = -EFAULT;
400*4882a593Smuzhiyun 		break;
401*4882a593Smuzhiyun 	}
402*4882a593Smuzhiyun 	case TLS_CIPHER_AES_GCM_256: {
403*4882a593Smuzhiyun 		struct tls12_crypto_info_aes_gcm_256 *
404*4882a593Smuzhiyun 		  crypto_info_aes_gcm_256 =
405*4882a593Smuzhiyun 		  container_of(crypto_info,
406*4882a593Smuzhiyun 			       struct tls12_crypto_info_aes_gcm_256,
407*4882a593Smuzhiyun 			       info);
408*4882a593Smuzhiyun 
409*4882a593Smuzhiyun 		if (len != sizeof(*crypto_info_aes_gcm_256)) {
410*4882a593Smuzhiyun 			rc = -EINVAL;
411*4882a593Smuzhiyun 			goto out;
412*4882a593Smuzhiyun 		}
413*4882a593Smuzhiyun 		lock_sock(sk);
414*4882a593Smuzhiyun 		memcpy(crypto_info_aes_gcm_256->iv,
415*4882a593Smuzhiyun 		       cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
416*4882a593Smuzhiyun 		       TLS_CIPHER_AES_GCM_256_IV_SIZE);
417*4882a593Smuzhiyun 		memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq,
418*4882a593Smuzhiyun 		       TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
419*4882a593Smuzhiyun 		release_sock(sk);
420*4882a593Smuzhiyun 		if (copy_to_user(optval,
421*4882a593Smuzhiyun 				 crypto_info_aes_gcm_256,
422*4882a593Smuzhiyun 				 sizeof(*crypto_info_aes_gcm_256)))
423*4882a593Smuzhiyun 			rc = -EFAULT;
424*4882a593Smuzhiyun 		break;
425*4882a593Smuzhiyun 	}
426*4882a593Smuzhiyun 	default:
427*4882a593Smuzhiyun 		rc = -EINVAL;
428*4882a593Smuzhiyun 	}
429*4882a593Smuzhiyun 
430*4882a593Smuzhiyun out:
431*4882a593Smuzhiyun 	return rc;
432*4882a593Smuzhiyun }
433*4882a593Smuzhiyun 
do_tls_getsockopt(struct sock * sk,int optname,char __user * optval,int __user * optlen)434*4882a593Smuzhiyun static int do_tls_getsockopt(struct sock *sk, int optname,
435*4882a593Smuzhiyun 			     char __user *optval, int __user *optlen)
436*4882a593Smuzhiyun {
437*4882a593Smuzhiyun 	int rc = 0;
438*4882a593Smuzhiyun 
439*4882a593Smuzhiyun 	switch (optname) {
440*4882a593Smuzhiyun 	case TLS_TX:
441*4882a593Smuzhiyun 	case TLS_RX:
442*4882a593Smuzhiyun 		rc = do_tls_getsockopt_conf(sk, optval, optlen,
443*4882a593Smuzhiyun 					    optname == TLS_TX);
444*4882a593Smuzhiyun 		break;
445*4882a593Smuzhiyun 	default:
446*4882a593Smuzhiyun 		rc = -ENOPROTOOPT;
447*4882a593Smuzhiyun 		break;
448*4882a593Smuzhiyun 	}
449*4882a593Smuzhiyun 	return rc;
450*4882a593Smuzhiyun }
451*4882a593Smuzhiyun 
tls_getsockopt(struct sock * sk,int level,int optname,char __user * optval,int __user * optlen)452*4882a593Smuzhiyun static int tls_getsockopt(struct sock *sk, int level, int optname,
453*4882a593Smuzhiyun 			  char __user *optval, int __user *optlen)
454*4882a593Smuzhiyun {
455*4882a593Smuzhiyun 	struct tls_context *ctx = tls_get_ctx(sk);
456*4882a593Smuzhiyun 
457*4882a593Smuzhiyun 	if (level != SOL_TLS)
458*4882a593Smuzhiyun 		return ctx->sk_proto->getsockopt(sk, level,
459*4882a593Smuzhiyun 						 optname, optval, optlen);
460*4882a593Smuzhiyun 
461*4882a593Smuzhiyun 	return do_tls_getsockopt(sk, optname, optval, optlen);
462*4882a593Smuzhiyun }
463*4882a593Smuzhiyun 
do_tls_setsockopt_conf(struct sock * sk,sockptr_t optval,unsigned int optlen,int tx)464*4882a593Smuzhiyun static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
465*4882a593Smuzhiyun 				  unsigned int optlen, int tx)
466*4882a593Smuzhiyun {
467*4882a593Smuzhiyun 	struct tls_crypto_info *crypto_info;
468*4882a593Smuzhiyun 	struct tls_crypto_info *alt_crypto_info;
469*4882a593Smuzhiyun 	struct tls_context *ctx = tls_get_ctx(sk);
470*4882a593Smuzhiyun 	size_t optsize;
471*4882a593Smuzhiyun 	int rc = 0;
472*4882a593Smuzhiyun 	int conf;
473*4882a593Smuzhiyun 
474*4882a593Smuzhiyun 	if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) {
475*4882a593Smuzhiyun 		rc = -EINVAL;
476*4882a593Smuzhiyun 		goto out;
477*4882a593Smuzhiyun 	}
478*4882a593Smuzhiyun 
479*4882a593Smuzhiyun 	if (tx) {
480*4882a593Smuzhiyun 		crypto_info = &ctx->crypto_send.info;
481*4882a593Smuzhiyun 		alt_crypto_info = &ctx->crypto_recv.info;
482*4882a593Smuzhiyun 	} else {
483*4882a593Smuzhiyun 		crypto_info = &ctx->crypto_recv.info;
484*4882a593Smuzhiyun 		alt_crypto_info = &ctx->crypto_send.info;
485*4882a593Smuzhiyun 	}
486*4882a593Smuzhiyun 
487*4882a593Smuzhiyun 	/* Currently we don't support set crypto info more than one time */
488*4882a593Smuzhiyun 	if (TLS_CRYPTO_INFO_READY(crypto_info)) {
489*4882a593Smuzhiyun 		rc = -EBUSY;
490*4882a593Smuzhiyun 		goto out;
491*4882a593Smuzhiyun 	}
492*4882a593Smuzhiyun 
493*4882a593Smuzhiyun 	rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
494*4882a593Smuzhiyun 	if (rc) {
495*4882a593Smuzhiyun 		rc = -EFAULT;
496*4882a593Smuzhiyun 		goto err_crypto_info;
497*4882a593Smuzhiyun 	}
498*4882a593Smuzhiyun 
499*4882a593Smuzhiyun 	/* check version */
500*4882a593Smuzhiyun 	if (crypto_info->version != TLS_1_2_VERSION &&
501*4882a593Smuzhiyun 	    crypto_info->version != TLS_1_3_VERSION) {
502*4882a593Smuzhiyun 		rc = -EINVAL;
503*4882a593Smuzhiyun 		goto err_crypto_info;
504*4882a593Smuzhiyun 	}
505*4882a593Smuzhiyun 
506*4882a593Smuzhiyun 	/* Ensure that TLS version and ciphers are same in both directions */
507*4882a593Smuzhiyun 	if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
508*4882a593Smuzhiyun 		if (alt_crypto_info->version != crypto_info->version ||
509*4882a593Smuzhiyun 		    alt_crypto_info->cipher_type != crypto_info->cipher_type) {
510*4882a593Smuzhiyun 			rc = -EINVAL;
511*4882a593Smuzhiyun 			goto err_crypto_info;
512*4882a593Smuzhiyun 		}
513*4882a593Smuzhiyun 	}
514*4882a593Smuzhiyun 
515*4882a593Smuzhiyun 	switch (crypto_info->cipher_type) {
516*4882a593Smuzhiyun 	case TLS_CIPHER_AES_GCM_128:
517*4882a593Smuzhiyun 		optsize = sizeof(struct tls12_crypto_info_aes_gcm_128);
518*4882a593Smuzhiyun 		break;
519*4882a593Smuzhiyun 	case TLS_CIPHER_AES_GCM_256: {
520*4882a593Smuzhiyun 		optsize = sizeof(struct tls12_crypto_info_aes_gcm_256);
521*4882a593Smuzhiyun 		break;
522*4882a593Smuzhiyun 	}
523*4882a593Smuzhiyun 	case TLS_CIPHER_AES_CCM_128:
524*4882a593Smuzhiyun 		optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
525*4882a593Smuzhiyun 		break;
526*4882a593Smuzhiyun 	default:
527*4882a593Smuzhiyun 		rc = -EINVAL;
528*4882a593Smuzhiyun 		goto err_crypto_info;
529*4882a593Smuzhiyun 	}
530*4882a593Smuzhiyun 
531*4882a593Smuzhiyun 	if (optlen != optsize) {
532*4882a593Smuzhiyun 		rc = -EINVAL;
533*4882a593Smuzhiyun 		goto err_crypto_info;
534*4882a593Smuzhiyun 	}
535*4882a593Smuzhiyun 
536*4882a593Smuzhiyun 	rc = copy_from_sockptr_offset(crypto_info + 1, optval,
537*4882a593Smuzhiyun 				      sizeof(*crypto_info),
538*4882a593Smuzhiyun 				      optlen - sizeof(*crypto_info));
539*4882a593Smuzhiyun 	if (rc) {
540*4882a593Smuzhiyun 		rc = -EFAULT;
541*4882a593Smuzhiyun 		goto err_crypto_info;
542*4882a593Smuzhiyun 	}
543*4882a593Smuzhiyun 
544*4882a593Smuzhiyun 	if (tx) {
545*4882a593Smuzhiyun 		rc = tls_set_device_offload(sk, ctx);
546*4882a593Smuzhiyun 		conf = TLS_HW;
547*4882a593Smuzhiyun 		if (!rc) {
548*4882a593Smuzhiyun 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
549*4882a593Smuzhiyun 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
550*4882a593Smuzhiyun 		} else {
551*4882a593Smuzhiyun 			rc = tls_set_sw_offload(sk, ctx, 1);
552*4882a593Smuzhiyun 			if (rc)
553*4882a593Smuzhiyun 				goto err_crypto_info;
554*4882a593Smuzhiyun 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
555*4882a593Smuzhiyun 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
556*4882a593Smuzhiyun 			conf = TLS_SW;
557*4882a593Smuzhiyun 		}
558*4882a593Smuzhiyun 	} else {
559*4882a593Smuzhiyun 		rc = tls_set_device_offload_rx(sk, ctx);
560*4882a593Smuzhiyun 		conf = TLS_HW;
561*4882a593Smuzhiyun 		if (!rc) {
562*4882a593Smuzhiyun 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
563*4882a593Smuzhiyun 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
564*4882a593Smuzhiyun 		} else {
565*4882a593Smuzhiyun 			rc = tls_set_sw_offload(sk, ctx, 0);
566*4882a593Smuzhiyun 			if (rc)
567*4882a593Smuzhiyun 				goto err_crypto_info;
568*4882a593Smuzhiyun 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
569*4882a593Smuzhiyun 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
570*4882a593Smuzhiyun 			conf = TLS_SW;
571*4882a593Smuzhiyun 		}
572*4882a593Smuzhiyun 		tls_sw_strparser_arm(sk, ctx);
573*4882a593Smuzhiyun 	}
574*4882a593Smuzhiyun 
575*4882a593Smuzhiyun 	if (tx)
576*4882a593Smuzhiyun 		ctx->tx_conf = conf;
577*4882a593Smuzhiyun 	else
578*4882a593Smuzhiyun 		ctx->rx_conf = conf;
579*4882a593Smuzhiyun 	update_sk_prot(sk, ctx);
580*4882a593Smuzhiyun 	if (tx) {
581*4882a593Smuzhiyun 		ctx->sk_write_space = sk->sk_write_space;
582*4882a593Smuzhiyun 		sk->sk_write_space = tls_write_space;
583*4882a593Smuzhiyun 	}
584*4882a593Smuzhiyun 	goto out;
585*4882a593Smuzhiyun 
586*4882a593Smuzhiyun err_crypto_info:
587*4882a593Smuzhiyun 	memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
588*4882a593Smuzhiyun out:
589*4882a593Smuzhiyun 	return rc;
590*4882a593Smuzhiyun }
591*4882a593Smuzhiyun 
do_tls_setsockopt(struct sock * sk,int optname,sockptr_t optval,unsigned int optlen)592*4882a593Smuzhiyun static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
593*4882a593Smuzhiyun 			     unsigned int optlen)
594*4882a593Smuzhiyun {
595*4882a593Smuzhiyun 	int rc = 0;
596*4882a593Smuzhiyun 
597*4882a593Smuzhiyun 	switch (optname) {
598*4882a593Smuzhiyun 	case TLS_TX:
599*4882a593Smuzhiyun 	case TLS_RX:
600*4882a593Smuzhiyun 		lock_sock(sk);
601*4882a593Smuzhiyun 		rc = do_tls_setsockopt_conf(sk, optval, optlen,
602*4882a593Smuzhiyun 					    optname == TLS_TX);
603*4882a593Smuzhiyun 		release_sock(sk);
604*4882a593Smuzhiyun 		break;
605*4882a593Smuzhiyun 	default:
606*4882a593Smuzhiyun 		rc = -ENOPROTOOPT;
607*4882a593Smuzhiyun 		break;
608*4882a593Smuzhiyun 	}
609*4882a593Smuzhiyun 	return rc;
610*4882a593Smuzhiyun }
611*4882a593Smuzhiyun 
tls_setsockopt(struct sock * sk,int level,int optname,sockptr_t optval,unsigned int optlen)612*4882a593Smuzhiyun static int tls_setsockopt(struct sock *sk, int level, int optname,
613*4882a593Smuzhiyun 			  sockptr_t optval, unsigned int optlen)
614*4882a593Smuzhiyun {
615*4882a593Smuzhiyun 	struct tls_context *ctx = tls_get_ctx(sk);
616*4882a593Smuzhiyun 
617*4882a593Smuzhiyun 	if (level != SOL_TLS)
618*4882a593Smuzhiyun 		return ctx->sk_proto->setsockopt(sk, level, optname, optval,
619*4882a593Smuzhiyun 						 optlen);
620*4882a593Smuzhiyun 
621*4882a593Smuzhiyun 	return do_tls_setsockopt(sk, optname, optval, optlen);
622*4882a593Smuzhiyun }
623*4882a593Smuzhiyun 
tls_ctx_create(struct sock * sk)624*4882a593Smuzhiyun struct tls_context *tls_ctx_create(struct sock *sk)
625*4882a593Smuzhiyun {
626*4882a593Smuzhiyun 	struct inet_connection_sock *icsk = inet_csk(sk);
627*4882a593Smuzhiyun 	struct tls_context *ctx;
628*4882a593Smuzhiyun 
629*4882a593Smuzhiyun 	ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
630*4882a593Smuzhiyun 	if (!ctx)
631*4882a593Smuzhiyun 		return NULL;
632*4882a593Smuzhiyun 
633*4882a593Smuzhiyun 	mutex_init(&ctx->tx_lock);
634*4882a593Smuzhiyun 	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
635*4882a593Smuzhiyun 	ctx->sk_proto = READ_ONCE(sk->sk_prot);
636*4882a593Smuzhiyun 	ctx->sk = sk;
637*4882a593Smuzhiyun 	return ctx;
638*4882a593Smuzhiyun }
639*4882a593Smuzhiyun 
build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],const struct proto_ops * base)640*4882a593Smuzhiyun static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
641*4882a593Smuzhiyun 			    const struct proto_ops *base)
642*4882a593Smuzhiyun {
643*4882a593Smuzhiyun 	ops[TLS_BASE][TLS_BASE] = *base;
644*4882a593Smuzhiyun 
645*4882a593Smuzhiyun 	ops[TLS_SW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
646*4882a593Smuzhiyun 	ops[TLS_SW  ][TLS_BASE].sendpage_locked	= tls_sw_sendpage_locked;
647*4882a593Smuzhiyun 
648*4882a593Smuzhiyun 	ops[TLS_BASE][TLS_SW  ] = ops[TLS_BASE][TLS_BASE];
649*4882a593Smuzhiyun 	ops[TLS_BASE][TLS_SW  ].splice_read	= tls_sw_splice_read;
650*4882a593Smuzhiyun 
651*4882a593Smuzhiyun 	ops[TLS_SW  ][TLS_SW  ] = ops[TLS_SW  ][TLS_BASE];
652*4882a593Smuzhiyun 	ops[TLS_SW  ][TLS_SW  ].splice_read	= tls_sw_splice_read;
653*4882a593Smuzhiyun 
654*4882a593Smuzhiyun #ifdef CONFIG_TLS_DEVICE
655*4882a593Smuzhiyun 	ops[TLS_HW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
656*4882a593Smuzhiyun 	ops[TLS_HW  ][TLS_BASE].sendpage_locked	= NULL;
657*4882a593Smuzhiyun 
658*4882a593Smuzhiyun 	ops[TLS_HW  ][TLS_SW  ] = ops[TLS_BASE][TLS_SW  ];
659*4882a593Smuzhiyun 	ops[TLS_HW  ][TLS_SW  ].sendpage_locked	= NULL;
660*4882a593Smuzhiyun 
661*4882a593Smuzhiyun 	ops[TLS_BASE][TLS_HW  ] = ops[TLS_BASE][TLS_SW  ];
662*4882a593Smuzhiyun 
663*4882a593Smuzhiyun 	ops[TLS_SW  ][TLS_HW  ] = ops[TLS_SW  ][TLS_SW  ];
664*4882a593Smuzhiyun 
665*4882a593Smuzhiyun 	ops[TLS_HW  ][TLS_HW  ] = ops[TLS_HW  ][TLS_SW  ];
666*4882a593Smuzhiyun 	ops[TLS_HW  ][TLS_HW  ].sendpage_locked	= NULL;
667*4882a593Smuzhiyun #endif
668*4882a593Smuzhiyun #ifdef CONFIG_TLS_TOE
669*4882a593Smuzhiyun 	ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
670*4882a593Smuzhiyun #endif
671*4882a593Smuzhiyun }
672*4882a593Smuzhiyun 
tls_build_proto(struct sock * sk)673*4882a593Smuzhiyun static void tls_build_proto(struct sock *sk)
674*4882a593Smuzhiyun {
675*4882a593Smuzhiyun 	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
676*4882a593Smuzhiyun 	struct proto *prot = READ_ONCE(sk->sk_prot);
677*4882a593Smuzhiyun 
678*4882a593Smuzhiyun 	/* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
679*4882a593Smuzhiyun 	if (ip_ver == TLSV6 &&
680*4882a593Smuzhiyun 	    unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
681*4882a593Smuzhiyun 		mutex_lock(&tcpv6_prot_mutex);
682*4882a593Smuzhiyun 		if (likely(prot != saved_tcpv6_prot)) {
683*4882a593Smuzhiyun 			build_protos(tls_prots[TLSV6], prot);
684*4882a593Smuzhiyun 			build_proto_ops(tls_proto_ops[TLSV6],
685*4882a593Smuzhiyun 					sk->sk_socket->ops);
686*4882a593Smuzhiyun 			smp_store_release(&saved_tcpv6_prot, prot);
687*4882a593Smuzhiyun 		}
688*4882a593Smuzhiyun 		mutex_unlock(&tcpv6_prot_mutex);
689*4882a593Smuzhiyun 	}
690*4882a593Smuzhiyun 
691*4882a593Smuzhiyun 	if (ip_ver == TLSV4 &&
692*4882a593Smuzhiyun 	    unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
693*4882a593Smuzhiyun 		mutex_lock(&tcpv4_prot_mutex);
694*4882a593Smuzhiyun 		if (likely(prot != saved_tcpv4_prot)) {
695*4882a593Smuzhiyun 			build_protos(tls_prots[TLSV4], prot);
696*4882a593Smuzhiyun 			build_proto_ops(tls_proto_ops[TLSV4],
697*4882a593Smuzhiyun 					sk->sk_socket->ops);
698*4882a593Smuzhiyun 			smp_store_release(&saved_tcpv4_prot, prot);
699*4882a593Smuzhiyun 		}
700*4882a593Smuzhiyun 		mutex_unlock(&tcpv4_prot_mutex);
701*4882a593Smuzhiyun 	}
702*4882a593Smuzhiyun }
703*4882a593Smuzhiyun 
build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],const struct proto * base)704*4882a593Smuzhiyun static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
705*4882a593Smuzhiyun 			 const struct proto *base)
706*4882a593Smuzhiyun {
707*4882a593Smuzhiyun 	prot[TLS_BASE][TLS_BASE] = *base;
708*4882a593Smuzhiyun 	prot[TLS_BASE][TLS_BASE].setsockopt	= tls_setsockopt;
709*4882a593Smuzhiyun 	prot[TLS_BASE][TLS_BASE].getsockopt	= tls_getsockopt;
710*4882a593Smuzhiyun 	prot[TLS_BASE][TLS_BASE].close		= tls_sk_proto_close;
711*4882a593Smuzhiyun 
712*4882a593Smuzhiyun 	prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
713*4882a593Smuzhiyun 	prot[TLS_SW][TLS_BASE].sendmsg		= tls_sw_sendmsg;
714*4882a593Smuzhiyun 	prot[TLS_SW][TLS_BASE].sendpage		= tls_sw_sendpage;
715*4882a593Smuzhiyun 
716*4882a593Smuzhiyun 	prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
717*4882a593Smuzhiyun 	prot[TLS_BASE][TLS_SW].recvmsg		  = tls_sw_recvmsg;
718*4882a593Smuzhiyun 	prot[TLS_BASE][TLS_SW].stream_memory_read = tls_sw_stream_read;
719*4882a593Smuzhiyun 	prot[TLS_BASE][TLS_SW].close		  = tls_sk_proto_close;
720*4882a593Smuzhiyun 
721*4882a593Smuzhiyun 	prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
722*4882a593Smuzhiyun 	prot[TLS_SW][TLS_SW].recvmsg		= tls_sw_recvmsg;
723*4882a593Smuzhiyun 	prot[TLS_SW][TLS_SW].stream_memory_read	= tls_sw_stream_read;
724*4882a593Smuzhiyun 	prot[TLS_SW][TLS_SW].close		= tls_sk_proto_close;
725*4882a593Smuzhiyun 
726*4882a593Smuzhiyun #ifdef CONFIG_TLS_DEVICE
727*4882a593Smuzhiyun 	prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
728*4882a593Smuzhiyun 	prot[TLS_HW][TLS_BASE].sendmsg		= tls_device_sendmsg;
729*4882a593Smuzhiyun 	prot[TLS_HW][TLS_BASE].sendpage		= tls_device_sendpage;
730*4882a593Smuzhiyun 
731*4882a593Smuzhiyun 	prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
732*4882a593Smuzhiyun 	prot[TLS_HW][TLS_SW].sendmsg		= tls_device_sendmsg;
733*4882a593Smuzhiyun 	prot[TLS_HW][TLS_SW].sendpage		= tls_device_sendpage;
734*4882a593Smuzhiyun 
735*4882a593Smuzhiyun 	prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
736*4882a593Smuzhiyun 
737*4882a593Smuzhiyun 	prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
738*4882a593Smuzhiyun 
739*4882a593Smuzhiyun 	prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
740*4882a593Smuzhiyun #endif
741*4882a593Smuzhiyun #ifdef CONFIG_TLS_TOE
742*4882a593Smuzhiyun 	prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
743*4882a593Smuzhiyun 	prot[TLS_HW_RECORD][TLS_HW_RECORD].hash		= tls_toe_hash;
744*4882a593Smuzhiyun 	prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash	= tls_toe_unhash;
745*4882a593Smuzhiyun #endif
746*4882a593Smuzhiyun }
747*4882a593Smuzhiyun 
tls_init(struct sock * sk)748*4882a593Smuzhiyun static int tls_init(struct sock *sk)
749*4882a593Smuzhiyun {
750*4882a593Smuzhiyun 	struct tls_context *ctx;
751*4882a593Smuzhiyun 	int rc = 0;
752*4882a593Smuzhiyun 
753*4882a593Smuzhiyun 	tls_build_proto(sk);
754*4882a593Smuzhiyun 
755*4882a593Smuzhiyun #ifdef CONFIG_TLS_TOE
756*4882a593Smuzhiyun 	if (tls_toe_bypass(sk))
757*4882a593Smuzhiyun 		return 0;
758*4882a593Smuzhiyun #endif
759*4882a593Smuzhiyun 
760*4882a593Smuzhiyun 	/* The TLS ulp is currently supported only for TCP sockets
761*4882a593Smuzhiyun 	 * in ESTABLISHED state.
762*4882a593Smuzhiyun 	 * Supporting sockets in LISTEN state will require us
763*4882a593Smuzhiyun 	 * to modify the accept implementation to clone rather then
764*4882a593Smuzhiyun 	 * share the ulp context.
765*4882a593Smuzhiyun 	 */
766*4882a593Smuzhiyun 	if (sk->sk_state != TCP_ESTABLISHED)
767*4882a593Smuzhiyun 		return -ENOTCONN;
768*4882a593Smuzhiyun 
769*4882a593Smuzhiyun 	/* allocate tls context */
770*4882a593Smuzhiyun 	write_lock_bh(&sk->sk_callback_lock);
771*4882a593Smuzhiyun 	ctx = tls_ctx_create(sk);
772*4882a593Smuzhiyun 	if (!ctx) {
773*4882a593Smuzhiyun 		rc = -ENOMEM;
774*4882a593Smuzhiyun 		goto out;
775*4882a593Smuzhiyun 	}
776*4882a593Smuzhiyun 
777*4882a593Smuzhiyun 	ctx->tx_conf = TLS_BASE;
778*4882a593Smuzhiyun 	ctx->rx_conf = TLS_BASE;
779*4882a593Smuzhiyun 	update_sk_prot(sk, ctx);
780*4882a593Smuzhiyun out:
781*4882a593Smuzhiyun 	write_unlock_bh(&sk->sk_callback_lock);
782*4882a593Smuzhiyun 	return rc;
783*4882a593Smuzhiyun }
784*4882a593Smuzhiyun 
tls_update(struct sock * sk,struct proto * p,void (* write_space)(struct sock * sk))785*4882a593Smuzhiyun static void tls_update(struct sock *sk, struct proto *p,
786*4882a593Smuzhiyun 		       void (*write_space)(struct sock *sk))
787*4882a593Smuzhiyun {
788*4882a593Smuzhiyun 	struct tls_context *ctx;
789*4882a593Smuzhiyun 
790*4882a593Smuzhiyun 	ctx = tls_get_ctx(sk);
791*4882a593Smuzhiyun 	if (likely(ctx)) {
792*4882a593Smuzhiyun 		ctx->sk_write_space = write_space;
793*4882a593Smuzhiyun 		ctx->sk_proto = p;
794*4882a593Smuzhiyun 	} else {
795*4882a593Smuzhiyun 		/* Pairs with lockless read in sk_clone_lock(). */
796*4882a593Smuzhiyun 		WRITE_ONCE(sk->sk_prot, p);
797*4882a593Smuzhiyun 		sk->sk_write_space = write_space;
798*4882a593Smuzhiyun 	}
799*4882a593Smuzhiyun }
800*4882a593Smuzhiyun 
tls_get_info(const struct sock * sk,struct sk_buff * skb)801*4882a593Smuzhiyun static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
802*4882a593Smuzhiyun {
803*4882a593Smuzhiyun 	u16 version, cipher_type;
804*4882a593Smuzhiyun 	struct tls_context *ctx;
805*4882a593Smuzhiyun 	struct nlattr *start;
806*4882a593Smuzhiyun 	int err;
807*4882a593Smuzhiyun 
808*4882a593Smuzhiyun 	start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
809*4882a593Smuzhiyun 	if (!start)
810*4882a593Smuzhiyun 		return -EMSGSIZE;
811*4882a593Smuzhiyun 
812*4882a593Smuzhiyun 	rcu_read_lock();
813*4882a593Smuzhiyun 	ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
814*4882a593Smuzhiyun 	if (!ctx) {
815*4882a593Smuzhiyun 		err = 0;
816*4882a593Smuzhiyun 		goto nla_failure;
817*4882a593Smuzhiyun 	}
818*4882a593Smuzhiyun 	version = ctx->prot_info.version;
819*4882a593Smuzhiyun 	if (version) {
820*4882a593Smuzhiyun 		err = nla_put_u16(skb, TLS_INFO_VERSION, version);
821*4882a593Smuzhiyun 		if (err)
822*4882a593Smuzhiyun 			goto nla_failure;
823*4882a593Smuzhiyun 	}
824*4882a593Smuzhiyun 	cipher_type = ctx->prot_info.cipher_type;
825*4882a593Smuzhiyun 	if (cipher_type) {
826*4882a593Smuzhiyun 		err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
827*4882a593Smuzhiyun 		if (err)
828*4882a593Smuzhiyun 			goto nla_failure;
829*4882a593Smuzhiyun 	}
830*4882a593Smuzhiyun 	err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
831*4882a593Smuzhiyun 	if (err)
832*4882a593Smuzhiyun 		goto nla_failure;
833*4882a593Smuzhiyun 
834*4882a593Smuzhiyun 	err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
835*4882a593Smuzhiyun 	if (err)
836*4882a593Smuzhiyun 		goto nla_failure;
837*4882a593Smuzhiyun 
838*4882a593Smuzhiyun 	rcu_read_unlock();
839*4882a593Smuzhiyun 	nla_nest_end(skb, start);
840*4882a593Smuzhiyun 	return 0;
841*4882a593Smuzhiyun 
842*4882a593Smuzhiyun nla_failure:
843*4882a593Smuzhiyun 	rcu_read_unlock();
844*4882a593Smuzhiyun 	nla_nest_cancel(skb, start);
845*4882a593Smuzhiyun 	return err;
846*4882a593Smuzhiyun }
847*4882a593Smuzhiyun 
tls_get_info_size(const struct sock * sk)848*4882a593Smuzhiyun static size_t tls_get_info_size(const struct sock *sk)
849*4882a593Smuzhiyun {
850*4882a593Smuzhiyun 	size_t size = 0;
851*4882a593Smuzhiyun 
852*4882a593Smuzhiyun 	size += nla_total_size(0) +		/* INET_ULP_INFO_TLS */
853*4882a593Smuzhiyun 		nla_total_size(sizeof(u16)) +	/* TLS_INFO_VERSION */
854*4882a593Smuzhiyun 		nla_total_size(sizeof(u16)) +	/* TLS_INFO_CIPHER */
855*4882a593Smuzhiyun 		nla_total_size(sizeof(u16)) +	/* TLS_INFO_RXCONF */
856*4882a593Smuzhiyun 		nla_total_size(sizeof(u16)) +	/* TLS_INFO_TXCONF */
857*4882a593Smuzhiyun 		0;
858*4882a593Smuzhiyun 
859*4882a593Smuzhiyun 	return size;
860*4882a593Smuzhiyun }
861*4882a593Smuzhiyun 
tls_init_net(struct net * net)862*4882a593Smuzhiyun static int __net_init tls_init_net(struct net *net)
863*4882a593Smuzhiyun {
864*4882a593Smuzhiyun 	int err;
865*4882a593Smuzhiyun 
866*4882a593Smuzhiyun 	net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
867*4882a593Smuzhiyun 	if (!net->mib.tls_statistics)
868*4882a593Smuzhiyun 		return -ENOMEM;
869*4882a593Smuzhiyun 
870*4882a593Smuzhiyun 	err = tls_proc_init(net);
871*4882a593Smuzhiyun 	if (err)
872*4882a593Smuzhiyun 		goto err_free_stats;
873*4882a593Smuzhiyun 
874*4882a593Smuzhiyun 	return 0;
875*4882a593Smuzhiyun err_free_stats:
876*4882a593Smuzhiyun 	free_percpu(net->mib.tls_statistics);
877*4882a593Smuzhiyun 	return err;
878*4882a593Smuzhiyun }
879*4882a593Smuzhiyun 
tls_exit_net(struct net * net)880*4882a593Smuzhiyun static void __net_exit tls_exit_net(struct net *net)
881*4882a593Smuzhiyun {
882*4882a593Smuzhiyun 	tls_proc_fini(net);
883*4882a593Smuzhiyun 	free_percpu(net->mib.tls_statistics);
884*4882a593Smuzhiyun }
885*4882a593Smuzhiyun 
886*4882a593Smuzhiyun static struct pernet_operations tls_proc_ops = {
887*4882a593Smuzhiyun 	.init = tls_init_net,
888*4882a593Smuzhiyun 	.exit = tls_exit_net,
889*4882a593Smuzhiyun };
890*4882a593Smuzhiyun 
891*4882a593Smuzhiyun static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
892*4882a593Smuzhiyun 	.name			= "tls",
893*4882a593Smuzhiyun 	.owner			= THIS_MODULE,
894*4882a593Smuzhiyun 	.init			= tls_init,
895*4882a593Smuzhiyun 	.update			= tls_update,
896*4882a593Smuzhiyun 	.get_info		= tls_get_info,
897*4882a593Smuzhiyun 	.get_info_size		= tls_get_info_size,
898*4882a593Smuzhiyun };
899*4882a593Smuzhiyun 
tls_register(void)900*4882a593Smuzhiyun static int __init tls_register(void)
901*4882a593Smuzhiyun {
902*4882a593Smuzhiyun 	int err;
903*4882a593Smuzhiyun 
904*4882a593Smuzhiyun 	err = register_pernet_subsys(&tls_proc_ops);
905*4882a593Smuzhiyun 	if (err)
906*4882a593Smuzhiyun 		return err;
907*4882a593Smuzhiyun 
908*4882a593Smuzhiyun 	err = tls_device_init();
909*4882a593Smuzhiyun 	if (err) {
910*4882a593Smuzhiyun 		unregister_pernet_subsys(&tls_proc_ops);
911*4882a593Smuzhiyun 		return err;
912*4882a593Smuzhiyun 	}
913*4882a593Smuzhiyun 
914*4882a593Smuzhiyun 	tcp_register_ulp(&tcp_tls_ulp_ops);
915*4882a593Smuzhiyun 
916*4882a593Smuzhiyun 	return 0;
917*4882a593Smuzhiyun }
918*4882a593Smuzhiyun 
tls_unregister(void)919*4882a593Smuzhiyun static void __exit tls_unregister(void)
920*4882a593Smuzhiyun {
921*4882a593Smuzhiyun 	tcp_unregister_ulp(&tcp_tls_ulp_ops);
922*4882a593Smuzhiyun 	tls_device_cleanup();
923*4882a593Smuzhiyun 	unregister_pernet_subsys(&tls_proc_ops);
924*4882a593Smuzhiyun }
925*4882a593Smuzhiyun 
926*4882a593Smuzhiyun module_init(tls_register);
927*4882a593Smuzhiyun module_exit(tls_unregister);
928