The common practice in C drivers is to store pointers into `driver_data` field of device IDs. The Rust code is however currently storing indices into the fields and then carry a side table that maps the index to pointers.
It is much simpler to just have `DeviceId` carry the pointer like C code does. However, just doing so naively would cause a "pointers cannot be cast to integers during const eval" error, as kernel_ulong_t does not have provenance while pointers do, and Rust forbids `expose_provenance` during consteval. Work around this limitation by wrapping raw IDs in `MaybeUninit`. `MaybeUninit` is allowed to host arbitrary bytes with or without provenance, so we can just then use `unsafe` to store a pointer with provenance there. This has the same effect as changing the C-side definition to use `void*` instead of `kernel_ulong_t`, but without actually changing the C side. Signed-off-by: Gary Guo <[email protected]> --- rust/kernel/acpi.rs | 4 --- rust/kernel/auxiliary.rs | 8 ++--- rust/kernel/device_id.rs | 88 +++++++++++++++++++++++++++++------------------- rust/kernel/driver.rs | 14 ++++---- rust/kernel/i2c.rs | 7 ++-- rust/kernel/of.rs | 4 --- rust/kernel/pci.rs | 11 +++--- rust/kernel/usb.rs | 7 ++-- 8 files changed, 73 insertions(+), 70 deletions(-) diff --git a/rust/kernel/acpi.rs b/rust/kernel/acpi.rs index 315f2f2af446..ea2ce61ee393 100644 --- a/rust/kernel/acpi.rs +++ b/rust/kernel/acpi.rs @@ -25,10 +25,6 @@ unsafe impl RawDeviceId for DeviceId { // SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `driver_data` field. unsafe impl RawDeviceIdIndex for DeviceId { const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::acpi_device_id, driver_data); - - fn index(&self) -> usize { - self.0.driver_data - } } impl DeviceId { diff --git a/rust/kernel/auxiliary.rs b/rust/kernel/auxiliary.rs index 59787c9bff26..aa13d8866a19 100644 --- a/rust/kernel/auxiliary.rs +++ b/rust/kernel/auxiliary.rs @@ -93,7 +93,9 @@ extern "C" fn probe_callback( // SAFETY: `DeviceId` is a `#[repr(transparent)`] wrapper of `struct auxiliary_device_id` // and does not add additional invariants, so it's safe to transmute. let id = unsafe { &*id.cast::<DeviceId>() }; - let info = T::ID_TABLE.info(id.index()); + + // SAFETY: `id` comes from `T::ID_TABLE` which is of type `IdArray<_, T::IdInfo>`. + let info = unsafe { id.info_unchecked::<T::IdInfo>() }; from_result(|| { let data = T::probe(adev, info); @@ -169,10 +171,6 @@ unsafe impl RawDeviceId for DeviceId { unsafe impl RawDeviceIdIndex for DeviceId { const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::auxiliary_device_id, driver_data); - - fn index(&self) -> usize { - self.0.driver_data - } } /// IdTable type for auxiliary drivers. diff --git a/rust/kernel/device_id.rs b/rust/kernel/device_id.rs index 84852a2d9ad7..59453588df0e 100644 --- a/rust/kernel/device_id.rs +++ b/rust/kernel/device_id.rs @@ -5,7 +5,10 @@ //! Each bus / subsystem that matches device and driver through a bus / subsystem specific ID is //! expected to implement [`RawDeviceId`]. -use core::mem::MaybeUninit; +use core::{ + marker::PhantomData, + mem::MaybeUninit, // +}; /// Marker trait to indicate a Rust device ID type represents a corresponding C device ID type. /// @@ -47,15 +50,48 @@ pub unsafe trait RawDeviceIdIndex: RawDeviceId { /// The offset (in bytes) to the context/data field in the raw device ID. const DRIVER_DATA_OFFSET: usize; - /// The index stored at `DRIVER_DATA_OFFSET` of the implementor of the [`RawDeviceIdIndex`] - /// trait. - fn index(&self) -> usize; + /// Obtain the data pointer stored inside the device ID. + /// + /// # Safety + /// + /// `&Self` must be stored inside a `IdArray<Self, U>`. + unsafe fn info_unchecked<U>(&self) -> &'static U { + // SAFETY: By safety requirement of the trait, this is `self.driver_data as *const U` and by + // the safety requirement of the function, this is stored in `IdArray<Self, U>` so is + // convertible to `&'static U`. + unsafe { + core::ptr::from_ref(self) + .byte_add(Self::DRIVER_DATA_OFFSET) + .cast::<&U>() + .read() + } + } + + /// Obtain the data pointer stored inside the device ID. + /// + /// # Safety + /// + /// `&Self` must be stored inside a `IdArray<Self, U>`, or has NULL (or 0) as driver data. + unsafe fn info_unchecked_opt<U>(&self) -> Option<&'static U> { + // SAFETY: By safety requirement of the trait, this is `self.driver_data as *const U` and by + // the safety requirement of the function, if this is stored in `IdArray<Self, U>`, this is + // convertible to `Option<&'static U>`. Otherwise it is NULL which is `None` as + // `Option<&U>`. + unsafe { + core::ptr::from_ref(self) + .byte_add(Self::DRIVER_DATA_OFFSET) + .cast::<Option<&U>>() + .read() + } + } } /// A zero-terminated device id array. #[repr(C)] pub struct RawIdArray<T: RawDeviceId, const N: usize> { - ids: [T::RawType; N], + // This is `MaybeUninit<T::RawType>` so any bytes inside it can carry provenance in CTFE. + // If this were `T::RawType`, integer fields would not be able to contain pointers. + ids: [MaybeUninit<T::RawType>; N], sentinel: MaybeUninit<T::RawType>, } @@ -68,18 +104,17 @@ pub const fn size(&self) -> usize { /// A zero-terminated device id array, followed by context data. #[repr(C)] -pub struct IdArray<T: RawDeviceId, U, const N: usize> { +pub struct IdArray<T: RawDeviceId, U: 'static, const N: usize> { raw_ids: RawIdArray<T, N>, - id_infos: [U; N], + phantom: PhantomData<&'static U>, } -impl<T: RawDeviceId + RawDeviceIdIndex, U, const N: usize> IdArray<T, U, N> { +impl<T: RawDeviceId + RawDeviceIdIndex, U: 'static, const N: usize> IdArray<T, U, N> { /// Creates a new instance of the array. /// /// The contents are derived from the given identifiers and context information. - pub const fn new(ids: [(T, U); N]) -> Self { + pub const fn new(ids: [(T, &'static U); N]) -> Self { let mut raw_ids = [const { MaybeUninit::<T::RawType>::uninit() }; N]; - let mut infos = [const { MaybeUninit::uninit() }; N]; let mut i = 0usize; while i < N { @@ -87,18 +122,15 @@ impl<T: RawDeviceId + RawDeviceIdIndex, U, const N: usize> IdArray<T, U, N> { // layout-wise compatible with `RawType`. raw_ids[i] = unsafe { core::mem::transmute_copy(&ids[i].0) }; // SAFETY: by the safety requirement of `RawDeviceIdIndex`, this would be effectively - // `raw_ids[i].driver_data = i;`. + // `raw_ids[i].driver_data = ids[i].1;`. unsafe { raw_ids[i] .as_mut_ptr() .byte_add(T::DRIVER_DATA_OFFSET) - .cast::<usize>() - .write(i); + .cast::<&U>() + .write(ids[i].1); } - // SAFETY: this is effectively a move: `infos[i] = ids[i].1`. We make a copy here but - // later forget `ids`. - infos[i] = MaybeUninit::new(unsafe { core::ptr::read(&ids[i].1) }); i += 1; } @@ -106,20 +138,15 @@ impl<T: RawDeviceId + RawDeviceIdIndex, U, const N: usize> IdArray<T, U, N> { Self { raw_ids: RawIdArray { - // SAFETY: this is effectively `array_assume_init`, which is unstable, so we use - // `transmute_copy` instead. We have initialized all elements of `raw_ids` so this - // `array_assume_init` is safe. - ids: unsafe { core::mem::transmute_copy(&raw_ids) }, + ids: raw_ids, sentinel: MaybeUninit::zeroed(), }, - // SAFETY: We have initialized all elements of `infos` so this `array_assume_init` is - // safe. - id_infos: unsafe { core::mem::transmute_copy(&infos) }, + phantom: PhantomData, } } } -impl<T: RawDeviceId, U, const N: usize> IdArray<T, U, N> { +impl<T: RawDeviceId, U: 'static, const N: usize> IdArray<T, U, N> { /// Reference to the contained [`RawIdArray`]. pub const fn raw_ids(&self) -> &RawIdArray<T, N> { &self.raw_ids @@ -133,7 +160,7 @@ impl<T: RawDeviceId, const N: usize> IdArray<T, (), N> { /// If the device implements [`RawDeviceIdIndex`], consider using [`IdArray::new`] instead. pub const fn new_without_index(ids: [T; N]) -> Self { // SAFETY: `T` is layout-wise compatible with `T::RawType`, so is the array of them. - let raw_ids: [T::RawType; N] = unsafe { core::mem::transmute_copy(&ids) }; + let raw_ids: [MaybeUninit<T::RawType>; N] = unsafe { core::mem::transmute_copy(&ids) }; core::mem::forget(ids); Self { @@ -141,7 +168,7 @@ impl<T: RawDeviceId, const N: usize> IdArray<T, (), N> { ids: raw_ids, sentinel: MaybeUninit::zeroed(), }, - id_infos: [(); N], + phantom: PhantomData, } } } @@ -155,9 +182,6 @@ impl<T: RawDeviceId, const N: usize> IdArray<T, (), N> { pub trait IdTable<T: RawDeviceId, U> { /// Obtain the pointer to the ID table. fn as_ptr(&self) -> *const T::RawType; - - /// Obtain the pointer to the driver-specific information from an index. - fn info(&self, index: usize) -> &U; } impl<T: RawDeviceId, U, const N: usize> IdTable<T, U> for IdArray<T, U, N> { @@ -166,10 +190,6 @@ fn as_ptr(&self) -> *const T::RawType { // to access the sentinel. core::ptr::from_ref(self).cast() } - - fn info(&self, index: usize) -> &U { - &self.id_infos[index] - } } /// Create device table alias for modpost. @@ -184,7 +204,7 @@ macro_rules! module_device_table { $device_id_ty, $id_info_type, { <[$device_id_ty]>::len(&[$($id,)*]) }, - > = $crate::device_id::IdArray::new([$(($id, $info),)*]); + > = $crate::device_id::IdArray::new([$(($id, &$info),)*]); $crate::module_device_table!($table_type, $table_name); }; diff --git a/rust/kernel/driver.rs b/rust/kernel/driver.rs index bf5ba0d27553..824899d76fed 100644 --- a/rust/kernel/driver.rs +++ b/rust/kernel/driver.rs @@ -107,6 +107,7 @@ use crate::{ acpi, device, + device_id::RawDeviceIdIndex, of, prelude::*, types::Opaque, @@ -350,7 +351,8 @@ fn acpi_id_info(dev: &device::Device) -> Option<&'static Self::IdInfo> { // and does not add additional invariants, so it's safe to transmute. let id = unsafe { &*raw_id.cast::<acpi::DeviceId>() }; - Some(table.info(<acpi::DeviceId as crate::device_id::RawDeviceIdIndex>::index(id))) + // SAFETY: `id` comes from `table` which is of type `IdArray<_, Self::IdInfo>`. + Some(unsafe { id.info_unchecked::<Self::IdInfo>() }) } } } @@ -381,9 +383,8 @@ fn of_id_info(dev: &device::Device) -> Option<&'static Self::IdInfo> { // and does not add additional invariants, so it's safe to transmute. let id = unsafe { &*raw_id.cast::<of::DeviceId>() }; - return Some(table.info( - <of::DeviceId as crate::device_id::RawDeviceIdIndex>::index(id), - )); + // SAFETY: `id` comes from `table` which is of type `IdArray<_, Self::IdInfo>`. + return Some(unsafe { id.info_unchecked::<Self::IdInfo>() }); } } @@ -412,9 +413,8 @@ fn of_id_info(dev: &device::Device) -> Option<&'static Self::IdInfo> { // and does not add additional invariants, so it's safe to transmute. let id = unsafe { &*raw_id.cast::<of::DeviceId>() }; - return Some(table.info( - <of::DeviceId as crate::device_id::RawDeviceIdIndex>::index(id), - )); + // SAFETY: `id` comes from `table` which is of type `IdArray<_, Self::IdInfo>`. + return Some(unsafe { id.info_unchecked::<Self::IdInfo>() }); } } diff --git a/rust/kernel/i2c.rs b/rust/kernel/i2c.rs index 55c89ba3a82a..9e551c7e8e41 100644 --- a/rust/kernel/i2c.rs +++ b/rust/kernel/i2c.rs @@ -65,10 +65,6 @@ unsafe impl RawDeviceId for DeviceId { // SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `driver_data` field. unsafe impl RawDeviceIdIndex for DeviceId { const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::i2c_device_id, driver_data); - - fn index(&self) -> usize { - self.0.driver_data - } } /// IdTable type for I2C @@ -212,7 +208,8 @@ fn i2c_id_info(dev: &I2cClient) -> Option<&'static <Self as driver::Adapter>::Id // does not add additional invariants, so it's safe to transmute. let id = unsafe { &*raw_id.cast::<DeviceId>() }; - Some(table.info(<DeviceId as RawDeviceIdIndex>::index(id))) + // SAFETY: `id` comes from `table` which is of type `IdArray<_, Self::IdInfo>`. + Some(unsafe { id.info_unchecked::<T::IdInfo>() }) } } diff --git a/rust/kernel/of.rs b/rust/kernel/of.rs index 35aa6d36d309..d0318f62afd7 100644 --- a/rust/kernel/of.rs +++ b/rust/kernel/of.rs @@ -25,10 +25,6 @@ unsafe impl RawDeviceId for DeviceId { // SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `data` field. unsafe impl RawDeviceIdIndex for DeviceId { const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::of_device_id, data); - - fn index(&self) -> usize { - self.0.data as usize - } } impl DeviceId { diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs index a3dd48f76353..a630c7fc6a85 100644 --- a/rust/kernel/pci.rs +++ b/rust/kernel/pci.rs @@ -110,10 +110,13 @@ extern "C" fn probe_callback( // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct pci_device_id` and // does not add additional invariants, so it's safe to transmute. let id = unsafe { &*id.cast::<DeviceId>() }; - let info = T::ID_TABLE.info(id.index()); + + // SAFETY: `id` comes from `T::ID_TABLE` which is of type `IdArray<_, T::IdInfo>` or + // `pci_device_id_any` which has 0 as driver_data. + let info = unsafe { id.info_unchecked_opt::<T::IdInfo>() }; from_result(|| { - let data = T::probe(pdev, Some(info)); + let data = T::probe(pdev, info); pdev.as_ref().set_drvdata(data)?; Ok(0) @@ -233,10 +236,6 @@ unsafe impl RawDeviceId for DeviceId { // SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `driver_data` field. unsafe impl RawDeviceIdIndex for DeviceId { const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::pci_device_id, driver_data); - - fn index(&self) -> usize { - self.0.driver_data - } } /// `IdTable` type for PCI. diff --git a/rust/kernel/usb.rs b/rust/kernel/usb.rs index 500b5e0ba4ea..8aeff5011755 100644 --- a/rust/kernel/usb.rs +++ b/rust/kernel/usb.rs @@ -89,7 +89,8 @@ extern "C" fn probe_callback( // does not add additional invariants, so it's safe to transmute. let id = unsafe { &*id.cast::<DeviceId>() }; - let info = T::ID_TABLE.info(id.index()); + // SAFETY: `id` comes from `T::ID_TABLE` which is of type `IdArray<_, T::IdInfo>`. + let info = unsafe { id.info_unchecked::<T::IdInfo>() }; let data = T::probe(intf, id, info); let dev: &device::Device<device::CoreInternal<'_>> = intf.as_ref(); @@ -242,10 +243,6 @@ unsafe impl RawDeviceId for DeviceId { // SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `driver_info` field. unsafe impl RawDeviceIdIndex for DeviceId { const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::usb_device_id, driver_info); - - fn index(&self) -> usize { - self.0.driver_info - } } /// [`IdTable`](kernel::device_id::IdTable) type for USB. -- 2.54.0
