Add registration_data_with() taking a for<'a> closure that receives Pin<&'a F::Of<'a>>, which works with any ForLt type. Taking a for<'a> closure rather than returning a direct reference prevents callers from choosing a concrete lifetime for the data, which is required for soundness with non-covariant ForLt types.
Extract the common null-check, TypeId-check and KBox-borrow logic into a private registration_data_pinned() helper shared by both registration_data_with() and the existing registration_data(). Relax Registration's bound from CovariantForLt to ForLt so that non-covariant types can be registered. Signed-off-by: Danilo Krummrich <[email protected]> --- rust/kernel/auxiliary.rs | 89 ++++++++++++++++++++++++++++++---------- 1 file changed, 68 insertions(+), 21 deletions(-) diff --git a/rust/kernel/auxiliary.rs b/rust/kernel/auxiliary.rs index 40a0af74a8e5..81549a3e347e 100644 --- a/rust/kernel/auxiliary.rs +++ b/rust/kernel/auxiliary.rs @@ -21,6 +21,7 @@ prelude::*, types::{ CovariantForLt, + ForLt, ForeignOwnable, Opaque, // }, @@ -270,18 +271,15 @@ pub fn parent(&self) -> &device::Device<device::Bound> { unsafe { parent.as_bound() } } - /// Returns a pinned reference to the registration data set by the registering (parent) driver. + /// Returns the stored registration data as a pinned `'static` reference. /// - /// `F` is the [`CovariantForLt`](trait@CovariantForLt) encoding of the data type. The returned - /// reference has its lifetime shortened from `'static` to `&self`'s borrow lifetime via - /// [`CovariantForLt::cast_ref`]. + /// Performs null and [`TypeId`] checks, then borrows the stored [`KBox`]. /// - /// Returns [`EINVAL`] if `F` does not match the type used by the parent driver when calling - /// [`Registration::new()`]. + /// # Safety /// - /// Returns [`ENOENT`] if no registration data has been set, e.g. when the device was - /// registered by a C driver. - pub fn registration_data<F: CovariantForLt + 'static>(&self) -> Result<Pin<&F::Of<'_>>> { + /// The returned `'static` lifetime was transmuted from the device's bound lifetime during + /// registration. Callers must shorten it before exposing it. + unsafe fn registration_data_pinned<F: ForLt + 'static>(&self) -> Result<Pin<&F::Of<'static>>> { // SAFETY: By the type invariant, `self.as_raw()` is a valid `struct auxiliary_device`. let ptr = unsafe { (*self.as_raw()).registration_data_rust }; if ptr.is_null() { @@ -306,10 +304,58 @@ pub fn registration_data<F: CovariantForLt + 'static>(&self) -> Result<Pin<&F::O let wrapper = unsafe { Pin::<KBox<RegistrationData<F::Of<'static>>>>::borrow(ptr) }; // SAFETY: `data` is a structurally pinned field of `RegistrationData`. - let pinned: Pin<&F::Of<'_>> = unsafe { wrapper.map_unchecked(|w| &w.data) }; + Ok(unsafe { wrapper.map_unchecked(|w| &w.data) }) + } + + /// Access the registration data set by the registering (parent) driver through a closure. + /// + /// `F` is the [`ForLt`](trait@ForLt) encoding of the data type. The closure receives a pinned + /// reference to the registration data. + /// + /// For covariant types that implement [`trait@CovariantForLt`], prefer + /// [`registration_data`](Self::registration_data) which returns a direct reference. + /// + /// Returns [`EINVAL`] if `F` does not match the type used by the parent driver when calling + /// [`Registration::new()`]. + /// + /// Returns [`ENOENT`] if no registration data has been set, e.g. when the device was + /// registered by a C driver. + pub fn registration_data_with<F: ForLt + 'static, R>( + &self, + f: impl for<'a> FnOnce(Pin<&'a F::Of<'a>>) -> R, + ) -> Result<R> { + // SAFETY: The HRTB closure prevents the caller from smuggling in references with a + // concrete short lifetime, making the round-trip from `'static` sound regardless of + // variance. + let pinned = unsafe { self.registration_data_pinned::<F>()? }; + + // SAFETY: See above; the closure's HRTB makes the round-trip sound. + let short = unsafe { F::cast_ref_unchecked(pinned.get_ref()) }; + + // SAFETY: The data was pinned before the lifetime was shortened; pinning is + // orthogonal to lifetimes. + Ok(f(unsafe { Pin::new_unchecked(short) })) + } - // SAFETY: The data was pinned when stored; `cast_ref` only shortens - // the lifetime, so the pinning guarantee is preserved. + /// Returns a pinned reference to the registration data set by the registering (parent) driver. + /// + /// This method is only available when `F` implements [`trait@CovariantForLt`], which guarantees + /// safe lifetime shortening via [`CovariantForLt::cast_ref`]. + /// + /// For non-covariant types, use the closure-based [`Self::registration_data_with`]. + /// + /// Returns [`EINVAL`] if `F` does not match the type used by the parent driver when calling + /// [`Registration::new()`]. + /// + /// Returns [`ENOENT`] if no registration data has been set, e.g. when the device was + /// registered by a C driver. + pub fn registration_data<F: CovariantForLt + 'static>(&self) -> Result<Pin<&F::Of<'_>>> { + // SAFETY: CovariantForLt guarantees covariance, so cast_ref safely shortens the + // `'static` lifetime. + let pinned = unsafe { self.registration_data_pinned::<F>()? }; + + // SAFETY: The data was pinned before the lifetime was shortened; pinning is orthogonal + // to lifetimes. Ok(unsafe { Pin::new_unchecked(F::cast_ref(pinned.get_ref())) }) } } @@ -399,22 +445,23 @@ struct RegistrationData<T> { /// This type represents the registration of a [`struct auxiliary_device`]. When its parent device /// is unbound, the corresponding auxiliary device will be unregistered from the system. /// -/// The type parameter `F` is a [`CovariantForLt`](trait@CovariantForLt) encoding of the -/// registration data type. For non-lifetime-parameterized types, use -/// [`CovariantForLt!(T)`](macro@CovariantForLt). -/// The data can be accessed by the auxiliary driver through [`Device::registration_data()`]. +/// The type parameter `F` is a [`ForLt`](trait@ForLt) encoding of the registration +/// data type. For non-lifetime-parameterized types, use [`ForLt!(T)`](macro@ForLt). +/// +/// The data can be accessed by the auxiliary driver through [`Device::registration_data()`] and +/// [`Device::registration_data_with()`]. /// /// # Invariants /// /// `self.adev` always holds a valid pointer to an initialized and registered /// [`struct auxiliary_device`] whose `registration_data_rust` field points to a /// valid `Pin<KBox<RegistrationData<F::Of<'static>>>>`. -pub struct Registration<'a, F: CovariantForLt + 'static> { +pub struct Registration<'a, F: ForLt + 'static> { adev: NonNull<bindings::auxiliary_device>, _phantom: PhantomData<F::Of<'a>>, } -impl<'a, F: CovariantForLt> Registration<'a, F> +impl<'a, F: ForLt> Registration<'a, F> where for<'b> F::Of<'b>: Send + Sync, { @@ -526,7 +573,7 @@ pub fn new<E>( } } -impl<F: CovariantForLt> Drop for Registration<'_, F> { +impl<F: ForLt> Drop for Registration<'_, F> { fn drop(&mut self) { // SAFETY: By the type invariant of `Self`, `self.adev.as_ptr()` is a valid registered // `struct auxiliary_device`. @@ -548,7 +595,7 @@ fn drop(&mut self) { } // SAFETY: A `Registration` of a `struct auxiliary_device` can be released from any thread. -unsafe impl<F: CovariantForLt> Send for Registration<'_, F> where for<'a> F::Of<'a>: Send {} +unsafe impl<F: ForLt> Send for Registration<'_, F> where for<'a> F::Of<'a>: Send {} // SAFETY: `Registration` does not expose any methods or fields that need synchronization. -unsafe impl<F: CovariantForLt> Sync for Registration<'_, F> where for<'a> F::Of<'a>: Send {} +unsafe impl<F: ForLt> Sync for Registration<'_, F> where for<'a> F::Of<'a>: Send {} -- 2.54.0
