The file operations dereference the context pointer after checking
status under ctx->status_mutex, then drop the lock before using the
context. This allows afu_release() running on another CPU to free
the context, leading to a use-after-free vulnerability.

The race window exists in afu_ioctl(), afu_mmap(), afu_poll() and
afu_read() between the status check and context usage. During device
hot-unplug or rapid open/close cycles, this causes kernel crashes.

Introduce reference counting via kref to prevent premature free.
ocxl_context_get() atomically checks status and acquires a reference
under status_mutex. File operations hold this reference for their
duration, ensuring the context remains valid even if another thread
calls afu_release().

ocxl_context_alloc() initializes refcount to 1 for the file's
lifetime. afu_release() drops this reference, with the context freed
when the last reference goes away. Preserve existing -EBUSY behavior
where the context intentionally leaks on detach timeout.

Reported-by: Yuhao Jiang <[email protected]>
Fixes: 5ef3166e8a32 ("ocxl: Driver code for 'generic' opencapi devices")
Cc: [email protected]
Signed-off-by: Yuhao Jiang <[email protected]>
---
 drivers/misc/ocxl/context.c       |  69 ++++++++++++++----
 drivers/misc/ocxl/file.c          | 113 +++++++++++++++++++++++-------
 drivers/misc/ocxl/ocxl_internal.h |   4 ++
 3 files changed, 144 insertions(+), 42 deletions(-)

