1*4882a593Smuzhiyun // SPDX-License-Identifier: GPL-2.0-only
2*4882a593Smuzhiyun /*
3*4882a593Smuzhiyun * vhost transport for vsock
4*4882a593Smuzhiyun *
5*4882a593Smuzhiyun * Copyright (C) 2013-2015 Red Hat, Inc.
6*4882a593Smuzhiyun * Author: Asias He <asias@redhat.com>
7*4882a593Smuzhiyun * Stefan Hajnoczi <stefanha@redhat.com>
8*4882a593Smuzhiyun */
9*4882a593Smuzhiyun #include <linux/miscdevice.h>
10*4882a593Smuzhiyun #include <linux/atomic.h>
11*4882a593Smuzhiyun #include <linux/module.h>
12*4882a593Smuzhiyun #include <linux/mutex.h>
13*4882a593Smuzhiyun #include <linux/vmalloc.h>
14*4882a593Smuzhiyun #include <net/sock.h>
15*4882a593Smuzhiyun #include <linux/virtio_vsock.h>
16*4882a593Smuzhiyun #include <linux/vhost.h>
17*4882a593Smuzhiyun #include <linux/hashtable.h>
18*4882a593Smuzhiyun
19*4882a593Smuzhiyun #include <net/af_vsock.h>
20*4882a593Smuzhiyun #include "vhost.h"
21*4882a593Smuzhiyun
22*4882a593Smuzhiyun #define VHOST_VSOCK_DEFAULT_HOST_CID 2
23*4882a593Smuzhiyun /* Max number of bytes transferred before requeueing the job.
24*4882a593Smuzhiyun * Using this limit prevents one virtqueue from starving others. */
25*4882a593Smuzhiyun #define VHOST_VSOCK_WEIGHT 0x80000
26*4882a593Smuzhiyun /* Max number of packets transferred before requeueing the job.
27*4882a593Smuzhiyun * Using this limit prevents one virtqueue from starving others with
28*4882a593Smuzhiyun * small pkts.
29*4882a593Smuzhiyun */
30*4882a593Smuzhiyun #define VHOST_VSOCK_PKT_WEIGHT 256
31*4882a593Smuzhiyun
32*4882a593Smuzhiyun enum {
33*4882a593Smuzhiyun VHOST_VSOCK_FEATURES = VHOST_FEATURES,
34*4882a593Smuzhiyun };
35*4882a593Smuzhiyun
36*4882a593Smuzhiyun /* Used to track all the vhost_vsock instances on the system. */
37*4882a593Smuzhiyun static DEFINE_MUTEX(vhost_vsock_mutex);
38*4882a593Smuzhiyun static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
39*4882a593Smuzhiyun
40*4882a593Smuzhiyun struct vhost_vsock {
41*4882a593Smuzhiyun struct vhost_dev dev;
42*4882a593Smuzhiyun struct vhost_virtqueue vqs[2];
43*4882a593Smuzhiyun
44*4882a593Smuzhiyun /* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
45*4882a593Smuzhiyun struct hlist_node hash;
46*4882a593Smuzhiyun
47*4882a593Smuzhiyun struct vhost_work send_pkt_work;
48*4882a593Smuzhiyun spinlock_t send_pkt_list_lock;
49*4882a593Smuzhiyun struct list_head send_pkt_list; /* host->guest pending packets */
50*4882a593Smuzhiyun
51*4882a593Smuzhiyun atomic_t queued_replies;
52*4882a593Smuzhiyun
53*4882a593Smuzhiyun u32 guest_cid;
54*4882a593Smuzhiyun };
55*4882a593Smuzhiyun
vhost_transport_get_local_cid(void)56*4882a593Smuzhiyun static u32 vhost_transport_get_local_cid(void)
57*4882a593Smuzhiyun {
58*4882a593Smuzhiyun return VHOST_VSOCK_DEFAULT_HOST_CID;
59*4882a593Smuzhiyun }
60*4882a593Smuzhiyun
61*4882a593Smuzhiyun /* Callers that dereference the return value must hold vhost_vsock_mutex or the
62*4882a593Smuzhiyun * RCU read lock.
63*4882a593Smuzhiyun */
vhost_vsock_get(u32 guest_cid)64*4882a593Smuzhiyun static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
65*4882a593Smuzhiyun {
66*4882a593Smuzhiyun struct vhost_vsock *vsock;
67*4882a593Smuzhiyun
68*4882a593Smuzhiyun hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) {
69*4882a593Smuzhiyun u32 other_cid = vsock->guest_cid;
70*4882a593Smuzhiyun
71*4882a593Smuzhiyun /* Skip instances that have no CID yet */
72*4882a593Smuzhiyun if (other_cid == 0)
73*4882a593Smuzhiyun continue;
74*4882a593Smuzhiyun
75*4882a593Smuzhiyun if (other_cid == guest_cid)
76*4882a593Smuzhiyun return vsock;
77*4882a593Smuzhiyun
78*4882a593Smuzhiyun }
79*4882a593Smuzhiyun
80*4882a593Smuzhiyun return NULL;
81*4882a593Smuzhiyun }
82*4882a593Smuzhiyun
83*4882a593Smuzhiyun static void
vhost_transport_do_send_pkt(struct vhost_vsock * vsock,struct vhost_virtqueue * vq)84*4882a593Smuzhiyun vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
85*4882a593Smuzhiyun struct vhost_virtqueue *vq)
86*4882a593Smuzhiyun {
87*4882a593Smuzhiyun struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
88*4882a593Smuzhiyun int pkts = 0, total_len = 0;
89*4882a593Smuzhiyun bool added = false;
90*4882a593Smuzhiyun bool restart_tx = false;
91*4882a593Smuzhiyun
92*4882a593Smuzhiyun mutex_lock(&vq->mutex);
93*4882a593Smuzhiyun
94*4882a593Smuzhiyun if (!vhost_vq_get_backend(vq))
95*4882a593Smuzhiyun goto out;
96*4882a593Smuzhiyun
97*4882a593Smuzhiyun /* Avoid further vmexits, we're already processing the virtqueue */
98*4882a593Smuzhiyun vhost_disable_notify(&vsock->dev, vq);
99*4882a593Smuzhiyun
100*4882a593Smuzhiyun do {
101*4882a593Smuzhiyun struct virtio_vsock_pkt *pkt;
102*4882a593Smuzhiyun struct iov_iter iov_iter;
103*4882a593Smuzhiyun unsigned out, in;
104*4882a593Smuzhiyun size_t nbytes;
105*4882a593Smuzhiyun size_t iov_len, payload_len;
106*4882a593Smuzhiyun int head;
107*4882a593Smuzhiyun
108*4882a593Smuzhiyun spin_lock_bh(&vsock->send_pkt_list_lock);
109*4882a593Smuzhiyun if (list_empty(&vsock->send_pkt_list)) {
110*4882a593Smuzhiyun spin_unlock_bh(&vsock->send_pkt_list_lock);
111*4882a593Smuzhiyun vhost_enable_notify(&vsock->dev, vq);
112*4882a593Smuzhiyun break;
113*4882a593Smuzhiyun }
114*4882a593Smuzhiyun
115*4882a593Smuzhiyun pkt = list_first_entry(&vsock->send_pkt_list,
116*4882a593Smuzhiyun struct virtio_vsock_pkt, list);
117*4882a593Smuzhiyun list_del_init(&pkt->list);
118*4882a593Smuzhiyun spin_unlock_bh(&vsock->send_pkt_list_lock);
119*4882a593Smuzhiyun
120*4882a593Smuzhiyun head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
121*4882a593Smuzhiyun &out, &in, NULL, NULL);
122*4882a593Smuzhiyun if (head < 0) {
123*4882a593Smuzhiyun spin_lock_bh(&vsock->send_pkt_list_lock);
124*4882a593Smuzhiyun list_add(&pkt->list, &vsock->send_pkt_list);
125*4882a593Smuzhiyun spin_unlock_bh(&vsock->send_pkt_list_lock);
126*4882a593Smuzhiyun break;
127*4882a593Smuzhiyun }
128*4882a593Smuzhiyun
129*4882a593Smuzhiyun if (head == vq->num) {
130*4882a593Smuzhiyun spin_lock_bh(&vsock->send_pkt_list_lock);
131*4882a593Smuzhiyun list_add(&pkt->list, &vsock->send_pkt_list);
132*4882a593Smuzhiyun spin_unlock_bh(&vsock->send_pkt_list_lock);
133*4882a593Smuzhiyun
134*4882a593Smuzhiyun /* We cannot finish yet if more buffers snuck in while
135*4882a593Smuzhiyun * re-enabling notify.
136*4882a593Smuzhiyun */
137*4882a593Smuzhiyun if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
138*4882a593Smuzhiyun vhost_disable_notify(&vsock->dev, vq);
139*4882a593Smuzhiyun continue;
140*4882a593Smuzhiyun }
141*4882a593Smuzhiyun break;
142*4882a593Smuzhiyun }
143*4882a593Smuzhiyun
144*4882a593Smuzhiyun if (out) {
145*4882a593Smuzhiyun virtio_transport_free_pkt(pkt);
146*4882a593Smuzhiyun vq_err(vq, "Expected 0 output buffers, got %u\n", out);
147*4882a593Smuzhiyun break;
148*4882a593Smuzhiyun }
149*4882a593Smuzhiyun
150*4882a593Smuzhiyun iov_len = iov_length(&vq->iov[out], in);
151*4882a593Smuzhiyun if (iov_len < sizeof(pkt->hdr)) {
152*4882a593Smuzhiyun virtio_transport_free_pkt(pkt);
153*4882a593Smuzhiyun vq_err(vq, "Buffer len [%zu] too small\n", iov_len);
154*4882a593Smuzhiyun break;
155*4882a593Smuzhiyun }
156*4882a593Smuzhiyun
157*4882a593Smuzhiyun iov_iter_init(&iov_iter, READ, &vq->iov[out], in, iov_len);
158*4882a593Smuzhiyun payload_len = pkt->len - pkt->off;
159*4882a593Smuzhiyun
160*4882a593Smuzhiyun /* If the packet is greater than the space available in the
161*4882a593Smuzhiyun * buffer, we split it using multiple buffers.
162*4882a593Smuzhiyun */
163*4882a593Smuzhiyun if (payload_len > iov_len - sizeof(pkt->hdr))
164*4882a593Smuzhiyun payload_len = iov_len - sizeof(pkt->hdr);
165*4882a593Smuzhiyun
166*4882a593Smuzhiyun /* Set the correct length in the header */
167*4882a593Smuzhiyun pkt->hdr.len = cpu_to_le32(payload_len);
168*4882a593Smuzhiyun
169*4882a593Smuzhiyun nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
170*4882a593Smuzhiyun if (nbytes != sizeof(pkt->hdr)) {
171*4882a593Smuzhiyun virtio_transport_free_pkt(pkt);
172*4882a593Smuzhiyun vq_err(vq, "Faulted on copying pkt hdr\n");
173*4882a593Smuzhiyun break;
174*4882a593Smuzhiyun }
175*4882a593Smuzhiyun
176*4882a593Smuzhiyun nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len,
177*4882a593Smuzhiyun &iov_iter);
178*4882a593Smuzhiyun if (nbytes != payload_len) {
179*4882a593Smuzhiyun virtio_transport_free_pkt(pkt);
180*4882a593Smuzhiyun vq_err(vq, "Faulted on copying pkt buf\n");
181*4882a593Smuzhiyun break;
182*4882a593Smuzhiyun }
183*4882a593Smuzhiyun
184*4882a593Smuzhiyun /* Deliver to monitoring devices all packets that we
185*4882a593Smuzhiyun * will transmit.
186*4882a593Smuzhiyun */
187*4882a593Smuzhiyun virtio_transport_deliver_tap_pkt(pkt);
188*4882a593Smuzhiyun
189*4882a593Smuzhiyun vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len);
190*4882a593Smuzhiyun added = true;
191*4882a593Smuzhiyun
192*4882a593Smuzhiyun pkt->off += payload_len;
193*4882a593Smuzhiyun total_len += payload_len;
194*4882a593Smuzhiyun
195*4882a593Smuzhiyun /* If we didn't send all the payload we can requeue the packet
196*4882a593Smuzhiyun * to send it with the next available buffer.
197*4882a593Smuzhiyun */
198*4882a593Smuzhiyun if (pkt->off < pkt->len) {
199*4882a593Smuzhiyun /* We are queueing the same virtio_vsock_pkt to handle
200*4882a593Smuzhiyun * the remaining bytes, and we want to deliver it
201*4882a593Smuzhiyun * to monitoring devices in the next iteration.
202*4882a593Smuzhiyun */
203*4882a593Smuzhiyun pkt->tap_delivered = false;
204*4882a593Smuzhiyun
205*4882a593Smuzhiyun spin_lock_bh(&vsock->send_pkt_list_lock);
206*4882a593Smuzhiyun list_add(&pkt->list, &vsock->send_pkt_list);
207*4882a593Smuzhiyun spin_unlock_bh(&vsock->send_pkt_list_lock);
208*4882a593Smuzhiyun } else {
209*4882a593Smuzhiyun if (pkt->reply) {
210*4882a593Smuzhiyun int val;
211*4882a593Smuzhiyun
212*4882a593Smuzhiyun val = atomic_dec_return(&vsock->queued_replies);
213*4882a593Smuzhiyun
214*4882a593Smuzhiyun /* Do we have resources to resume tx
215*4882a593Smuzhiyun * processing?
216*4882a593Smuzhiyun */
217*4882a593Smuzhiyun if (val + 1 == tx_vq->num)
218*4882a593Smuzhiyun restart_tx = true;
219*4882a593Smuzhiyun }
220*4882a593Smuzhiyun
221*4882a593Smuzhiyun virtio_transport_free_pkt(pkt);
222*4882a593Smuzhiyun }
223*4882a593Smuzhiyun } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
224*4882a593Smuzhiyun if (added)
225*4882a593Smuzhiyun vhost_signal(&vsock->dev, vq);
226*4882a593Smuzhiyun
227*4882a593Smuzhiyun out:
228*4882a593Smuzhiyun mutex_unlock(&vq->mutex);
229*4882a593Smuzhiyun
230*4882a593Smuzhiyun if (restart_tx)
231*4882a593Smuzhiyun vhost_poll_queue(&tx_vq->poll);
232*4882a593Smuzhiyun }
233*4882a593Smuzhiyun
vhost_transport_send_pkt_work(struct vhost_work * work)234*4882a593Smuzhiyun static void vhost_transport_send_pkt_work(struct vhost_work *work)
235*4882a593Smuzhiyun {
236*4882a593Smuzhiyun struct vhost_virtqueue *vq;
237*4882a593Smuzhiyun struct vhost_vsock *vsock;
238*4882a593Smuzhiyun
239*4882a593Smuzhiyun vsock = container_of(work, struct vhost_vsock, send_pkt_work);
240*4882a593Smuzhiyun vq = &vsock->vqs[VSOCK_VQ_RX];
241*4882a593Smuzhiyun
242*4882a593Smuzhiyun vhost_transport_do_send_pkt(vsock, vq);
243*4882a593Smuzhiyun }
244*4882a593Smuzhiyun
245*4882a593Smuzhiyun static int
vhost_transport_send_pkt(struct virtio_vsock_pkt * pkt)246*4882a593Smuzhiyun vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
247*4882a593Smuzhiyun {
248*4882a593Smuzhiyun struct vhost_vsock *vsock;
249*4882a593Smuzhiyun int len = pkt->len;
250*4882a593Smuzhiyun
251*4882a593Smuzhiyun rcu_read_lock();
252*4882a593Smuzhiyun
253*4882a593Smuzhiyun /* Find the vhost_vsock according to guest context id */
254*4882a593Smuzhiyun vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
255*4882a593Smuzhiyun if (!vsock) {
256*4882a593Smuzhiyun rcu_read_unlock();
257*4882a593Smuzhiyun virtio_transport_free_pkt(pkt);
258*4882a593Smuzhiyun return -ENODEV;
259*4882a593Smuzhiyun }
260*4882a593Smuzhiyun
261*4882a593Smuzhiyun if (pkt->reply)
262*4882a593Smuzhiyun atomic_inc(&vsock->queued_replies);
263*4882a593Smuzhiyun
264*4882a593Smuzhiyun spin_lock_bh(&vsock->send_pkt_list_lock);
265*4882a593Smuzhiyun list_add_tail(&pkt->list, &vsock->send_pkt_list);
266*4882a593Smuzhiyun spin_unlock_bh(&vsock->send_pkt_list_lock);
267*4882a593Smuzhiyun
268*4882a593Smuzhiyun vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
269*4882a593Smuzhiyun
270*4882a593Smuzhiyun rcu_read_unlock();
271*4882a593Smuzhiyun return len;
272*4882a593Smuzhiyun }
273*4882a593Smuzhiyun
274*4882a593Smuzhiyun static int
vhost_transport_cancel_pkt(struct vsock_sock * vsk)275*4882a593Smuzhiyun vhost_transport_cancel_pkt(struct vsock_sock *vsk)
276*4882a593Smuzhiyun {
277*4882a593Smuzhiyun struct vhost_vsock *vsock;
278*4882a593Smuzhiyun struct virtio_vsock_pkt *pkt, *n;
279*4882a593Smuzhiyun int cnt = 0;
280*4882a593Smuzhiyun int ret = -ENODEV;
281*4882a593Smuzhiyun LIST_HEAD(freeme);
282*4882a593Smuzhiyun
283*4882a593Smuzhiyun rcu_read_lock();
284*4882a593Smuzhiyun
285*4882a593Smuzhiyun /* Find the vhost_vsock according to guest context id */
286*4882a593Smuzhiyun vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
287*4882a593Smuzhiyun if (!vsock)
288*4882a593Smuzhiyun goto out;
289*4882a593Smuzhiyun
290*4882a593Smuzhiyun spin_lock_bh(&vsock->send_pkt_list_lock);
291*4882a593Smuzhiyun list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
292*4882a593Smuzhiyun if (pkt->vsk != vsk)
293*4882a593Smuzhiyun continue;
294*4882a593Smuzhiyun list_move(&pkt->list, &freeme);
295*4882a593Smuzhiyun }
296*4882a593Smuzhiyun spin_unlock_bh(&vsock->send_pkt_list_lock);
297*4882a593Smuzhiyun
298*4882a593Smuzhiyun list_for_each_entry_safe(pkt, n, &freeme, list) {
299*4882a593Smuzhiyun if (pkt->reply)
300*4882a593Smuzhiyun cnt++;
301*4882a593Smuzhiyun list_del(&pkt->list);
302*4882a593Smuzhiyun virtio_transport_free_pkt(pkt);
303*4882a593Smuzhiyun }
304*4882a593Smuzhiyun
305*4882a593Smuzhiyun if (cnt) {
306*4882a593Smuzhiyun struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
307*4882a593Smuzhiyun int new_cnt;
308*4882a593Smuzhiyun
309*4882a593Smuzhiyun new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
310*4882a593Smuzhiyun if (new_cnt + cnt >= tx_vq->num && new_cnt < tx_vq->num)
311*4882a593Smuzhiyun vhost_poll_queue(&tx_vq->poll);
312*4882a593Smuzhiyun }
313*4882a593Smuzhiyun
314*4882a593Smuzhiyun ret = 0;
315*4882a593Smuzhiyun out:
316*4882a593Smuzhiyun rcu_read_unlock();
317*4882a593Smuzhiyun return ret;
318*4882a593Smuzhiyun }
319*4882a593Smuzhiyun
320*4882a593Smuzhiyun static struct virtio_vsock_pkt *
vhost_vsock_alloc_pkt(struct vhost_virtqueue * vq,unsigned int out,unsigned int in)321*4882a593Smuzhiyun vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
322*4882a593Smuzhiyun unsigned int out, unsigned int in)
323*4882a593Smuzhiyun {
324*4882a593Smuzhiyun struct virtio_vsock_pkt *pkt;
325*4882a593Smuzhiyun struct iov_iter iov_iter;
326*4882a593Smuzhiyun size_t nbytes;
327*4882a593Smuzhiyun size_t len;
328*4882a593Smuzhiyun
329*4882a593Smuzhiyun if (in != 0) {
330*4882a593Smuzhiyun vq_err(vq, "Expected 0 input buffers, got %u\n", in);
331*4882a593Smuzhiyun return NULL;
332*4882a593Smuzhiyun }
333*4882a593Smuzhiyun
334*4882a593Smuzhiyun pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
335*4882a593Smuzhiyun if (!pkt)
336*4882a593Smuzhiyun return NULL;
337*4882a593Smuzhiyun
338*4882a593Smuzhiyun len = iov_length(vq->iov, out);
339*4882a593Smuzhiyun iov_iter_init(&iov_iter, WRITE, vq->iov, out, len);
340*4882a593Smuzhiyun
341*4882a593Smuzhiyun nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
342*4882a593Smuzhiyun if (nbytes != sizeof(pkt->hdr)) {
343*4882a593Smuzhiyun vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
344*4882a593Smuzhiyun sizeof(pkt->hdr), nbytes);
345*4882a593Smuzhiyun kfree(pkt);
346*4882a593Smuzhiyun return NULL;
347*4882a593Smuzhiyun }
348*4882a593Smuzhiyun
349*4882a593Smuzhiyun if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM)
350*4882a593Smuzhiyun pkt->len = le32_to_cpu(pkt->hdr.len);
351*4882a593Smuzhiyun
352*4882a593Smuzhiyun /* No payload */
353*4882a593Smuzhiyun if (!pkt->len)
354*4882a593Smuzhiyun return pkt;
355*4882a593Smuzhiyun
356*4882a593Smuzhiyun /* The pkt is too big */
357*4882a593Smuzhiyun if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
358*4882a593Smuzhiyun kfree(pkt);
359*4882a593Smuzhiyun return NULL;
360*4882a593Smuzhiyun }
361*4882a593Smuzhiyun
362*4882a593Smuzhiyun pkt->buf = kvmalloc(pkt->len, GFP_KERNEL);
363*4882a593Smuzhiyun if (!pkt->buf) {
364*4882a593Smuzhiyun kfree(pkt);
365*4882a593Smuzhiyun return NULL;
366*4882a593Smuzhiyun }
367*4882a593Smuzhiyun
368*4882a593Smuzhiyun pkt->buf_len = pkt->len;
369*4882a593Smuzhiyun
370*4882a593Smuzhiyun nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
371*4882a593Smuzhiyun if (nbytes != pkt->len) {
372*4882a593Smuzhiyun vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
373*4882a593Smuzhiyun pkt->len, nbytes);
374*4882a593Smuzhiyun virtio_transport_free_pkt(pkt);
375*4882a593Smuzhiyun return NULL;
376*4882a593Smuzhiyun }
377*4882a593Smuzhiyun
378*4882a593Smuzhiyun return pkt;
379*4882a593Smuzhiyun }
380*4882a593Smuzhiyun
381*4882a593Smuzhiyun /* Is there space left for replies to rx packets? */
vhost_vsock_more_replies(struct vhost_vsock * vsock)382*4882a593Smuzhiyun static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
383*4882a593Smuzhiyun {
384*4882a593Smuzhiyun struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX];
385*4882a593Smuzhiyun int val;
386*4882a593Smuzhiyun
387*4882a593Smuzhiyun smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
388*4882a593Smuzhiyun val = atomic_read(&vsock->queued_replies);
389*4882a593Smuzhiyun
390*4882a593Smuzhiyun return val < vq->num;
391*4882a593Smuzhiyun }
392*4882a593Smuzhiyun
393*4882a593Smuzhiyun static struct virtio_transport vhost_transport = {
394*4882a593Smuzhiyun .transport = {
395*4882a593Smuzhiyun .module = THIS_MODULE,
396*4882a593Smuzhiyun
397*4882a593Smuzhiyun .get_local_cid = vhost_transport_get_local_cid,
398*4882a593Smuzhiyun
399*4882a593Smuzhiyun .init = virtio_transport_do_socket_init,
400*4882a593Smuzhiyun .destruct = virtio_transport_destruct,
401*4882a593Smuzhiyun .release = virtio_transport_release,
402*4882a593Smuzhiyun .connect = virtio_transport_connect,
403*4882a593Smuzhiyun .shutdown = virtio_transport_shutdown,
404*4882a593Smuzhiyun .cancel_pkt = vhost_transport_cancel_pkt,
405*4882a593Smuzhiyun
406*4882a593Smuzhiyun .dgram_enqueue = virtio_transport_dgram_enqueue,
407*4882a593Smuzhiyun .dgram_dequeue = virtio_transport_dgram_dequeue,
408*4882a593Smuzhiyun .dgram_bind = virtio_transport_dgram_bind,
409*4882a593Smuzhiyun .dgram_allow = virtio_transport_dgram_allow,
410*4882a593Smuzhiyun
411*4882a593Smuzhiyun .stream_enqueue = virtio_transport_stream_enqueue,
412*4882a593Smuzhiyun .stream_dequeue = virtio_transport_stream_dequeue,
413*4882a593Smuzhiyun .stream_has_data = virtio_transport_stream_has_data,
414*4882a593Smuzhiyun .stream_has_space = virtio_transport_stream_has_space,
415*4882a593Smuzhiyun .stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
416*4882a593Smuzhiyun .stream_is_active = virtio_transport_stream_is_active,
417*4882a593Smuzhiyun .stream_allow = virtio_transport_stream_allow,
418*4882a593Smuzhiyun
419*4882a593Smuzhiyun .notify_poll_in = virtio_transport_notify_poll_in,
420*4882a593Smuzhiyun .notify_poll_out = virtio_transport_notify_poll_out,
421*4882a593Smuzhiyun .notify_recv_init = virtio_transport_notify_recv_init,
422*4882a593Smuzhiyun .notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
423*4882a593Smuzhiyun .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
424*4882a593Smuzhiyun .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
425*4882a593Smuzhiyun .notify_send_init = virtio_transport_notify_send_init,
426*4882a593Smuzhiyun .notify_send_pre_block = virtio_transport_notify_send_pre_block,
427*4882a593Smuzhiyun .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
428*4882a593Smuzhiyun .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
429*4882a593Smuzhiyun .notify_buffer_size = virtio_transport_notify_buffer_size,
430*4882a593Smuzhiyun
431*4882a593Smuzhiyun },
432*4882a593Smuzhiyun
433*4882a593Smuzhiyun .send_pkt = vhost_transport_send_pkt,
434*4882a593Smuzhiyun };
435*4882a593Smuzhiyun
vhost_vsock_handle_tx_kick(struct vhost_work * work)436*4882a593Smuzhiyun static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
437*4882a593Smuzhiyun {
438*4882a593Smuzhiyun struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
439*4882a593Smuzhiyun poll.work);
440*4882a593Smuzhiyun struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
441*4882a593Smuzhiyun dev);
442*4882a593Smuzhiyun struct virtio_vsock_pkt *pkt;
443*4882a593Smuzhiyun int head, pkts = 0, total_len = 0;
444*4882a593Smuzhiyun unsigned int out, in;
445*4882a593Smuzhiyun bool added = false;
446*4882a593Smuzhiyun
447*4882a593Smuzhiyun mutex_lock(&vq->mutex);
448*4882a593Smuzhiyun
449*4882a593Smuzhiyun if (!vhost_vq_get_backend(vq))
450*4882a593Smuzhiyun goto out;
451*4882a593Smuzhiyun
452*4882a593Smuzhiyun vhost_disable_notify(&vsock->dev, vq);
453*4882a593Smuzhiyun do {
454*4882a593Smuzhiyun u32 len;
455*4882a593Smuzhiyun
456*4882a593Smuzhiyun if (!vhost_vsock_more_replies(vsock)) {
457*4882a593Smuzhiyun /* Stop tx until the device processes already
458*4882a593Smuzhiyun * pending replies. Leave tx virtqueue
459*4882a593Smuzhiyun * callbacks disabled.
460*4882a593Smuzhiyun */
461*4882a593Smuzhiyun goto no_more_replies;
462*4882a593Smuzhiyun }
463*4882a593Smuzhiyun
464*4882a593Smuzhiyun head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
465*4882a593Smuzhiyun &out, &in, NULL, NULL);
466*4882a593Smuzhiyun if (head < 0)
467*4882a593Smuzhiyun break;
468*4882a593Smuzhiyun
469*4882a593Smuzhiyun if (head == vq->num) {
470*4882a593Smuzhiyun if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
471*4882a593Smuzhiyun vhost_disable_notify(&vsock->dev, vq);
472*4882a593Smuzhiyun continue;
473*4882a593Smuzhiyun }
474*4882a593Smuzhiyun break;
475*4882a593Smuzhiyun }
476*4882a593Smuzhiyun
477*4882a593Smuzhiyun pkt = vhost_vsock_alloc_pkt(vq, out, in);
478*4882a593Smuzhiyun if (!pkt) {
479*4882a593Smuzhiyun vq_err(vq, "Faulted on pkt\n");
480*4882a593Smuzhiyun continue;
481*4882a593Smuzhiyun }
482*4882a593Smuzhiyun
483*4882a593Smuzhiyun len = pkt->len;
484*4882a593Smuzhiyun
485*4882a593Smuzhiyun /* Deliver to monitoring devices all received packets */
486*4882a593Smuzhiyun virtio_transport_deliver_tap_pkt(pkt);
487*4882a593Smuzhiyun
488*4882a593Smuzhiyun /* Only accept correctly addressed packets */
489*4882a593Smuzhiyun if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid &&
490*4882a593Smuzhiyun le64_to_cpu(pkt->hdr.dst_cid) ==
491*4882a593Smuzhiyun vhost_transport_get_local_cid())
492*4882a593Smuzhiyun virtio_transport_recv_pkt(&vhost_transport, pkt);
493*4882a593Smuzhiyun else
494*4882a593Smuzhiyun virtio_transport_free_pkt(pkt);
495*4882a593Smuzhiyun
496*4882a593Smuzhiyun len += sizeof(pkt->hdr);
497*4882a593Smuzhiyun vhost_add_used(vq, head, 0);
498*4882a593Smuzhiyun total_len += len;
499*4882a593Smuzhiyun added = true;
500*4882a593Smuzhiyun } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
501*4882a593Smuzhiyun
502*4882a593Smuzhiyun no_more_replies:
503*4882a593Smuzhiyun if (added)
504*4882a593Smuzhiyun vhost_signal(&vsock->dev, vq);
505*4882a593Smuzhiyun
506*4882a593Smuzhiyun out:
507*4882a593Smuzhiyun mutex_unlock(&vq->mutex);
508*4882a593Smuzhiyun }
509*4882a593Smuzhiyun
vhost_vsock_handle_rx_kick(struct vhost_work * work)510*4882a593Smuzhiyun static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
511*4882a593Smuzhiyun {
512*4882a593Smuzhiyun struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
513*4882a593Smuzhiyun poll.work);
514*4882a593Smuzhiyun struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
515*4882a593Smuzhiyun dev);
516*4882a593Smuzhiyun
517*4882a593Smuzhiyun vhost_transport_do_send_pkt(vsock, vq);
518*4882a593Smuzhiyun }
519*4882a593Smuzhiyun
vhost_vsock_start(struct vhost_vsock * vsock)520*4882a593Smuzhiyun static int vhost_vsock_start(struct vhost_vsock *vsock)
521*4882a593Smuzhiyun {
522*4882a593Smuzhiyun struct vhost_virtqueue *vq;
523*4882a593Smuzhiyun size_t i;
524*4882a593Smuzhiyun int ret;
525*4882a593Smuzhiyun
526*4882a593Smuzhiyun mutex_lock(&vsock->dev.mutex);
527*4882a593Smuzhiyun
528*4882a593Smuzhiyun ret = vhost_dev_check_owner(&vsock->dev);
529*4882a593Smuzhiyun if (ret)
530*4882a593Smuzhiyun goto err;
531*4882a593Smuzhiyun
532*4882a593Smuzhiyun for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
533*4882a593Smuzhiyun vq = &vsock->vqs[i];
534*4882a593Smuzhiyun
535*4882a593Smuzhiyun mutex_lock(&vq->mutex);
536*4882a593Smuzhiyun
537*4882a593Smuzhiyun if (!vhost_vq_access_ok(vq)) {
538*4882a593Smuzhiyun ret = -EFAULT;
539*4882a593Smuzhiyun goto err_vq;
540*4882a593Smuzhiyun }
541*4882a593Smuzhiyun
542*4882a593Smuzhiyun if (!vhost_vq_get_backend(vq)) {
543*4882a593Smuzhiyun vhost_vq_set_backend(vq, vsock);
544*4882a593Smuzhiyun ret = vhost_vq_init_access(vq);
545*4882a593Smuzhiyun if (ret)
546*4882a593Smuzhiyun goto err_vq;
547*4882a593Smuzhiyun }
548*4882a593Smuzhiyun
549*4882a593Smuzhiyun mutex_unlock(&vq->mutex);
550*4882a593Smuzhiyun }
551*4882a593Smuzhiyun
552*4882a593Smuzhiyun /* Some packets may have been queued before the device was started,
553*4882a593Smuzhiyun * let's kick the send worker to send them.
554*4882a593Smuzhiyun */
555*4882a593Smuzhiyun vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
556*4882a593Smuzhiyun
557*4882a593Smuzhiyun mutex_unlock(&vsock->dev.mutex);
558*4882a593Smuzhiyun return 0;
559*4882a593Smuzhiyun
560*4882a593Smuzhiyun err_vq:
561*4882a593Smuzhiyun vhost_vq_set_backend(vq, NULL);
562*4882a593Smuzhiyun mutex_unlock(&vq->mutex);
563*4882a593Smuzhiyun
564*4882a593Smuzhiyun for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
565*4882a593Smuzhiyun vq = &vsock->vqs[i];
566*4882a593Smuzhiyun
567*4882a593Smuzhiyun mutex_lock(&vq->mutex);
568*4882a593Smuzhiyun vhost_vq_set_backend(vq, NULL);
569*4882a593Smuzhiyun mutex_unlock(&vq->mutex);
570*4882a593Smuzhiyun }
571*4882a593Smuzhiyun err:
572*4882a593Smuzhiyun mutex_unlock(&vsock->dev.mutex);
573*4882a593Smuzhiyun return ret;
574*4882a593Smuzhiyun }
575*4882a593Smuzhiyun
vhost_vsock_stop(struct vhost_vsock * vsock,bool check_owner)576*4882a593Smuzhiyun static int vhost_vsock_stop(struct vhost_vsock *vsock, bool check_owner)
577*4882a593Smuzhiyun {
578*4882a593Smuzhiyun size_t i;
579*4882a593Smuzhiyun int ret = 0;
580*4882a593Smuzhiyun
581*4882a593Smuzhiyun mutex_lock(&vsock->dev.mutex);
582*4882a593Smuzhiyun
583*4882a593Smuzhiyun if (check_owner) {
584*4882a593Smuzhiyun ret = vhost_dev_check_owner(&vsock->dev);
585*4882a593Smuzhiyun if (ret)
586*4882a593Smuzhiyun goto err;
587*4882a593Smuzhiyun }
588*4882a593Smuzhiyun
589*4882a593Smuzhiyun for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
590*4882a593Smuzhiyun struct vhost_virtqueue *vq = &vsock->vqs[i];
591*4882a593Smuzhiyun
592*4882a593Smuzhiyun mutex_lock(&vq->mutex);
593*4882a593Smuzhiyun vhost_vq_set_backend(vq, NULL);
594*4882a593Smuzhiyun mutex_unlock(&vq->mutex);
595*4882a593Smuzhiyun }
596*4882a593Smuzhiyun
597*4882a593Smuzhiyun err:
598*4882a593Smuzhiyun mutex_unlock(&vsock->dev.mutex);
599*4882a593Smuzhiyun return ret;
600*4882a593Smuzhiyun }
601*4882a593Smuzhiyun
vhost_vsock_free(struct vhost_vsock * vsock)602*4882a593Smuzhiyun static void vhost_vsock_free(struct vhost_vsock *vsock)
603*4882a593Smuzhiyun {
604*4882a593Smuzhiyun kvfree(vsock);
605*4882a593Smuzhiyun }
606*4882a593Smuzhiyun
vhost_vsock_dev_open(struct inode * inode,struct file * file)607*4882a593Smuzhiyun static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
608*4882a593Smuzhiyun {
609*4882a593Smuzhiyun struct vhost_virtqueue **vqs;
610*4882a593Smuzhiyun struct vhost_vsock *vsock;
611*4882a593Smuzhiyun int ret;
612*4882a593Smuzhiyun
613*4882a593Smuzhiyun /* This struct is large and allocation could fail, fall back to vmalloc
614*4882a593Smuzhiyun * if there is no other way.
615*4882a593Smuzhiyun */
616*4882a593Smuzhiyun vsock = kvmalloc(sizeof(*vsock), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
617*4882a593Smuzhiyun if (!vsock)
618*4882a593Smuzhiyun return -ENOMEM;
619*4882a593Smuzhiyun
620*4882a593Smuzhiyun vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL);
621*4882a593Smuzhiyun if (!vqs) {
622*4882a593Smuzhiyun ret = -ENOMEM;
623*4882a593Smuzhiyun goto out;
624*4882a593Smuzhiyun }
625*4882a593Smuzhiyun
626*4882a593Smuzhiyun vsock->guest_cid = 0; /* no CID assigned yet */
627*4882a593Smuzhiyun
628*4882a593Smuzhiyun atomic_set(&vsock->queued_replies, 0);
629*4882a593Smuzhiyun
630*4882a593Smuzhiyun vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX];
631*4882a593Smuzhiyun vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX];
632*4882a593Smuzhiyun vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
633*4882a593Smuzhiyun vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
634*4882a593Smuzhiyun
635*4882a593Smuzhiyun vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs),
636*4882a593Smuzhiyun UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT,
637*4882a593Smuzhiyun VHOST_VSOCK_WEIGHT, true, NULL);
638*4882a593Smuzhiyun
639*4882a593Smuzhiyun file->private_data = vsock;
640*4882a593Smuzhiyun spin_lock_init(&vsock->send_pkt_list_lock);
641*4882a593Smuzhiyun INIT_LIST_HEAD(&vsock->send_pkt_list);
642*4882a593Smuzhiyun vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
643*4882a593Smuzhiyun return 0;
644*4882a593Smuzhiyun
645*4882a593Smuzhiyun out:
646*4882a593Smuzhiyun vhost_vsock_free(vsock);
647*4882a593Smuzhiyun return ret;
648*4882a593Smuzhiyun }
649*4882a593Smuzhiyun
vhost_vsock_flush(struct vhost_vsock * vsock)650*4882a593Smuzhiyun static void vhost_vsock_flush(struct vhost_vsock *vsock)
651*4882a593Smuzhiyun {
652*4882a593Smuzhiyun int i;
653*4882a593Smuzhiyun
654*4882a593Smuzhiyun for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++)
655*4882a593Smuzhiyun if (vsock->vqs[i].handle_kick)
656*4882a593Smuzhiyun vhost_poll_flush(&vsock->vqs[i].poll);
657*4882a593Smuzhiyun vhost_work_flush(&vsock->dev, &vsock->send_pkt_work);
658*4882a593Smuzhiyun }
659*4882a593Smuzhiyun
vhost_vsock_reset_orphans(struct sock * sk)660*4882a593Smuzhiyun static void vhost_vsock_reset_orphans(struct sock *sk)
661*4882a593Smuzhiyun {
662*4882a593Smuzhiyun struct vsock_sock *vsk = vsock_sk(sk);
663*4882a593Smuzhiyun
664*4882a593Smuzhiyun /* vmci_transport.c doesn't take sk_lock here either. At least we're
665*4882a593Smuzhiyun * under vsock_table_lock so the sock cannot disappear while we're
666*4882a593Smuzhiyun * executing.
667*4882a593Smuzhiyun */
668*4882a593Smuzhiyun
669*4882a593Smuzhiyun /* If the peer is still valid, no need to reset connection */
670*4882a593Smuzhiyun if (vhost_vsock_get(vsk->remote_addr.svm_cid))
671*4882a593Smuzhiyun return;
672*4882a593Smuzhiyun
673*4882a593Smuzhiyun /* If the close timeout is pending, let it expire. This avoids races
674*4882a593Smuzhiyun * with the timeout callback.
675*4882a593Smuzhiyun */
676*4882a593Smuzhiyun if (vsk->close_work_scheduled)
677*4882a593Smuzhiyun return;
678*4882a593Smuzhiyun
679*4882a593Smuzhiyun sock_set_flag(sk, SOCK_DONE);
680*4882a593Smuzhiyun vsk->peer_shutdown = SHUTDOWN_MASK;
681*4882a593Smuzhiyun sk->sk_state = SS_UNCONNECTED;
682*4882a593Smuzhiyun sk->sk_err = ECONNRESET;
683*4882a593Smuzhiyun sk->sk_error_report(sk);
684*4882a593Smuzhiyun }
685*4882a593Smuzhiyun
vhost_vsock_dev_release(struct inode * inode,struct file * file)686*4882a593Smuzhiyun static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
687*4882a593Smuzhiyun {
688*4882a593Smuzhiyun struct vhost_vsock *vsock = file->private_data;
689*4882a593Smuzhiyun
690*4882a593Smuzhiyun mutex_lock(&vhost_vsock_mutex);
691*4882a593Smuzhiyun if (vsock->guest_cid)
692*4882a593Smuzhiyun hash_del_rcu(&vsock->hash);
693*4882a593Smuzhiyun mutex_unlock(&vhost_vsock_mutex);
694*4882a593Smuzhiyun
695*4882a593Smuzhiyun /* Wait for other CPUs to finish using vsock */
696*4882a593Smuzhiyun synchronize_rcu();
697*4882a593Smuzhiyun
698*4882a593Smuzhiyun /* Iterating over all connections for all CIDs to find orphans is
699*4882a593Smuzhiyun * inefficient. Room for improvement here. */
700*4882a593Smuzhiyun vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
701*4882a593Smuzhiyun
702*4882a593Smuzhiyun /* Don't check the owner, because we are in the release path, so we
703*4882a593Smuzhiyun * need to stop the vsock device in any case.
704*4882a593Smuzhiyun * vhost_vsock_stop() can not fail in this case, so we don't need to
705*4882a593Smuzhiyun * check the return code.
706*4882a593Smuzhiyun */
707*4882a593Smuzhiyun vhost_vsock_stop(vsock, false);
708*4882a593Smuzhiyun vhost_vsock_flush(vsock);
709*4882a593Smuzhiyun vhost_dev_stop(&vsock->dev);
710*4882a593Smuzhiyun
711*4882a593Smuzhiyun spin_lock_bh(&vsock->send_pkt_list_lock);
712*4882a593Smuzhiyun while (!list_empty(&vsock->send_pkt_list)) {
713*4882a593Smuzhiyun struct virtio_vsock_pkt *pkt;
714*4882a593Smuzhiyun
715*4882a593Smuzhiyun pkt = list_first_entry(&vsock->send_pkt_list,
716*4882a593Smuzhiyun struct virtio_vsock_pkt, list);
717*4882a593Smuzhiyun list_del_init(&pkt->list);
718*4882a593Smuzhiyun virtio_transport_free_pkt(pkt);
719*4882a593Smuzhiyun }
720*4882a593Smuzhiyun spin_unlock_bh(&vsock->send_pkt_list_lock);
721*4882a593Smuzhiyun
722*4882a593Smuzhiyun vhost_dev_cleanup(&vsock->dev);
723*4882a593Smuzhiyun kfree(vsock->dev.vqs);
724*4882a593Smuzhiyun vhost_vsock_free(vsock);
725*4882a593Smuzhiyun return 0;
726*4882a593Smuzhiyun }
727*4882a593Smuzhiyun
vhost_vsock_set_cid(struct vhost_vsock * vsock,u64 guest_cid)728*4882a593Smuzhiyun static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
729*4882a593Smuzhiyun {
730*4882a593Smuzhiyun struct vhost_vsock *other;
731*4882a593Smuzhiyun
732*4882a593Smuzhiyun /* Refuse reserved CIDs */
733*4882a593Smuzhiyun if (guest_cid <= VMADDR_CID_HOST ||
734*4882a593Smuzhiyun guest_cid == U32_MAX)
735*4882a593Smuzhiyun return -EINVAL;
736*4882a593Smuzhiyun
737*4882a593Smuzhiyun /* 64-bit CIDs are not yet supported */
738*4882a593Smuzhiyun if (guest_cid > U32_MAX)
739*4882a593Smuzhiyun return -EINVAL;
740*4882a593Smuzhiyun
741*4882a593Smuzhiyun /* Refuse if CID is assigned to the guest->host transport (i.e. nested
742*4882a593Smuzhiyun * VM), to make the loopback work.
743*4882a593Smuzhiyun */
744*4882a593Smuzhiyun if (vsock_find_cid(guest_cid))
745*4882a593Smuzhiyun return -EADDRINUSE;
746*4882a593Smuzhiyun
747*4882a593Smuzhiyun /* Refuse if CID is already in use */
748*4882a593Smuzhiyun mutex_lock(&vhost_vsock_mutex);
749*4882a593Smuzhiyun other = vhost_vsock_get(guest_cid);
750*4882a593Smuzhiyun if (other && other != vsock) {
751*4882a593Smuzhiyun mutex_unlock(&vhost_vsock_mutex);
752*4882a593Smuzhiyun return -EADDRINUSE;
753*4882a593Smuzhiyun }
754*4882a593Smuzhiyun
755*4882a593Smuzhiyun if (vsock->guest_cid)
756*4882a593Smuzhiyun hash_del_rcu(&vsock->hash);
757*4882a593Smuzhiyun
758*4882a593Smuzhiyun vsock->guest_cid = guest_cid;
759*4882a593Smuzhiyun hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
760*4882a593Smuzhiyun mutex_unlock(&vhost_vsock_mutex);
761*4882a593Smuzhiyun
762*4882a593Smuzhiyun return 0;
763*4882a593Smuzhiyun }
764*4882a593Smuzhiyun
vhost_vsock_set_features(struct vhost_vsock * vsock,u64 features)765*4882a593Smuzhiyun static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
766*4882a593Smuzhiyun {
767*4882a593Smuzhiyun struct vhost_virtqueue *vq;
768*4882a593Smuzhiyun int i;
769*4882a593Smuzhiyun
770*4882a593Smuzhiyun if (features & ~VHOST_VSOCK_FEATURES)
771*4882a593Smuzhiyun return -EOPNOTSUPP;
772*4882a593Smuzhiyun
773*4882a593Smuzhiyun mutex_lock(&vsock->dev.mutex);
774*4882a593Smuzhiyun if ((features & (1 << VHOST_F_LOG_ALL)) &&
775*4882a593Smuzhiyun !vhost_log_access_ok(&vsock->dev)) {
776*4882a593Smuzhiyun mutex_unlock(&vsock->dev.mutex);
777*4882a593Smuzhiyun return -EFAULT;
778*4882a593Smuzhiyun }
779*4882a593Smuzhiyun
780*4882a593Smuzhiyun for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
781*4882a593Smuzhiyun vq = &vsock->vqs[i];
782*4882a593Smuzhiyun mutex_lock(&vq->mutex);
783*4882a593Smuzhiyun vq->acked_features = features;
784*4882a593Smuzhiyun mutex_unlock(&vq->mutex);
785*4882a593Smuzhiyun }
786*4882a593Smuzhiyun mutex_unlock(&vsock->dev.mutex);
787*4882a593Smuzhiyun return 0;
788*4882a593Smuzhiyun }
789*4882a593Smuzhiyun
vhost_vsock_dev_ioctl(struct file * f,unsigned int ioctl,unsigned long arg)790*4882a593Smuzhiyun static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl,
791*4882a593Smuzhiyun unsigned long arg)
792*4882a593Smuzhiyun {
793*4882a593Smuzhiyun struct vhost_vsock *vsock = f->private_data;
794*4882a593Smuzhiyun void __user *argp = (void __user *)arg;
795*4882a593Smuzhiyun u64 guest_cid;
796*4882a593Smuzhiyun u64 features;
797*4882a593Smuzhiyun int start;
798*4882a593Smuzhiyun int r;
799*4882a593Smuzhiyun
800*4882a593Smuzhiyun switch (ioctl) {
801*4882a593Smuzhiyun case VHOST_VSOCK_SET_GUEST_CID:
802*4882a593Smuzhiyun if (copy_from_user(&guest_cid, argp, sizeof(guest_cid)))
803*4882a593Smuzhiyun return -EFAULT;
804*4882a593Smuzhiyun return vhost_vsock_set_cid(vsock, guest_cid);
805*4882a593Smuzhiyun case VHOST_VSOCK_SET_RUNNING:
806*4882a593Smuzhiyun if (copy_from_user(&start, argp, sizeof(start)))
807*4882a593Smuzhiyun return -EFAULT;
808*4882a593Smuzhiyun if (start)
809*4882a593Smuzhiyun return vhost_vsock_start(vsock);
810*4882a593Smuzhiyun else
811*4882a593Smuzhiyun return vhost_vsock_stop(vsock, true);
812*4882a593Smuzhiyun case VHOST_GET_FEATURES:
813*4882a593Smuzhiyun features = VHOST_VSOCK_FEATURES;
814*4882a593Smuzhiyun if (copy_to_user(argp, &features, sizeof(features)))
815*4882a593Smuzhiyun return -EFAULT;
816*4882a593Smuzhiyun return 0;
817*4882a593Smuzhiyun case VHOST_SET_FEATURES:
818*4882a593Smuzhiyun if (copy_from_user(&features, argp, sizeof(features)))
819*4882a593Smuzhiyun return -EFAULT;
820*4882a593Smuzhiyun return vhost_vsock_set_features(vsock, features);
821*4882a593Smuzhiyun default:
822*4882a593Smuzhiyun mutex_lock(&vsock->dev.mutex);
823*4882a593Smuzhiyun r = vhost_dev_ioctl(&vsock->dev, ioctl, argp);
824*4882a593Smuzhiyun if (r == -ENOIOCTLCMD)
825*4882a593Smuzhiyun r = vhost_vring_ioctl(&vsock->dev, ioctl, argp);
826*4882a593Smuzhiyun else
827*4882a593Smuzhiyun vhost_vsock_flush(vsock);
828*4882a593Smuzhiyun mutex_unlock(&vsock->dev.mutex);
829*4882a593Smuzhiyun return r;
830*4882a593Smuzhiyun }
831*4882a593Smuzhiyun }
832*4882a593Smuzhiyun
833*4882a593Smuzhiyun static const struct file_operations vhost_vsock_fops = {
834*4882a593Smuzhiyun .owner = THIS_MODULE,
835*4882a593Smuzhiyun .open = vhost_vsock_dev_open,
836*4882a593Smuzhiyun .release = vhost_vsock_dev_release,
837*4882a593Smuzhiyun .llseek = noop_llseek,
838*4882a593Smuzhiyun .unlocked_ioctl = vhost_vsock_dev_ioctl,
839*4882a593Smuzhiyun .compat_ioctl = compat_ptr_ioctl,
840*4882a593Smuzhiyun };
841*4882a593Smuzhiyun
842*4882a593Smuzhiyun static struct miscdevice vhost_vsock_misc = {
843*4882a593Smuzhiyun .minor = VHOST_VSOCK_MINOR,
844*4882a593Smuzhiyun .name = "vhost-vsock",
845*4882a593Smuzhiyun .fops = &vhost_vsock_fops,
846*4882a593Smuzhiyun };
847*4882a593Smuzhiyun
vhost_vsock_init(void)848*4882a593Smuzhiyun static int __init vhost_vsock_init(void)
849*4882a593Smuzhiyun {
850*4882a593Smuzhiyun int ret;
851*4882a593Smuzhiyun
852*4882a593Smuzhiyun ret = vsock_core_register(&vhost_transport.transport,
853*4882a593Smuzhiyun VSOCK_TRANSPORT_F_H2G);
854*4882a593Smuzhiyun if (ret < 0)
855*4882a593Smuzhiyun return ret;
856*4882a593Smuzhiyun return misc_register(&vhost_vsock_misc);
857*4882a593Smuzhiyun };
858*4882a593Smuzhiyun
vhost_vsock_exit(void)859*4882a593Smuzhiyun static void __exit vhost_vsock_exit(void)
860*4882a593Smuzhiyun {
861*4882a593Smuzhiyun misc_deregister(&vhost_vsock_misc);
862*4882a593Smuzhiyun vsock_core_unregister(&vhost_transport.transport);
863*4882a593Smuzhiyun };
864*4882a593Smuzhiyun
865*4882a593Smuzhiyun module_init(vhost_vsock_init);
866*4882a593Smuzhiyun module_exit(vhost_vsock_exit);
867*4882a593Smuzhiyun MODULE_LICENSE("GPL v2");
868*4882a593Smuzhiyun MODULE_AUTHOR("Asias He");
869*4882a593Smuzhiyun MODULE_DESCRIPTION("vhost transport for vsock ");
870*4882a593Smuzhiyun MODULE_ALIAS_MISCDEV(VHOST_VSOCK_MINOR);
871*4882a593Smuzhiyun MODULE_ALIAS("devname:vhost-vsock");
872