On 2024/10/29 0:26, Dan Williams wrote:
Yi Yang wrote:
It will cause null-ptr-deref when nd_dax_alloc() returns NULL, fix it by
add check for nd_dax_alloc().

Fixes: c5ed9268643c ("libnvdimm, dax: autodetect support")
Signed-off-by: Yi Yang <yiyan...@huawei.com>
---
  drivers/nvdimm/dax_devs.c | 4 ++++
  1 file changed, 4 insertions(+)

diff --git a/drivers/nvdimm/dax_devs.c b/drivers/nvdimm/dax_devs.c
index 6b4922de3047..70a7e401f90d 100644
--- a/drivers/nvdimm/dax_devs.c
+++ b/drivers/nvdimm/dax_devs.c
@@ -106,6 +106,10 @@ int nd_dax_probe(struct device *dev, struct 
nd_namespace_common *ndns)
nvdimm_bus_lock(&ndns->dev);
        nd_dax = nd_dax_alloc(nd_region);
+       if (!nd_dax) {
+               nvdimm_bus_unlock(&ndns->dev);
+               return -ENOMEM;
+       }
        nd_pfn = &nd_dax->nd_pfn;

No, this isn't a NULL pointer de-reference, but it is indeed
unreasonably subtle.

If nd_dax is NULL, then nd_pfn is NULL because nd_dax is just a
type-wrapper around nd_pfn.

When nd_pfn is NULL then nd_pfn_devinit will fail.

What I think this needs to make this clear is something like this:

---

diff --git a/drivers/nvdimm/dax_devs.c b/drivers/nvdimm/dax_devs.c
index 6b4922de3047..37b743acbb7b 100644
--- a/drivers/nvdimm/dax_devs.c
+++ b/drivers/nvdimm/dax_devs.c
@@ -106,12 +106,12 @@ int nd_dax_probe(struct device *dev, struct 
nd_namespace_common *ndns)
nvdimm_bus_lock(&ndns->dev);
        nd_dax = nd_dax_alloc(nd_region);
-       nd_pfn = &nd_dax->nd_pfn;
-       dax_dev = nd_pfn_devinit(nd_pfn, ndns);
+       dax_dev = nd_dax_devinit(nd_dax, ndns);
        nvdimm_bus_unlock(&ndns->dev);
        if (!dax_dev)
                return -ENOMEM;
        pfn_sb = devm_kmalloc(dev, sizeof(*pfn_sb), GFP_KERNEL);
+       nd_pfn = &nd_dax->nd_pfn;
        nd_pfn->pfn_sb = pfn_sb;
        rc = nd_pfn_validate(nd_pfn, DAX_SIG);
        dev_dbg(dev, "dax: %s\n", rc == 0 ? dev_name(dax_dev) : "<none>");
diff --git a/drivers/nvdimm/nd.h b/drivers/nvdimm/nd.h
index 2dbb1dca17b5..5ca06e9a2d29 100644
--- a/drivers/nvdimm/nd.h
+++ b/drivers/nvdimm/nd.h
@@ -600,6 +600,13 @@ struct nd_dax *to_nd_dax(struct device *dev);
  int nd_dax_probe(struct device *dev, struct nd_namespace_common *ndns);
  bool is_nd_dax(const struct device *dev);
  struct device *nd_dax_create(struct nd_region *nd_region);
+static inline struct device *nd_dax_devinit(struct nd_dax *nd_dax,
+                                           struct nd_namespace_common *ndns)
+{
+       if (!nd_dax)
+               return NULL;
+       return nd_pfn_devinit(&nd_dax->nd_pfn, ndns);
+}
  #else
  static inline int nd_dax_probe(struct device *dev,
                struct nd_namespace_common *ndns)

.


LGTM,.
Your code is better.

Best regards,
Yiyang

--


Reply via email to