diff --git a/drivers/misc/ocxl/context.c b/drivers/misc/ocxl/context.c
index cded7d1caf32..e154adc972a5 100644
--- a/drivers/misc/ocxl/context.c
+++ b/drivers/misc/ocxl/context.c
@@ -28,6 +28,7 @@ int ocxl_context_alloc(struct ocxl_context **context, struct 
ocxl_afu *afu,
 
        ctx->pasid = pasid;
        ctx->status = OPENED;
+       kref_init(&ctx->kref);
        mutex_init(&ctx->status_mutex);
        ctx->mapping = mapping;
        mutex_init(&ctx->mapping_lock);
@@ -47,6 +48,59 @@ int ocxl_context_alloc(struct ocxl_context **context, struct 
ocxl_afu *afu,
 }
 EXPORT_SYMBOL_GPL(ocxl_context_alloc);
 
+/**
+ * ocxl_context_get() - Get a reference to the context if not closed
+ * @ctx: The context
+ *
+ * Atomically checks if context status is not CLOSED and acquires a reference.
+ * Must be called with ctx->status_mutex held.
+ *
+ * Return: true if reference acquired, false if context is CLOSED
+ */
+bool ocxl_context_get(struct ocxl_context *ctx)
+{
+       lockdep_assert_held(&ctx->status_mutex);
+
+       if (ctx->status == CLOSED)
+               return false;
+
+       kref_get(&ctx->kref);
+       return true;
+}
+EXPORT_SYMBOL_GPL(ocxl_context_get);
+
+/*
+ * kref release callback - called when last reference is dropped
+ */
+static void ocxl_context_release(struct kref *kref)
+{
+       struct ocxl_context *ctx = container_of(kref, struct ocxl_context,
+                                                kref);
+
+       mutex_lock(&ctx->afu->contexts_lock);
+       ctx->afu->pasid_count--;
+       idr_remove(&ctx->afu->contexts_idr, ctx->pasid);
+       mutex_unlock(&ctx->afu->contexts_lock);
+
+       ocxl_afu_irq_free_all(ctx);
+       idr_destroy(&ctx->irq_idr);
+       /* reference to the AFU taken in ocxl_context_alloc() */
+       ocxl_afu_put(ctx->afu);
+       kfree(ctx);
+}
+
+/**
+ * ocxl_context_put() - Release a reference to the context
+ * @ctx: The context
+ *
+ * Decrements the reference count. When it reaches zero, the context is freed.
+ */
+void ocxl_context_put(struct ocxl_context *ctx)
+{
+       kref_put(&ctx->kref, ocxl_context_release);
+}
+EXPORT_SYMBOL_GPL(ocxl_context_put);
+
 /*
  * Callback for when a translation fault triggers an error
  * data:       a pointer to the context which triggered the fault
@@ -279,18 +333,3 @@ void ocxl_context_detach_all(struct ocxl_afu *afu)
        }
        mutex_unlock(&afu->contexts_lock);
 }
-
-void ocxl_context_free(struct ocxl_context *ctx)
-{
-       mutex_lock(&ctx->afu->contexts_lock);
-       ctx->afu->pasid_count--;
-       idr_remove(&ctx->afu->contexts_idr, ctx->pasid);
-       mutex_unlock(&ctx->afu->contexts_lock);
-
-       ocxl_afu_irq_free_all(ctx);
-       idr_destroy(&ctx->irq_idr);
-       /* reference to the AFU taken in ocxl_context_alloc() */
-       ocxl_afu_put(ctx->afu);
-       kfree(ctx);
-}
-EXPORT_SYMBOL_GPL(ocxl_context_free);
diff --git a/drivers/misc/ocxl/file.c b/drivers/misc/ocxl/file.c
index 7eb74711ac96..c08724e7ff1e 100644
--- a/drivers/misc/ocxl/file.c
+++ b/drivers/misc/ocxl/file.c
@@ -204,17 +204,21 @@ static long afu_ioctl(struct file *file, unsigned int cmd,
        int irq_id;
        u64 irq_offset;
        long rc;
-       bool closed;
-
-       pr_debug("%s for context %d, command %s\n", __func__, ctx->pasid,
-               CMD_STR(cmd));
 
+       /*
+        * Hold a reference to the context for the duration of this operation.
+        * We check the status and acquire the reference atomically under the
+        * status_mutex to ensure the context remains valid.
+        */
        mutex_lock(&ctx->status_mutex);
-       closed = (ctx->status == CLOSED);
+       if (!ocxl_context_get(ctx)) {
+               mutex_unlock(&ctx->status_mutex);
+               return -EIO;
+       }
        mutex_unlock(&ctx->status_mutex);
 
-       if (closed)
-               return -EIO;
+       pr_debug("%s for context %d, command %s\n", __func__, ctx->pasid,
+               CMD_STR(cmd));
 
        switch (cmd) {
        case OCXL_IOCTL_ATTACH:
@@ -230,7 +234,7 @@ static long afu_ioctl(struct file *file, unsigned int cmd,
                                        sizeof(irq_offset));
                        if (rc) {
                                ocxl_afu_irq_free(ctx, irq_id);
-                               return -EFAULT;
+                               rc = -EFAULT;
                        }
                }
                break;
@@ -238,8 +242,10 @@ static long afu_ioctl(struct file *file, unsigned int cmd,
        case OCXL_IOCTL_IRQ_FREE:
                rc = copy_from_user(&irq_offset, (u64 __user *) args,
                                sizeof(irq_offset));
-               if (rc)
-                       return -EFAULT;
+               if (rc) {
+                       rc = -EFAULT;
+                       break;
+               }
                irq_id = ocxl_irq_offset_to_id(ctx, irq_offset);
                rc = ocxl_afu_irq_free(ctx, irq_id);
                break;
@@ -247,14 +253,20 @@ static long afu_ioctl(struct file *file, unsigned int cmd,
        case OCXL_IOCTL_IRQ_SET_FD:
                rc = copy_from_user(&irq_fd, (u64 __user *) args,
                                sizeof(irq_fd));
-               if (rc)
-                       return -EFAULT;
-               if (irq_fd.reserved)
-                       return -EINVAL;
+               if (rc) {
+                       rc = -EFAULT;
+                       break;
+               }
+               if (irq_fd.reserved) {
+                       rc = -EINVAL;
+                       break;
+               }
                irq_id = ocxl_irq_offset_to_id(ctx, irq_fd.irq_offset);
                ev_ctx = eventfd_ctx_fdget(irq_fd.eventfd);
-               if (IS_ERR(ev_ctx))
-                       return PTR_ERR(ev_ctx);
+               if (IS_ERR(ev_ctx)) {
+                       rc = PTR_ERR(ev_ctx);
+                       break;
+               }
                rc = ocxl_irq_set_handler(ctx, irq_id, irq_handler, irq_free, 
ev_ctx);
                if (rc)
                        eventfd_ctx_put(ev_ctx);
@@ -280,6 +292,8 @@ static long afu_ioctl(struct file *file, unsigned int cmd,
        default:
                rc = -EINVAL;
        }
+
+       ocxl_context_put(ctx);
        return rc;
 }
 
@@ -292,9 +306,23 @@ static long afu_compat_ioctl(struct file *file, unsigned 
int cmd,
 static int afu_mmap(struct file *file, struct vm_area_struct *vma)
 {
        struct ocxl_context *ctx = file->private_data;
+       int rc;
+
+       /*
+        * Hold a reference during mmap setup to ensure the context
+        * remains valid.
+        */
+       mutex_lock(&ctx->status_mutex);
+       if (!ocxl_context_get(ctx)) {
+               mutex_unlock(&ctx->status_mutex);
+               return -EIO;
+       }
+       mutex_unlock(&ctx->status_mutex);
 
        pr_debug("%s for context %d\n", __func__, ctx->pasid);
-       return ocxl_context_mmap(ctx, vma);
+       rc = ocxl_context_mmap(ctx, vma);
+       ocxl_context_put(ctx);
+       return rc;
 }
 
 static bool has_xsl_error(struct ocxl_context *ctx)
@@ -324,21 +352,31 @@ static unsigned int afu_poll(struct file *file, struct 
poll_table_struct *wait)
 {
        struct ocxl_context *ctx = file->private_data;
        unsigned int mask = 0;
-       bool closed;
+
+       /*
+        * Hold a reference to the context while checking for events.
+        */
+       mutex_lock(&ctx->status_mutex);
+       if (!ocxl_context_get(ctx)) {
+               mutex_unlock(&ctx->status_mutex);
+               return EPOLLERR;
+       }
+       mutex_unlock(&ctx->status_mutex);
 
        pr_debug("%s for context %d\n", __func__, ctx->pasid);
 
        poll_wait(file, &ctx->events_wq, wait);
 
-       mutex_lock(&ctx->status_mutex);
-       closed = (ctx->status == CLOSED);
-       mutex_unlock(&ctx->status_mutex);
-
        if (afu_events_pending(ctx))
                mask = EPOLLIN | EPOLLRDNORM;
-       else if (closed)
-               mask = EPOLLERR;
+       else {
+               mutex_lock(&ctx->status_mutex);
+               if (ctx->status == CLOSED)
+                       mask = EPOLLERR;
+               mutex_unlock(&ctx->status_mutex);
+       }
 
+       ocxl_context_put(ctx);
        return mask;
 }
 
@@ -410,6 +448,16 @@ static ssize_t afu_read(struct file *file, char __user 
*buf, size_t count,
                        AFU_EVENT_BODY_MAX_SIZE))
                return -EINVAL;
 
+       /*
+        * Hold a reference to the context for the duration of the read 
operation.
+        */
+       mutex_lock(&ctx->status_mutex);
+       if (!ocxl_context_get(ctx)) {
+               mutex_unlock(&ctx->status_mutex);
+               return -EIO;
+       }
+       mutex_unlock(&ctx->status_mutex);
+
        for (;;) {
                prepare_to_wait(&ctx->events_wq, &event_wait,
                                TASK_INTERRUPTIBLE);
@@ -422,11 +470,13 @@ static ssize_t afu_read(struct file *file, char __user 
*buf, size_t count,
 
                if (file->f_flags & O_NONBLOCK) {
                        finish_wait(&ctx->events_wq, &event_wait);
+                       ocxl_context_put(ctx);
                        return -EAGAIN;
                }
 
                if (signal_pending(current)) {
                        finish_wait(&ctx->events_wq, &event_wait);
+                       ocxl_context_put(ctx);
                        return -ERESTARTSYS;
                }
 
@@ -437,19 +487,24 @@ static ssize_t afu_read(struct file *file, char __user 
*buf, size_t count,
 
        if (has_xsl_error(ctx)) {
                used = append_xsl_error(ctx, &header, buf + sizeof(header));
-               if (used < 0)
+               if (used < 0) {
+                       ocxl_context_put(ctx);
                        return used;
+               }
        }
 
        if (!afu_events_pending(ctx))
                header.flags |= OCXL_KERNEL_EVENT_FLAG_LAST;
 
-       if (copy_to_user(buf, &header, sizeof(header)))
+       if (copy_to_user(buf, &header, sizeof(header))) {
+               ocxl_context_put(ctx);
                return -EFAULT;
+       }
 
        used += sizeof(header);
 
        rc = used;
+       ocxl_context_put(ctx);
        return rc;
 }
 
@@ -464,8 +519,12 @@ static int afu_release(struct inode *inode, struct file 
*file)
        ctx->mapping = NULL;
        mutex_unlock(&ctx->mapping_lock);
        wake_up_all(&ctx->events_wq);
+       /*
+        * Drop the initial reference from afu_open(). The context will be
+        * freed when all references are released.
+        */
        if (rc != -EBUSY)
-               ocxl_context_free(ctx);
+               ocxl_context_put(ctx);
        return 0;
 }
 
diff --git a/drivers/misc/ocxl/ocxl_internal.h 
b/drivers/misc/ocxl/ocxl_internal.h
index d2028d6c6f08..6eab7806b43d 100644
--- a/drivers/misc/ocxl/ocxl_internal.h
+++ b/drivers/misc/ocxl/ocxl_internal.h
@@ -5,6 +5,7 @@
 
 #include <linux/pci.h>
 #include <linux/cdev.h>
+#include <linux/kref.h>
 #include <linux/list.h>
 #include <misc/ocxl.h>
 
@@ -68,6 +69,7 @@ struct ocxl_xsl_error {
 };
 
 struct ocxl_context {
+       struct kref kref;
        struct ocxl_afu *afu;
        int pasid;
        struct mutex status_mutex;
@@ -140,6 +142,8 @@ int ocxl_link_update_pe(void *link_handle, int pasid, __u16 
tid);
 
 int ocxl_context_mmap(struct ocxl_context *ctx,
                        struct vm_area_struct *vma);
+bool ocxl_context_get(struct ocxl_context *ctx);
+void ocxl_context_put(struct ocxl_context *ctx);
 void ocxl_context_detach_all(struct ocxl_afu *afu);
 
 int ocxl_sysfs_register_afu(struct ocxl_file_info *info);
-- 
2.34.1


Reply via email to