From 3ad6f93e98d6df25d0667d847d3ab9cbdccb3eae Mon Sep 17 00:00:00 2001
From: Al Viro <viro@zeniv.linux.org.uk>
Date: Mon, 3 Jul 2017 20:14:56 -0400
Subject: [PATCH] annotate poll-related wait keys

__poll_t is also used as wait key in some waitqueues.
Verify that wait_..._poll() gets __poll_t as key and
provide a helper for wakeup functions to get back to
that __poll_t value.

Signed-off-by: Al Viro <viro@zeniv.linux.org.uk>
---
 drivers/vfio/virqfd.c |  2 +-
 drivers/vhost/vhost.c |  4 ++--
 fs/eventpoll.c        |  9 +++++----
 fs/select.c           |  2 +-
 include/linux/wait.h  | 10 ++++++----
 mm/memcontrol.c       |  2 +-
 net/core/datagram.c   |  4 +---
 net/unix/af_unix.c    |  2 +-
 virt/kvm/eventfd.c    |  2 +-
 9 files changed, 19 insertions(+), 18 deletions(-)

diff --git a/drivers/vfio/virqfd.c b/drivers/vfio/virqfd.c
index d18b10ff119e6..8cc4b48ff1273 100644
--- a/drivers/vfio/virqfd.c
+++ b/drivers/vfio/virqfd.c
@@ -46,7 +46,7 @@ static void virqfd_deactivate(struct virqfd *virqfd)
 static int virqfd_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync, void *key)
 {
 	struct virqfd *virqfd = container_of(wait, struct virqfd, wait);
-	unsigned long flags = (unsigned long)key;
+	__poll_t flags = key_to_poll(key);
 
 	if (flags & POLLIN) {
 		/* An event has been signaled, call function */
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index c18e70bd0466a..7aad77be0b461 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -170,7 +170,7 @@ static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync,
 {
 	struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait);
 
-	if (!((unsigned long)key & poll->mask))
+	if (!(key_to_poll(key) & poll->mask))
 		return 0;
 
 	vhost_poll_queue(poll);
@@ -211,7 +211,7 @@ int vhost_poll_start(struct vhost_poll *poll, struct file *file)
 
 	mask = file->f_op->poll(file, &poll->table);
 	if (mask)
-		vhost_poll_wakeup(&poll->wait, 0, 0, (void *)(uintptr_t)mask);
+		vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask));
 	if (mask & POLLERR) {
 		if (poll->wqh)
 			remove_wait_queue(poll->wqh, &poll->wait);
diff --git a/fs/eventpoll.c b/fs/eventpoll.c
index afd548ebc3282..21e6fee00e8b2 100644
--- a/fs/eventpoll.c
+++ b/fs/eventpoll.c
@@ -1117,6 +1117,7 @@ static int ep_poll_callback(wait_queue_entry_t *wait, unsigned mode, int sync, v
 	unsigned long flags;
 	struct epitem *epi = ep_item_from_wait(wait);
 	struct eventpoll *ep = epi->ep;
+	__poll_t pollflags = key_to_poll(key);
 	int ewake = 0;
 
 	spin_lock_irqsave(&ep->lock, flags);
@@ -1138,7 +1139,7 @@ static int ep_poll_callback(wait_queue_entry_t *wait, unsigned mode, int sync, v
 	 * callback. We need to be able to handle both cases here, hence the
 	 * test for "key" != NULL before the event match test.
 	 */
-	if (key && !((unsigned long) key & epi->event.events))
+	if (pollflags && !(pollflags & epi->event.events))
 		goto out_unlock;
 
 	/*
@@ -1175,8 +1176,8 @@ static int ep_poll_callback(wait_queue_entry_t *wait, unsigned mode, int sync, v
 	 */
 	if (waitqueue_active(&ep->wq)) {
 		if ((epi->event.events & EPOLLEXCLUSIVE) &&
-					!((unsigned long)key & POLLFREE)) {
-			switch ((unsigned long)key & EPOLLINOUT_BITS) {
+					!(pollflags & POLLFREE)) {
+			switch (pollflags & EPOLLINOUT_BITS) {
 			case POLLIN:
 				if (epi->event.events & POLLIN)
 					ewake = 1;
@@ -1205,7 +1206,7 @@ out_unlock:
 	if (!(epi->event.events & EPOLLEXCLUSIVE))
 		ewake = 1;
 
-	if ((unsigned long)key & POLLFREE) {
+	if (pollflags & POLLFREE) {
 		/*
 		 * If we race with ep_remove_wait_queue() it can miss
 		 * ->whead = NULL and do another remove_wait_queue() after
diff --git a/fs/select.c b/fs/select.c
index b2bf84be50569..ffc16fd3673e1 100644
--- a/fs/select.c
+++ b/fs/select.c
@@ -212,7 +212,7 @@ static int pollwake(wait_queue_entry_t *wait, unsigned mode, int sync, void *key
 	struct poll_table_entry *entry;
 
 	entry = container_of(wait, struct poll_table_entry, wait);
-	if (key && !((unsigned long)key & entry->key))
+	if (key && !(key_to_poll(key) & entry->key))
 		return 0;
 	return __pollwake(wait, mode, sync, key);
 }
diff --git a/include/linux/wait.h b/include/linux/wait.h
index 158715445ffb6..55a611486bac1 100644
--- a/include/linux/wait.h
+++ b/include/linux/wait.h
@@ -206,14 +206,16 @@ void __wake_up_sync(struct wait_queue_head *wq_head, unsigned int mode, int nr);
 /*
  * Wakeup macros to be used to report events to the targets.
  */
+#define poll_to_key(m) ((void *)(__force uintptr_t)(__poll_t)(m))
+#define key_to_poll(m) ((__force __poll_t)(uintptr_t)(void *)(m))
 #define wake_up_poll(x, m)							\
-	__wake_up(x, TASK_NORMAL, 1, (void *) (m))
+	__wake_up(x, TASK_NORMAL, 1, poll_to_key(m))
 #define wake_up_locked_poll(x, m)						\
-	__wake_up_locked_key((x), TASK_NORMAL, (void *) (m))
+	__wake_up_locked_key((x), TASK_NORMAL, poll_to_key(m))
 #define wake_up_interruptible_poll(x, m)					\
-	__wake_up(x, TASK_INTERRUPTIBLE, 1, (void *) (m))
+	__wake_up(x, TASK_INTERRUPTIBLE, 1, poll_to_key(m))
 #define wake_up_interruptible_sync_poll(x, m)					\
-	__wake_up_sync_key((x), TASK_INTERRUPTIBLE, 1, (void *) (m))
+	__wake_up_sync_key((x), TASK_INTERRUPTIBLE, 1, poll_to_key(m))
 
 #define ___wait_cond_timeout(condition)						\
 ({										\
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 50e6906314f8d..006aa27f4fb44 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -3777,7 +3777,7 @@ static int memcg_event_wake(wait_queue_entry_t *wait, unsigned mode,
 	struct mem_cgroup_event *event =
 		container_of(wait, struct mem_cgroup_event, wait);
 	struct mem_cgroup *memcg = event->memcg;
-	unsigned long flags = (unsigned long)key;
+	__poll_t flags = key_to_poll(key);
 
 	if (flags & POLLHUP) {
 		/*
diff --git a/net/core/datagram.c b/net/core/datagram.c
index 522873ed120bd..000da13c01f2f 100644
--- a/net/core/datagram.c
+++ b/net/core/datagram.c
@@ -72,12 +72,10 @@ static inline int connection_based(struct sock *sk)
 static int receiver_wake_function(wait_queue_entry_t *wait, unsigned int mode, int sync,
 				  void *key)
 {
-	unsigned long bits = (unsigned long)key;
-
 	/*
 	 * Avoid a wakeup if event not interesting for us
 	 */
-	if (bits && !(bits & (POLLIN | POLLERR)))
+	if (key && !(key_to_poll(key) & (POLLIN | POLLERR)))
 		return 0;
 	return autoremove_wake_function(wait, mode, sync, key);
 }
diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c
index a9ee634f3c428..72957961ac220 100644
--- a/net/unix/af_unix.c
+++ b/net/unix/af_unix.c
@@ -367,7 +367,7 @@ static int unix_dgram_peer_wake_relay(wait_queue_entry_t *q, unsigned mode, int
 	/* relaying can only happen while the wq still exists */
 	u_sleep = sk_sleep(&u->sk);
 	if (u_sleep)
-		wake_up_interruptible_poll(u_sleep, key);
+		wake_up_interruptible_poll(u_sleep, key_to_poll(key));
 
 	return 0;
 }
diff --git a/virt/kvm/eventfd.c b/virt/kvm/eventfd.c
index a1f68ed999d8a..a334399fafec5 100644
--- a/virt/kvm/eventfd.c
+++ b/virt/kvm/eventfd.c
@@ -188,7 +188,7 @@ irqfd_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync, void *key)
 {
 	struct kvm_kernel_irqfd *irqfd =
 		container_of(wait, struct kvm_kernel_irqfd, wait);
-	unsigned long flags = (unsigned long)key;
+	__poll_t flags = key_to_poll(key);
 	struct kvm_kernel_irq_routing_entry irq;
 	struct kvm *kvm = irqfd->kvm;
 	unsigned seq;
-- 
2.39.5