xref: /OK3568_Linux_fs/kernel/drivers/vhost/vsock.c (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0-only
2*4882a593Smuzhiyun /*
3*4882a593Smuzhiyun  * vhost transport for vsock
4*4882a593Smuzhiyun  *
5*4882a593Smuzhiyun  * Copyright (C) 2013-2015 Red Hat, Inc.
6*4882a593Smuzhiyun  * Author: Asias He <asias@redhat.com>
7*4882a593Smuzhiyun  *         Stefan Hajnoczi <stefanha@redhat.com>
8*4882a593Smuzhiyun  */
9*4882a593Smuzhiyun #include <linux/miscdevice.h>
10*4882a593Smuzhiyun #include <linux/atomic.h>
11*4882a593Smuzhiyun #include <linux/module.h>
12*4882a593Smuzhiyun #include <linux/mutex.h>
13*4882a593Smuzhiyun #include <linux/vmalloc.h>
14*4882a593Smuzhiyun #include <net/sock.h>
15*4882a593Smuzhiyun #include <linux/virtio_vsock.h>
16*4882a593Smuzhiyun #include <linux/vhost.h>
17*4882a593Smuzhiyun #include <linux/hashtable.h>
18*4882a593Smuzhiyun 
19*4882a593Smuzhiyun #include <net/af_vsock.h>
20*4882a593Smuzhiyun #include "vhost.h"
21*4882a593Smuzhiyun 
22*4882a593Smuzhiyun #define VHOST_VSOCK_DEFAULT_HOST_CID	2
23*4882a593Smuzhiyun /* Max number of bytes transferred before requeueing the job.
24*4882a593Smuzhiyun  * Using this limit prevents one virtqueue from starving others. */
25*4882a593Smuzhiyun #define VHOST_VSOCK_WEIGHT 0x80000
26*4882a593Smuzhiyun /* Max number of packets transferred before requeueing the job.
27*4882a593Smuzhiyun  * Using this limit prevents one virtqueue from starving others with
28*4882a593Smuzhiyun  * small pkts.
29*4882a593Smuzhiyun  */
30*4882a593Smuzhiyun #define VHOST_VSOCK_PKT_WEIGHT 256
31*4882a593Smuzhiyun 
32*4882a593Smuzhiyun enum {
33*4882a593Smuzhiyun 	VHOST_VSOCK_FEATURES = VHOST_FEATURES,
34*4882a593Smuzhiyun };
35*4882a593Smuzhiyun 
36*4882a593Smuzhiyun /* Used to track all the vhost_vsock instances on the system. */
37*4882a593Smuzhiyun static DEFINE_MUTEX(vhost_vsock_mutex);
38*4882a593Smuzhiyun static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
39*4882a593Smuzhiyun 
40*4882a593Smuzhiyun struct vhost_vsock {
41*4882a593Smuzhiyun 	struct vhost_dev dev;
42*4882a593Smuzhiyun 	struct vhost_virtqueue vqs[2];
43*4882a593Smuzhiyun 
44*4882a593Smuzhiyun 	/* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
45*4882a593Smuzhiyun 	struct hlist_node hash;
46*4882a593Smuzhiyun 
47*4882a593Smuzhiyun 	struct vhost_work send_pkt_work;
48*4882a593Smuzhiyun 	spinlock_t send_pkt_list_lock;
49*4882a593Smuzhiyun 	struct list_head send_pkt_list;	/* host->guest pending packets */
50*4882a593Smuzhiyun 
51*4882a593Smuzhiyun 	atomic_t queued_replies;
52*4882a593Smuzhiyun 
53*4882a593Smuzhiyun 	u32 guest_cid;
54*4882a593Smuzhiyun };
55*4882a593Smuzhiyun 
vhost_transport_get_local_cid(void)56*4882a593Smuzhiyun static u32 vhost_transport_get_local_cid(void)
57*4882a593Smuzhiyun {
58*4882a593Smuzhiyun 	return VHOST_VSOCK_DEFAULT_HOST_CID;
59*4882a593Smuzhiyun }
60*4882a593Smuzhiyun 
61*4882a593Smuzhiyun /* Callers that dereference the return value must hold vhost_vsock_mutex or the
62*4882a593Smuzhiyun  * RCU read lock.
63*4882a593Smuzhiyun  */
vhost_vsock_get(u32 guest_cid)64*4882a593Smuzhiyun static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
65*4882a593Smuzhiyun {
66*4882a593Smuzhiyun 	struct vhost_vsock *vsock;
67*4882a593Smuzhiyun 
68*4882a593Smuzhiyun 	hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) {
69*4882a593Smuzhiyun 		u32 other_cid = vsock->guest_cid;
70*4882a593Smuzhiyun 
71*4882a593Smuzhiyun 		/* Skip instances that have no CID yet */
72*4882a593Smuzhiyun 		if (other_cid == 0)
73*4882a593Smuzhiyun 			continue;
74*4882a593Smuzhiyun 
75*4882a593Smuzhiyun 		if (other_cid == guest_cid)
76*4882a593Smuzhiyun 			return vsock;
77*4882a593Smuzhiyun 
78*4882a593Smuzhiyun 	}
79*4882a593Smuzhiyun 
80*4882a593Smuzhiyun 	return NULL;
81*4882a593Smuzhiyun }
82*4882a593Smuzhiyun 
83*4882a593Smuzhiyun static void
vhost_transport_do_send_pkt(struct vhost_vsock * vsock,struct vhost_virtqueue * vq)84*4882a593Smuzhiyun vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
85*4882a593Smuzhiyun 			    struct vhost_virtqueue *vq)
86*4882a593Smuzhiyun {
87*4882a593Smuzhiyun 	struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
88*4882a593Smuzhiyun 	int pkts = 0, total_len = 0;
89*4882a593Smuzhiyun 	bool added = false;
90*4882a593Smuzhiyun 	bool restart_tx = false;
91*4882a593Smuzhiyun 
92*4882a593Smuzhiyun 	mutex_lock(&vq->mutex);
93*4882a593Smuzhiyun 
94*4882a593Smuzhiyun 	if (!vhost_vq_get_backend(vq))
95*4882a593Smuzhiyun 		goto out;
96*4882a593Smuzhiyun 
97*4882a593Smuzhiyun 	/* Avoid further vmexits, we're already processing the virtqueue */
98*4882a593Smuzhiyun 	vhost_disable_notify(&vsock->dev, vq);
99*4882a593Smuzhiyun 
100*4882a593Smuzhiyun 	do {
101*4882a593Smuzhiyun 		struct virtio_vsock_pkt *pkt;
102*4882a593Smuzhiyun 		struct iov_iter iov_iter;
103*4882a593Smuzhiyun 		unsigned out, in;
104*4882a593Smuzhiyun 		size_t nbytes;
105*4882a593Smuzhiyun 		size_t iov_len, payload_len;
106*4882a593Smuzhiyun 		int head;
107*4882a593Smuzhiyun 
108*4882a593Smuzhiyun 		spin_lock_bh(&vsock->send_pkt_list_lock);
109*4882a593Smuzhiyun 		if (list_empty(&vsock->send_pkt_list)) {
110*4882a593Smuzhiyun 			spin_unlock_bh(&vsock->send_pkt_list_lock);
111*4882a593Smuzhiyun 			vhost_enable_notify(&vsock->dev, vq);
112*4882a593Smuzhiyun 			break;
113*4882a593Smuzhiyun 		}
114*4882a593Smuzhiyun 
115*4882a593Smuzhiyun 		pkt = list_first_entry(&vsock->send_pkt_list,
116*4882a593Smuzhiyun 				       struct virtio_vsock_pkt, list);
117*4882a593Smuzhiyun 		list_del_init(&pkt->list);
118*4882a593Smuzhiyun 		spin_unlock_bh(&vsock->send_pkt_list_lock);
119*4882a593Smuzhiyun 
120*4882a593Smuzhiyun 		head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
121*4882a593Smuzhiyun 					 &out, &in, NULL, NULL);
122*4882a593Smuzhiyun 		if (head < 0) {
123*4882a593Smuzhiyun 			spin_lock_bh(&vsock->send_pkt_list_lock);
124*4882a593Smuzhiyun 			list_add(&pkt->list, &vsock->send_pkt_list);
125*4882a593Smuzhiyun 			spin_unlock_bh(&vsock->send_pkt_list_lock);
126*4882a593Smuzhiyun 			break;
127*4882a593Smuzhiyun 		}
128*4882a593Smuzhiyun 
129*4882a593Smuzhiyun 		if (head == vq->num) {
130*4882a593Smuzhiyun 			spin_lock_bh(&vsock->send_pkt_list_lock);
131*4882a593Smuzhiyun 			list_add(&pkt->list, &vsock->send_pkt_list);
132*4882a593Smuzhiyun 			spin_unlock_bh(&vsock->send_pkt_list_lock);
133*4882a593Smuzhiyun 
134*4882a593Smuzhiyun 			/* We cannot finish yet if more buffers snuck in while
135*4882a593Smuzhiyun 			 * re-enabling notify.
136*4882a593Smuzhiyun 			 */
137*4882a593Smuzhiyun 			if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
138*4882a593Smuzhiyun 				vhost_disable_notify(&vsock->dev, vq);
139*4882a593Smuzhiyun 				continue;
140*4882a593Smuzhiyun 			}
141*4882a593Smuzhiyun 			break;
142*4882a593Smuzhiyun 		}
143*4882a593Smuzhiyun 
144*4882a593Smuzhiyun 		if (out) {
145*4882a593Smuzhiyun 			virtio_transport_free_pkt(pkt);
146*4882a593Smuzhiyun 			vq_err(vq, "Expected 0 output buffers, got %u\n", out);
147*4882a593Smuzhiyun 			break;
148*4882a593Smuzhiyun 		}
149*4882a593Smuzhiyun 
150*4882a593Smuzhiyun 		iov_len = iov_length(&vq->iov[out], in);
151*4882a593Smuzhiyun 		if (iov_len < sizeof(pkt->hdr)) {
152*4882a593Smuzhiyun 			virtio_transport_free_pkt(pkt);
153*4882a593Smuzhiyun 			vq_err(vq, "Buffer len [%zu] too small\n", iov_len);
154*4882a593Smuzhiyun 			break;
155*4882a593Smuzhiyun 		}
156*4882a593Smuzhiyun 
157*4882a593Smuzhiyun 		iov_iter_init(&iov_iter, READ, &vq->iov[out], in, iov_len);
158*4882a593Smuzhiyun 		payload_len = pkt->len - pkt->off;
159*4882a593Smuzhiyun 
160*4882a593Smuzhiyun 		/* If the packet is greater than the space available in the
161*4882a593Smuzhiyun 		 * buffer, we split it using multiple buffers.
162*4882a593Smuzhiyun 		 */
163*4882a593Smuzhiyun 		if (payload_len > iov_len - sizeof(pkt->hdr))
164*4882a593Smuzhiyun 			payload_len = iov_len - sizeof(pkt->hdr);
165*4882a593Smuzhiyun 
166*4882a593Smuzhiyun 		/* Set the correct length in the header */
167*4882a593Smuzhiyun 		pkt->hdr.len = cpu_to_le32(payload_len);
168*4882a593Smuzhiyun 
169*4882a593Smuzhiyun 		nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
170*4882a593Smuzhiyun 		if (nbytes != sizeof(pkt->hdr)) {
171*4882a593Smuzhiyun 			virtio_transport_free_pkt(pkt);
172*4882a593Smuzhiyun 			vq_err(vq, "Faulted on copying pkt hdr\n");
173*4882a593Smuzhiyun 			break;
174*4882a593Smuzhiyun 		}
175*4882a593Smuzhiyun 
176*4882a593Smuzhiyun 		nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len,
177*4882a593Smuzhiyun 				      &iov_iter);
178*4882a593Smuzhiyun 		if (nbytes != payload_len) {
179*4882a593Smuzhiyun 			virtio_transport_free_pkt(pkt);
180*4882a593Smuzhiyun 			vq_err(vq, "Faulted on copying pkt buf\n");
181*4882a593Smuzhiyun 			break;
182*4882a593Smuzhiyun 		}
183*4882a593Smuzhiyun 
184*4882a593Smuzhiyun 		/* Deliver to monitoring devices all packets that we
185*4882a593Smuzhiyun 		 * will transmit.
186*4882a593Smuzhiyun 		 */
187*4882a593Smuzhiyun 		virtio_transport_deliver_tap_pkt(pkt);
188*4882a593Smuzhiyun 
189*4882a593Smuzhiyun 		vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len);
190*4882a593Smuzhiyun 		added = true;
191*4882a593Smuzhiyun 
192*4882a593Smuzhiyun 		pkt->off += payload_len;
193*4882a593Smuzhiyun 		total_len += payload_len;
194*4882a593Smuzhiyun 
195*4882a593Smuzhiyun 		/* If we didn't send all the payload we can requeue the packet
196*4882a593Smuzhiyun 		 * to send it with the next available buffer.
197*4882a593Smuzhiyun 		 */
198*4882a593Smuzhiyun 		if (pkt->off < pkt->len) {
199*4882a593Smuzhiyun 			/* We are queueing the same virtio_vsock_pkt to handle
200*4882a593Smuzhiyun 			 * the remaining bytes, and we want to deliver it
201*4882a593Smuzhiyun 			 * to monitoring devices in the next iteration.
202*4882a593Smuzhiyun 			 */
203*4882a593Smuzhiyun 			pkt->tap_delivered = false;
204*4882a593Smuzhiyun 
205*4882a593Smuzhiyun 			spin_lock_bh(&vsock->send_pkt_list_lock);
206*4882a593Smuzhiyun 			list_add(&pkt->list, &vsock->send_pkt_list);
207*4882a593Smuzhiyun 			spin_unlock_bh(&vsock->send_pkt_list_lock);
208*4882a593Smuzhiyun 		} else {
209*4882a593Smuzhiyun 			if (pkt->reply) {
210*4882a593Smuzhiyun 				int val;
211*4882a593Smuzhiyun 
212*4882a593Smuzhiyun 				val = atomic_dec_return(&vsock->queued_replies);
213*4882a593Smuzhiyun 
214*4882a593Smuzhiyun 				/* Do we have resources to resume tx
215*4882a593Smuzhiyun 				 * processing?
216*4882a593Smuzhiyun 				 */
217*4882a593Smuzhiyun 				if (val + 1 == tx_vq->num)
218*4882a593Smuzhiyun 					restart_tx = true;
219*4882a593Smuzhiyun 			}
220*4882a593Smuzhiyun 
221*4882a593Smuzhiyun 			virtio_transport_free_pkt(pkt);
222*4882a593Smuzhiyun 		}
223*4882a593Smuzhiyun 	} while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
224*4882a593Smuzhiyun 	if (added)
225*4882a593Smuzhiyun 		vhost_signal(&vsock->dev, vq);
226*4882a593Smuzhiyun 
227*4882a593Smuzhiyun out:
228*4882a593Smuzhiyun 	mutex_unlock(&vq->mutex);
229*4882a593Smuzhiyun 
230*4882a593Smuzhiyun 	if (restart_tx)
231*4882a593Smuzhiyun 		vhost_poll_queue(&tx_vq->poll);
232*4882a593Smuzhiyun }
233*4882a593Smuzhiyun 
vhost_transport_send_pkt_work(struct vhost_work * work)234*4882a593Smuzhiyun static void vhost_transport_send_pkt_work(struct vhost_work *work)
235*4882a593Smuzhiyun {
236*4882a593Smuzhiyun 	struct vhost_virtqueue *vq;
237*4882a593Smuzhiyun 	struct vhost_vsock *vsock;
238*4882a593Smuzhiyun 
239*4882a593Smuzhiyun 	vsock = container_of(work, struct vhost_vsock, send_pkt_work);
240*4882a593Smuzhiyun 	vq = &vsock->vqs[VSOCK_VQ_RX];
241*4882a593Smuzhiyun 
242*4882a593Smuzhiyun 	vhost_transport_do_send_pkt(vsock, vq);
243*4882a593Smuzhiyun }
244*4882a593Smuzhiyun 
245*4882a593Smuzhiyun static int
vhost_transport_send_pkt(struct virtio_vsock_pkt * pkt)246*4882a593Smuzhiyun vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
247*4882a593Smuzhiyun {
248*4882a593Smuzhiyun 	struct vhost_vsock *vsock;
249*4882a593Smuzhiyun 	int len = pkt->len;
250*4882a593Smuzhiyun 
251*4882a593Smuzhiyun 	rcu_read_lock();
252*4882a593Smuzhiyun 
253*4882a593Smuzhiyun 	/* Find the vhost_vsock according to guest context id  */
254*4882a593Smuzhiyun 	vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
255*4882a593Smuzhiyun 	if (!vsock) {
256*4882a593Smuzhiyun 		rcu_read_unlock();
257*4882a593Smuzhiyun 		virtio_transport_free_pkt(pkt);
258*4882a593Smuzhiyun 		return -ENODEV;
259*4882a593Smuzhiyun 	}
260*4882a593Smuzhiyun 
261*4882a593Smuzhiyun 	if (pkt->reply)
262*4882a593Smuzhiyun 		atomic_inc(&vsock->queued_replies);
263*4882a593Smuzhiyun 
264*4882a593Smuzhiyun 	spin_lock_bh(&vsock->send_pkt_list_lock);
265*4882a593Smuzhiyun 	list_add_tail(&pkt->list, &vsock->send_pkt_list);
266*4882a593Smuzhiyun 	spin_unlock_bh(&vsock->send_pkt_list_lock);
267*4882a593Smuzhiyun 
268*4882a593Smuzhiyun 	vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
269*4882a593Smuzhiyun 
270*4882a593Smuzhiyun 	rcu_read_unlock();
271*4882a593Smuzhiyun 	return len;
272*4882a593Smuzhiyun }
273*4882a593Smuzhiyun 
274*4882a593Smuzhiyun static int
vhost_transport_cancel_pkt(struct vsock_sock * vsk)275*4882a593Smuzhiyun vhost_transport_cancel_pkt(struct vsock_sock *vsk)
276*4882a593Smuzhiyun {
277*4882a593Smuzhiyun 	struct vhost_vsock *vsock;
278*4882a593Smuzhiyun 	struct virtio_vsock_pkt *pkt, *n;
279*4882a593Smuzhiyun 	int cnt = 0;
280*4882a593Smuzhiyun 	int ret = -ENODEV;
281*4882a593Smuzhiyun 	LIST_HEAD(freeme);
282*4882a593Smuzhiyun 
283*4882a593Smuzhiyun 	rcu_read_lock();
284*4882a593Smuzhiyun 
285*4882a593Smuzhiyun 	/* Find the vhost_vsock according to guest context id  */
286*4882a593Smuzhiyun 	vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
287*4882a593Smuzhiyun 	if (!vsock)
288*4882a593Smuzhiyun 		goto out;
289*4882a593Smuzhiyun 
290*4882a593Smuzhiyun 	spin_lock_bh(&vsock->send_pkt_list_lock);
291*4882a593Smuzhiyun 	list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
292*4882a593Smuzhiyun 		if (pkt->vsk != vsk)
293*4882a593Smuzhiyun 			continue;
294*4882a593Smuzhiyun 		list_move(&pkt->list, &freeme);
295*4882a593Smuzhiyun 	}
296*4882a593Smuzhiyun 	spin_unlock_bh(&vsock->send_pkt_list_lock);
297*4882a593Smuzhiyun 
298*4882a593Smuzhiyun 	list_for_each_entry_safe(pkt, n, &freeme, list) {
299*4882a593Smuzhiyun 		if (pkt->reply)
300*4882a593Smuzhiyun 			cnt++;
301*4882a593Smuzhiyun 		list_del(&pkt->list);
302*4882a593Smuzhiyun 		virtio_transport_free_pkt(pkt);
303*4882a593Smuzhiyun 	}
304*4882a593Smuzhiyun 
305*4882a593Smuzhiyun 	if (cnt) {
306*4882a593Smuzhiyun 		struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
307*4882a593Smuzhiyun 		int new_cnt;
308*4882a593Smuzhiyun 
309*4882a593Smuzhiyun 		new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
310*4882a593Smuzhiyun 		if (new_cnt + cnt >= tx_vq->num && new_cnt < tx_vq->num)
311*4882a593Smuzhiyun 			vhost_poll_queue(&tx_vq->poll);
312*4882a593Smuzhiyun 	}
313*4882a593Smuzhiyun 
314*4882a593Smuzhiyun 	ret = 0;
315*4882a593Smuzhiyun out:
316*4882a593Smuzhiyun 	rcu_read_unlock();
317*4882a593Smuzhiyun 	return ret;
318*4882a593Smuzhiyun }
319*4882a593Smuzhiyun 
320*4882a593Smuzhiyun static struct virtio_vsock_pkt *
vhost_vsock_alloc_pkt(struct vhost_virtqueue * vq,unsigned int out,unsigned int in)321*4882a593Smuzhiyun vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
322*4882a593Smuzhiyun 		      unsigned int out, unsigned int in)
323*4882a593Smuzhiyun {
324*4882a593Smuzhiyun 	struct virtio_vsock_pkt *pkt;
325*4882a593Smuzhiyun 	struct iov_iter iov_iter;
326*4882a593Smuzhiyun 	size_t nbytes;
327*4882a593Smuzhiyun 	size_t len;
328*4882a593Smuzhiyun 
329*4882a593Smuzhiyun 	if (in != 0) {
330*4882a593Smuzhiyun 		vq_err(vq, "Expected 0 input buffers, got %u\n", in);
331*4882a593Smuzhiyun 		return NULL;
332*4882a593Smuzhiyun 	}
333*4882a593Smuzhiyun 
334*4882a593Smuzhiyun 	pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
335*4882a593Smuzhiyun 	if (!pkt)
336*4882a593Smuzhiyun 		return NULL;
337*4882a593Smuzhiyun 
338*4882a593Smuzhiyun 	len = iov_length(vq->iov, out);
339*4882a593Smuzhiyun 	iov_iter_init(&iov_iter, WRITE, vq->iov, out, len);
340*4882a593Smuzhiyun 
341*4882a593Smuzhiyun 	nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
342*4882a593Smuzhiyun 	if (nbytes != sizeof(pkt->hdr)) {
343*4882a593Smuzhiyun 		vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
344*4882a593Smuzhiyun 		       sizeof(pkt->hdr), nbytes);
345*4882a593Smuzhiyun 		kfree(pkt);
346*4882a593Smuzhiyun 		return NULL;
347*4882a593Smuzhiyun 	}
348*4882a593Smuzhiyun 
349*4882a593Smuzhiyun 	if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM)
350*4882a593Smuzhiyun 		pkt->len = le32_to_cpu(pkt->hdr.len);
351*4882a593Smuzhiyun 
352*4882a593Smuzhiyun 	/* No payload */
353*4882a593Smuzhiyun 	if (!pkt->len)
354*4882a593Smuzhiyun 		return pkt;
355*4882a593Smuzhiyun 
356*4882a593Smuzhiyun 	/* The pkt is too big */
357*4882a593Smuzhiyun 	if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
358*4882a593Smuzhiyun 		kfree(pkt);
359*4882a593Smuzhiyun 		return NULL;
360*4882a593Smuzhiyun 	}
361*4882a593Smuzhiyun 
362*4882a593Smuzhiyun 	pkt->buf = kvmalloc(pkt->len, GFP_KERNEL);
363*4882a593Smuzhiyun 	if (!pkt->buf) {
364*4882a593Smuzhiyun 		kfree(pkt);
365*4882a593Smuzhiyun 		return NULL;
366*4882a593Smuzhiyun 	}
367*4882a593Smuzhiyun 
368*4882a593Smuzhiyun 	pkt->buf_len = pkt->len;
369*4882a593Smuzhiyun 
370*4882a593Smuzhiyun 	nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
371*4882a593Smuzhiyun 	if (nbytes != pkt->len) {
372*4882a593Smuzhiyun 		vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
373*4882a593Smuzhiyun 		       pkt->len, nbytes);
374*4882a593Smuzhiyun 		virtio_transport_free_pkt(pkt);
375*4882a593Smuzhiyun 		return NULL;
376*4882a593Smuzhiyun 	}
377*4882a593Smuzhiyun 
378*4882a593Smuzhiyun 	return pkt;
379*4882a593Smuzhiyun }
380*4882a593Smuzhiyun 
381*4882a593Smuzhiyun /* Is there space left for replies to rx packets? */
vhost_vsock_more_replies(struct vhost_vsock * vsock)382*4882a593Smuzhiyun static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
383*4882a593Smuzhiyun {
384*4882a593Smuzhiyun 	struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX];
385*4882a593Smuzhiyun 	int val;
386*4882a593Smuzhiyun 
387*4882a593Smuzhiyun 	smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
388*4882a593Smuzhiyun 	val = atomic_read(&vsock->queued_replies);
389*4882a593Smuzhiyun 
390*4882a593Smuzhiyun 	return val < vq->num;
391*4882a593Smuzhiyun }
392*4882a593Smuzhiyun 
393*4882a593Smuzhiyun static struct virtio_transport vhost_transport = {
394*4882a593Smuzhiyun 	.transport = {
395*4882a593Smuzhiyun 		.module                   = THIS_MODULE,
396*4882a593Smuzhiyun 
397*4882a593Smuzhiyun 		.get_local_cid            = vhost_transport_get_local_cid,
398*4882a593Smuzhiyun 
399*4882a593Smuzhiyun 		.init                     = virtio_transport_do_socket_init,
400*4882a593Smuzhiyun 		.destruct                 = virtio_transport_destruct,
401*4882a593Smuzhiyun 		.release                  = virtio_transport_release,
402*4882a593Smuzhiyun 		.connect                  = virtio_transport_connect,
403*4882a593Smuzhiyun 		.shutdown                 = virtio_transport_shutdown,
404*4882a593Smuzhiyun 		.cancel_pkt               = vhost_transport_cancel_pkt,
405*4882a593Smuzhiyun 
406*4882a593Smuzhiyun 		.dgram_enqueue            = virtio_transport_dgram_enqueue,
407*4882a593Smuzhiyun 		.dgram_dequeue            = virtio_transport_dgram_dequeue,
408*4882a593Smuzhiyun 		.dgram_bind               = virtio_transport_dgram_bind,
409*4882a593Smuzhiyun 		.dgram_allow              = virtio_transport_dgram_allow,
410*4882a593Smuzhiyun 
411*4882a593Smuzhiyun 		.stream_enqueue           = virtio_transport_stream_enqueue,
412*4882a593Smuzhiyun 		.stream_dequeue           = virtio_transport_stream_dequeue,
413*4882a593Smuzhiyun 		.stream_has_data          = virtio_transport_stream_has_data,
414*4882a593Smuzhiyun 		.stream_has_space         = virtio_transport_stream_has_space,
415*4882a593Smuzhiyun 		.stream_rcvhiwat          = virtio_transport_stream_rcvhiwat,
416*4882a593Smuzhiyun 		.stream_is_active         = virtio_transport_stream_is_active,
417*4882a593Smuzhiyun 		.stream_allow             = virtio_transport_stream_allow,
418*4882a593Smuzhiyun 
419*4882a593Smuzhiyun 		.notify_poll_in           = virtio_transport_notify_poll_in,
420*4882a593Smuzhiyun 		.notify_poll_out          = virtio_transport_notify_poll_out,
421*4882a593Smuzhiyun 		.notify_recv_init         = virtio_transport_notify_recv_init,
422*4882a593Smuzhiyun 		.notify_recv_pre_block    = virtio_transport_notify_recv_pre_block,
423*4882a593Smuzhiyun 		.notify_recv_pre_dequeue  = virtio_transport_notify_recv_pre_dequeue,
424*4882a593Smuzhiyun 		.notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
425*4882a593Smuzhiyun 		.notify_send_init         = virtio_transport_notify_send_init,
426*4882a593Smuzhiyun 		.notify_send_pre_block    = virtio_transport_notify_send_pre_block,
427*4882a593Smuzhiyun 		.notify_send_pre_enqueue  = virtio_transport_notify_send_pre_enqueue,
428*4882a593Smuzhiyun 		.notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
429*4882a593Smuzhiyun 		.notify_buffer_size       = virtio_transport_notify_buffer_size,
430*4882a593Smuzhiyun 
431*4882a593Smuzhiyun 	},
432*4882a593Smuzhiyun 
433*4882a593Smuzhiyun 	.send_pkt = vhost_transport_send_pkt,
434*4882a593Smuzhiyun };
435*4882a593Smuzhiyun 
vhost_vsock_handle_tx_kick(struct vhost_work * work)436*4882a593Smuzhiyun static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
437*4882a593Smuzhiyun {
438*4882a593Smuzhiyun 	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
439*4882a593Smuzhiyun 						  poll.work);
440*4882a593Smuzhiyun 	struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
441*4882a593Smuzhiyun 						 dev);
442*4882a593Smuzhiyun 	struct virtio_vsock_pkt *pkt;
443*4882a593Smuzhiyun 	int head, pkts = 0, total_len = 0;
444*4882a593Smuzhiyun 	unsigned int out, in;
445*4882a593Smuzhiyun 	bool added = false;
446*4882a593Smuzhiyun 
447*4882a593Smuzhiyun 	mutex_lock(&vq->mutex);
448*4882a593Smuzhiyun 
449*4882a593Smuzhiyun 	if (!vhost_vq_get_backend(vq))
450*4882a593Smuzhiyun 		goto out;
451*4882a593Smuzhiyun 
452*4882a593Smuzhiyun 	vhost_disable_notify(&vsock->dev, vq);
453*4882a593Smuzhiyun 	do {
454*4882a593Smuzhiyun 		u32 len;
455*4882a593Smuzhiyun 
456*4882a593Smuzhiyun 		if (!vhost_vsock_more_replies(vsock)) {
457*4882a593Smuzhiyun 			/* Stop tx until the device processes already
458*4882a593Smuzhiyun 			 * pending replies.  Leave tx virtqueue
459*4882a593Smuzhiyun 			 * callbacks disabled.
460*4882a593Smuzhiyun 			 */
461*4882a593Smuzhiyun 			goto no_more_replies;
462*4882a593Smuzhiyun 		}
463*4882a593Smuzhiyun 
464*4882a593Smuzhiyun 		head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
465*4882a593Smuzhiyun 					 &out, &in, NULL, NULL);
466*4882a593Smuzhiyun 		if (head < 0)
467*4882a593Smuzhiyun 			break;
468*4882a593Smuzhiyun 
469*4882a593Smuzhiyun 		if (head == vq->num) {
470*4882a593Smuzhiyun 			if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
471*4882a593Smuzhiyun 				vhost_disable_notify(&vsock->dev, vq);
472*4882a593Smuzhiyun 				continue;
473*4882a593Smuzhiyun 			}
474*4882a593Smuzhiyun 			break;
475*4882a593Smuzhiyun 		}
476*4882a593Smuzhiyun 
477*4882a593Smuzhiyun 		pkt = vhost_vsock_alloc_pkt(vq, out, in);
478*4882a593Smuzhiyun 		if (!pkt) {
479*4882a593Smuzhiyun 			vq_err(vq, "Faulted on pkt\n");
480*4882a593Smuzhiyun 			continue;
481*4882a593Smuzhiyun 		}
482*4882a593Smuzhiyun 
483*4882a593Smuzhiyun 		len = pkt->len;
484*4882a593Smuzhiyun 
485*4882a593Smuzhiyun 		/* Deliver to monitoring devices all received packets */
486*4882a593Smuzhiyun 		virtio_transport_deliver_tap_pkt(pkt);
487*4882a593Smuzhiyun 
488*4882a593Smuzhiyun 		/* Only accept correctly addressed packets */
489*4882a593Smuzhiyun 		if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid &&
490*4882a593Smuzhiyun 		    le64_to_cpu(pkt->hdr.dst_cid) ==
491*4882a593Smuzhiyun 		    vhost_transport_get_local_cid())
492*4882a593Smuzhiyun 			virtio_transport_recv_pkt(&vhost_transport, pkt);
493*4882a593Smuzhiyun 		else
494*4882a593Smuzhiyun 			virtio_transport_free_pkt(pkt);
495*4882a593Smuzhiyun 
496*4882a593Smuzhiyun 		len += sizeof(pkt->hdr);
497*4882a593Smuzhiyun 		vhost_add_used(vq, head, 0);
498*4882a593Smuzhiyun 		total_len += len;
499*4882a593Smuzhiyun 		added = true;
500*4882a593Smuzhiyun 	} while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
501*4882a593Smuzhiyun 
502*4882a593Smuzhiyun no_more_replies:
503*4882a593Smuzhiyun 	if (added)
504*4882a593Smuzhiyun 		vhost_signal(&vsock->dev, vq);
505*4882a593Smuzhiyun 
506*4882a593Smuzhiyun out:
507*4882a593Smuzhiyun 	mutex_unlock(&vq->mutex);
508*4882a593Smuzhiyun }
509*4882a593Smuzhiyun 
vhost_vsock_handle_rx_kick(struct vhost_work * work)510*4882a593Smuzhiyun static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
511*4882a593Smuzhiyun {
512*4882a593Smuzhiyun 	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
513*4882a593Smuzhiyun 						poll.work);
514*4882a593Smuzhiyun 	struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
515*4882a593Smuzhiyun 						 dev);
516*4882a593Smuzhiyun 
517*4882a593Smuzhiyun 	vhost_transport_do_send_pkt(vsock, vq);
518*4882a593Smuzhiyun }
519*4882a593Smuzhiyun 
vhost_vsock_start(struct vhost_vsock * vsock)520*4882a593Smuzhiyun static int vhost_vsock_start(struct vhost_vsock *vsock)
521*4882a593Smuzhiyun {
522*4882a593Smuzhiyun 	struct vhost_virtqueue *vq;
523*4882a593Smuzhiyun 	size_t i;
524*4882a593Smuzhiyun 	int ret;
525*4882a593Smuzhiyun 
526*4882a593Smuzhiyun 	mutex_lock(&vsock->dev.mutex);
527*4882a593Smuzhiyun 
528*4882a593Smuzhiyun 	ret = vhost_dev_check_owner(&vsock->dev);
529*4882a593Smuzhiyun 	if (ret)
530*4882a593Smuzhiyun 		goto err;
531*4882a593Smuzhiyun 
532*4882a593Smuzhiyun 	for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
533*4882a593Smuzhiyun 		vq = &vsock->vqs[i];
534*4882a593Smuzhiyun 
535*4882a593Smuzhiyun 		mutex_lock(&vq->mutex);
536*4882a593Smuzhiyun 
537*4882a593Smuzhiyun 		if (!vhost_vq_access_ok(vq)) {
538*4882a593Smuzhiyun 			ret = -EFAULT;
539*4882a593Smuzhiyun 			goto err_vq;
540*4882a593Smuzhiyun 		}
541*4882a593Smuzhiyun 
542*4882a593Smuzhiyun 		if (!vhost_vq_get_backend(vq)) {
543*4882a593Smuzhiyun 			vhost_vq_set_backend(vq, vsock);
544*4882a593Smuzhiyun 			ret = vhost_vq_init_access(vq);
545*4882a593Smuzhiyun 			if (ret)
546*4882a593Smuzhiyun 				goto err_vq;
547*4882a593Smuzhiyun 		}
548*4882a593Smuzhiyun 
549*4882a593Smuzhiyun 		mutex_unlock(&vq->mutex);
550*4882a593Smuzhiyun 	}
551*4882a593Smuzhiyun 
552*4882a593Smuzhiyun 	/* Some packets may have been queued before the device was started,
553*4882a593Smuzhiyun 	 * let's kick the send worker to send them.
554*4882a593Smuzhiyun 	 */
555*4882a593Smuzhiyun 	vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
556*4882a593Smuzhiyun 
557*4882a593Smuzhiyun 	mutex_unlock(&vsock->dev.mutex);
558*4882a593Smuzhiyun 	return 0;
559*4882a593Smuzhiyun 
560*4882a593Smuzhiyun err_vq:
561*4882a593Smuzhiyun 	vhost_vq_set_backend(vq, NULL);
562*4882a593Smuzhiyun 	mutex_unlock(&vq->mutex);
563*4882a593Smuzhiyun 
564*4882a593Smuzhiyun 	for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
565*4882a593Smuzhiyun 		vq = &vsock->vqs[i];
566*4882a593Smuzhiyun 
567*4882a593Smuzhiyun 		mutex_lock(&vq->mutex);
568*4882a593Smuzhiyun 		vhost_vq_set_backend(vq, NULL);
569*4882a593Smuzhiyun 		mutex_unlock(&vq->mutex);
570*4882a593Smuzhiyun 	}
571*4882a593Smuzhiyun err:
572*4882a593Smuzhiyun 	mutex_unlock(&vsock->dev.mutex);
573*4882a593Smuzhiyun 	return ret;
574*4882a593Smuzhiyun }
575*4882a593Smuzhiyun 
vhost_vsock_stop(struct vhost_vsock * vsock,bool check_owner)576*4882a593Smuzhiyun static int vhost_vsock_stop(struct vhost_vsock *vsock, bool check_owner)
577*4882a593Smuzhiyun {
578*4882a593Smuzhiyun 	size_t i;
579*4882a593Smuzhiyun 	int ret = 0;
580*4882a593Smuzhiyun 
581*4882a593Smuzhiyun 	mutex_lock(&vsock->dev.mutex);
582*4882a593Smuzhiyun 
583*4882a593Smuzhiyun 	if (check_owner) {
584*4882a593Smuzhiyun 		ret = vhost_dev_check_owner(&vsock->dev);
585*4882a593Smuzhiyun 		if (ret)
586*4882a593Smuzhiyun 			goto err;
587*4882a593Smuzhiyun 	}
588*4882a593Smuzhiyun 
589*4882a593Smuzhiyun 	for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
590*4882a593Smuzhiyun 		struct vhost_virtqueue *vq = &vsock->vqs[i];
591*4882a593Smuzhiyun 
592*4882a593Smuzhiyun 		mutex_lock(&vq->mutex);
593*4882a593Smuzhiyun 		vhost_vq_set_backend(vq, NULL);
594*4882a593Smuzhiyun 		mutex_unlock(&vq->mutex);
595*4882a593Smuzhiyun 	}
596*4882a593Smuzhiyun 
597*4882a593Smuzhiyun err:
598*4882a593Smuzhiyun 	mutex_unlock(&vsock->dev.mutex);
599*4882a593Smuzhiyun 	return ret;
600*4882a593Smuzhiyun }
601*4882a593Smuzhiyun 
vhost_vsock_free(struct vhost_vsock * vsock)602*4882a593Smuzhiyun static void vhost_vsock_free(struct vhost_vsock *vsock)
603*4882a593Smuzhiyun {
604*4882a593Smuzhiyun 	kvfree(vsock);
605*4882a593Smuzhiyun }
606*4882a593Smuzhiyun 
vhost_vsock_dev_open(struct inode * inode,struct file * file)607*4882a593Smuzhiyun static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
608*4882a593Smuzhiyun {
609*4882a593Smuzhiyun 	struct vhost_virtqueue **vqs;
610*4882a593Smuzhiyun 	struct vhost_vsock *vsock;
611*4882a593Smuzhiyun 	int ret;
612*4882a593Smuzhiyun 
613*4882a593Smuzhiyun 	/* This struct is large and allocation could fail, fall back to vmalloc
614*4882a593Smuzhiyun 	 * if there is no other way.
615*4882a593Smuzhiyun 	 */
616*4882a593Smuzhiyun 	vsock = kvmalloc(sizeof(*vsock), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
617*4882a593Smuzhiyun 	if (!vsock)
618*4882a593Smuzhiyun 		return -ENOMEM;
619*4882a593Smuzhiyun 
620*4882a593Smuzhiyun 	vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL);
621*4882a593Smuzhiyun 	if (!vqs) {
622*4882a593Smuzhiyun 		ret = -ENOMEM;
623*4882a593Smuzhiyun 		goto out;
624*4882a593Smuzhiyun 	}
625*4882a593Smuzhiyun 
626*4882a593Smuzhiyun 	vsock->guest_cid = 0; /* no CID assigned yet */
627*4882a593Smuzhiyun 
628*4882a593Smuzhiyun 	atomic_set(&vsock->queued_replies, 0);
629*4882a593Smuzhiyun 
630*4882a593Smuzhiyun 	vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX];
631*4882a593Smuzhiyun 	vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX];
632*4882a593Smuzhiyun 	vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
633*4882a593Smuzhiyun 	vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
634*4882a593Smuzhiyun 
635*4882a593Smuzhiyun 	vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs),
636*4882a593Smuzhiyun 		       UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT,
637*4882a593Smuzhiyun 		       VHOST_VSOCK_WEIGHT, true, NULL);
638*4882a593Smuzhiyun 
639*4882a593Smuzhiyun 	file->private_data = vsock;
640*4882a593Smuzhiyun 	spin_lock_init(&vsock->send_pkt_list_lock);
641*4882a593Smuzhiyun 	INIT_LIST_HEAD(&vsock->send_pkt_list);
642*4882a593Smuzhiyun 	vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
643*4882a593Smuzhiyun 	return 0;
644*4882a593Smuzhiyun 
645*4882a593Smuzhiyun out:
646*4882a593Smuzhiyun 	vhost_vsock_free(vsock);
647*4882a593Smuzhiyun 	return ret;
648*4882a593Smuzhiyun }
649*4882a593Smuzhiyun 
vhost_vsock_flush(struct vhost_vsock * vsock)650*4882a593Smuzhiyun static void vhost_vsock_flush(struct vhost_vsock *vsock)
651*4882a593Smuzhiyun {
652*4882a593Smuzhiyun 	int i;
653*4882a593Smuzhiyun 
654*4882a593Smuzhiyun 	for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++)
655*4882a593Smuzhiyun 		if (vsock->vqs[i].handle_kick)
656*4882a593Smuzhiyun 			vhost_poll_flush(&vsock->vqs[i].poll);
657*4882a593Smuzhiyun 	vhost_work_flush(&vsock->dev, &vsock->send_pkt_work);
658*4882a593Smuzhiyun }
659*4882a593Smuzhiyun 
vhost_vsock_reset_orphans(struct sock * sk)660*4882a593Smuzhiyun static void vhost_vsock_reset_orphans(struct sock *sk)
661*4882a593Smuzhiyun {
662*4882a593Smuzhiyun 	struct vsock_sock *vsk = vsock_sk(sk);
663*4882a593Smuzhiyun 
664*4882a593Smuzhiyun 	/* vmci_transport.c doesn't take sk_lock here either.  At least we're
665*4882a593Smuzhiyun 	 * under vsock_table_lock so the sock cannot disappear while we're
666*4882a593Smuzhiyun 	 * executing.
667*4882a593Smuzhiyun 	 */
668*4882a593Smuzhiyun 
669*4882a593Smuzhiyun 	/* If the peer is still valid, no need to reset connection */
670*4882a593Smuzhiyun 	if (vhost_vsock_get(vsk->remote_addr.svm_cid))
671*4882a593Smuzhiyun 		return;
672*4882a593Smuzhiyun 
673*4882a593Smuzhiyun 	/* If the close timeout is pending, let it expire.  This avoids races
674*4882a593Smuzhiyun 	 * with the timeout callback.
675*4882a593Smuzhiyun 	 */
676*4882a593Smuzhiyun 	if (vsk->close_work_scheduled)
677*4882a593Smuzhiyun 		return;
678*4882a593Smuzhiyun 
679*4882a593Smuzhiyun 	sock_set_flag(sk, SOCK_DONE);
680*4882a593Smuzhiyun 	vsk->peer_shutdown = SHUTDOWN_MASK;
681*4882a593Smuzhiyun 	sk->sk_state = SS_UNCONNECTED;
682*4882a593Smuzhiyun 	sk->sk_err = ECONNRESET;
683*4882a593Smuzhiyun 	sk->sk_error_report(sk);
684*4882a593Smuzhiyun }
685*4882a593Smuzhiyun 
vhost_vsock_dev_release(struct inode * inode,struct file * file)686*4882a593Smuzhiyun static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
687*4882a593Smuzhiyun {
688*4882a593Smuzhiyun 	struct vhost_vsock *vsock = file->private_data;
689*4882a593Smuzhiyun 
690*4882a593Smuzhiyun 	mutex_lock(&vhost_vsock_mutex);
691*4882a593Smuzhiyun 	if (vsock->guest_cid)
692*4882a593Smuzhiyun 		hash_del_rcu(&vsock->hash);
693*4882a593Smuzhiyun 	mutex_unlock(&vhost_vsock_mutex);
694*4882a593Smuzhiyun 
695*4882a593Smuzhiyun 	/* Wait for other CPUs to finish using vsock */
696*4882a593Smuzhiyun 	synchronize_rcu();
697*4882a593Smuzhiyun 
698*4882a593Smuzhiyun 	/* Iterating over all connections for all CIDs to find orphans is
699*4882a593Smuzhiyun 	 * inefficient.  Room for improvement here. */
700*4882a593Smuzhiyun 	vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
701*4882a593Smuzhiyun 
702*4882a593Smuzhiyun 	/* Don't check the owner, because we are in the release path, so we
703*4882a593Smuzhiyun 	 * need to stop the vsock device in any case.
704*4882a593Smuzhiyun 	 * vhost_vsock_stop() can not fail in this case, so we don't need to
705*4882a593Smuzhiyun 	 * check the return code.
706*4882a593Smuzhiyun 	 */
707*4882a593Smuzhiyun 	vhost_vsock_stop(vsock, false);
708*4882a593Smuzhiyun 	vhost_vsock_flush(vsock);
709*4882a593Smuzhiyun 	vhost_dev_stop(&vsock->dev);
710*4882a593Smuzhiyun 
711*4882a593Smuzhiyun 	spin_lock_bh(&vsock->send_pkt_list_lock);
712*4882a593Smuzhiyun 	while (!list_empty(&vsock->send_pkt_list)) {
713*4882a593Smuzhiyun 		struct virtio_vsock_pkt *pkt;
714*4882a593Smuzhiyun 
715*4882a593Smuzhiyun 		pkt = list_first_entry(&vsock->send_pkt_list,
716*4882a593Smuzhiyun 				struct virtio_vsock_pkt, list);
717*4882a593Smuzhiyun 		list_del_init(&pkt->list);
718*4882a593Smuzhiyun 		virtio_transport_free_pkt(pkt);
719*4882a593Smuzhiyun 	}
720*4882a593Smuzhiyun 	spin_unlock_bh(&vsock->send_pkt_list_lock);
721*4882a593Smuzhiyun 
722*4882a593Smuzhiyun 	vhost_dev_cleanup(&vsock->dev);
723*4882a593Smuzhiyun 	kfree(vsock->dev.vqs);
724*4882a593Smuzhiyun 	vhost_vsock_free(vsock);
725*4882a593Smuzhiyun 	return 0;
726*4882a593Smuzhiyun }
727*4882a593Smuzhiyun 
vhost_vsock_set_cid(struct vhost_vsock * vsock,u64 guest_cid)728*4882a593Smuzhiyun static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
729*4882a593Smuzhiyun {
730*4882a593Smuzhiyun 	struct vhost_vsock *other;
731*4882a593Smuzhiyun 
732*4882a593Smuzhiyun 	/* Refuse reserved CIDs */
733*4882a593Smuzhiyun 	if (guest_cid <= VMADDR_CID_HOST ||
734*4882a593Smuzhiyun 	    guest_cid == U32_MAX)
735*4882a593Smuzhiyun 		return -EINVAL;
736*4882a593Smuzhiyun 
737*4882a593Smuzhiyun 	/* 64-bit CIDs are not yet supported */
738*4882a593Smuzhiyun 	if (guest_cid > U32_MAX)
739*4882a593Smuzhiyun 		return -EINVAL;
740*4882a593Smuzhiyun 
741*4882a593Smuzhiyun 	/* Refuse if CID is assigned to the guest->host transport (i.e. nested
742*4882a593Smuzhiyun 	 * VM), to make the loopback work.
743*4882a593Smuzhiyun 	 */
744*4882a593Smuzhiyun 	if (vsock_find_cid(guest_cid))
745*4882a593Smuzhiyun 		return -EADDRINUSE;
746*4882a593Smuzhiyun 
747*4882a593Smuzhiyun 	/* Refuse if CID is already in use */
748*4882a593Smuzhiyun 	mutex_lock(&vhost_vsock_mutex);
749*4882a593Smuzhiyun 	other = vhost_vsock_get(guest_cid);
750*4882a593Smuzhiyun 	if (other && other != vsock) {
751*4882a593Smuzhiyun 		mutex_unlock(&vhost_vsock_mutex);
752*4882a593Smuzhiyun 		return -EADDRINUSE;
753*4882a593Smuzhiyun 	}
754*4882a593Smuzhiyun 
755*4882a593Smuzhiyun 	if (vsock->guest_cid)
756*4882a593Smuzhiyun 		hash_del_rcu(&vsock->hash);
757*4882a593Smuzhiyun 
758*4882a593Smuzhiyun 	vsock->guest_cid = guest_cid;
759*4882a593Smuzhiyun 	hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
760*4882a593Smuzhiyun 	mutex_unlock(&vhost_vsock_mutex);
761*4882a593Smuzhiyun 
762*4882a593Smuzhiyun 	return 0;
763*4882a593Smuzhiyun }
764*4882a593Smuzhiyun 
vhost_vsock_set_features(struct vhost_vsock * vsock,u64 features)765*4882a593Smuzhiyun static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
766*4882a593Smuzhiyun {
767*4882a593Smuzhiyun 	struct vhost_virtqueue *vq;
768*4882a593Smuzhiyun 	int i;
769*4882a593Smuzhiyun 
770*4882a593Smuzhiyun 	if (features & ~VHOST_VSOCK_FEATURES)
771*4882a593Smuzhiyun 		return -EOPNOTSUPP;
772*4882a593Smuzhiyun 
773*4882a593Smuzhiyun 	mutex_lock(&vsock->dev.mutex);
774*4882a593Smuzhiyun 	if ((features & (1 << VHOST_F_LOG_ALL)) &&
775*4882a593Smuzhiyun 	    !vhost_log_access_ok(&vsock->dev)) {
776*4882a593Smuzhiyun 		mutex_unlock(&vsock->dev.mutex);
777*4882a593Smuzhiyun 		return -EFAULT;
778*4882a593Smuzhiyun 	}
779*4882a593Smuzhiyun 
780*4882a593Smuzhiyun 	for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
781*4882a593Smuzhiyun 		vq = &vsock->vqs[i];
782*4882a593Smuzhiyun 		mutex_lock(&vq->mutex);
783*4882a593Smuzhiyun 		vq->acked_features = features;
784*4882a593Smuzhiyun 		mutex_unlock(&vq->mutex);
785*4882a593Smuzhiyun 	}
786*4882a593Smuzhiyun 	mutex_unlock(&vsock->dev.mutex);
787*4882a593Smuzhiyun 	return 0;
788*4882a593Smuzhiyun }
789*4882a593Smuzhiyun 
vhost_vsock_dev_ioctl(struct file * f,unsigned int ioctl,unsigned long arg)790*4882a593Smuzhiyun static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl,
791*4882a593Smuzhiyun 				  unsigned long arg)
792*4882a593Smuzhiyun {
793*4882a593Smuzhiyun 	struct vhost_vsock *vsock = f->private_data;
794*4882a593Smuzhiyun 	void __user *argp = (void __user *)arg;
795*4882a593Smuzhiyun 	u64 guest_cid;
796*4882a593Smuzhiyun 	u64 features;
797*4882a593Smuzhiyun 	int start;
798*4882a593Smuzhiyun 	int r;
799*4882a593Smuzhiyun 
800*4882a593Smuzhiyun 	switch (ioctl) {
801*4882a593Smuzhiyun 	case VHOST_VSOCK_SET_GUEST_CID:
802*4882a593Smuzhiyun 		if (copy_from_user(&guest_cid, argp, sizeof(guest_cid)))
803*4882a593Smuzhiyun 			return -EFAULT;
804*4882a593Smuzhiyun 		return vhost_vsock_set_cid(vsock, guest_cid);
805*4882a593Smuzhiyun 	case VHOST_VSOCK_SET_RUNNING:
806*4882a593Smuzhiyun 		if (copy_from_user(&start, argp, sizeof(start)))
807*4882a593Smuzhiyun 			return -EFAULT;
808*4882a593Smuzhiyun 		if (start)
809*4882a593Smuzhiyun 			return vhost_vsock_start(vsock);
810*4882a593Smuzhiyun 		else
811*4882a593Smuzhiyun 			return vhost_vsock_stop(vsock, true);
812*4882a593Smuzhiyun 	case VHOST_GET_FEATURES:
813*4882a593Smuzhiyun 		features = VHOST_VSOCK_FEATURES;
814*4882a593Smuzhiyun 		if (copy_to_user(argp, &features, sizeof(features)))
815*4882a593Smuzhiyun 			return -EFAULT;
816*4882a593Smuzhiyun 		return 0;
817*4882a593Smuzhiyun 	case VHOST_SET_FEATURES:
818*4882a593Smuzhiyun 		if (copy_from_user(&features, argp, sizeof(features)))
819*4882a593Smuzhiyun 			return -EFAULT;
820*4882a593Smuzhiyun 		return vhost_vsock_set_features(vsock, features);
821*4882a593Smuzhiyun 	default:
822*4882a593Smuzhiyun 		mutex_lock(&vsock->dev.mutex);
823*4882a593Smuzhiyun 		r = vhost_dev_ioctl(&vsock->dev, ioctl, argp);
824*4882a593Smuzhiyun 		if (r == -ENOIOCTLCMD)
825*4882a593Smuzhiyun 			r = vhost_vring_ioctl(&vsock->dev, ioctl, argp);
826*4882a593Smuzhiyun 		else
827*4882a593Smuzhiyun 			vhost_vsock_flush(vsock);
828*4882a593Smuzhiyun 		mutex_unlock(&vsock->dev.mutex);
829*4882a593Smuzhiyun 		return r;
830*4882a593Smuzhiyun 	}
831*4882a593Smuzhiyun }
832*4882a593Smuzhiyun 
833*4882a593Smuzhiyun static const struct file_operations vhost_vsock_fops = {
834*4882a593Smuzhiyun 	.owner          = THIS_MODULE,
835*4882a593Smuzhiyun 	.open           = vhost_vsock_dev_open,
836*4882a593Smuzhiyun 	.release        = vhost_vsock_dev_release,
837*4882a593Smuzhiyun 	.llseek		= noop_llseek,
838*4882a593Smuzhiyun 	.unlocked_ioctl = vhost_vsock_dev_ioctl,
839*4882a593Smuzhiyun 	.compat_ioctl   = compat_ptr_ioctl,
840*4882a593Smuzhiyun };
841*4882a593Smuzhiyun 
842*4882a593Smuzhiyun static struct miscdevice vhost_vsock_misc = {
843*4882a593Smuzhiyun 	.minor = VHOST_VSOCK_MINOR,
844*4882a593Smuzhiyun 	.name = "vhost-vsock",
845*4882a593Smuzhiyun 	.fops = &vhost_vsock_fops,
846*4882a593Smuzhiyun };
847*4882a593Smuzhiyun 
vhost_vsock_init(void)848*4882a593Smuzhiyun static int __init vhost_vsock_init(void)
849*4882a593Smuzhiyun {
850*4882a593Smuzhiyun 	int ret;
851*4882a593Smuzhiyun 
852*4882a593Smuzhiyun 	ret = vsock_core_register(&vhost_transport.transport,
853*4882a593Smuzhiyun 				  VSOCK_TRANSPORT_F_H2G);
854*4882a593Smuzhiyun 	if (ret < 0)
855*4882a593Smuzhiyun 		return ret;
856*4882a593Smuzhiyun 	return misc_register(&vhost_vsock_misc);
857*4882a593Smuzhiyun };
858*4882a593Smuzhiyun 
vhost_vsock_exit(void)859*4882a593Smuzhiyun static void __exit vhost_vsock_exit(void)
860*4882a593Smuzhiyun {
861*4882a593Smuzhiyun 	misc_deregister(&vhost_vsock_misc);
862*4882a593Smuzhiyun 	vsock_core_unregister(&vhost_transport.transport);
863*4882a593Smuzhiyun };
864*4882a593Smuzhiyun 
865*4882a593Smuzhiyun module_init(vhost_vsock_init);
866*4882a593Smuzhiyun module_exit(vhost_vsock_exit);
867*4882a593Smuzhiyun MODULE_LICENSE("GPL v2");
868*4882a593Smuzhiyun MODULE_AUTHOR("Asias He");
869*4882a593Smuzhiyun MODULE_DESCRIPTION("vhost transport for vsock ");
870*4882a593Smuzhiyun MODULE_ALIAS_MISCDEV(VHOST_VSOCK_MINOR);
871*4882a593Smuzhiyun MODULE_ALIAS("devname:vhost-vsock");
872