Add a general purpose sum type `Either` (with no preference on either side,
unlike `Result`) and implement `Io` for it.

This is generic version of what C `iosys_map` provides, which can be
just implemented like this:

    type IoSysMap<'a, T> = Either<Mmio<'a, T>, SysMem<'a, T>>;

Signed-off-by: Gary Guo <[email protected]>
---
 rust/kernel/io.rs    | 93 +++++++++++++++++++++++++++++++++++++++++++++++++++-
 rust/kernel/types.rs |  9 +++++
 2 files changed, 101 insertions(+), 1 deletion(-)

diff --git a/rust/kernel/io.rs b/rust/kernel/io.rs
index 2b238b625672..28d713eaddda 100644
--- a/rust/kernel/io.rs
+++ b/rust/kernel/io.rs
@@ -19,7 +19,8 @@
     transmute::{
         AsBytes,
         FromBytes, //
-    }, //
+    },
+    types::Either, //
 };
 
 pub mod mem;
@@ -1592,6 +1593,96 @@ pub unsafe fn project_view<U: ?Sized + KnownSize>(
     }
 }
 
+impl<'a, T: ?Sized + KnownSize, L: IoBase<'a, Target = T>, R: IoBase<'a, 
Target = T>> IoBase<'a>
+    for Either<L, R>
+{
+    type Backend = Either<L::Backend, R::Backend>;
+    type Target = T;
+
+    #[inline]
+    fn as_view(self) -> <Self::Backend as IoBackend>::View<'a, Self::Target> {
+        match self {
+            Either::Left(l) => Either::Left(l.as_view()),
+            Either::Right(r) => Either::Right(r.as_view()),
+        }
+    }
+}
+
+impl<L: IoBackend, R: IoBackend> IoBackend for Either<L, R> {
+    type View<'a, T: ?Sized + KnownSize> = Either<L::View<'a, T>, R::View<'a, 
T>>;
+
+    #[inline]
+    fn as_ptr<'a, T: ?Sized + KnownSize>(view: Self::View<'a, T>) -> *mut T {
+        match view {
+            Either::Left(l) => L::as_ptr(l),
+            Either::Right(r) => R::as_ptr(r),
+        }
+    }
+
+    #[inline]
+    unsafe fn project_view<'a, T: ?Sized + KnownSize, U: ?Sized + KnownSize>(
+        view: Self::View<'a, T>,
+        ptr: *mut U,
+    ) -> Self::View<'a, U> {
+        match view {
+            // SAFETY: Per safety requirement.
+            Either::Left(l) => Either::Left(unsafe { L::project_view(l, ptr) 
}),
+            // SAFETY: Per safety requirement.
+            Either::Right(r) => Either::Right(unsafe { R::project_view(r, ptr) 
}),
+        }
+    }
+}
+
+impl<T, L: IoCapable<T>, R: IoCapable<T>> IoCapable<T> for Either<L, R> {
+    #[inline]
+    fn io_read(view: Self::View<'_, T>) -> T {
+        match view {
+            Either::Left(l) => L::io_read(l),
+            Either::Right(r) => R::io_read(r),
+        }
+    }
+
+    #[inline]
+    fn io_write<'a>(view: Self::View<'a, T>, value: T) {
+        match view {
+            Either::Left(l) => L::io_write(l, value),
+            Either::Right(r) => R::io_write(r, value),
+        }
+    }
+}
+
+// SAFETY: Per safety guarantee of `L` and `R`'s `IoCopyable` impl, 
`is_mapped` is correctly
+// implemented.
+unsafe impl<L: IoCopyable, R: IoCopyable> IoCopyable for Either<L, R> {
+    #[inline]
+    fn is_mapped<T: ?Sized + KnownSize>(view: Self::View<'_, T>) -> bool {
+        match view {
+            Either::Left(l) => L::is_mapped(l),
+            Either::Right(r) => R::is_mapped(r),
+        }
+    }
+
+    #[inline]
+    unsafe fn copy_from_io(view: Self::View<'_, [u8]>, buffer: *mut u8) {
+        match view {
+            // SAFETY: Per safety requirement.
+            Either::Left(l) => unsafe { L::copy_from_io(l, buffer) },
+            // SAFETY: Per safety requirement.
+            Either::Right(r) => unsafe { R::copy_from_io(r, buffer) },
+        }
+    }
+
+    #[inline]
+    unsafe fn copy_to_io(view: Self::View<'_, [u8]>, buffer: *const u8) {
+        match view {
+            // SAFETY: Per safety requirement.
+            Either::Left(l) => unsafe { L::copy_to_io(l, buffer) },
+            // SAFETY: Per safety requirement.
+            Either::Right(r) => unsafe { R::copy_to_io(r, buffer) },
+        }
+    }
+}
+
 /// Project an I/O type to a subview of it.
 ///
 /// The syntax is of form `io_project!(io, proj)` where `io` is an expression 
to a type that
diff --git a/rust/kernel/types.rs b/rust/kernel/types.rs
index ac316fd7b538..12546c312dd2 100644
--- a/rust/kernel/types.rs
+++ b/rust/kernel/types.rs
@@ -448,3 +448,12 @@ fn pin_init<E>(slot: impl PinInit<T, E>) -> impl 
PinInit<Self, E> {
 /// [`NotThreadSafe`]: type@NotThreadSafe
 #[allow(non_upper_case_globals)]
 pub const NotThreadSafe: NotThreadSafe = PhantomData;
+
+/// General purpose sum type with two cases.
+#[derive(Clone, Copy)]
+pub enum Either<L, R> {
+    /// A value of type `L`.
+    Left(L),
+    /// A value of type `R`.
+    Right(R),
+}

-- 
2.54.0

Reply via email to