Add the BoundedInt type, which restricts the number of bits allowed to be used in a given integer value. This is useful to carry guarantees when setting bitfields.
Alongside this type, many `From` and `TryFrom` implementations are provided to reduce friction when using with regular integer types. Proxy implementations of common integer traits are also provided. Signed-off-by: Alexandre Courbot <[email protected]> --- rust/kernel/lib.rs | 1 + rust/kernel/num.rs | 499 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 500 insertions(+) diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index fcffc3988a90..21c1f452ee6a 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -101,6 +101,7 @@ pub mod mm; #[cfg(CONFIG_NET)] pub mod net; +pub mod num; pub mod of; #[cfg(CONFIG_PM_OPP)] pub mod opp; diff --git a/rust/kernel/num.rs b/rust/kernel/num.rs new file mode 100644 index 000000000000..b2aad95ce51c --- /dev/null +++ b/rust/kernel/num.rs @@ -0,0 +1,499 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Numerical types for the kernel. + +use kernel::prelude::*; + +/// Integer type for which only the bits `0..NUM_BITS` are valid. +/// +/// # Invariants +/// +/// Stored values are represented with at most `NUM_BITS` bits. +#[repr(transparent)] +#[derive(Clone, Copy, Debug, Default, Hash)] +pub struct BoundedInt<T, const NUM_BITS: u32>(T); + +/// Returns `true` if `$value` can be represented with at most `$NUM_BITS` on `$type`. +macro_rules! is_in_bounds { + ($value:expr, $type:ty, $num_bits:expr) => {{ + let v = $value; + v & <$type as Boundable<NUM_BITS>>::MASK == v + }}; +} + +/// Trait for primitive integer types that can be used with `BoundedInt`. +pub trait Boundable<const NUM_BITS: u32> +where + Self: Sized + Copy + core::ops::BitAnd<Output = Self> + core::cmp::PartialEq, + Self: TryInto<u8> + TryInto<u16> + TryInto<u32> + TryInto<u64>, +{ + /// Mask of the valid bits for this type. + const MASK: Self; + + /// Returns `true` if `value` can be represented with at most `NUM_BITS`. + /// + /// TODO: post-RFC: replace this with a left-shift followed by right-shift operation. This will + /// allow us to handle signed values as well. + fn is_in_bounds(value: Self) -> bool { + is_in_bounds!(value, Self, NUM_BITS) + } +} + +impl<const NUM_BITS: u32> Boundable<NUM_BITS> for u8 { + const MASK: u8 = crate::bits::genmask_u8(0..=(NUM_BITS - 1)); +} + +impl<const NUM_BITS: u32> Boundable<NUM_BITS> for u16 { + const MASK: u16 = crate::bits::genmask_u16(0..=(NUM_BITS - 1)); +} + +impl<const NUM_BITS: u32> Boundable<NUM_BITS> for u32 { + const MASK: u32 = crate::bits::genmask_u32(0..=(NUM_BITS - 1)); +} + +impl<const NUM_BITS: u32> Boundable<NUM_BITS> for u64 { + const MASK: u64 = crate::bits::genmask_u64(0..=(NUM_BITS - 1)); +} + +impl<T, const NUM_BITS: u32> BoundedInt<T, NUM_BITS> +where + T: Boundable<NUM_BITS>, +{ + /// Checks that `value` is valid for this type at compile-time and build a new value. + /// + /// This relies on [`build_assert!`] to perform validation at compile-time. If `value` cannot + /// be inferred to be in bounds at compile-time, use the fallible [`Self::try_new`] instead. + /// + /// When possible, use one of the `new_const` methods instead of this method as it statically + /// validates `value` instead of relying on the compiler's optimizations. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::BoundedInt; + /// + /// # fn some_number() -> u32 { 0xffffffff } + /// + /// assert_eq!(BoundedInt::<u8, 1>::new(1).get(), 1); + /// assert_eq!(BoundedInt::<u16, 8>::new(0xff).get(), 0xff); + /// + /// // Triggers a build error as `0x1ff` doesn't fit into 8 bits. + /// // assert_eq!(BoundedInt::<u32, 8>::new(0x1ff).get(), 0x1ff); + /// + /// let v: u32 = some_number(); + /// // Triggers a build error as `v` cannot be asserted to fit within 4 bits... + /// // let _ = BoundedInt::<u32, 4>::new(v); + /// // ... but this works as the compiler can assert the range from the mask. + /// let _ = BoundedInt::<u32, 4>::new(v & 0xf); + /// ``` + pub fn new(value: T) -> Self { + crate::build_assert!( + T::is_in_bounds(value), + "Provided parameter is larger than maximal supported value" + ); + + Self(value) + } + + /// Attempts to convert `value` into a value bounded by `NUM_BITS`. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::BoundedInt; + /// + /// assert_eq!(BoundedInt::<u8, 1>::try_new(1).map(|v| v.get()), Ok(1)); + /// assert_eq!(BoundedInt::<u16, 8>::try_new(0xff).map(|v| v.get()), Ok(0xff)); + /// + /// // `0x1ff` doesn't fit into 8 bits. + /// assert_eq!(BoundedInt::<u32, 8>::try_new(0x1ff), Err(EOVERFLOW)); + /// ``` + pub fn try_new(value: T) -> Result<Self> { + if !T::is_in_bounds(value) { + Err(EOVERFLOW) + } else { + Ok(Self(value)) + } + } + + /// Returns the contained value as a primitive type. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::BoundedInt; + /// + /// let v = BoundedInt::<u32, 4>::new_const::<7>(); + /// assert_eq!(v.get(), 7u32); + /// ``` + pub fn get(self) -> T { + if !T::is_in_bounds(self.0) { + // SAFETY: Per the invariants, `self.0` cannot have bits set outside of `MASK`, so + // this block will + // never be reached. + unsafe { core::hint::unreachable_unchecked() } + } + self.0 + } + + /// Increase the number of bits usable for `self`. + /// + /// This operation cannot fail. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::BoundedInt; + /// + /// let v = BoundedInt::<u32, 4>::new_const::<7>(); + /// let larger_v = v.enlarge::<12>(); + /// // The contained values are equal even though `larger_v` has a bigger capacity. + /// assert_eq!(larger_v, v); + /// ``` + pub const fn enlarge<const NEW_NUM_BITS: u32>(self) -> BoundedInt<T, NEW_NUM_BITS> + where + T: Boundable<NEW_NUM_BITS>, + T: Copy, + { + build_assert!(NEW_NUM_BITS >= NUM_BITS); + + // INVARIANT: the value did fit within `NUM_BITS`, so it will all the more fit within + // `NEW_NUM_BITS` which is larger. + BoundedInt(self.0) + } + + /// Shrink the number of bits usable for `self`. + /// + /// Returns `EOVERFLOW` if the value of `self` cannot be represented within `NEW_NUM_BITS`. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::BoundedInt; + /// + /// let v = BoundedInt::<u32, 12>::new_const::<7>(); + /// let smaller_v = v.shrink::<4>()?; + /// // The contained values are equal even though `smaller_v` has a smaller capacity. + /// assert_eq!(smaller_v, v); + /// + /// # Ok::<(), Error>(()) + /// ``` + pub fn shrink<const NEW_NUM_BITS: u32>(self) -> Result<BoundedInt<T, NEW_NUM_BITS>> + where + T: Boundable<NEW_NUM_BITS>, + T: Copy, + { + BoundedInt::<T, NEW_NUM_BITS>::try_new(self.get()) + } + + /// Casts `self` into a `BoundedInt` using a different storage type, but using the same + /// number of bits for representation. + /// + /// This method cannot fail as the number of bits used for representation doesn't change. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::BoundedInt; + /// + /// let v = BoundedInt::<u32, 4>::new_const::<7>(); + /// let smaller_v: BoundedInt<u8, _> = v.cast(); + /// // The contained values are equal even though `smaller_v` has a smaller storage type. + /// assert_eq!(u32::from(smaller_v.get()), v.get()); + /// ``` + pub fn cast<U>(self) -> BoundedInt<U, NUM_BITS> + where + U: TryFrom<T> + Boundable<NUM_BITS>, + { + // SAFETY: the contained value is represented using `NUM_BITS`, and `U` can be bounded to + // `NUM_BITS`, hence the conversion cannot fail. + let value = unsafe { U::try_from(self.0).unwrap_unchecked() }; + + // INVARIANT: although the storage type has changed, the value is still represented within + // `NUM_BITS`. + BoundedInt(value) + } +} + +/// Validating the value as a const expression cannot be done as a regular method, as the +/// arithmetic expressions we rely on to check the bounds are not const. Thus, implement +/// `new_const` using a macro. +macro_rules! impl_const_new { + ($($type:ty)*) => { + $( + impl<const NUM_BITS: u32> BoundedInt<$type, NUM_BITS> { + /// Creates a bounded value for `VALUE`, statically validated. + /// + /// This method should be used instead of [`Self::new`] when the value is a constant + /// expression. + /// + /// # Examples + /// ``` + /// use kernel::num::BoundedInt; + /// + #[doc = ::core::concat!( + "let v = BoundedInt::<", + ::core::stringify!($type), + ", 4>::new_const::<7>();")] + /// assert_eq!(v.get(), 7); + /// ``` + pub const fn new_const<const VALUE: $type>() -> Self { + build_assert!(is_in_bounds!(VALUE, $type, NUM_BITS)); + + Self(VALUE) + } + } + )* + }; +} + +impl_const_new!(u8 u16 u32 u64); + +/// Declares a new `$trait` and implements it for all bounded types represented using `$num_bits`. +/// +/// This is used to declare properties as traits that we can use for later implementations. +macro_rules! impl_size_rule { + ($trait:ident, $($num_bits:literal)*) => { + trait $trait {} + + $( + impl<T> $trait for BoundedInt<T, $num_bits> where T: Boundable<$num_bits> {} + )* + }; +} + +// Bounds that are larger than a `u64`. +impl_size_rule!(LargerThanU64, 64); + +// Bounds that are larger than a `u32`. +impl_size_rule!(LargerThanU32, + 32 33 34 35 36 37 38 39 + 40 41 42 43 44 45 46 47 + 48 49 50 51 52 53 54 55 + 56 57 58 59 60 61 62 63 +); +// Anything larger than `u64` is also larger than `u32`. +impl<T> LargerThanU32 for T where T: LargerThanU64 {} + +// Bounds that are larger than a `u16`. +impl_size_rule!(LargerThanU16, + 16 17 18 19 20 21 22 23 + 24 25 26 27 28 29 30 31 +); +// Anything larger than `u32` is also larger than `u16`. +impl<T> LargerThanU16 for T where T: LargerThanU32 {} + +// Bounds that are larger than a `u8`. +impl_size_rule!(LargerThanU8, 8 9 10 11 12 13 14 15); +// Anything larger than `u16` is also larger than `u8`. +impl<T> LargerThanU8 for T where T: LargerThanU16 {} + +// Bounds that are larger than a boolean. +impl_size_rule!(LargerThanBool, 1 2 3 4 5 6 7); +// Anything larger than `u8` is also larger than `bool`. +impl<T> LargerThanBool for T where T: LargerThanU8 {} + +/// Generates `From` implementations from a primitive type into a bounded integer that is +/// guaranteed to being able to contain it. +macro_rules! impl_from_primitive { + ($($type:ty => $trait:ident),*) => { + $( + impl<T, const NUM_BITS: u32> From<$type> for BoundedInt<T, NUM_BITS> + where + Self: $trait, + T: From<$type>, + { + fn from(value: $type) -> Self { + Self(T::from(value)) + } + } + )* + } +} + +impl_from_primitive!( + bool => LargerThanBool, + u8 => LargerThanU8, + u16 => LargerThanU16, + u32 => LargerThanU32, + u64 => LargerThanU64 +); + +impl_size_rule!(FitsIntoBool, 1); + +impl_size_rule!(FitsIntoU8, 2 3 4 5 6 7 8); + +// Anything that fits into a `bool` also fits into a `u8`. +impl<T> FitsIntoU8 for T where T: FitsIntoBool {} + +impl_size_rule!(FitsIntoU16, 9 10 11 12 13 14 15 16); + +// Anything that fits into a `u8` also fits into a `u16`. +impl<T> FitsIntoU16 for T where T: FitsIntoU8 {} + +impl_size_rule!(FitsIntoU32, + 17 18 19 20 21 22 23 24 + 25 26 27 28 29 30 31 32 +); + +// Anything that fits into a `u16` also fits into a `u32`. +impl<T> FitsIntoU32 for T where T: FitsIntoU16 {} + +impl_size_rule!(FitsIntoU64, + 33 34 35 36 37 38 39 40 + 41 42 43 44 45 46 47 48 + 49 50 51 52 53 54 55 56 + 57 58 59 60 61 62 63 64 +); + +// Anything that fits into a `u32` also fits into a `u64`. +impl<T> FitsIntoU64 for T where T: FitsIntoU32 {} + +/// Generates `From` implementations from a bounded integer into a primitive type that is +/// guaranteed to being able to contain it. +macro_rules! impl_into_primitive { + ($($trait:ident => $type:ty),*) => { + $( + impl<T, const NUM_BITS: u32> From<BoundedInt<T, NUM_BITS>> for $type + where + T: Boundable<NUM_BITS>, + BoundedInt<T, NUM_BITS>: $trait + { + fn from(value: BoundedInt<T, NUM_BITS>) -> Self { + // SAFETY: per the `BoundedInt` invariants, less than 8 bits are used to the conversion + // cannot fail. + unsafe { value.get().try_into().unwrap_unchecked() } + } + } + )* + } +} + +impl_into_primitive!( + FitsIntoU8 => u8, + FitsIntoU16 => u16, + FitsIntoU32 => u32, + FitsIntoU64 => u64 +); + +// Conversion to boolean must be handled separately as it does not have `TryFrom` implementation +// from integers. +impl<T> From<BoundedInt<T, 1>> for bool +where + T: Boundable<1>, + BoundedInt<T, 1>: FitsIntoBool, + T: PartialEq + Zeroable, +{ + fn from(value: BoundedInt<T, 1>) -> Self { + value.get() != Zeroable::zeroed() + } +} + +/// Trait similar to `TryInto` to avoid conflicting implementations errors. +pub trait TryIntoBounded<T: Boundable<NUM_BITS>, const NUM_BITS: u32> { + /// Attempts to convert `self` into a value bounded by `NUM_BITS`. + fn try_into(self) -> Result<BoundedInt<T, NUM_BITS>>; +} + +/// Any value can be attempted to be converted into a bounded integer of any size. +impl<T, U, const NUM_BITS: u32> TryIntoBounded<T, NUM_BITS> for U +where + T: Boundable<NUM_BITS>, + U: TryInto<T>, +{ + fn try_into(self) -> Result<BoundedInt<T, NUM_BITS>> { + self.try_into() + .map_err(|_| EOVERFLOW) + .and_then(BoundedInt::try_new) + } +} + +/// `BoundedInts` can be compared if their respective storage types can be. +impl<T, U, const NUM_BITS: u32, const NUM_BITS_U: u32> PartialEq<BoundedInt<U, NUM_BITS_U>> + for BoundedInt<T, NUM_BITS> +where + T: Boundable<NUM_BITS>, + U: Boundable<NUM_BITS_U>, + T: PartialEq<U>, +{ + fn eq(&self, other: &BoundedInt<U, NUM_BITS_U>) -> bool { + self.get() == other.get() + } +} + +impl<T, const NUM_BITS: u32> Eq for BoundedInt<T, NUM_BITS> where T: Boundable<NUM_BITS> {} + +/// `BoundedInts` can be ordered if their respective storage types can be. +impl<T, U, const NUM_BITS: u32, const NUM_BITS_U: u32> PartialOrd<BoundedInt<U, NUM_BITS_U>> + for BoundedInt<T, NUM_BITS> +where + T: Boundable<NUM_BITS>, + U: Boundable<NUM_BITS_U>, + T: PartialOrd<U>, +{ + fn partial_cmp(&self, other: &BoundedInt<U, NUM_BITS_U>) -> Option<core::cmp::Ordering> { + self.get().partial_cmp(&other.get()) + } +} + +impl<T, const NUM_BITS: u32> Ord for BoundedInt<T, NUM_BITS> +where + T: Boundable<NUM_BITS>, + T: Ord, +{ + fn cmp(&self, other: &Self) -> core::cmp::Ordering { + self.get().cmp(&other.get()) + } +} + +/// Allow comparison with non-bounded values. +impl<T, const NUM_BITS: u32> PartialEq<T> for BoundedInt<T, NUM_BITS> +where + T: Boundable<NUM_BITS>, + T: PartialEq, +{ + fn eq(&self, other: &T) -> bool { + self.get() == *other + } +} + +/// Allow ordering with non-bounded values. +impl<T, const NUM_BITS: u32> PartialOrd<T> for BoundedInt<T, NUM_BITS> +where + T: Boundable<NUM_BITS>, + T: PartialOrd, +{ + fn partial_cmp(&self, other: &T) -> Option<core::cmp::Ordering> { + self.get().partial_cmp(other) + } +} + +impl<T, const NUM_BITS: u32> core::fmt::Display for BoundedInt<T, NUM_BITS> +where + T: Boundable<NUM_BITS>, + T: core::fmt::Display, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.fmt(f) + } +} + +impl<T, const NUM_BITS: u32> core::fmt::LowerHex for BoundedInt<T, NUM_BITS> +where + T: Boundable<NUM_BITS>, + T: core::fmt::LowerHex, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.fmt(f) + } +} + +impl<T, const NUM_BITS: u32> core::fmt::UpperHex for BoundedInt<T, NUM_BITS> +where + T: Boundable<NUM_BITS>, + T: core::fmt::UpperHex, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.fmt(f) + } +} -- 2.51.0
