Make pci::Driver take a lifetime parameter 'bound that ties device
resources to the binding scope.

Internally, Adapter<T: Driver> becomes Adapter<F: ForLt> with a HRTB
bound for<'bound> F::Of<'bound>: Driver<'bound>; module_pci_driver!
wraps the driver type in ForLt!() so drivers don't have to.

Signed-off-by: Danilo Krummrich <[email protected]>
---
 drivers/gpu/nova-core/driver.rs       |  9 ++-
 drivers/gpu/nova-core/nova_core.rs    |  4 +-
 rust/kernel/pci.rs                    | 80 +++++++++++++++++++--------
 samples/rust/rust_dma.rs              |  9 ++-
 samples/rust/rust_driver_auxiliary.rs | 13 +++--
 samples/rust/rust_driver_pci.rs       | 11 ++--
 6 files changed, 87 insertions(+), 39 deletions(-)

diff --git a/drivers/gpu/nova-core/driver.rs b/drivers/gpu/nova-core/driver.rs
index 8fe484d357f6..d0ccfbc8d0ea 100644
--- a/drivers/gpu/nova-core/driver.rs
+++ b/drivers/gpu/nova-core/driver.rs
@@ -50,7 +50,7 @@ pub(crate) struct NovaCore {
 kernel::pci_device_table!(
     PCI_TABLE,
     MODULE_PCI_TABLE,
-    <NovaCore as pci::Driver>::IdInfo,
+    <NovaCore as pci::Driver<'_>>::IdInfo,
     [
         // Modern NVIDIA GPUs will show up as either VGA or 3D controllers.
         (
@@ -72,11 +72,14 @@ pub(crate) struct NovaCore {
     ]
 );
 
-impl pci::Driver for NovaCore {
+impl<'bound> pci::Driver<'bound> for NovaCore {
     type IdInfo = ();
     const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE;
 
-    fn probe(pdev: &pci::Device<Core>, _info: &Self::IdInfo) -> impl 
PinInit<Self, Error> {
+    fn probe(
+        pdev: &'bound pci::Device<Core>,
+        _info: &'bound Self::IdInfo,
+    ) -> impl PinInit<Self, Error> + 'bound {
         pin_init::pin_init_scope(move || {
             dev_dbg!(pdev, "Probe Nova Core GPU driver.\n");
 
diff --git a/drivers/gpu/nova-core/nova_core.rs 
b/drivers/gpu/nova-core/nova_core.rs
index 04a1fa6b25f8..49c093a0cb42 100644
--- a/drivers/gpu/nova-core/nova_core.rs
+++ b/drivers/gpu/nova-core/nova_core.rs
@@ -7,6 +7,7 @@
     driver::Registration,
     pci,
     prelude::*,
+    types::ForLt,
     InPlaceModule, //
 };
 
@@ -46,8 +47,9 @@ fn drop(&mut self) {
 struct NovaCoreModule {
     // Fields are dropped in declaration order, so `_driver` is dropped first,
     // then `_debugfs_guard` clears `DEBUGFS_ROOT`.
+    #[allow(clippy::type_complexity)]
     #[pin]
-    _driver: Registration<pci::Adapter<driver::NovaCore>>,
+    _driver: Registration<pci::Adapter<ForLt!(driver::NovaCore)>>,
     _debugfs_guard: DebugfsRootGuard,
 }
 
diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs
index 6f82f2e6c74f..1335857cae94 100644
--- a/rust/kernel/pci.rs
+++ b/rust/kernel/pci.rs
@@ -58,22 +58,35 @@
 };
 
 /// An adapter for the registration of PCI drivers.
-pub struct Adapter<T: Driver>(T);
+///
+/// `F` is a [`ForLt`](trait@ForLt) type that maps lifetimes to the driver's 
device
+/// private data type, i.e. `F::Of<'bound>` is the driver struct
+/// parameterized by `'bound`. The macro `module_pci_driver!` generates
+/// this automatically via `ForLt!()`.
+pub struct Adapter<F>(PhantomData<F>);
 
 // SAFETY:
 // - `bindings::pci_driver` is a C type declared as `repr(C)`.
-// - `T` is the type of the driver's device private data.
+// - `F::Of<'static>` is the stored type of the driver's device private data.
 // - `struct pci_driver` embeds a `struct device_driver`.
 // - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct 
device_driver`.
-unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> {
+unsafe impl<F> driver::DriverLayout for Adapter<F>
+where
+    F: ForLt + 'static,
+    for<'bound> F::Of<'bound>: Driver<'bound>,
+{
     type DriverType = bindings::pci_driver;
-    type DriverData = ForLt!(T);
+    type DriverData = F;
     const DEVICE_DRIVER_OFFSET: usize = 
core::mem::offset_of!(Self::DriverType, driver);
 }
 
 // SAFETY: A call to `unregister` for a given instance of `DriverType` is 
guaranteed to be valid if
 // a preceding call to `register` has been successful.
-unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> {
+unsafe impl<F> driver::RegistrationOps for Adapter<F>
+where
+    F: ForLt + 'static,
+    for<'bound> F::Of<'bound>: Driver<'bound>,
+{
     unsafe fn register(
         pdrv: &Opaque<Self::DriverType>,
         name: &'static CStr,
@@ -84,7 +97,7 @@ unsafe fn register(
             (*pdrv.get()).name = name.as_char_ptr();
             (*pdrv.get()).probe = Some(Self::probe_callback);
             (*pdrv.get()).remove = Some(Self::remove_callback);
-            (*pdrv.get()).id_table = T::ID_TABLE.as_ptr();
+            (*pdrv.get()).id_table = <F::Of<'static> as 
Driver<'static>>::ID_TABLE.as_ptr();
         }
 
         // SAFETY: `pdrv` is guaranteed to be a valid `DriverType`.
@@ -99,7 +112,11 @@ unsafe fn unregister(pdrv: &Opaque<Self::DriverType>) {
     }
 }
 
-impl<T: Driver + 'static> Adapter<T> {
+impl<F> Adapter<F>
+where
+    F: ForLt + 'static,
+    for<'bound> F::Of<'bound>: Driver<'bound>,
+{
     extern "C" fn probe_callback(
         pdev: *mut bindings::pci_dev,
         id: *const bindings::pci_device_id,
@@ -113,12 +130,12 @@ extern "C" fn probe_callback(
         // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct 
pci_device_id` and
         // does not add additional invariants, so it's safe to transmute.
         let id = unsafe { &*id.cast::<DeviceId>() };
-        let info = T::ID_TABLE.info(id.index());
 
         from_result(|| {
-            let data = T::probe(pdev, info);
+            let info = <F::Of<'_> as Driver<'_>>::ID_TABLE.info(id.index());
+            let data = <F::Of<'_> as Driver<'_>>::probe(pdev, info);
 
-            pdev.as_ref().set_drvdata::<ForLt!(T)>(data)?;
+            pdev.as_ref().set_drvdata::<F>(data)?;
             Ok(0)
         })
     }
@@ -131,16 +148,18 @@ extern "C" fn remove_callback(pdev: *mut 
bindings::pci_dev) {
         let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() };
 
         // SAFETY: `remove_callback` is only ever called after a successful 
call to
-        // `probe_callback`, hence it's guaranteed that 
`Device::set_drvdata()` has been called
-        // and stored a `Pin<KBox<T>>`.
-        let data = unsafe { pdev.as_ref().drvdata_borrow::<ForLt!(T)>() };
+        // `probe_callback`, hence it's guaranteed that drvdata has been set.
+        let data = unsafe { pdev.as_ref().drvdata_borrow::<F>() };
 
-        T::unbind(pdev, data);
+        <F::Of<'_> as Driver<'_>>::unbind(pdev, data);
     }
 }
 
 /// Declares a kernel module that exposes a single PCI driver.
 ///
+/// The `type` field accepts a driver type, optionally with a lifetime 
placeholder `'_` for
+/// lifetime-parameterized drivers. The macro wraps it in [`ForLt!`] 
automatically.
+///
 /// # Examples
 ///
 ///```ignore
@@ -152,10 +171,16 @@ extern "C" fn remove_callback(pdev: *mut 
bindings::pci_dev) {
 ///     license: "GPL v2",
 /// }
 ///```
+///
+/// [`ForLt!`]: macro@ForLt
+/// [`ForLt`]: trait@ForLt
 #[macro_export]
 macro_rules! module_pci_driver {
-($($f:tt)*) => {
-    $crate::module_driver!(<T>, $crate::pci::Adapter<T>, { $($f)* });
+(type: $type:ty, $($rest:tt)*) => {
+    $crate::module_driver!(<T>, $crate::pci::Adapter<T>, {
+        type: $crate::types::ForLt!($type),
+        $($rest)*
+    });
 };
 }
 
@@ -261,6 +286,9 @@ macro_rules! pci_device_table {
 
 /// The PCI driver trait.
 ///
+/// Drivers implement this trait with a lifetime parameter `'bound` that ties
+/// device resources to the device scope.
+///
 /// # Examples
 ///
 ///```
@@ -271,7 +299,7 @@ macro_rules! pci_device_table {
 /// kernel::pci_device_table!(
 ///     PCI_TABLE,
 ///     MODULE_PCI_TABLE,
-///     <MyDriver as pci::Driver>::IdInfo,
+///     <MyDriver as pci::Driver<'_>>::IdInfo,
 ///     [
 ///         (
 ///             pci::DeviceId::from_id(pci::Vendor::REDHAT, 
bindings::PCI_ANY_ID as u32),
@@ -280,21 +308,22 @@ macro_rules! pci_device_table {
 ///     ]
 /// );
 ///
-/// impl pci::Driver for MyDriver {
+/// impl<'bound> pci::Driver<'bound> for MyDriver {
 ///     type IdInfo = ();
 ///     const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE;
 ///
 ///     fn probe(
-///         _pdev: &pci::Device<Core>,
-///         _id_info: &Self::IdInfo,
-///     ) -> impl PinInit<Self, Error> {
+///         _pdev: &'bound pci::Device<Core>,
+///         _id_info: &'bound Self::IdInfo,
+///     ) -> impl PinInit<Self, Error> + 'bound {
 ///         Err(ENODEV)
 ///     }
 /// }
 ///```
+///
 /// Drivers must implement this trait in order to get a PCI driver registered. 
Please refer to the
 /// `Adapter` documentation for an example.
-pub trait Driver: Send {
+pub trait Driver<'bound>: Send {
     /// The type holding information about each device id supported by the 
driver.
     // TODO: Use `associated_type_defaults` once stabilized:
     //
@@ -310,7 +339,10 @@ pub trait Driver: Send {
     ///
     /// Called when a new pci device is added or discovered. Implementers 
should
     /// attempt to initialize the device here.
-    fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl 
PinInit<Self, Error>;
+    fn probe(
+        dev: &'bound Device<device::Core>,
+        id_info: &'bound Self::IdInfo,
+    ) -> impl PinInit<Self, Error> + 'bound;
 
     /// PCI driver unbind.
     ///
@@ -322,7 +354,7 @@ pub trait Driver: Send {
     /// operations to gracefully tear down the device.
     ///
     /// Otherwise, release operations for driver resources should be performed 
in `Self::drop`.
-    fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) {
+    fn unbind(dev: &'bound Device<device::Core>, this: Pin<&'bound Self>) {
         let _ = (dev, this);
     }
 }
diff --git a/samples/rust/rust_dma.rs b/samples/rust/rust_dma.rs
index 129bb4b39c04..e8b3e2e799f3 100644
--- a/samples/rust/rust_dma.rs
+++ b/samples/rust/rust_dma.rs
@@ -52,15 +52,18 @@ unsafe impl kernel::transmute::FromBytes for MyStruct {}
 kernel::pci_device_table!(
     PCI_TABLE,
     MODULE_PCI_TABLE,
-    <DmaSampleDriver as pci::Driver>::IdInfo,
+    <DmaSampleDriver as pci::Driver<'_>>::IdInfo,
     [(pci::DeviceId::from_id(pci::Vendor::REDHAT, 0x5), ())]
 );
 
-impl pci::Driver for DmaSampleDriver {
+impl<'bound> pci::Driver<'bound> for DmaSampleDriver {
     type IdInfo = ();
     const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE;
 
-    fn probe(pdev: &pci::Device<Core>, _info: &Self::IdInfo) -> impl 
PinInit<Self, Error> {
+    fn probe(
+        pdev: &'bound pci::Device<Core>,
+        _info: &'bound Self::IdInfo,
+    ) -> impl PinInit<Self, Error> + 'bound {
         pin_init::pin_init_scope(move || {
             dev_info!(pdev, "Probe DMA test driver.\n");
 
diff --git a/samples/rust/rust_driver_auxiliary.rs 
b/samples/rust/rust_driver_auxiliary.rs
index 319ef734c02b..a1b42d30580e 100644
--- a/samples/rust/rust_driver_auxiliary.rs
+++ b/samples/rust/rust_driver_auxiliary.rs
@@ -14,6 +14,7 @@
     driver,
     pci,
     prelude::*,
+    types::ForLt,
     InPlaceModule, //
 };
 
@@ -59,16 +60,19 @@ struct ParentDriver {
 kernel::pci_device_table!(
     PCI_TABLE,
     MODULE_PCI_TABLE,
-    <ParentDriver as pci::Driver>::IdInfo,
+    <ParentDriver as pci::Driver<'_>>::IdInfo,
     [(pci::DeviceId::from_id(pci::Vendor::REDHAT, 0x5), ())]
 );
 
-impl pci::Driver for ParentDriver {
+impl<'bound> pci::Driver<'bound> for ParentDriver {
     type IdInfo = ();
 
     const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE;
 
-    fn probe(pdev: &pci::Device<Core>, _info: &Self::IdInfo) -> impl 
PinInit<Self, Error> {
+    fn probe(
+        pdev: &'bound pci::Device<Core>,
+        _info: &'bound Self::IdInfo,
+    ) -> impl PinInit<Self, Error> + 'bound {
         Ok(Self {
             _reg0: auxiliary::Registration::new(
                 pdev.as_ref(),
@@ -116,7 +120,8 @@ fn connect(adev: &auxiliary::Device<Bound>) -> Result {
 #[pin_data]
 struct SampleModule {
     #[pin]
-    _pci_driver: driver::Registration<pci::Adapter<ParentDriver>>,
+    #[allow(clippy::type_complexity)]
+    _pci_driver: driver::Registration<pci::Adapter<ForLt!(ParentDriver)>>,
     #[pin]
     _aux_driver: driver::Registration<auxiliary::Adapter<AuxiliaryDriver>>,
 }
diff --git a/samples/rust/rust_driver_pci.rs b/samples/rust/rust_driver_pci.rs
index 47d3e84fab63..794311691d1e 100644
--- a/samples/rust/rust_driver_pci.rs
+++ b/samples/rust/rust_driver_pci.rs
@@ -77,7 +77,7 @@ struct SampleDriver {
 kernel::pci_device_table!(
     PCI_TABLE,
     MODULE_PCI_TABLE,
-    <SampleDriver as pci::Driver>::IdInfo,
+    <SampleDriver as pci::Driver<'_>>::IdInfo,
     [(
         pci::DeviceId::from_id(pci::Vendor::REDHAT, 0x5),
         TestIndex::NO_EVENTFD
@@ -138,12 +138,15 @@ fn config_space(pdev: &pci::Device<Bound>) {
     }
 }
 
-impl pci::Driver for SampleDriver {
+impl<'bound> pci::Driver<'bound> for SampleDriver {
     type IdInfo = TestIndex;
 
     const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE;
 
-    fn probe(pdev: &pci::Device<Core>, info: &Self::IdInfo) -> impl 
PinInit<Self, Error> {
+    fn probe(
+        pdev: &'bound pci::Device<Core>,
+        info: &'bound Self::IdInfo,
+    ) -> impl PinInit<Self, Error> + 'bound {
         pin_init::pin_init_scope(move || {
             let vendor = pdev.vendor_id();
             dev_dbg!(
@@ -174,7 +177,7 @@ fn probe(pdev: &pci::Device<Core>, info: &Self::IdInfo) -> 
impl PinInit<Self, Er
         })
     }
 
-    fn unbind(pdev: &pci::Device<Core>, this: Pin<&Self>) {
+    fn unbind(pdev: &'bound pci::Device<Core>, this: Pin<&'bound Self>) {
         if let Ok(bar) = this.bar.access(pdev.as_ref()) {
             // Reset pci-testdev by writing a new test index.
             bar.write_reg(regs::TEST::zeroed().with_index(this.index));
-- 
2.54.0

Reply via email to