Loading drivers/vhost/net.c +252 −41 Original line number Diff line number Diff line Loading @@ -74,6 +74,22 @@ static int move_iovec_hdr(struct iovec *from, struct iovec *to, } return seg; } /* Copy iovec entries for len bytes from iovec. */ static void copy_iovec_hdr(const struct iovec *from, struct iovec *to, size_t len, int iovcount) { int seg = 0; size_t size; while (len && seg < iovcount) { size = min(from->iov_len, len); to->iov_base = from->iov_base; to->iov_len = size; len -= size; ++from; ++to; ++seg; } } /* Caller must have TX VQ lock */ static void tx_poll_stop(struct vhost_net *net) Loading Loading @@ -129,7 +145,7 @@ static void handle_tx(struct vhost_net *net) if (wmem < sock->sk->sk_sndbuf / 2) tx_poll_stop(net); hdr_size = vq->hdr_size; hdr_size = vq->vhost_hlen; for (;;) { head = vhost_get_vq_desc(&net->dev, vq, vq->iov, Loading Loading @@ -172,7 +188,7 @@ static void handle_tx(struct vhost_net *net) /* TODO: Check specific error and bomb out unless ENOBUFS? */ err = sock->ops->sendmsg(NULL, sock, &msg, len); if (unlikely(err < 0)) { vhost_discard_vq_desc(vq); vhost_discard_vq_desc(vq, 1); tx_poll_start(net, sock); break; } Loading @@ -191,9 +207,82 @@ static void handle_tx(struct vhost_net *net) unuse_mm(net->dev.mm); } static int peek_head_len(struct sock *sk) { struct sk_buff *head; int len = 0; lock_sock(sk); head = skb_peek(&sk->sk_receive_queue); if (head) len = head->len; release_sock(sk); return len; } /* This is a multi-buffer version of vhost_get_desc, that works if * vq has read descriptors only. * @vq - the relevant virtqueue * @datalen - data length we'll be reading * @iovcount - returned count of io vectors we fill * @log - vhost log * @log_num - log offset * returns number of buffer heads allocated, negative on error */ static int get_rx_bufs(struct vhost_virtqueue *vq, struct vring_used_elem *heads, int datalen, unsigned *iovcount, struct vhost_log *log, unsigned *log_num) { unsigned int out, in; int seg = 0; int headcount = 0; unsigned d; int r, nlogs = 0; while (datalen > 0) { if (unlikely(headcount >= VHOST_NET_MAX_SG)) { r = -ENOBUFS; goto err; } d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg, ARRAY_SIZE(vq->iov) - seg, &out, &in, log, log_num); if (d == vq->num) { r = 0; goto err; } if (unlikely(out || in <= 0)) { vq_err(vq, "unexpected descriptor format for RX: " "out %d, in %d\n", out, in); r = -EINVAL; goto err; } if (unlikely(log)) { nlogs += *log_num; log += *log_num; } heads[headcount].id = d; heads[headcount].len = iov_length(vq->iov + seg, in); datalen -= heads[headcount].len; ++headcount; seg += in; } heads[headcount - 1].len += datalen; *iovcount = seg; if (unlikely(log)) *log_num = nlogs; return headcount; err: vhost_discard_vq_desc(vq, headcount); return r; } /* Expects to be always run from workqueue - which acts as * read-size critical section for our kind of RCU. */ static void handle_rx(struct vhost_net *net) static void handle_rx_big(struct vhost_net *net) { struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; unsigned out, in, log, s; Loading Loading @@ -223,7 +312,7 @@ static void handle_rx(struct vhost_net *net) use_mm(net->dev.mm); mutex_lock(&vq->mutex); vhost_disable_notify(vq); hdr_size = vq->hdr_size; hdr_size = vq->vhost_hlen; vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ? vq->log : NULL; Loading Loading @@ -270,14 +359,14 @@ static void handle_rx(struct vhost_net *net) len, MSG_DONTWAIT | MSG_TRUNC); /* TODO: Check specific error and bomb out unless EAGAIN? */ if (err < 0) { vhost_discard_vq_desc(vq); vhost_discard_vq_desc(vq, 1); break; } /* TODO: Should check and handle checksum. */ if (err > len) { pr_debug("Discarded truncated rx packet: " " len %d > %zd\n", err, len); vhost_discard_vq_desc(vq); vhost_discard_vq_desc(vq, 1); continue; } len = err; Loading @@ -302,54 +391,175 @@ static void handle_rx(struct vhost_net *net) unuse_mm(net->dev.mm); } static void handle_tx_kick(struct work_struct *work) /* Expects to be always run from workqueue - which acts as * read-size critical section for our kind of RCU. */ static void handle_rx_mergeable(struct vhost_net *net) { struct vhost_virtqueue *vq; struct vhost_net *net; vq = container_of(work, struct vhost_virtqueue, poll.work); net = container_of(vq->dev, struct vhost_net, dev); struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; unsigned uninitialized_var(in), log; struct vhost_log *vq_log; struct msghdr msg = { .msg_name = NULL, .msg_namelen = 0, .msg_control = NULL, /* FIXME: get and handle RX aux data. */ .msg_controllen = 0, .msg_iov = vq->iov, .msg_flags = MSG_DONTWAIT, }; struct virtio_net_hdr_mrg_rxbuf hdr = { .hdr.flags = 0, .hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE }; size_t total_len = 0; int err, headcount; size_t vhost_hlen, sock_hlen; size_t vhost_len, sock_len; struct socket *sock = rcu_dereference(vq->private_data); if (!sock || skb_queue_empty(&sock->sk->sk_receive_queue)) return; use_mm(net->dev.mm); mutex_lock(&vq->mutex); vhost_disable_notify(vq); vhost_hlen = vq->vhost_hlen; sock_hlen = vq->sock_hlen; vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ? vq->log : NULL; while ((sock_len = peek_head_len(sock->sk))) { sock_len += sock_hlen; vhost_len = sock_len + vhost_hlen; headcount = get_rx_bufs(vq, vq->heads, vhost_len, &in, vq_log, &log); /* On error, stop handling until the next kick. */ if (unlikely(headcount < 0)) break; /* OK, now we need to know about added descriptors. */ if (!headcount) { if (unlikely(vhost_enable_notify(vq))) { /* They have slipped one in as we were * doing that: check again. */ vhost_disable_notify(vq); continue; } /* Nothing new? Wait for eventfd to tell us * they refilled. */ break; } /* We don't need to be notified again. */ if (unlikely((vhost_hlen))) /* Skip header. TODO: support TSO. */ move_iovec_hdr(vq->iov, vq->hdr, vhost_hlen, in); else /* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF: * needed because sendmsg can modify msg_iov. */ copy_iovec_hdr(vq->iov, vq->hdr, sock_hlen, in); msg.msg_iovlen = in; err = sock->ops->recvmsg(NULL, sock, &msg, sock_len, MSG_DONTWAIT | MSG_TRUNC); /* Userspace might have consumed the packet meanwhile: * it's not supposed to do this usually, but might be hard * to prevent. Discard data we got (if any) and keep going. */ if (unlikely(err != sock_len)) { pr_debug("Discarded rx packet: " " len %d, expected %zd\n", err, sock_len); vhost_discard_vq_desc(vq, headcount); continue; } if (unlikely(vhost_hlen) && memcpy_toiovecend(vq->hdr, (unsigned char *)&hdr, 0, vhost_hlen)) { vq_err(vq, "Unable to write vnet_hdr at addr %p\n", vq->iov->iov_base); break; } /* TODO: Should check and handle checksum. */ if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF) && memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount, offsetof(typeof(hdr), num_buffers), sizeof hdr.num_buffers)) { vq_err(vq, "Failed num_buffers write"); vhost_discard_vq_desc(vq, headcount); break; } vhost_add_used_and_signal_n(&net->dev, vq, vq->heads, headcount); if (unlikely(vq_log)) vhost_log_write(vq, vq_log, log, vhost_len); total_len += vhost_len; if (unlikely(total_len >= VHOST_NET_WEIGHT)) { vhost_poll_queue(&vq->poll); break; } } mutex_unlock(&vq->mutex); unuse_mm(net->dev.mm); } static void handle_rx(struct vhost_net *net) { if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF)) handle_rx_mergeable(net); else handle_rx_big(net); } static void handle_tx_kick(struct vhost_work *work) { struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, poll.work); struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev); handle_tx(net); } static void handle_rx_kick(struct work_struct *work) static void handle_rx_kick(struct vhost_work *work) { struct vhost_virtqueue *vq; struct vhost_net *net; vq = container_of(work, struct vhost_virtqueue, poll.work); net = container_of(vq->dev, struct vhost_net, dev); struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, poll.work); struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev); handle_rx(net); } static void handle_tx_net(struct work_struct *work) static void handle_tx_net(struct vhost_work *work) { struct vhost_net *net; net = container_of(work, struct vhost_net, poll[VHOST_NET_VQ_TX].work); struct vhost_net *net = container_of(work, struct vhost_net, poll[VHOST_NET_VQ_TX].work); handle_tx(net); } static void handle_rx_net(struct work_struct *work) static void handle_rx_net(struct vhost_work *work) { struct vhost_net *net; net = container_of(work, struct vhost_net, poll[VHOST_NET_VQ_RX].work); struct vhost_net *net = container_of(work, struct vhost_net, poll[VHOST_NET_VQ_RX].work); handle_rx(net); } static int vhost_net_open(struct inode *inode, struct file *f) { struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL); struct vhost_dev *dev; int r; if (!n) return -ENOMEM; dev = &n->dev; n->vqs[VHOST_NET_VQ_TX].handle_kick = handle_tx_kick; n->vqs[VHOST_NET_VQ_RX].handle_kick = handle_rx_kick; r = vhost_dev_init(&n->dev, n->vqs, VHOST_NET_VQ_MAX); r = vhost_dev_init(dev, n->vqs, VHOST_NET_VQ_MAX); if (r < 0) { kfree(n); return r; } vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT); vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN); vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev); vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev); n->tx_poll_state = VHOST_NET_POLL_DISABLED; f->private_data = n; Loading Loading @@ -573,9 +783,21 @@ static long vhost_net_reset_owner(struct vhost_net *n) static int vhost_net_set_features(struct vhost_net *n, u64 features) { size_t hdr_size = features & (1 << VHOST_NET_F_VIRTIO_NET_HDR) ? sizeof(struct virtio_net_hdr) : 0; size_t vhost_hlen, sock_hlen, hdr_len; int i; hdr_len = (features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ? sizeof(struct virtio_net_hdr_mrg_rxbuf) : sizeof(struct virtio_net_hdr); if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) { /* vhost provides vnet_hdr */ vhost_hlen = hdr_len; sock_hlen = 0; } else { /* socket provides vnet_hdr */ vhost_hlen = 0; sock_hlen = hdr_len; } mutex_lock(&n->dev.mutex); if ((features & (1 << VHOST_F_LOG_ALL)) && !vhost_log_access_ok(&n->dev)) { Loading @@ -586,7 +808,8 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features) smp_wmb(); for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { mutex_lock(&n->vqs[i].mutex); n->vqs[i].hdr_size = hdr_size; n->vqs[i].vhost_hlen = vhost_hlen; n->vqs[i].sock_hlen = sock_hlen; mutex_unlock(&n->vqs[i].mutex); } vhost_net_flush(n); Loading Loading @@ -656,25 +879,13 @@ static struct miscdevice vhost_net_misc = { static int vhost_net_init(void) { int r = vhost_init(); if (r) goto err_init; r = misc_register(&vhost_net_misc); if (r) goto err_reg; return 0; err_reg: vhost_cleanup(); err_init: return r; return misc_register(&vhost_net_misc); } module_init(vhost_net_init); static void vhost_net_exit(void) { misc_deregister(&vhost_net_misc); vhost_cleanup(); } module_exit(vhost_net_exit); Loading drivers/vhost/vhost.c +196 −32 Original line number Diff line number Diff line Loading @@ -17,12 +17,13 @@ #include <linux/mm.h> #include <linux/miscdevice.h> #include <linux/mutex.h> #include <linux/workqueue.h> #include <linux/rcupdate.h> #include <linux/poll.h> #include <linux/file.h> #include <linux/highmem.h> #include <linux/slab.h> #include <linux/kthread.h> #include <linux/cgroup.h> #include <linux/net.h> #include <linux/if_packet.h> Loading @@ -37,8 +38,6 @@ enum { VHOST_MEMORY_F_LOG = 0x1, }; static struct workqueue_struct *vhost_workqueue; static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, poll_table *pt) { Loading @@ -52,23 +51,31 @@ static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync, void *key) { struct vhost_poll *poll; poll = container_of(wait, struct vhost_poll, wait); struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait); if (!((unsigned long)key & poll->mask)) return 0; queue_work(vhost_workqueue, &poll->work); vhost_poll_queue(poll); return 0; } /* Init poll structure */ void vhost_poll_init(struct vhost_poll *poll, work_func_t func, unsigned long mask) void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, unsigned long mask, struct vhost_dev *dev) { INIT_WORK(&poll->work, func); struct vhost_work *work = &poll->work; init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup); init_poll_funcptr(&poll->table, vhost_poll_func); poll->mask = mask; poll->dev = dev; INIT_LIST_HEAD(&work->node); work->fn = fn; init_waitqueue_head(&work->done); work->flushing = 0; work->queue_seq = work->done_seq = 0; } /* Start polling a file. We add ourselves to file's wait queue. The caller must Loading @@ -92,12 +99,40 @@ void vhost_poll_stop(struct vhost_poll *poll) * locks that are also used by the callback. */ void vhost_poll_flush(struct vhost_poll *poll) { flush_work(&poll->work); struct vhost_work *work = &poll->work; unsigned seq; int left; int flushing; spin_lock_irq(&poll->dev->work_lock); seq = work->queue_seq; work->flushing++; spin_unlock_irq(&poll->dev->work_lock); wait_event(work->done, ({ spin_lock_irq(&poll->dev->work_lock); left = seq - work->done_seq <= 0; spin_unlock_irq(&poll->dev->work_lock); left; })); spin_lock_irq(&poll->dev->work_lock); flushing = --work->flushing; spin_unlock_irq(&poll->dev->work_lock); BUG_ON(flushing < 0); } void vhost_poll_queue(struct vhost_poll *poll) { queue_work(vhost_workqueue, &poll->work); struct vhost_dev *dev = poll->dev; struct vhost_work *work = &poll->work; unsigned long flags; spin_lock_irqsave(&dev->work_lock, flags); if (list_empty(&work->node)) { list_add_tail(&work->node, &dev->work_list); work->queue_seq++; wake_up_process(dev->worker); } spin_unlock_irqrestore(&dev->work_lock, flags); } static void vhost_vq_reset(struct vhost_dev *dev, Loading @@ -114,7 +149,8 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->used_flags = 0; vq->log_used = false; vq->log_addr = -1ull; vq->hdr_size = 0; vq->vhost_hlen = 0; vq->sock_hlen = 0; vq->private_data = NULL; vq->log_base = NULL; vq->error_ctx = NULL; Loading @@ -125,10 +161,51 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->log_ctx = NULL; } static int vhost_worker(void *data) { struct vhost_dev *dev = data; struct vhost_work *work = NULL; unsigned uninitialized_var(seq); for (;;) { /* mb paired w/ kthread_stop */ set_current_state(TASK_INTERRUPTIBLE); spin_lock_irq(&dev->work_lock); if (work) { work->done_seq = seq; if (work->flushing) wake_up_all(&work->done); } if (kthread_should_stop()) { spin_unlock_irq(&dev->work_lock); __set_current_state(TASK_RUNNING); return 0; } if (!list_empty(&dev->work_list)) { work = list_first_entry(&dev->work_list, struct vhost_work, node); list_del_init(&work->node); seq = work->queue_seq; } else work = NULL; spin_unlock_irq(&dev->work_lock); if (work) { __set_current_state(TASK_RUNNING); work->fn(work); } else schedule(); } } long vhost_dev_init(struct vhost_dev *dev, struct vhost_virtqueue *vqs, int nvqs) { int i; dev->vqs = vqs; dev->nvqs = nvqs; mutex_init(&dev->mutex); Loading @@ -136,6 +213,9 @@ long vhost_dev_init(struct vhost_dev *dev, dev->log_file = NULL; dev->memory = NULL; dev->mm = NULL; spin_lock_init(&dev->work_lock); INIT_LIST_HEAD(&dev->work_list); dev->worker = NULL; for (i = 0; i < dev->nvqs; ++i) { dev->vqs[i].dev = dev; Loading @@ -143,9 +223,9 @@ long vhost_dev_init(struct vhost_dev *dev, vhost_vq_reset(dev, dev->vqs + i); if (dev->vqs[i].handle_kick) vhost_poll_init(&dev->vqs[i].poll, dev->vqs[i].handle_kick, POLLIN); dev->vqs[i].handle_kick, POLLIN, dev); } return 0; } Loading @@ -159,12 +239,36 @@ long vhost_dev_check_owner(struct vhost_dev *dev) /* Caller should have device mutex */ static long vhost_dev_set_owner(struct vhost_dev *dev) { struct task_struct *worker; int err; /* Is there an owner already? */ if (dev->mm) return -EBUSY; if (dev->mm) { err = -EBUSY; goto err_mm; } /* No owner, become one */ dev->mm = get_task_mm(current); worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid); if (IS_ERR(worker)) { err = PTR_ERR(worker); goto err_worker; } dev->worker = worker; err = cgroup_attach_task_current_cg(worker); if (err) goto err_cgroup; wake_up_process(worker); /* avoid contributing to loadavg */ return 0; err_cgroup: kthread_stop(worker); err_worker: if (dev->mm) mmput(dev->mm); dev->mm = NULL; err_mm: return err; } /* Caller should have device mutex */ Loading Loading @@ -217,6 +321,9 @@ void vhost_dev_cleanup(struct vhost_dev *dev) if (dev->mm) mmput(dev->mm); dev->mm = NULL; WARN_ON(!list_empty(&dev->work_list)); kthread_stop(dev->worker); } static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) Loading Loading @@ -995,9 +1102,9 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, } /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */ void vhost_discard_vq_desc(struct vhost_virtqueue *vq) void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n) { vq->last_avail_idx--; vq->last_avail_idx -= n; } /* After we've used one of their buffers, we tell them about it. We'll then Loading Loading @@ -1042,6 +1149,67 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len) return 0; } static int __vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, unsigned count) { struct vring_used_elem __user *used; int start; start = vq->last_used_idx % vq->num; used = vq->used->ring + start; if (copy_to_user(used, heads, count * sizeof *used)) { vq_err(vq, "Failed to write used"); return -EFAULT; } if (unlikely(vq->log_used)) { /* Make sure data is seen before log. */ smp_wmb(); /* Log used ring entry write. */ log_write(vq->log_base, vq->log_addr + ((void __user *)used - (void __user *)vq->used), count * sizeof *used); } vq->last_used_idx += count; return 0; } /* After we've used one of their buffers, we tell them about it. We'll then * want to notify the guest, using eventfd. */ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, unsigned count) { int start, n, r; start = vq->last_used_idx % vq->num; n = vq->num - start; if (n < count) { r = __vhost_add_used_n(vq, heads, n); if (r < 0) return r; heads += n; count -= n; } r = __vhost_add_used_n(vq, heads, count); /* Make sure buffer is written before we update index. */ smp_wmb(); if (put_user(vq->last_used_idx, &vq->used->idx)) { vq_err(vq, "Failed to increment used idx"); return -EFAULT; } if (unlikely(vq->log_used)) { /* Log used index update. */ log_write(vq->log_base, vq->log_addr + offsetof(struct vring_used, idx), sizeof vq->used->idx); if (vq->log_ctx) eventfd_signal(vq->log_ctx, 1); } return r; } /* This actually signals the guest, using eventfd. */ void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq) { Loading Loading @@ -1076,6 +1244,15 @@ void vhost_add_used_and_signal(struct vhost_dev *dev, vhost_signal(dev, vq); } /* multi-buffer version of vhost_add_used_and_signal */ void vhost_add_used_and_signal_n(struct vhost_dev *dev, struct vhost_virtqueue *vq, struct vring_used_elem *heads, unsigned count) { vhost_add_used_n(vq, heads, count); vhost_signal(dev, vq); } /* OK, now we need to know about added descriptors. */ bool vhost_enable_notify(struct vhost_virtqueue *vq) { Loading @@ -1100,7 +1277,7 @@ bool vhost_enable_notify(struct vhost_virtqueue *vq) return false; } return avail_idx != vq->last_avail_idx; return avail_idx != vq->avail_idx; } /* We don't need to be notified again. */ Loading @@ -1115,16 +1292,3 @@ void vhost_disable_notify(struct vhost_virtqueue *vq) vq_err(vq, "Failed to enable notification at %p: %d\n", &vq->used->flags, r); } int vhost_init(void) { vhost_workqueue = create_singlethread_workqueue("vhost"); if (!vhost_workqueue) return -ENOMEM; return 0; } void vhost_cleanup(void) { destroy_workqueue(vhost_workqueue); } drivers/vhost/vhost.h +37 −18 Original line number Diff line number Diff line Loading @@ -5,13 +5,13 @@ #include <linux/vhost.h> #include <linux/mm.h> #include <linux/mutex.h> #include <linux/workqueue.h> #include <linux/poll.h> #include <linux/file.h> #include <linux/skbuff.h> #include <linux/uio.h> #include <linux/virtio_config.h> #include <linux/virtio_ring.h> #include <asm/atomic.h> struct vhost_device; Loading @@ -20,19 +20,31 @@ enum { VHOST_NET_MAX_SG = MAX_SKB_FRAGS + 2, }; struct vhost_work; typedef void (*vhost_work_fn_t)(struct vhost_work *work); struct vhost_work { struct list_head node; vhost_work_fn_t fn; wait_queue_head_t done; int flushing; unsigned queue_seq; unsigned done_seq; }; /* Poll a file (eventfd or socket) */ /* Note: there's nothing vhost specific about this structure. */ struct vhost_poll { poll_table table; wait_queue_head_t *wqh; wait_queue_t wait; /* struct which will handle all actual work. */ struct work_struct work; struct vhost_work work; unsigned long mask; struct vhost_dev *dev; }; void vhost_poll_init(struct vhost_poll *poll, work_func_t func, unsigned long mask); void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, unsigned long mask, struct vhost_dev *dev); void vhost_poll_start(struct vhost_poll *poll, struct file *file); void vhost_poll_stop(struct vhost_poll *poll); void vhost_poll_flush(struct vhost_poll *poll); Loading Loading @@ -63,7 +75,7 @@ struct vhost_virtqueue { struct vhost_poll poll; /* The routine to call when the Guest pings us, or timeout. */ work_func_t handle_kick; vhost_work_fn_t handle_kick; /* Last available index we saw. */ u16 last_avail_idx; Loading @@ -84,13 +96,15 @@ struct vhost_virtqueue { struct iovec indirect[VHOST_NET_MAX_SG]; struct iovec iov[VHOST_NET_MAX_SG]; struct iovec hdr[VHOST_NET_MAX_SG]; size_t hdr_size; size_t vhost_hlen; size_t sock_hlen; struct vring_used_elem heads[VHOST_NET_MAX_SG]; /* We use a kind of RCU to access private pointer. * All readers access it from workqueue, which makes it possible to * flush the workqueue instead of synchronize_rcu. Therefore readers do * All readers access it from worker, which makes it possible to * flush the vhost_work instead of synchronize_rcu. Therefore readers do * not need to call rcu_read_lock/rcu_read_unlock: the beginning of * work item execution acts instead of rcu_read_lock() and the end of * work item execution acts instead of rcu_read_lock(). * vhost_work execution acts instead of rcu_read_lock() and the end of * vhost_work execution acts instead of rcu_read_lock(). * Writers use virtqueue mutex. */ void *private_data; /* Log write descriptors */ Loading @@ -110,6 +124,9 @@ struct vhost_dev { int nvqs; struct file *log_file; struct eventfd_ctx *log_ctx; spinlock_t work_lock; struct list_head work_list; struct task_struct *worker; }; long vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue *vqs, int nvqs); Loading @@ -124,21 +141,22 @@ int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, struct iovec iov[], unsigned int iov_count, unsigned int *out_num, unsigned int *in_num, struct vhost_log *log, unsigned int *log_num); void vhost_discard_vq_desc(struct vhost_virtqueue *); void vhost_discard_vq_desc(struct vhost_virtqueue *, int n); int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len); void vhost_signal(struct vhost_dev *, struct vhost_virtqueue *); int vhost_add_used_n(struct vhost_virtqueue *, struct vring_used_elem *heads, unsigned count); void vhost_add_used_and_signal(struct vhost_dev *, struct vhost_virtqueue *, unsigned int head, int len); unsigned int id, int len); void vhost_add_used_and_signal_n(struct vhost_dev *, struct vhost_virtqueue *, struct vring_used_elem *heads, unsigned count); void vhost_signal(struct vhost_dev *, struct vhost_virtqueue *); void vhost_disable_notify(struct vhost_virtqueue *); bool vhost_enable_notify(struct vhost_virtqueue *); int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, unsigned int log_num, u64 len); int vhost_init(void); void vhost_cleanup(void); #define vq_err(vq, fmt, ...) do { \ pr_debug(pr_fmt(fmt), ##__VA_ARGS__); \ if ((vq)->error_ctx) \ Loading @@ -149,7 +167,8 @@ enum { VHOST_FEATURES = (1 << VIRTIO_F_NOTIFY_ON_EMPTY) | (1 << VIRTIO_RING_F_INDIRECT_DESC) | (1 << VHOST_F_LOG_ALL) | (1 << VHOST_NET_F_VIRTIO_NET_HDR), (1 << VHOST_NET_F_VIRTIO_NET_HDR) | (1 << VIRTIO_NET_F_MRG_RXBUF), }; static inline int vhost_has_feature(struct vhost_dev *dev, int bit) Loading include/linux/cgroup.h +7 −0 Original line number Diff line number Diff line Loading @@ -570,6 +570,7 @@ struct task_struct *cgroup_iter_next(struct cgroup *cgrp, void cgroup_iter_end(struct cgroup *cgrp, struct cgroup_iter *it); int cgroup_scan_tasks(struct cgroup_scanner *scan); int cgroup_attach_task(struct cgroup *, struct task_struct *); int cgroup_attach_task_current_cg(struct task_struct *); /* * CSS ID is ID for cgroup_subsys_state structs under subsys. This only works Loading Loading @@ -626,6 +627,12 @@ static inline int cgroupstats_build(struct cgroupstats *stats, return -EINVAL; } /* No cgroups - nothing to do */ static inline int cgroup_attach_task_current_cg(struct task_struct *t) { return 0; } #endif /* !CONFIG_CGROUPS */ #endif /* _LINUX_CGROUP_H */ kernel/cgroup.c +23 −0 Original line number Diff line number Diff line Loading @@ -1788,6 +1788,29 @@ int cgroup_attach_task(struct cgroup *cgrp, struct task_struct *tsk) return retval; } /** * cgroup_attach_task_current_cg - attach task 'tsk' to current task's cgroup * @tsk: the task to be attached */ int cgroup_attach_task_current_cg(struct task_struct *tsk) { struct cgroupfs_root *root; struct cgroup *cur_cg; int retval = 0; cgroup_lock(); for_each_active_root(root) { cur_cg = task_cgroup_from_root(current, root); retval = cgroup_attach_task(cur_cg, tsk); if (retval) break; } cgroup_unlock(); return retval; } EXPORT_SYMBOL_GPL(cgroup_attach_task_current_cg); /* * Attach task with pid 'pid' to cgroup 'cgrp'. Call with cgroup_mutex * held. May take task_lock of task Loading Loading
drivers/vhost/net.c +252 −41 Original line number Diff line number Diff line Loading @@ -74,6 +74,22 @@ static int move_iovec_hdr(struct iovec *from, struct iovec *to, } return seg; } /* Copy iovec entries for len bytes from iovec. */ static void copy_iovec_hdr(const struct iovec *from, struct iovec *to, size_t len, int iovcount) { int seg = 0; size_t size; while (len && seg < iovcount) { size = min(from->iov_len, len); to->iov_base = from->iov_base; to->iov_len = size; len -= size; ++from; ++to; ++seg; } } /* Caller must have TX VQ lock */ static void tx_poll_stop(struct vhost_net *net) Loading Loading @@ -129,7 +145,7 @@ static void handle_tx(struct vhost_net *net) if (wmem < sock->sk->sk_sndbuf / 2) tx_poll_stop(net); hdr_size = vq->hdr_size; hdr_size = vq->vhost_hlen; for (;;) { head = vhost_get_vq_desc(&net->dev, vq, vq->iov, Loading Loading @@ -172,7 +188,7 @@ static void handle_tx(struct vhost_net *net) /* TODO: Check specific error and bomb out unless ENOBUFS? */ err = sock->ops->sendmsg(NULL, sock, &msg, len); if (unlikely(err < 0)) { vhost_discard_vq_desc(vq); vhost_discard_vq_desc(vq, 1); tx_poll_start(net, sock); break; } Loading @@ -191,9 +207,82 @@ static void handle_tx(struct vhost_net *net) unuse_mm(net->dev.mm); } static int peek_head_len(struct sock *sk) { struct sk_buff *head; int len = 0; lock_sock(sk); head = skb_peek(&sk->sk_receive_queue); if (head) len = head->len; release_sock(sk); return len; } /* This is a multi-buffer version of vhost_get_desc, that works if * vq has read descriptors only. * @vq - the relevant virtqueue * @datalen - data length we'll be reading * @iovcount - returned count of io vectors we fill * @log - vhost log * @log_num - log offset * returns number of buffer heads allocated, negative on error */ static int get_rx_bufs(struct vhost_virtqueue *vq, struct vring_used_elem *heads, int datalen, unsigned *iovcount, struct vhost_log *log, unsigned *log_num) { unsigned int out, in; int seg = 0; int headcount = 0; unsigned d; int r, nlogs = 0; while (datalen > 0) { if (unlikely(headcount >= VHOST_NET_MAX_SG)) { r = -ENOBUFS; goto err; } d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg, ARRAY_SIZE(vq->iov) - seg, &out, &in, log, log_num); if (d == vq->num) { r = 0; goto err; } if (unlikely(out || in <= 0)) { vq_err(vq, "unexpected descriptor format for RX: " "out %d, in %d\n", out, in); r = -EINVAL; goto err; } if (unlikely(log)) { nlogs += *log_num; log += *log_num; } heads[headcount].id = d; heads[headcount].len = iov_length(vq->iov + seg, in); datalen -= heads[headcount].len; ++headcount; seg += in; } heads[headcount - 1].len += datalen; *iovcount = seg; if (unlikely(log)) *log_num = nlogs; return headcount; err: vhost_discard_vq_desc(vq, headcount); return r; } /* Expects to be always run from workqueue - which acts as * read-size critical section for our kind of RCU. */ static void handle_rx(struct vhost_net *net) static void handle_rx_big(struct vhost_net *net) { struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; unsigned out, in, log, s; Loading Loading @@ -223,7 +312,7 @@ static void handle_rx(struct vhost_net *net) use_mm(net->dev.mm); mutex_lock(&vq->mutex); vhost_disable_notify(vq); hdr_size = vq->hdr_size; hdr_size = vq->vhost_hlen; vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ? vq->log : NULL; Loading Loading @@ -270,14 +359,14 @@ static void handle_rx(struct vhost_net *net) len, MSG_DONTWAIT | MSG_TRUNC); /* TODO: Check specific error and bomb out unless EAGAIN? */ if (err < 0) { vhost_discard_vq_desc(vq); vhost_discard_vq_desc(vq, 1); break; } /* TODO: Should check and handle checksum. */ if (err > len) { pr_debug("Discarded truncated rx packet: " " len %d > %zd\n", err, len); vhost_discard_vq_desc(vq); vhost_discard_vq_desc(vq, 1); continue; } len = err; Loading @@ -302,54 +391,175 @@ static void handle_rx(struct vhost_net *net) unuse_mm(net->dev.mm); } static void handle_tx_kick(struct work_struct *work) /* Expects to be always run from workqueue - which acts as * read-size critical section for our kind of RCU. */ static void handle_rx_mergeable(struct vhost_net *net) { struct vhost_virtqueue *vq; struct vhost_net *net; vq = container_of(work, struct vhost_virtqueue, poll.work); net = container_of(vq->dev, struct vhost_net, dev); struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; unsigned uninitialized_var(in), log; struct vhost_log *vq_log; struct msghdr msg = { .msg_name = NULL, .msg_namelen = 0, .msg_control = NULL, /* FIXME: get and handle RX aux data. */ .msg_controllen = 0, .msg_iov = vq->iov, .msg_flags = MSG_DONTWAIT, }; struct virtio_net_hdr_mrg_rxbuf hdr = { .hdr.flags = 0, .hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE }; size_t total_len = 0; int err, headcount; size_t vhost_hlen, sock_hlen; size_t vhost_len, sock_len; struct socket *sock = rcu_dereference(vq->private_data); if (!sock || skb_queue_empty(&sock->sk->sk_receive_queue)) return; use_mm(net->dev.mm); mutex_lock(&vq->mutex); vhost_disable_notify(vq); vhost_hlen = vq->vhost_hlen; sock_hlen = vq->sock_hlen; vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ? vq->log : NULL; while ((sock_len = peek_head_len(sock->sk))) { sock_len += sock_hlen; vhost_len = sock_len + vhost_hlen; headcount = get_rx_bufs(vq, vq->heads, vhost_len, &in, vq_log, &log); /* On error, stop handling until the next kick. */ if (unlikely(headcount < 0)) break; /* OK, now we need to know about added descriptors. */ if (!headcount) { if (unlikely(vhost_enable_notify(vq))) { /* They have slipped one in as we were * doing that: check again. */ vhost_disable_notify(vq); continue; } /* Nothing new? Wait for eventfd to tell us * they refilled. */ break; } /* We don't need to be notified again. */ if (unlikely((vhost_hlen))) /* Skip header. TODO: support TSO. */ move_iovec_hdr(vq->iov, vq->hdr, vhost_hlen, in); else /* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF: * needed because sendmsg can modify msg_iov. */ copy_iovec_hdr(vq->iov, vq->hdr, sock_hlen, in); msg.msg_iovlen = in; err = sock->ops->recvmsg(NULL, sock, &msg, sock_len, MSG_DONTWAIT | MSG_TRUNC); /* Userspace might have consumed the packet meanwhile: * it's not supposed to do this usually, but might be hard * to prevent. Discard data we got (if any) and keep going. */ if (unlikely(err != sock_len)) { pr_debug("Discarded rx packet: " " len %d, expected %zd\n", err, sock_len); vhost_discard_vq_desc(vq, headcount); continue; } if (unlikely(vhost_hlen) && memcpy_toiovecend(vq->hdr, (unsigned char *)&hdr, 0, vhost_hlen)) { vq_err(vq, "Unable to write vnet_hdr at addr %p\n", vq->iov->iov_base); break; } /* TODO: Should check and handle checksum. */ if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF) && memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount, offsetof(typeof(hdr), num_buffers), sizeof hdr.num_buffers)) { vq_err(vq, "Failed num_buffers write"); vhost_discard_vq_desc(vq, headcount); break; } vhost_add_used_and_signal_n(&net->dev, vq, vq->heads, headcount); if (unlikely(vq_log)) vhost_log_write(vq, vq_log, log, vhost_len); total_len += vhost_len; if (unlikely(total_len >= VHOST_NET_WEIGHT)) { vhost_poll_queue(&vq->poll); break; } } mutex_unlock(&vq->mutex); unuse_mm(net->dev.mm); } static void handle_rx(struct vhost_net *net) { if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF)) handle_rx_mergeable(net); else handle_rx_big(net); } static void handle_tx_kick(struct vhost_work *work) { struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, poll.work); struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev); handle_tx(net); } static void handle_rx_kick(struct work_struct *work) static void handle_rx_kick(struct vhost_work *work) { struct vhost_virtqueue *vq; struct vhost_net *net; vq = container_of(work, struct vhost_virtqueue, poll.work); net = container_of(vq->dev, struct vhost_net, dev); struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, poll.work); struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev); handle_rx(net); } static void handle_tx_net(struct work_struct *work) static void handle_tx_net(struct vhost_work *work) { struct vhost_net *net; net = container_of(work, struct vhost_net, poll[VHOST_NET_VQ_TX].work); struct vhost_net *net = container_of(work, struct vhost_net, poll[VHOST_NET_VQ_TX].work); handle_tx(net); } static void handle_rx_net(struct work_struct *work) static void handle_rx_net(struct vhost_work *work) { struct vhost_net *net; net = container_of(work, struct vhost_net, poll[VHOST_NET_VQ_RX].work); struct vhost_net *net = container_of(work, struct vhost_net, poll[VHOST_NET_VQ_RX].work); handle_rx(net); } static int vhost_net_open(struct inode *inode, struct file *f) { struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL); struct vhost_dev *dev; int r; if (!n) return -ENOMEM; dev = &n->dev; n->vqs[VHOST_NET_VQ_TX].handle_kick = handle_tx_kick; n->vqs[VHOST_NET_VQ_RX].handle_kick = handle_rx_kick; r = vhost_dev_init(&n->dev, n->vqs, VHOST_NET_VQ_MAX); r = vhost_dev_init(dev, n->vqs, VHOST_NET_VQ_MAX); if (r < 0) { kfree(n); return r; } vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT); vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN); vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev); vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev); n->tx_poll_state = VHOST_NET_POLL_DISABLED; f->private_data = n; Loading Loading @@ -573,9 +783,21 @@ static long vhost_net_reset_owner(struct vhost_net *n) static int vhost_net_set_features(struct vhost_net *n, u64 features) { size_t hdr_size = features & (1 << VHOST_NET_F_VIRTIO_NET_HDR) ? sizeof(struct virtio_net_hdr) : 0; size_t vhost_hlen, sock_hlen, hdr_len; int i; hdr_len = (features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ? sizeof(struct virtio_net_hdr_mrg_rxbuf) : sizeof(struct virtio_net_hdr); if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) { /* vhost provides vnet_hdr */ vhost_hlen = hdr_len; sock_hlen = 0; } else { /* socket provides vnet_hdr */ vhost_hlen = 0; sock_hlen = hdr_len; } mutex_lock(&n->dev.mutex); if ((features & (1 << VHOST_F_LOG_ALL)) && !vhost_log_access_ok(&n->dev)) { Loading @@ -586,7 +808,8 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features) smp_wmb(); for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { mutex_lock(&n->vqs[i].mutex); n->vqs[i].hdr_size = hdr_size; n->vqs[i].vhost_hlen = vhost_hlen; n->vqs[i].sock_hlen = sock_hlen; mutex_unlock(&n->vqs[i].mutex); } vhost_net_flush(n); Loading Loading @@ -656,25 +879,13 @@ static struct miscdevice vhost_net_misc = { static int vhost_net_init(void) { int r = vhost_init(); if (r) goto err_init; r = misc_register(&vhost_net_misc); if (r) goto err_reg; return 0; err_reg: vhost_cleanup(); err_init: return r; return misc_register(&vhost_net_misc); } module_init(vhost_net_init); static void vhost_net_exit(void) { misc_deregister(&vhost_net_misc); vhost_cleanup(); } module_exit(vhost_net_exit); Loading
drivers/vhost/vhost.c +196 −32 Original line number Diff line number Diff line Loading @@ -17,12 +17,13 @@ #include <linux/mm.h> #include <linux/miscdevice.h> #include <linux/mutex.h> #include <linux/workqueue.h> #include <linux/rcupdate.h> #include <linux/poll.h> #include <linux/file.h> #include <linux/highmem.h> #include <linux/slab.h> #include <linux/kthread.h> #include <linux/cgroup.h> #include <linux/net.h> #include <linux/if_packet.h> Loading @@ -37,8 +38,6 @@ enum { VHOST_MEMORY_F_LOG = 0x1, }; static struct workqueue_struct *vhost_workqueue; static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, poll_table *pt) { Loading @@ -52,23 +51,31 @@ static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync, void *key) { struct vhost_poll *poll; poll = container_of(wait, struct vhost_poll, wait); struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait); if (!((unsigned long)key & poll->mask)) return 0; queue_work(vhost_workqueue, &poll->work); vhost_poll_queue(poll); return 0; } /* Init poll structure */ void vhost_poll_init(struct vhost_poll *poll, work_func_t func, unsigned long mask) void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, unsigned long mask, struct vhost_dev *dev) { INIT_WORK(&poll->work, func); struct vhost_work *work = &poll->work; init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup); init_poll_funcptr(&poll->table, vhost_poll_func); poll->mask = mask; poll->dev = dev; INIT_LIST_HEAD(&work->node); work->fn = fn; init_waitqueue_head(&work->done); work->flushing = 0; work->queue_seq = work->done_seq = 0; } /* Start polling a file. We add ourselves to file's wait queue. The caller must Loading @@ -92,12 +99,40 @@ void vhost_poll_stop(struct vhost_poll *poll) * locks that are also used by the callback. */ void vhost_poll_flush(struct vhost_poll *poll) { flush_work(&poll->work); struct vhost_work *work = &poll->work; unsigned seq; int left; int flushing; spin_lock_irq(&poll->dev->work_lock); seq = work->queue_seq; work->flushing++; spin_unlock_irq(&poll->dev->work_lock); wait_event(work->done, ({ spin_lock_irq(&poll->dev->work_lock); left = seq - work->done_seq <= 0; spin_unlock_irq(&poll->dev->work_lock); left; })); spin_lock_irq(&poll->dev->work_lock); flushing = --work->flushing; spin_unlock_irq(&poll->dev->work_lock); BUG_ON(flushing < 0); } void vhost_poll_queue(struct vhost_poll *poll) { queue_work(vhost_workqueue, &poll->work); struct vhost_dev *dev = poll->dev; struct vhost_work *work = &poll->work; unsigned long flags; spin_lock_irqsave(&dev->work_lock, flags); if (list_empty(&work->node)) { list_add_tail(&work->node, &dev->work_list); work->queue_seq++; wake_up_process(dev->worker); } spin_unlock_irqrestore(&dev->work_lock, flags); } static void vhost_vq_reset(struct vhost_dev *dev, Loading @@ -114,7 +149,8 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->used_flags = 0; vq->log_used = false; vq->log_addr = -1ull; vq->hdr_size = 0; vq->vhost_hlen = 0; vq->sock_hlen = 0; vq->private_data = NULL; vq->log_base = NULL; vq->error_ctx = NULL; Loading @@ -125,10 +161,51 @@ static void vhost_vq_reset(struct vhost_dev *dev, vq->log_ctx = NULL; } static int vhost_worker(void *data) { struct vhost_dev *dev = data; struct vhost_work *work = NULL; unsigned uninitialized_var(seq); for (;;) { /* mb paired w/ kthread_stop */ set_current_state(TASK_INTERRUPTIBLE); spin_lock_irq(&dev->work_lock); if (work) { work->done_seq = seq; if (work->flushing) wake_up_all(&work->done); } if (kthread_should_stop()) { spin_unlock_irq(&dev->work_lock); __set_current_state(TASK_RUNNING); return 0; } if (!list_empty(&dev->work_list)) { work = list_first_entry(&dev->work_list, struct vhost_work, node); list_del_init(&work->node); seq = work->queue_seq; } else work = NULL; spin_unlock_irq(&dev->work_lock); if (work) { __set_current_state(TASK_RUNNING); work->fn(work); } else schedule(); } } long vhost_dev_init(struct vhost_dev *dev, struct vhost_virtqueue *vqs, int nvqs) { int i; dev->vqs = vqs; dev->nvqs = nvqs; mutex_init(&dev->mutex); Loading @@ -136,6 +213,9 @@ long vhost_dev_init(struct vhost_dev *dev, dev->log_file = NULL; dev->memory = NULL; dev->mm = NULL; spin_lock_init(&dev->work_lock); INIT_LIST_HEAD(&dev->work_list); dev->worker = NULL; for (i = 0; i < dev->nvqs; ++i) { dev->vqs[i].dev = dev; Loading @@ -143,9 +223,9 @@ long vhost_dev_init(struct vhost_dev *dev, vhost_vq_reset(dev, dev->vqs + i); if (dev->vqs[i].handle_kick) vhost_poll_init(&dev->vqs[i].poll, dev->vqs[i].handle_kick, POLLIN); dev->vqs[i].handle_kick, POLLIN, dev); } return 0; } Loading @@ -159,12 +239,36 @@ long vhost_dev_check_owner(struct vhost_dev *dev) /* Caller should have device mutex */ static long vhost_dev_set_owner(struct vhost_dev *dev) { struct task_struct *worker; int err; /* Is there an owner already? */ if (dev->mm) return -EBUSY; if (dev->mm) { err = -EBUSY; goto err_mm; } /* No owner, become one */ dev->mm = get_task_mm(current); worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid); if (IS_ERR(worker)) { err = PTR_ERR(worker); goto err_worker; } dev->worker = worker; err = cgroup_attach_task_current_cg(worker); if (err) goto err_cgroup; wake_up_process(worker); /* avoid contributing to loadavg */ return 0; err_cgroup: kthread_stop(worker); err_worker: if (dev->mm) mmput(dev->mm); dev->mm = NULL; err_mm: return err; } /* Caller should have device mutex */ Loading Loading @@ -217,6 +321,9 @@ void vhost_dev_cleanup(struct vhost_dev *dev) if (dev->mm) mmput(dev->mm); dev->mm = NULL; WARN_ON(!list_empty(&dev->work_list)); kthread_stop(dev->worker); } static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) Loading Loading @@ -995,9 +1102,9 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, } /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */ void vhost_discard_vq_desc(struct vhost_virtqueue *vq) void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n) { vq->last_avail_idx--; vq->last_avail_idx -= n; } /* After we've used one of their buffers, we tell them about it. We'll then Loading Loading @@ -1042,6 +1149,67 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len) return 0; } static int __vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, unsigned count) { struct vring_used_elem __user *used; int start; start = vq->last_used_idx % vq->num; used = vq->used->ring + start; if (copy_to_user(used, heads, count * sizeof *used)) { vq_err(vq, "Failed to write used"); return -EFAULT; } if (unlikely(vq->log_used)) { /* Make sure data is seen before log. */ smp_wmb(); /* Log used ring entry write. */ log_write(vq->log_base, vq->log_addr + ((void __user *)used - (void __user *)vq->used), count * sizeof *used); } vq->last_used_idx += count; return 0; } /* After we've used one of their buffers, we tell them about it. We'll then * want to notify the guest, using eventfd. */ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, unsigned count) { int start, n, r; start = vq->last_used_idx % vq->num; n = vq->num - start; if (n < count) { r = __vhost_add_used_n(vq, heads, n); if (r < 0) return r; heads += n; count -= n; } r = __vhost_add_used_n(vq, heads, count); /* Make sure buffer is written before we update index. */ smp_wmb(); if (put_user(vq->last_used_idx, &vq->used->idx)) { vq_err(vq, "Failed to increment used idx"); return -EFAULT; } if (unlikely(vq->log_used)) { /* Log used index update. */ log_write(vq->log_base, vq->log_addr + offsetof(struct vring_used, idx), sizeof vq->used->idx); if (vq->log_ctx) eventfd_signal(vq->log_ctx, 1); } return r; } /* This actually signals the guest, using eventfd. */ void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq) { Loading Loading @@ -1076,6 +1244,15 @@ void vhost_add_used_and_signal(struct vhost_dev *dev, vhost_signal(dev, vq); } /* multi-buffer version of vhost_add_used_and_signal */ void vhost_add_used_and_signal_n(struct vhost_dev *dev, struct vhost_virtqueue *vq, struct vring_used_elem *heads, unsigned count) { vhost_add_used_n(vq, heads, count); vhost_signal(dev, vq); } /* OK, now we need to know about added descriptors. */ bool vhost_enable_notify(struct vhost_virtqueue *vq) { Loading @@ -1100,7 +1277,7 @@ bool vhost_enable_notify(struct vhost_virtqueue *vq) return false; } return avail_idx != vq->last_avail_idx; return avail_idx != vq->avail_idx; } /* We don't need to be notified again. */ Loading @@ -1115,16 +1292,3 @@ void vhost_disable_notify(struct vhost_virtqueue *vq) vq_err(vq, "Failed to enable notification at %p: %d\n", &vq->used->flags, r); } int vhost_init(void) { vhost_workqueue = create_singlethread_workqueue("vhost"); if (!vhost_workqueue) return -ENOMEM; return 0; } void vhost_cleanup(void) { destroy_workqueue(vhost_workqueue); }
drivers/vhost/vhost.h +37 −18 Original line number Diff line number Diff line Loading @@ -5,13 +5,13 @@ #include <linux/vhost.h> #include <linux/mm.h> #include <linux/mutex.h> #include <linux/workqueue.h> #include <linux/poll.h> #include <linux/file.h> #include <linux/skbuff.h> #include <linux/uio.h> #include <linux/virtio_config.h> #include <linux/virtio_ring.h> #include <asm/atomic.h> struct vhost_device; Loading @@ -20,19 +20,31 @@ enum { VHOST_NET_MAX_SG = MAX_SKB_FRAGS + 2, }; struct vhost_work; typedef void (*vhost_work_fn_t)(struct vhost_work *work); struct vhost_work { struct list_head node; vhost_work_fn_t fn; wait_queue_head_t done; int flushing; unsigned queue_seq; unsigned done_seq; }; /* Poll a file (eventfd or socket) */ /* Note: there's nothing vhost specific about this structure. */ struct vhost_poll { poll_table table; wait_queue_head_t *wqh; wait_queue_t wait; /* struct which will handle all actual work. */ struct work_struct work; struct vhost_work work; unsigned long mask; struct vhost_dev *dev; }; void vhost_poll_init(struct vhost_poll *poll, work_func_t func, unsigned long mask); void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, unsigned long mask, struct vhost_dev *dev); void vhost_poll_start(struct vhost_poll *poll, struct file *file); void vhost_poll_stop(struct vhost_poll *poll); void vhost_poll_flush(struct vhost_poll *poll); Loading Loading @@ -63,7 +75,7 @@ struct vhost_virtqueue { struct vhost_poll poll; /* The routine to call when the Guest pings us, or timeout. */ work_func_t handle_kick; vhost_work_fn_t handle_kick; /* Last available index we saw. */ u16 last_avail_idx; Loading @@ -84,13 +96,15 @@ struct vhost_virtqueue { struct iovec indirect[VHOST_NET_MAX_SG]; struct iovec iov[VHOST_NET_MAX_SG]; struct iovec hdr[VHOST_NET_MAX_SG]; size_t hdr_size; size_t vhost_hlen; size_t sock_hlen; struct vring_used_elem heads[VHOST_NET_MAX_SG]; /* We use a kind of RCU to access private pointer. * All readers access it from workqueue, which makes it possible to * flush the workqueue instead of synchronize_rcu. Therefore readers do * All readers access it from worker, which makes it possible to * flush the vhost_work instead of synchronize_rcu. Therefore readers do * not need to call rcu_read_lock/rcu_read_unlock: the beginning of * work item execution acts instead of rcu_read_lock() and the end of * work item execution acts instead of rcu_read_lock(). * vhost_work execution acts instead of rcu_read_lock() and the end of * vhost_work execution acts instead of rcu_read_lock(). * Writers use virtqueue mutex. */ void *private_data; /* Log write descriptors */ Loading @@ -110,6 +124,9 @@ struct vhost_dev { int nvqs; struct file *log_file; struct eventfd_ctx *log_ctx; spinlock_t work_lock; struct list_head work_list; struct task_struct *worker; }; long vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue *vqs, int nvqs); Loading @@ -124,21 +141,22 @@ int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, struct iovec iov[], unsigned int iov_count, unsigned int *out_num, unsigned int *in_num, struct vhost_log *log, unsigned int *log_num); void vhost_discard_vq_desc(struct vhost_virtqueue *); void vhost_discard_vq_desc(struct vhost_virtqueue *, int n); int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len); void vhost_signal(struct vhost_dev *, struct vhost_virtqueue *); int vhost_add_used_n(struct vhost_virtqueue *, struct vring_used_elem *heads, unsigned count); void vhost_add_used_and_signal(struct vhost_dev *, struct vhost_virtqueue *, unsigned int head, int len); unsigned int id, int len); void vhost_add_used_and_signal_n(struct vhost_dev *, struct vhost_virtqueue *, struct vring_used_elem *heads, unsigned count); void vhost_signal(struct vhost_dev *, struct vhost_virtqueue *); void vhost_disable_notify(struct vhost_virtqueue *); bool vhost_enable_notify(struct vhost_virtqueue *); int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, unsigned int log_num, u64 len); int vhost_init(void); void vhost_cleanup(void); #define vq_err(vq, fmt, ...) do { \ pr_debug(pr_fmt(fmt), ##__VA_ARGS__); \ if ((vq)->error_ctx) \ Loading @@ -149,7 +167,8 @@ enum { VHOST_FEATURES = (1 << VIRTIO_F_NOTIFY_ON_EMPTY) | (1 << VIRTIO_RING_F_INDIRECT_DESC) | (1 << VHOST_F_LOG_ALL) | (1 << VHOST_NET_F_VIRTIO_NET_HDR), (1 << VHOST_NET_F_VIRTIO_NET_HDR) | (1 << VIRTIO_NET_F_MRG_RXBUF), }; static inline int vhost_has_feature(struct vhost_dev *dev, int bit) Loading
include/linux/cgroup.h +7 −0 Original line number Diff line number Diff line Loading @@ -570,6 +570,7 @@ struct task_struct *cgroup_iter_next(struct cgroup *cgrp, void cgroup_iter_end(struct cgroup *cgrp, struct cgroup_iter *it); int cgroup_scan_tasks(struct cgroup_scanner *scan); int cgroup_attach_task(struct cgroup *, struct task_struct *); int cgroup_attach_task_current_cg(struct task_struct *); /* * CSS ID is ID for cgroup_subsys_state structs under subsys. This only works Loading Loading @@ -626,6 +627,12 @@ static inline int cgroupstats_build(struct cgroupstats *stats, return -EINVAL; } /* No cgroups - nothing to do */ static inline int cgroup_attach_task_current_cg(struct task_struct *t) { return 0; } #endif /* !CONFIG_CGROUPS */ #endif /* _LINUX_CGROUP_H */
kernel/cgroup.c +23 −0 Original line number Diff line number Diff line Loading @@ -1788,6 +1788,29 @@ int cgroup_attach_task(struct cgroup *cgrp, struct task_struct *tsk) return retval; } /** * cgroup_attach_task_current_cg - attach task 'tsk' to current task's cgroup * @tsk: the task to be attached */ int cgroup_attach_task_current_cg(struct task_struct *tsk) { struct cgroupfs_root *root; struct cgroup *cur_cg; int retval = 0; cgroup_lock(); for_each_active_root(root) { cur_cg = task_cgroup_from_root(current, root); retval = cgroup_attach_task(cur_cg, tsk); if (retval) break; } cgroup_unlock(); return retval; } EXPORT_SYMBOL_GPL(cgroup_attach_task_current_cg); /* * Attach task with pid 'pid' to cgroup 'cgrp'. Call with cgroup_mutex * held. May take task_lock of task Loading