1*4882a593Smuzhiyun /*
2*4882a593Smuzhiyun * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3*4882a593Smuzhiyun * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4*4882a593Smuzhiyun *
5*4882a593Smuzhiyun * This software is available to you under a choice of one of two
6*4882a593Smuzhiyun * licenses. You may choose to be licensed under the terms of the GNU
7*4882a593Smuzhiyun * General Public License (GPL) Version 2, available from the file
8*4882a593Smuzhiyun * COPYING in the main directory of this source tree, or the
9*4882a593Smuzhiyun * OpenIB.org BSD license below:
10*4882a593Smuzhiyun *
11*4882a593Smuzhiyun * Redistribution and use in source and binary forms, with or
12*4882a593Smuzhiyun * without modification, are permitted provided that the following
13*4882a593Smuzhiyun * conditions are met:
14*4882a593Smuzhiyun *
15*4882a593Smuzhiyun * - Redistributions of source code must retain the above
16*4882a593Smuzhiyun * copyright notice, this list of conditions and the following
17*4882a593Smuzhiyun * disclaimer.
18*4882a593Smuzhiyun *
19*4882a593Smuzhiyun * - Redistributions in binary form must reproduce the above
20*4882a593Smuzhiyun * copyright notice, this list of conditions and the following
21*4882a593Smuzhiyun * disclaimer in the documentation and/or other materials
22*4882a593Smuzhiyun * provided with the distribution.
23*4882a593Smuzhiyun *
24*4882a593Smuzhiyun * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25*4882a593Smuzhiyun * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26*4882a593Smuzhiyun * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27*4882a593Smuzhiyun * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28*4882a593Smuzhiyun * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29*4882a593Smuzhiyun * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30*4882a593Smuzhiyun * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31*4882a593Smuzhiyun * SOFTWARE.
32*4882a593Smuzhiyun */
33*4882a593Smuzhiyun
34*4882a593Smuzhiyun #include <linux/module.h>
35*4882a593Smuzhiyun
36*4882a593Smuzhiyun #include <net/tcp.h>
37*4882a593Smuzhiyun #include <net/inet_common.h>
38*4882a593Smuzhiyun #include <linux/highmem.h>
39*4882a593Smuzhiyun #include <linux/netdevice.h>
40*4882a593Smuzhiyun #include <linux/sched/signal.h>
41*4882a593Smuzhiyun #include <linux/inetdevice.h>
42*4882a593Smuzhiyun #include <linux/inet_diag.h>
43*4882a593Smuzhiyun
44*4882a593Smuzhiyun #include <net/snmp.h>
45*4882a593Smuzhiyun #include <net/tls.h>
46*4882a593Smuzhiyun #include <net/tls_toe.h>
47*4882a593Smuzhiyun
48*4882a593Smuzhiyun MODULE_AUTHOR("Mellanox Technologies");
49*4882a593Smuzhiyun MODULE_DESCRIPTION("Transport Layer Security Support");
50*4882a593Smuzhiyun MODULE_LICENSE("Dual BSD/GPL");
51*4882a593Smuzhiyun MODULE_ALIAS_TCP_ULP("tls");
52*4882a593Smuzhiyun
53*4882a593Smuzhiyun enum {
54*4882a593Smuzhiyun TLSV4,
55*4882a593Smuzhiyun TLSV6,
56*4882a593Smuzhiyun TLS_NUM_PROTS,
57*4882a593Smuzhiyun };
58*4882a593Smuzhiyun
59*4882a593Smuzhiyun static const struct proto *saved_tcpv6_prot;
60*4882a593Smuzhiyun static DEFINE_MUTEX(tcpv6_prot_mutex);
61*4882a593Smuzhiyun static const struct proto *saved_tcpv4_prot;
62*4882a593Smuzhiyun static DEFINE_MUTEX(tcpv4_prot_mutex);
63*4882a593Smuzhiyun static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
64*4882a593Smuzhiyun static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
65*4882a593Smuzhiyun static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
66*4882a593Smuzhiyun const struct proto *base);
67*4882a593Smuzhiyun
update_sk_prot(struct sock * sk,struct tls_context * ctx)68*4882a593Smuzhiyun void update_sk_prot(struct sock *sk, struct tls_context *ctx)
69*4882a593Smuzhiyun {
70*4882a593Smuzhiyun int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
71*4882a593Smuzhiyun
72*4882a593Smuzhiyun WRITE_ONCE(sk->sk_prot,
73*4882a593Smuzhiyun &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
74*4882a593Smuzhiyun WRITE_ONCE(sk->sk_socket->ops,
75*4882a593Smuzhiyun &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
76*4882a593Smuzhiyun }
77*4882a593Smuzhiyun
wait_on_pending_writer(struct sock * sk,long * timeo)78*4882a593Smuzhiyun int wait_on_pending_writer(struct sock *sk, long *timeo)
79*4882a593Smuzhiyun {
80*4882a593Smuzhiyun int rc = 0;
81*4882a593Smuzhiyun DEFINE_WAIT_FUNC(wait, woken_wake_function);
82*4882a593Smuzhiyun
83*4882a593Smuzhiyun add_wait_queue(sk_sleep(sk), &wait);
84*4882a593Smuzhiyun while (1) {
85*4882a593Smuzhiyun if (!*timeo) {
86*4882a593Smuzhiyun rc = -EAGAIN;
87*4882a593Smuzhiyun break;
88*4882a593Smuzhiyun }
89*4882a593Smuzhiyun
90*4882a593Smuzhiyun if (signal_pending(current)) {
91*4882a593Smuzhiyun rc = sock_intr_errno(*timeo);
92*4882a593Smuzhiyun break;
93*4882a593Smuzhiyun }
94*4882a593Smuzhiyun
95*4882a593Smuzhiyun if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
96*4882a593Smuzhiyun break;
97*4882a593Smuzhiyun }
98*4882a593Smuzhiyun remove_wait_queue(sk_sleep(sk), &wait);
99*4882a593Smuzhiyun return rc;
100*4882a593Smuzhiyun }
101*4882a593Smuzhiyun
tls_push_sg(struct sock * sk,struct tls_context * ctx,struct scatterlist * sg,u16 first_offset,int flags)102*4882a593Smuzhiyun int tls_push_sg(struct sock *sk,
103*4882a593Smuzhiyun struct tls_context *ctx,
104*4882a593Smuzhiyun struct scatterlist *sg,
105*4882a593Smuzhiyun u16 first_offset,
106*4882a593Smuzhiyun int flags)
107*4882a593Smuzhiyun {
108*4882a593Smuzhiyun int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
109*4882a593Smuzhiyun int ret = 0;
110*4882a593Smuzhiyun struct page *p;
111*4882a593Smuzhiyun size_t size;
112*4882a593Smuzhiyun int offset = first_offset;
113*4882a593Smuzhiyun
114*4882a593Smuzhiyun size = sg->length - offset;
115*4882a593Smuzhiyun offset += sg->offset;
116*4882a593Smuzhiyun
117*4882a593Smuzhiyun ctx->in_tcp_sendpages = true;
118*4882a593Smuzhiyun while (1) {
119*4882a593Smuzhiyun if (sg_is_last(sg))
120*4882a593Smuzhiyun sendpage_flags = flags;
121*4882a593Smuzhiyun
122*4882a593Smuzhiyun /* is sending application-limited? */
123*4882a593Smuzhiyun tcp_rate_check_app_limited(sk);
124*4882a593Smuzhiyun p = sg_page(sg);
125*4882a593Smuzhiyun retry:
126*4882a593Smuzhiyun ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
127*4882a593Smuzhiyun
128*4882a593Smuzhiyun if (ret != size) {
129*4882a593Smuzhiyun if (ret > 0) {
130*4882a593Smuzhiyun offset += ret;
131*4882a593Smuzhiyun size -= ret;
132*4882a593Smuzhiyun goto retry;
133*4882a593Smuzhiyun }
134*4882a593Smuzhiyun
135*4882a593Smuzhiyun offset -= sg->offset;
136*4882a593Smuzhiyun ctx->partially_sent_offset = offset;
137*4882a593Smuzhiyun ctx->partially_sent_record = (void *)sg;
138*4882a593Smuzhiyun ctx->in_tcp_sendpages = false;
139*4882a593Smuzhiyun return ret;
140*4882a593Smuzhiyun }
141*4882a593Smuzhiyun
142*4882a593Smuzhiyun put_page(p);
143*4882a593Smuzhiyun sk_mem_uncharge(sk, sg->length);
144*4882a593Smuzhiyun sg = sg_next(sg);
145*4882a593Smuzhiyun if (!sg)
146*4882a593Smuzhiyun break;
147*4882a593Smuzhiyun
148*4882a593Smuzhiyun offset = sg->offset;
149*4882a593Smuzhiyun size = sg->length;
150*4882a593Smuzhiyun }
151*4882a593Smuzhiyun
152*4882a593Smuzhiyun ctx->in_tcp_sendpages = false;
153*4882a593Smuzhiyun
154*4882a593Smuzhiyun return 0;
155*4882a593Smuzhiyun }
156*4882a593Smuzhiyun
tls_handle_open_record(struct sock * sk,int flags)157*4882a593Smuzhiyun static int tls_handle_open_record(struct sock *sk, int flags)
158*4882a593Smuzhiyun {
159*4882a593Smuzhiyun struct tls_context *ctx = tls_get_ctx(sk);
160*4882a593Smuzhiyun
161*4882a593Smuzhiyun if (tls_is_pending_open_record(ctx))
162*4882a593Smuzhiyun return ctx->push_pending_record(sk, flags);
163*4882a593Smuzhiyun
164*4882a593Smuzhiyun return 0;
165*4882a593Smuzhiyun }
166*4882a593Smuzhiyun
tls_proccess_cmsg(struct sock * sk,struct msghdr * msg,unsigned char * record_type)167*4882a593Smuzhiyun int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
168*4882a593Smuzhiyun unsigned char *record_type)
169*4882a593Smuzhiyun {
170*4882a593Smuzhiyun struct cmsghdr *cmsg;
171*4882a593Smuzhiyun int rc = -EINVAL;
172*4882a593Smuzhiyun
173*4882a593Smuzhiyun for_each_cmsghdr(cmsg, msg) {
174*4882a593Smuzhiyun if (!CMSG_OK(msg, cmsg))
175*4882a593Smuzhiyun return -EINVAL;
176*4882a593Smuzhiyun if (cmsg->cmsg_level != SOL_TLS)
177*4882a593Smuzhiyun continue;
178*4882a593Smuzhiyun
179*4882a593Smuzhiyun switch (cmsg->cmsg_type) {
180*4882a593Smuzhiyun case TLS_SET_RECORD_TYPE:
181*4882a593Smuzhiyun if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
182*4882a593Smuzhiyun return -EINVAL;
183*4882a593Smuzhiyun
184*4882a593Smuzhiyun if (msg->msg_flags & MSG_MORE)
185*4882a593Smuzhiyun return -EINVAL;
186*4882a593Smuzhiyun
187*4882a593Smuzhiyun rc = tls_handle_open_record(sk, msg->msg_flags);
188*4882a593Smuzhiyun if (rc)
189*4882a593Smuzhiyun return rc;
190*4882a593Smuzhiyun
191*4882a593Smuzhiyun *record_type = *(unsigned char *)CMSG_DATA(cmsg);
192*4882a593Smuzhiyun rc = 0;
193*4882a593Smuzhiyun break;
194*4882a593Smuzhiyun default:
195*4882a593Smuzhiyun return -EINVAL;
196*4882a593Smuzhiyun }
197*4882a593Smuzhiyun }
198*4882a593Smuzhiyun
199*4882a593Smuzhiyun return rc;
200*4882a593Smuzhiyun }
201*4882a593Smuzhiyun
tls_push_partial_record(struct sock * sk,struct tls_context * ctx,int flags)202*4882a593Smuzhiyun int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
203*4882a593Smuzhiyun int flags)
204*4882a593Smuzhiyun {
205*4882a593Smuzhiyun struct scatterlist *sg;
206*4882a593Smuzhiyun u16 offset;
207*4882a593Smuzhiyun
208*4882a593Smuzhiyun sg = ctx->partially_sent_record;
209*4882a593Smuzhiyun offset = ctx->partially_sent_offset;
210*4882a593Smuzhiyun
211*4882a593Smuzhiyun ctx->partially_sent_record = NULL;
212*4882a593Smuzhiyun return tls_push_sg(sk, ctx, sg, offset, flags);
213*4882a593Smuzhiyun }
214*4882a593Smuzhiyun
tls_free_partial_record(struct sock * sk,struct tls_context * ctx)215*4882a593Smuzhiyun void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
216*4882a593Smuzhiyun {
217*4882a593Smuzhiyun struct scatterlist *sg;
218*4882a593Smuzhiyun
219*4882a593Smuzhiyun for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
220*4882a593Smuzhiyun put_page(sg_page(sg));
221*4882a593Smuzhiyun sk_mem_uncharge(sk, sg->length);
222*4882a593Smuzhiyun }
223*4882a593Smuzhiyun ctx->partially_sent_record = NULL;
224*4882a593Smuzhiyun }
225*4882a593Smuzhiyun
tls_write_space(struct sock * sk)226*4882a593Smuzhiyun static void tls_write_space(struct sock *sk)
227*4882a593Smuzhiyun {
228*4882a593Smuzhiyun struct tls_context *ctx = tls_get_ctx(sk);
229*4882a593Smuzhiyun
230*4882a593Smuzhiyun /* If in_tcp_sendpages call lower protocol write space handler
231*4882a593Smuzhiyun * to ensure we wake up any waiting operations there. For example
232*4882a593Smuzhiyun * if do_tcp_sendpages where to call sk_wait_event.
233*4882a593Smuzhiyun */
234*4882a593Smuzhiyun if (ctx->in_tcp_sendpages) {
235*4882a593Smuzhiyun ctx->sk_write_space(sk);
236*4882a593Smuzhiyun return;
237*4882a593Smuzhiyun }
238*4882a593Smuzhiyun
239*4882a593Smuzhiyun #ifdef CONFIG_TLS_DEVICE
240*4882a593Smuzhiyun if (ctx->tx_conf == TLS_HW)
241*4882a593Smuzhiyun tls_device_write_space(sk, ctx);
242*4882a593Smuzhiyun else
243*4882a593Smuzhiyun #endif
244*4882a593Smuzhiyun tls_sw_write_space(sk, ctx);
245*4882a593Smuzhiyun
246*4882a593Smuzhiyun ctx->sk_write_space(sk);
247*4882a593Smuzhiyun }
248*4882a593Smuzhiyun
249*4882a593Smuzhiyun /**
250*4882a593Smuzhiyun * tls_ctx_free() - free TLS ULP context
251*4882a593Smuzhiyun * @sk: socket to with @ctx is attached
252*4882a593Smuzhiyun * @ctx: TLS context structure
253*4882a593Smuzhiyun *
254*4882a593Smuzhiyun * Free TLS context. If @sk is %NULL caller guarantees that the socket
255*4882a593Smuzhiyun * to which @ctx was attached has no outstanding references.
256*4882a593Smuzhiyun */
tls_ctx_free(struct sock * sk,struct tls_context * ctx)257*4882a593Smuzhiyun void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
258*4882a593Smuzhiyun {
259*4882a593Smuzhiyun if (!ctx)
260*4882a593Smuzhiyun return;
261*4882a593Smuzhiyun
262*4882a593Smuzhiyun memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
263*4882a593Smuzhiyun memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
264*4882a593Smuzhiyun mutex_destroy(&ctx->tx_lock);
265*4882a593Smuzhiyun
266*4882a593Smuzhiyun if (sk)
267*4882a593Smuzhiyun kfree_rcu(ctx, rcu);
268*4882a593Smuzhiyun else
269*4882a593Smuzhiyun kfree(ctx);
270*4882a593Smuzhiyun }
271*4882a593Smuzhiyun
tls_sk_proto_cleanup(struct sock * sk,struct tls_context * ctx,long timeo)272*4882a593Smuzhiyun static void tls_sk_proto_cleanup(struct sock *sk,
273*4882a593Smuzhiyun struct tls_context *ctx, long timeo)
274*4882a593Smuzhiyun {
275*4882a593Smuzhiyun if (unlikely(sk->sk_write_pending) &&
276*4882a593Smuzhiyun !wait_on_pending_writer(sk, &timeo))
277*4882a593Smuzhiyun tls_handle_open_record(sk, 0);
278*4882a593Smuzhiyun
279*4882a593Smuzhiyun /* We need these for tls_sw_fallback handling of other packets */
280*4882a593Smuzhiyun if (ctx->tx_conf == TLS_SW) {
281*4882a593Smuzhiyun kfree(ctx->tx.rec_seq);
282*4882a593Smuzhiyun kfree(ctx->tx.iv);
283*4882a593Smuzhiyun tls_sw_release_resources_tx(sk);
284*4882a593Smuzhiyun TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
285*4882a593Smuzhiyun } else if (ctx->tx_conf == TLS_HW) {
286*4882a593Smuzhiyun tls_device_free_resources_tx(sk);
287*4882a593Smuzhiyun TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
288*4882a593Smuzhiyun }
289*4882a593Smuzhiyun
290*4882a593Smuzhiyun if (ctx->rx_conf == TLS_SW) {
291*4882a593Smuzhiyun tls_sw_release_resources_rx(sk);
292*4882a593Smuzhiyun TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
293*4882a593Smuzhiyun } else if (ctx->rx_conf == TLS_HW) {
294*4882a593Smuzhiyun tls_device_offload_cleanup_rx(sk);
295*4882a593Smuzhiyun TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
296*4882a593Smuzhiyun }
297*4882a593Smuzhiyun }
298*4882a593Smuzhiyun
tls_sk_proto_close(struct sock * sk,long timeout)299*4882a593Smuzhiyun static void tls_sk_proto_close(struct sock *sk, long timeout)
300*4882a593Smuzhiyun {
301*4882a593Smuzhiyun struct inet_connection_sock *icsk = inet_csk(sk);
302*4882a593Smuzhiyun struct tls_context *ctx = tls_get_ctx(sk);
303*4882a593Smuzhiyun long timeo = sock_sndtimeo(sk, 0);
304*4882a593Smuzhiyun bool free_ctx;
305*4882a593Smuzhiyun
306*4882a593Smuzhiyun if (ctx->tx_conf == TLS_SW)
307*4882a593Smuzhiyun tls_sw_cancel_work_tx(ctx);
308*4882a593Smuzhiyun
309*4882a593Smuzhiyun lock_sock(sk);
310*4882a593Smuzhiyun free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
311*4882a593Smuzhiyun
312*4882a593Smuzhiyun if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
313*4882a593Smuzhiyun tls_sk_proto_cleanup(sk, ctx, timeo);
314*4882a593Smuzhiyun
315*4882a593Smuzhiyun write_lock_bh(&sk->sk_callback_lock);
316*4882a593Smuzhiyun if (free_ctx)
317*4882a593Smuzhiyun rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
318*4882a593Smuzhiyun WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
319*4882a593Smuzhiyun if (sk->sk_write_space == tls_write_space)
320*4882a593Smuzhiyun sk->sk_write_space = ctx->sk_write_space;
321*4882a593Smuzhiyun write_unlock_bh(&sk->sk_callback_lock);
322*4882a593Smuzhiyun release_sock(sk);
323*4882a593Smuzhiyun if (ctx->tx_conf == TLS_SW)
324*4882a593Smuzhiyun tls_sw_free_ctx_tx(ctx);
325*4882a593Smuzhiyun if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
326*4882a593Smuzhiyun tls_sw_strparser_done(ctx);
327*4882a593Smuzhiyun if (ctx->rx_conf == TLS_SW)
328*4882a593Smuzhiyun tls_sw_free_ctx_rx(ctx);
329*4882a593Smuzhiyun ctx->sk_proto->close(sk, timeout);
330*4882a593Smuzhiyun
331*4882a593Smuzhiyun if (free_ctx)
332*4882a593Smuzhiyun tls_ctx_free(sk, ctx);
333*4882a593Smuzhiyun }
334*4882a593Smuzhiyun
do_tls_getsockopt_conf(struct sock * sk,char __user * optval,int __user * optlen,int tx)335*4882a593Smuzhiyun static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
336*4882a593Smuzhiyun int __user *optlen, int tx)
337*4882a593Smuzhiyun {
338*4882a593Smuzhiyun int rc = 0;
339*4882a593Smuzhiyun struct tls_context *ctx = tls_get_ctx(sk);
340*4882a593Smuzhiyun struct tls_crypto_info *crypto_info;
341*4882a593Smuzhiyun struct cipher_context *cctx;
342*4882a593Smuzhiyun int len;
343*4882a593Smuzhiyun
344*4882a593Smuzhiyun if (get_user(len, optlen))
345*4882a593Smuzhiyun return -EFAULT;
346*4882a593Smuzhiyun
347*4882a593Smuzhiyun if (!optval || (len < sizeof(*crypto_info))) {
348*4882a593Smuzhiyun rc = -EINVAL;
349*4882a593Smuzhiyun goto out;
350*4882a593Smuzhiyun }
351*4882a593Smuzhiyun
352*4882a593Smuzhiyun if (!ctx) {
353*4882a593Smuzhiyun rc = -EBUSY;
354*4882a593Smuzhiyun goto out;
355*4882a593Smuzhiyun }
356*4882a593Smuzhiyun
357*4882a593Smuzhiyun /* get user crypto info */
358*4882a593Smuzhiyun if (tx) {
359*4882a593Smuzhiyun crypto_info = &ctx->crypto_send.info;
360*4882a593Smuzhiyun cctx = &ctx->tx;
361*4882a593Smuzhiyun } else {
362*4882a593Smuzhiyun crypto_info = &ctx->crypto_recv.info;
363*4882a593Smuzhiyun cctx = &ctx->rx;
364*4882a593Smuzhiyun }
365*4882a593Smuzhiyun
366*4882a593Smuzhiyun if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
367*4882a593Smuzhiyun rc = -EBUSY;
368*4882a593Smuzhiyun goto out;
369*4882a593Smuzhiyun }
370*4882a593Smuzhiyun
371*4882a593Smuzhiyun if (len == sizeof(*crypto_info)) {
372*4882a593Smuzhiyun if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
373*4882a593Smuzhiyun rc = -EFAULT;
374*4882a593Smuzhiyun goto out;
375*4882a593Smuzhiyun }
376*4882a593Smuzhiyun
377*4882a593Smuzhiyun switch (crypto_info->cipher_type) {
378*4882a593Smuzhiyun case TLS_CIPHER_AES_GCM_128: {
379*4882a593Smuzhiyun struct tls12_crypto_info_aes_gcm_128 *
380*4882a593Smuzhiyun crypto_info_aes_gcm_128 =
381*4882a593Smuzhiyun container_of(crypto_info,
382*4882a593Smuzhiyun struct tls12_crypto_info_aes_gcm_128,
383*4882a593Smuzhiyun info);
384*4882a593Smuzhiyun
385*4882a593Smuzhiyun if (len != sizeof(*crypto_info_aes_gcm_128)) {
386*4882a593Smuzhiyun rc = -EINVAL;
387*4882a593Smuzhiyun goto out;
388*4882a593Smuzhiyun }
389*4882a593Smuzhiyun lock_sock(sk);
390*4882a593Smuzhiyun memcpy(crypto_info_aes_gcm_128->iv,
391*4882a593Smuzhiyun cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
392*4882a593Smuzhiyun TLS_CIPHER_AES_GCM_128_IV_SIZE);
393*4882a593Smuzhiyun memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq,
394*4882a593Smuzhiyun TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
395*4882a593Smuzhiyun release_sock(sk);
396*4882a593Smuzhiyun if (copy_to_user(optval,
397*4882a593Smuzhiyun crypto_info_aes_gcm_128,
398*4882a593Smuzhiyun sizeof(*crypto_info_aes_gcm_128)))
399*4882a593Smuzhiyun rc = -EFAULT;
400*4882a593Smuzhiyun break;
401*4882a593Smuzhiyun }
402*4882a593Smuzhiyun case TLS_CIPHER_AES_GCM_256: {
403*4882a593Smuzhiyun struct tls12_crypto_info_aes_gcm_256 *
404*4882a593Smuzhiyun crypto_info_aes_gcm_256 =
405*4882a593Smuzhiyun container_of(crypto_info,
406*4882a593Smuzhiyun struct tls12_crypto_info_aes_gcm_256,
407*4882a593Smuzhiyun info);
408*4882a593Smuzhiyun
409*4882a593Smuzhiyun if (len != sizeof(*crypto_info_aes_gcm_256)) {
410*4882a593Smuzhiyun rc = -EINVAL;
411*4882a593Smuzhiyun goto out;
412*4882a593Smuzhiyun }
413*4882a593Smuzhiyun lock_sock(sk);
414*4882a593Smuzhiyun memcpy(crypto_info_aes_gcm_256->iv,
415*4882a593Smuzhiyun cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
416*4882a593Smuzhiyun TLS_CIPHER_AES_GCM_256_IV_SIZE);
417*4882a593Smuzhiyun memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq,
418*4882a593Smuzhiyun TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
419*4882a593Smuzhiyun release_sock(sk);
420*4882a593Smuzhiyun if (copy_to_user(optval,
421*4882a593Smuzhiyun crypto_info_aes_gcm_256,
422*4882a593Smuzhiyun sizeof(*crypto_info_aes_gcm_256)))
423*4882a593Smuzhiyun rc = -EFAULT;
424*4882a593Smuzhiyun break;
425*4882a593Smuzhiyun }
426*4882a593Smuzhiyun default:
427*4882a593Smuzhiyun rc = -EINVAL;
428*4882a593Smuzhiyun }
429*4882a593Smuzhiyun
430*4882a593Smuzhiyun out:
431*4882a593Smuzhiyun return rc;
432*4882a593Smuzhiyun }
433*4882a593Smuzhiyun
do_tls_getsockopt(struct sock * sk,int optname,char __user * optval,int __user * optlen)434*4882a593Smuzhiyun static int do_tls_getsockopt(struct sock *sk, int optname,
435*4882a593Smuzhiyun char __user *optval, int __user *optlen)
436*4882a593Smuzhiyun {
437*4882a593Smuzhiyun int rc = 0;
438*4882a593Smuzhiyun
439*4882a593Smuzhiyun switch (optname) {
440*4882a593Smuzhiyun case TLS_TX:
441*4882a593Smuzhiyun case TLS_RX:
442*4882a593Smuzhiyun rc = do_tls_getsockopt_conf(sk, optval, optlen,
443*4882a593Smuzhiyun optname == TLS_TX);
444*4882a593Smuzhiyun break;
445*4882a593Smuzhiyun default:
446*4882a593Smuzhiyun rc = -ENOPROTOOPT;
447*4882a593Smuzhiyun break;
448*4882a593Smuzhiyun }
449*4882a593Smuzhiyun return rc;
450*4882a593Smuzhiyun }
451*4882a593Smuzhiyun
tls_getsockopt(struct sock * sk,int level,int optname,char __user * optval,int __user * optlen)452*4882a593Smuzhiyun static int tls_getsockopt(struct sock *sk, int level, int optname,
453*4882a593Smuzhiyun char __user *optval, int __user *optlen)
454*4882a593Smuzhiyun {
455*4882a593Smuzhiyun struct tls_context *ctx = tls_get_ctx(sk);
456*4882a593Smuzhiyun
457*4882a593Smuzhiyun if (level != SOL_TLS)
458*4882a593Smuzhiyun return ctx->sk_proto->getsockopt(sk, level,
459*4882a593Smuzhiyun optname, optval, optlen);
460*4882a593Smuzhiyun
461*4882a593Smuzhiyun return do_tls_getsockopt(sk, optname, optval, optlen);
462*4882a593Smuzhiyun }
463*4882a593Smuzhiyun
do_tls_setsockopt_conf(struct sock * sk,sockptr_t optval,unsigned int optlen,int tx)464*4882a593Smuzhiyun static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
465*4882a593Smuzhiyun unsigned int optlen, int tx)
466*4882a593Smuzhiyun {
467*4882a593Smuzhiyun struct tls_crypto_info *crypto_info;
468*4882a593Smuzhiyun struct tls_crypto_info *alt_crypto_info;
469*4882a593Smuzhiyun struct tls_context *ctx = tls_get_ctx(sk);
470*4882a593Smuzhiyun size_t optsize;
471*4882a593Smuzhiyun int rc = 0;
472*4882a593Smuzhiyun int conf;
473*4882a593Smuzhiyun
474*4882a593Smuzhiyun if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) {
475*4882a593Smuzhiyun rc = -EINVAL;
476*4882a593Smuzhiyun goto out;
477*4882a593Smuzhiyun }
478*4882a593Smuzhiyun
479*4882a593Smuzhiyun if (tx) {
480*4882a593Smuzhiyun crypto_info = &ctx->crypto_send.info;
481*4882a593Smuzhiyun alt_crypto_info = &ctx->crypto_recv.info;
482*4882a593Smuzhiyun } else {
483*4882a593Smuzhiyun crypto_info = &ctx->crypto_recv.info;
484*4882a593Smuzhiyun alt_crypto_info = &ctx->crypto_send.info;
485*4882a593Smuzhiyun }
486*4882a593Smuzhiyun
487*4882a593Smuzhiyun /* Currently we don't support set crypto info more than one time */
488*4882a593Smuzhiyun if (TLS_CRYPTO_INFO_READY(crypto_info)) {
489*4882a593Smuzhiyun rc = -EBUSY;
490*4882a593Smuzhiyun goto out;
491*4882a593Smuzhiyun }
492*4882a593Smuzhiyun
493*4882a593Smuzhiyun rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
494*4882a593Smuzhiyun if (rc) {
495*4882a593Smuzhiyun rc = -EFAULT;
496*4882a593Smuzhiyun goto err_crypto_info;
497*4882a593Smuzhiyun }
498*4882a593Smuzhiyun
499*4882a593Smuzhiyun /* check version */
500*4882a593Smuzhiyun if (crypto_info->version != TLS_1_2_VERSION &&
501*4882a593Smuzhiyun crypto_info->version != TLS_1_3_VERSION) {
502*4882a593Smuzhiyun rc = -EINVAL;
503*4882a593Smuzhiyun goto err_crypto_info;
504*4882a593Smuzhiyun }
505*4882a593Smuzhiyun
506*4882a593Smuzhiyun /* Ensure that TLS version and ciphers are same in both directions */
507*4882a593Smuzhiyun if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
508*4882a593Smuzhiyun if (alt_crypto_info->version != crypto_info->version ||
509*4882a593Smuzhiyun alt_crypto_info->cipher_type != crypto_info->cipher_type) {
510*4882a593Smuzhiyun rc = -EINVAL;
511*4882a593Smuzhiyun goto err_crypto_info;
512*4882a593Smuzhiyun }
513*4882a593Smuzhiyun }
514*4882a593Smuzhiyun
515*4882a593Smuzhiyun switch (crypto_info->cipher_type) {
516*4882a593Smuzhiyun case TLS_CIPHER_AES_GCM_128:
517*4882a593Smuzhiyun optsize = sizeof(struct tls12_crypto_info_aes_gcm_128);
518*4882a593Smuzhiyun break;
519*4882a593Smuzhiyun case TLS_CIPHER_AES_GCM_256: {
520*4882a593Smuzhiyun optsize = sizeof(struct tls12_crypto_info_aes_gcm_256);
521*4882a593Smuzhiyun break;
522*4882a593Smuzhiyun }
523*4882a593Smuzhiyun case TLS_CIPHER_AES_CCM_128:
524*4882a593Smuzhiyun optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
525*4882a593Smuzhiyun break;
526*4882a593Smuzhiyun default:
527*4882a593Smuzhiyun rc = -EINVAL;
528*4882a593Smuzhiyun goto err_crypto_info;
529*4882a593Smuzhiyun }
530*4882a593Smuzhiyun
531*4882a593Smuzhiyun if (optlen != optsize) {
532*4882a593Smuzhiyun rc = -EINVAL;
533*4882a593Smuzhiyun goto err_crypto_info;
534*4882a593Smuzhiyun }
535*4882a593Smuzhiyun
536*4882a593Smuzhiyun rc = copy_from_sockptr_offset(crypto_info + 1, optval,
537*4882a593Smuzhiyun sizeof(*crypto_info),
538*4882a593Smuzhiyun optlen - sizeof(*crypto_info));
539*4882a593Smuzhiyun if (rc) {
540*4882a593Smuzhiyun rc = -EFAULT;
541*4882a593Smuzhiyun goto err_crypto_info;
542*4882a593Smuzhiyun }
543*4882a593Smuzhiyun
544*4882a593Smuzhiyun if (tx) {
545*4882a593Smuzhiyun rc = tls_set_device_offload(sk, ctx);
546*4882a593Smuzhiyun conf = TLS_HW;
547*4882a593Smuzhiyun if (!rc) {
548*4882a593Smuzhiyun TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
549*4882a593Smuzhiyun TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
550*4882a593Smuzhiyun } else {
551*4882a593Smuzhiyun rc = tls_set_sw_offload(sk, ctx, 1);
552*4882a593Smuzhiyun if (rc)
553*4882a593Smuzhiyun goto err_crypto_info;
554*4882a593Smuzhiyun TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
555*4882a593Smuzhiyun TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
556*4882a593Smuzhiyun conf = TLS_SW;
557*4882a593Smuzhiyun }
558*4882a593Smuzhiyun } else {
559*4882a593Smuzhiyun rc = tls_set_device_offload_rx(sk, ctx);
560*4882a593Smuzhiyun conf = TLS_HW;
561*4882a593Smuzhiyun if (!rc) {
562*4882a593Smuzhiyun TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
563*4882a593Smuzhiyun TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
564*4882a593Smuzhiyun } else {
565*4882a593Smuzhiyun rc = tls_set_sw_offload(sk, ctx, 0);
566*4882a593Smuzhiyun if (rc)
567*4882a593Smuzhiyun goto err_crypto_info;
568*4882a593Smuzhiyun TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
569*4882a593Smuzhiyun TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
570*4882a593Smuzhiyun conf = TLS_SW;
571*4882a593Smuzhiyun }
572*4882a593Smuzhiyun tls_sw_strparser_arm(sk, ctx);
573*4882a593Smuzhiyun }
574*4882a593Smuzhiyun
575*4882a593Smuzhiyun if (tx)
576*4882a593Smuzhiyun ctx->tx_conf = conf;
577*4882a593Smuzhiyun else
578*4882a593Smuzhiyun ctx->rx_conf = conf;
579*4882a593Smuzhiyun update_sk_prot(sk, ctx);
580*4882a593Smuzhiyun if (tx) {
581*4882a593Smuzhiyun ctx->sk_write_space = sk->sk_write_space;
582*4882a593Smuzhiyun sk->sk_write_space = tls_write_space;
583*4882a593Smuzhiyun }
584*4882a593Smuzhiyun goto out;
585*4882a593Smuzhiyun
586*4882a593Smuzhiyun err_crypto_info:
587*4882a593Smuzhiyun memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
588*4882a593Smuzhiyun out:
589*4882a593Smuzhiyun return rc;
590*4882a593Smuzhiyun }
591*4882a593Smuzhiyun
do_tls_setsockopt(struct sock * sk,int optname,sockptr_t optval,unsigned int optlen)592*4882a593Smuzhiyun static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
593*4882a593Smuzhiyun unsigned int optlen)
594*4882a593Smuzhiyun {
595*4882a593Smuzhiyun int rc = 0;
596*4882a593Smuzhiyun
597*4882a593Smuzhiyun switch (optname) {
598*4882a593Smuzhiyun case TLS_TX:
599*4882a593Smuzhiyun case TLS_RX:
600*4882a593Smuzhiyun lock_sock(sk);
601*4882a593Smuzhiyun rc = do_tls_setsockopt_conf(sk, optval, optlen,
602*4882a593Smuzhiyun optname == TLS_TX);
603*4882a593Smuzhiyun release_sock(sk);
604*4882a593Smuzhiyun break;
605*4882a593Smuzhiyun default:
606*4882a593Smuzhiyun rc = -ENOPROTOOPT;
607*4882a593Smuzhiyun break;
608*4882a593Smuzhiyun }
609*4882a593Smuzhiyun return rc;
610*4882a593Smuzhiyun }
611*4882a593Smuzhiyun
tls_setsockopt(struct sock * sk,int level,int optname,sockptr_t optval,unsigned int optlen)612*4882a593Smuzhiyun static int tls_setsockopt(struct sock *sk, int level, int optname,
613*4882a593Smuzhiyun sockptr_t optval, unsigned int optlen)
614*4882a593Smuzhiyun {
615*4882a593Smuzhiyun struct tls_context *ctx = tls_get_ctx(sk);
616*4882a593Smuzhiyun
617*4882a593Smuzhiyun if (level != SOL_TLS)
618*4882a593Smuzhiyun return ctx->sk_proto->setsockopt(sk, level, optname, optval,
619*4882a593Smuzhiyun optlen);
620*4882a593Smuzhiyun
621*4882a593Smuzhiyun return do_tls_setsockopt(sk, optname, optval, optlen);
622*4882a593Smuzhiyun }
623*4882a593Smuzhiyun
tls_ctx_create(struct sock * sk)624*4882a593Smuzhiyun struct tls_context *tls_ctx_create(struct sock *sk)
625*4882a593Smuzhiyun {
626*4882a593Smuzhiyun struct inet_connection_sock *icsk = inet_csk(sk);
627*4882a593Smuzhiyun struct tls_context *ctx;
628*4882a593Smuzhiyun
629*4882a593Smuzhiyun ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
630*4882a593Smuzhiyun if (!ctx)
631*4882a593Smuzhiyun return NULL;
632*4882a593Smuzhiyun
633*4882a593Smuzhiyun mutex_init(&ctx->tx_lock);
634*4882a593Smuzhiyun rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
635*4882a593Smuzhiyun ctx->sk_proto = READ_ONCE(sk->sk_prot);
636*4882a593Smuzhiyun ctx->sk = sk;
637*4882a593Smuzhiyun return ctx;
638*4882a593Smuzhiyun }
639*4882a593Smuzhiyun
build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],const struct proto_ops * base)640*4882a593Smuzhiyun static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
641*4882a593Smuzhiyun const struct proto_ops *base)
642*4882a593Smuzhiyun {
643*4882a593Smuzhiyun ops[TLS_BASE][TLS_BASE] = *base;
644*4882a593Smuzhiyun
645*4882a593Smuzhiyun ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
646*4882a593Smuzhiyun ops[TLS_SW ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked;
647*4882a593Smuzhiyun
648*4882a593Smuzhiyun ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE];
649*4882a593Smuzhiyun ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read;
650*4882a593Smuzhiyun
651*4882a593Smuzhiyun ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE];
652*4882a593Smuzhiyun ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read;
653*4882a593Smuzhiyun
654*4882a593Smuzhiyun #ifdef CONFIG_TLS_DEVICE
655*4882a593Smuzhiyun ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
656*4882a593Smuzhiyun ops[TLS_HW ][TLS_BASE].sendpage_locked = NULL;
657*4882a593Smuzhiyun
658*4882a593Smuzhiyun ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ];
659*4882a593Smuzhiyun ops[TLS_HW ][TLS_SW ].sendpage_locked = NULL;
660*4882a593Smuzhiyun
661*4882a593Smuzhiyun ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ];
662*4882a593Smuzhiyun
663*4882a593Smuzhiyun ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ];
664*4882a593Smuzhiyun
665*4882a593Smuzhiyun ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ];
666*4882a593Smuzhiyun ops[TLS_HW ][TLS_HW ].sendpage_locked = NULL;
667*4882a593Smuzhiyun #endif
668*4882a593Smuzhiyun #ifdef CONFIG_TLS_TOE
669*4882a593Smuzhiyun ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
670*4882a593Smuzhiyun #endif
671*4882a593Smuzhiyun }
672*4882a593Smuzhiyun
tls_build_proto(struct sock * sk)673*4882a593Smuzhiyun static void tls_build_proto(struct sock *sk)
674*4882a593Smuzhiyun {
675*4882a593Smuzhiyun int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
676*4882a593Smuzhiyun struct proto *prot = READ_ONCE(sk->sk_prot);
677*4882a593Smuzhiyun
678*4882a593Smuzhiyun /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
679*4882a593Smuzhiyun if (ip_ver == TLSV6 &&
680*4882a593Smuzhiyun unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
681*4882a593Smuzhiyun mutex_lock(&tcpv6_prot_mutex);
682*4882a593Smuzhiyun if (likely(prot != saved_tcpv6_prot)) {
683*4882a593Smuzhiyun build_protos(tls_prots[TLSV6], prot);
684*4882a593Smuzhiyun build_proto_ops(tls_proto_ops[TLSV6],
685*4882a593Smuzhiyun sk->sk_socket->ops);
686*4882a593Smuzhiyun smp_store_release(&saved_tcpv6_prot, prot);
687*4882a593Smuzhiyun }
688*4882a593Smuzhiyun mutex_unlock(&tcpv6_prot_mutex);
689*4882a593Smuzhiyun }
690*4882a593Smuzhiyun
691*4882a593Smuzhiyun if (ip_ver == TLSV4 &&
692*4882a593Smuzhiyun unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
693*4882a593Smuzhiyun mutex_lock(&tcpv4_prot_mutex);
694*4882a593Smuzhiyun if (likely(prot != saved_tcpv4_prot)) {
695*4882a593Smuzhiyun build_protos(tls_prots[TLSV4], prot);
696*4882a593Smuzhiyun build_proto_ops(tls_proto_ops[TLSV4],
697*4882a593Smuzhiyun sk->sk_socket->ops);
698*4882a593Smuzhiyun smp_store_release(&saved_tcpv4_prot, prot);
699*4882a593Smuzhiyun }
700*4882a593Smuzhiyun mutex_unlock(&tcpv4_prot_mutex);
701*4882a593Smuzhiyun }
702*4882a593Smuzhiyun }
703*4882a593Smuzhiyun
build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],const struct proto * base)704*4882a593Smuzhiyun static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
705*4882a593Smuzhiyun const struct proto *base)
706*4882a593Smuzhiyun {
707*4882a593Smuzhiyun prot[TLS_BASE][TLS_BASE] = *base;
708*4882a593Smuzhiyun prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
709*4882a593Smuzhiyun prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
710*4882a593Smuzhiyun prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close;
711*4882a593Smuzhiyun
712*4882a593Smuzhiyun prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
713*4882a593Smuzhiyun prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg;
714*4882a593Smuzhiyun prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage;
715*4882a593Smuzhiyun
716*4882a593Smuzhiyun prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
717*4882a593Smuzhiyun prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg;
718*4882a593Smuzhiyun prot[TLS_BASE][TLS_SW].stream_memory_read = tls_sw_stream_read;
719*4882a593Smuzhiyun prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close;
720*4882a593Smuzhiyun
721*4882a593Smuzhiyun prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
722*4882a593Smuzhiyun prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg;
723*4882a593Smuzhiyun prot[TLS_SW][TLS_SW].stream_memory_read = tls_sw_stream_read;
724*4882a593Smuzhiyun prot[TLS_SW][TLS_SW].close = tls_sk_proto_close;
725*4882a593Smuzhiyun
726*4882a593Smuzhiyun #ifdef CONFIG_TLS_DEVICE
727*4882a593Smuzhiyun prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
728*4882a593Smuzhiyun prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg;
729*4882a593Smuzhiyun prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage;
730*4882a593Smuzhiyun
731*4882a593Smuzhiyun prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
732*4882a593Smuzhiyun prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg;
733*4882a593Smuzhiyun prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage;
734*4882a593Smuzhiyun
735*4882a593Smuzhiyun prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
736*4882a593Smuzhiyun
737*4882a593Smuzhiyun prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
738*4882a593Smuzhiyun
739*4882a593Smuzhiyun prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
740*4882a593Smuzhiyun #endif
741*4882a593Smuzhiyun #ifdef CONFIG_TLS_TOE
742*4882a593Smuzhiyun prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
743*4882a593Smuzhiyun prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_toe_hash;
744*4882a593Smuzhiyun prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_toe_unhash;
745*4882a593Smuzhiyun #endif
746*4882a593Smuzhiyun }
747*4882a593Smuzhiyun
tls_init(struct sock * sk)748*4882a593Smuzhiyun static int tls_init(struct sock *sk)
749*4882a593Smuzhiyun {
750*4882a593Smuzhiyun struct tls_context *ctx;
751*4882a593Smuzhiyun int rc = 0;
752*4882a593Smuzhiyun
753*4882a593Smuzhiyun tls_build_proto(sk);
754*4882a593Smuzhiyun
755*4882a593Smuzhiyun #ifdef CONFIG_TLS_TOE
756*4882a593Smuzhiyun if (tls_toe_bypass(sk))
757*4882a593Smuzhiyun return 0;
758*4882a593Smuzhiyun #endif
759*4882a593Smuzhiyun
760*4882a593Smuzhiyun /* The TLS ulp is currently supported only for TCP sockets
761*4882a593Smuzhiyun * in ESTABLISHED state.
762*4882a593Smuzhiyun * Supporting sockets in LISTEN state will require us
763*4882a593Smuzhiyun * to modify the accept implementation to clone rather then
764*4882a593Smuzhiyun * share the ulp context.
765*4882a593Smuzhiyun */
766*4882a593Smuzhiyun if (sk->sk_state != TCP_ESTABLISHED)
767*4882a593Smuzhiyun return -ENOTCONN;
768*4882a593Smuzhiyun
769*4882a593Smuzhiyun /* allocate tls context */
770*4882a593Smuzhiyun write_lock_bh(&sk->sk_callback_lock);
771*4882a593Smuzhiyun ctx = tls_ctx_create(sk);
772*4882a593Smuzhiyun if (!ctx) {
773*4882a593Smuzhiyun rc = -ENOMEM;
774*4882a593Smuzhiyun goto out;
775*4882a593Smuzhiyun }
776*4882a593Smuzhiyun
777*4882a593Smuzhiyun ctx->tx_conf = TLS_BASE;
778*4882a593Smuzhiyun ctx->rx_conf = TLS_BASE;
779*4882a593Smuzhiyun update_sk_prot(sk, ctx);
780*4882a593Smuzhiyun out:
781*4882a593Smuzhiyun write_unlock_bh(&sk->sk_callback_lock);
782*4882a593Smuzhiyun return rc;
783*4882a593Smuzhiyun }
784*4882a593Smuzhiyun
tls_update(struct sock * sk,struct proto * p,void (* write_space)(struct sock * sk))785*4882a593Smuzhiyun static void tls_update(struct sock *sk, struct proto *p,
786*4882a593Smuzhiyun void (*write_space)(struct sock *sk))
787*4882a593Smuzhiyun {
788*4882a593Smuzhiyun struct tls_context *ctx;
789*4882a593Smuzhiyun
790*4882a593Smuzhiyun ctx = tls_get_ctx(sk);
791*4882a593Smuzhiyun if (likely(ctx)) {
792*4882a593Smuzhiyun ctx->sk_write_space = write_space;
793*4882a593Smuzhiyun ctx->sk_proto = p;
794*4882a593Smuzhiyun } else {
795*4882a593Smuzhiyun /* Pairs with lockless read in sk_clone_lock(). */
796*4882a593Smuzhiyun WRITE_ONCE(sk->sk_prot, p);
797*4882a593Smuzhiyun sk->sk_write_space = write_space;
798*4882a593Smuzhiyun }
799*4882a593Smuzhiyun }
800*4882a593Smuzhiyun
tls_get_info(const struct sock * sk,struct sk_buff * skb)801*4882a593Smuzhiyun static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
802*4882a593Smuzhiyun {
803*4882a593Smuzhiyun u16 version, cipher_type;
804*4882a593Smuzhiyun struct tls_context *ctx;
805*4882a593Smuzhiyun struct nlattr *start;
806*4882a593Smuzhiyun int err;
807*4882a593Smuzhiyun
808*4882a593Smuzhiyun start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
809*4882a593Smuzhiyun if (!start)
810*4882a593Smuzhiyun return -EMSGSIZE;
811*4882a593Smuzhiyun
812*4882a593Smuzhiyun rcu_read_lock();
813*4882a593Smuzhiyun ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
814*4882a593Smuzhiyun if (!ctx) {
815*4882a593Smuzhiyun err = 0;
816*4882a593Smuzhiyun goto nla_failure;
817*4882a593Smuzhiyun }
818*4882a593Smuzhiyun version = ctx->prot_info.version;
819*4882a593Smuzhiyun if (version) {
820*4882a593Smuzhiyun err = nla_put_u16(skb, TLS_INFO_VERSION, version);
821*4882a593Smuzhiyun if (err)
822*4882a593Smuzhiyun goto nla_failure;
823*4882a593Smuzhiyun }
824*4882a593Smuzhiyun cipher_type = ctx->prot_info.cipher_type;
825*4882a593Smuzhiyun if (cipher_type) {
826*4882a593Smuzhiyun err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
827*4882a593Smuzhiyun if (err)
828*4882a593Smuzhiyun goto nla_failure;
829*4882a593Smuzhiyun }
830*4882a593Smuzhiyun err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
831*4882a593Smuzhiyun if (err)
832*4882a593Smuzhiyun goto nla_failure;
833*4882a593Smuzhiyun
834*4882a593Smuzhiyun err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
835*4882a593Smuzhiyun if (err)
836*4882a593Smuzhiyun goto nla_failure;
837*4882a593Smuzhiyun
838*4882a593Smuzhiyun rcu_read_unlock();
839*4882a593Smuzhiyun nla_nest_end(skb, start);
840*4882a593Smuzhiyun return 0;
841*4882a593Smuzhiyun
842*4882a593Smuzhiyun nla_failure:
843*4882a593Smuzhiyun rcu_read_unlock();
844*4882a593Smuzhiyun nla_nest_cancel(skb, start);
845*4882a593Smuzhiyun return err;
846*4882a593Smuzhiyun }
847*4882a593Smuzhiyun
tls_get_info_size(const struct sock * sk)848*4882a593Smuzhiyun static size_t tls_get_info_size(const struct sock *sk)
849*4882a593Smuzhiyun {
850*4882a593Smuzhiyun size_t size = 0;
851*4882a593Smuzhiyun
852*4882a593Smuzhiyun size += nla_total_size(0) + /* INET_ULP_INFO_TLS */
853*4882a593Smuzhiyun nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */
854*4882a593Smuzhiyun nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */
855*4882a593Smuzhiyun nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */
856*4882a593Smuzhiyun nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */
857*4882a593Smuzhiyun 0;
858*4882a593Smuzhiyun
859*4882a593Smuzhiyun return size;
860*4882a593Smuzhiyun }
861*4882a593Smuzhiyun
tls_init_net(struct net * net)862*4882a593Smuzhiyun static int __net_init tls_init_net(struct net *net)
863*4882a593Smuzhiyun {
864*4882a593Smuzhiyun int err;
865*4882a593Smuzhiyun
866*4882a593Smuzhiyun net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
867*4882a593Smuzhiyun if (!net->mib.tls_statistics)
868*4882a593Smuzhiyun return -ENOMEM;
869*4882a593Smuzhiyun
870*4882a593Smuzhiyun err = tls_proc_init(net);
871*4882a593Smuzhiyun if (err)
872*4882a593Smuzhiyun goto err_free_stats;
873*4882a593Smuzhiyun
874*4882a593Smuzhiyun return 0;
875*4882a593Smuzhiyun err_free_stats:
876*4882a593Smuzhiyun free_percpu(net->mib.tls_statistics);
877*4882a593Smuzhiyun return err;
878*4882a593Smuzhiyun }
879*4882a593Smuzhiyun
tls_exit_net(struct net * net)880*4882a593Smuzhiyun static void __net_exit tls_exit_net(struct net *net)
881*4882a593Smuzhiyun {
882*4882a593Smuzhiyun tls_proc_fini(net);
883*4882a593Smuzhiyun free_percpu(net->mib.tls_statistics);
884*4882a593Smuzhiyun }
885*4882a593Smuzhiyun
886*4882a593Smuzhiyun static struct pernet_operations tls_proc_ops = {
887*4882a593Smuzhiyun .init = tls_init_net,
888*4882a593Smuzhiyun .exit = tls_exit_net,
889*4882a593Smuzhiyun };
890*4882a593Smuzhiyun
891*4882a593Smuzhiyun static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
892*4882a593Smuzhiyun .name = "tls",
893*4882a593Smuzhiyun .owner = THIS_MODULE,
894*4882a593Smuzhiyun .init = tls_init,
895*4882a593Smuzhiyun .update = tls_update,
896*4882a593Smuzhiyun .get_info = tls_get_info,
897*4882a593Smuzhiyun .get_info_size = tls_get_info_size,
898*4882a593Smuzhiyun };
899*4882a593Smuzhiyun
tls_register(void)900*4882a593Smuzhiyun static int __init tls_register(void)
901*4882a593Smuzhiyun {
902*4882a593Smuzhiyun int err;
903*4882a593Smuzhiyun
904*4882a593Smuzhiyun err = register_pernet_subsys(&tls_proc_ops);
905*4882a593Smuzhiyun if (err)
906*4882a593Smuzhiyun return err;
907*4882a593Smuzhiyun
908*4882a593Smuzhiyun err = tls_device_init();
909*4882a593Smuzhiyun if (err) {
910*4882a593Smuzhiyun unregister_pernet_subsys(&tls_proc_ops);
911*4882a593Smuzhiyun return err;
912*4882a593Smuzhiyun }
913*4882a593Smuzhiyun
914*4882a593Smuzhiyun tcp_register_ulp(&tcp_tls_ulp_ops);
915*4882a593Smuzhiyun
916*4882a593Smuzhiyun return 0;
917*4882a593Smuzhiyun }
918*4882a593Smuzhiyun
tls_unregister(void)919*4882a593Smuzhiyun static void __exit tls_unregister(void)
920*4882a593Smuzhiyun {
921*4882a593Smuzhiyun tcp_unregister_ulp(&tcp_tls_ulp_ops);
922*4882a593Smuzhiyun tls_device_cleanup();
923*4882a593Smuzhiyun unregister_pernet_subsys(&tls_proc_ops);
924*4882a593Smuzhiyun }
925*4882a593Smuzhiyun
926*4882a593Smuzhiyun module_init(tls_register);
927*4882a593Smuzhiyun module_exit(tls_unregister);
928