After commit 9a6d6a2ddabb ("ata: make ata port as parent device of scsi
host") manual driver unbind/remove causes use-after-free.

Unbind unconditionally invokes devres_release_all() which calls
ata_host_release() and frees ata_host/ata_port memory while it is still
being referenced as a parent of SCSI host. When SCSI host is finally
released scsi_host_dev_release() calls put_device(parent) and accesses
freed ata_port memory.

Add reference counting to make sure that ata_host lives long enough.

Bug report: https://lkml.org/lkml/2017/11/1/945
Fixes: 9a6d6a2ddabb ("ata: make ata port as parent device of scsi host")
Cc: Tejun Heo <[email protected]>
Cc: Lin Ming <[email protected]>
Cc: [email protected]
Cc: [email protected]
Signed-off-by: Taras Kondratiuk <[email protected]>
---

Based on v4.16-rc4.

 drivers/ata/libata-core.c      | 44 ++++++++++++++++++++++++++++++++++--------
 drivers/ata/libata-transport.c |  4 ++++
 drivers/ata/libata.h           |  2 ++
 include/linux/libata.h         |  1 +
 4 files changed, 43 insertions(+), 8 deletions(-)

diff --git a/drivers/ata/libata-core.c b/drivers/ata/libata-core.c
index 3c09122bf038..ee67f4b113c5 100644
--- a/drivers/ata/libata-core.c
+++ b/drivers/ata/libata-core.c
@@ -6006,7 +6006,7 @@ struct ata_port *ata_port_alloc(struct ata_host *host)
        return ap;
 }
 
-static void ata_host_release(struct device *gendev, void *res)
+static void ata_devres_release(struct device *gendev, void *res)
 {
        struct ata_host *host = dev_get_drvdata(gendev);
        int i;
@@ -6020,13 +6020,36 @@ static void ata_host_release(struct device *gendev, 
void *res)
                if (ap->scsi_host)
                        scsi_host_put(ap->scsi_host);
 
+       }
+
+       dev_set_drvdata(gendev, NULL);
+       ata_host_put(host);
+}
+
+static void ata_host_release(struct kref *kref)
+{
+       struct ata_host *host = container_of(kref, struct ata_host, kref);
+       int i;
+
+       for (i = 0; i < host->n_ports; i++) {
+               struct ata_port *ap = host->ports[i];
+
                kfree(ap->pmp_link);
                kfree(ap->slave_link);
                kfree(ap);
                host->ports[i] = NULL;
        }
+       kfree(host);
+}
 
-       dev_set_drvdata(gendev, NULL);
+void ata_host_get(struct ata_host *host)
+{
+       kref_get(&host->kref);
+}
+
+void ata_host_put(struct ata_host *host)
+{
+       kref_put(&host->kref, ata_host_release);
 }
 
 /**
@@ -6054,26 +6077,31 @@ struct ata_host *ata_host_alloc(struct device *dev, int 
max_ports)
        struct ata_host *host;
        size_t sz;
        int i;
+       void *dr;
 
        DPRINTK("ENTER\n");
 
-       if (!devres_open_group(dev, NULL, GFP_KERNEL))
-               return NULL;
-
        /* alloc a container for our list of ATA ports (buses) */
        sz = sizeof(struct ata_host) + (max_ports + 1) * sizeof(void *);
-       /* alloc a container for our list of ATA ports (buses) */
-       host = devres_alloc(ata_host_release, sz, GFP_KERNEL);
+       host = kzalloc(sz, GFP_KERNEL);
        if (!host)
+               return NULL;
+
+       if (!devres_open_group(dev, NULL, GFP_KERNEL))
+               return NULL;
+
+       dr = devres_alloc(ata_devres_release, 0, GFP_KERNEL);
+       if (!dr)
                goto err_out;
 
-       devres_add(dev, host);
+       devres_add(dev, dr);
        dev_set_drvdata(dev, host);
 
        spin_lock_init(&host->lock);
        mutex_init(&host->eh_mutex);
        host->dev = dev;
        host->n_ports = max_ports;
+       kref_init(&host->kref);
 
        /* allocate ports bound to this host */
        for (i = 0; i < max_ports; i++) {
diff --git a/drivers/ata/libata-transport.c b/drivers/ata/libata-transport.c
index 19e6e539a061..a0b0b4d986f2 100644
--- a/drivers/ata/libata-transport.c
+++ b/drivers/ata/libata-transport.c
@@ -224,6 +224,8 @@ static DECLARE_TRANSPORT_CLASS(ata_port_class,
 
 static void ata_tport_release(struct device *dev)
 {
+       struct ata_port *ap = tdev_to_port(dev);
+       ata_host_put(ap->host);
 }
 
 /**
@@ -284,6 +286,7 @@ int ata_tport_add(struct device *parent,
        dev->type = &ata_port_type;
 
        dev->parent = parent;
+       ata_host_get(ap->host);
        dev->release = ata_tport_release;
        dev_set_name(dev, "ata%d", ap->print_id);
        transport_setup_device(dev);
@@ -314,6 +317,7 @@ int ata_tport_add(struct device *parent,
  tport_err:
        transport_destroy_device(dev);
        put_device(dev);
+       ata_host_put(ap->host);
        return error;
 }
 
diff --git a/drivers/ata/libata.h b/drivers/ata/libata.h
index f953cb4bb1ba..9e21c49cf6be 100644
--- a/drivers/ata/libata.h
+++ b/drivers/ata/libata.h
@@ -100,6 +100,8 @@ extern int ata_port_probe(struct ata_port *ap);
 extern void __ata_port_probe(struct ata_port *ap);
 extern unsigned int ata_read_log_page(struct ata_device *dev, u8 log,
                                      u8 page, void *buf, unsigned int sectors);
+extern void ata_host_get(struct ata_host *host);
+extern void ata_host_put(struct ata_host *host);
 
 #define to_ata_port(d) container_of(d, struct ata_port, tdev)
 
diff --git a/include/linux/libata.h b/include/linux/libata.h
index ed9826b21c5e..1795fecdea17 100644
--- a/include/linux/libata.h
+++ b/include/linux/libata.h
@@ -617,6 +617,7 @@ struct ata_host {
        void                    *private_data;
        struct ata_port_operations *ops;
        unsigned long           flags;
+       struct kref             kref;
 
        struct mutex            eh_mutex;
        struct task_struct      *eh_owner;
-- 
2.10.3.dirty

Reply via email to