This is an automated email from the ASF dual-hosted git repository.

rduan pushed a commit to branch v2.0.0-preview
in repository https://gitbox.apache.org/repos/asf/incubator-teaclave-sgx-sdk.git


The following commit(s) were added to refs/heads/v2.0.0-preview by this push:
     new e4899117 Update SpinMutex and SpinRwLock
e4899117 is described below

commit e48991177dfb5ccb80f8504c4ca05fc16b96986f
Author: volcano <[email protected]>
AuthorDate: Thu Sep 29 11:06:07 2022 +0800

    Update SpinMutex and SpinRwLock
---
 sgx_trts/src/sync/mutex.rs  |  67 ++++++++++++++++--------
 sgx_trts/src/sync/rwlock.rs | 122 +++++++++++++++++++++++++++++---------------
 2 files changed, 126 insertions(+), 63 deletions(-)

diff --git a/sgx_trts/src/sync/mutex.rs b/sgx_trts/src/sync/mutex.rs
index 5ace3bba..4d3f0d96 100644
--- a/sgx_trts/src/sync/mutex.rs
+++ b/sgx_trts/src/sync/mutex.rs
@@ -27,7 +27,7 @@ use sgx_types::marker::ContiguousMemory;
 
 pub struct SpinMutex<T: ?Sized> {
     lock: AtomicBool,
-    value: UnsafeCell<T>,
+    data: UnsafeCell<T>,
 }
 
 unsafe impl<T: ContiguousMemory> ContiguousMemory for SpinMutex<T> {}
@@ -36,50 +36,73 @@ unsafe impl<T: ?Sized + Send> Sync for SpinMutex<T> {}
 unsafe impl<T: ?Sized + Send> Send for SpinMutex<T> {}
 
 pub struct SpinMutexGuard<'a, T: ?Sized + 'a> {
-    lock: &'a SpinMutex<T>,
+    lock: &'a AtomicBool,
+    data: &'a mut T,
 }
 
 impl<T: ?Sized> !Send for SpinMutexGuard<'_, T> {}
 unsafe impl<T: ?Sized + Sync> Sync for SpinMutexGuard<'_, T> {}
 
 impl<T> SpinMutex<T> {
-    pub const fn new(value: T) -> Self {
+    pub const fn new(data: T) -> Self {
         SpinMutex {
-            value: UnsafeCell::new(value),
             lock: AtomicBool::new(false),
+            data: UnsafeCell::new(data),
         }
     }
 
     #[inline]
     pub fn into_inner(self) -> T {
-        let SpinMutex { value, .. } = self;
-        value.into_inner()
+        let SpinMutex { data, .. } = self;
+        data.into_inner()
     }
 }
 
 impl<T: ?Sized> SpinMutex<T> {
     #[inline]
     pub fn lock(&self) -> SpinMutexGuard<'_, T> {
-        loop {
-            match self.try_lock() {
-                None => {
-                    while self.lock.load(Ordering::Relaxed) {
-                        spin_loop()
-                    }
-                }
-                Some(guard) => return guard,
+        while self
+            .lock
+            .compare_exchange_weak(false, true, Ordering::Acquire, 
Ordering::Relaxed)
+            .is_err()
+        {
+            while self.is_locked() {
+                spin_loop();
             }
         }
+
+        SpinMutexGuard {
+            lock: &self.lock,
+            data: unsafe { &mut *self.data.get() },
+        }
+    }
+
+    #[inline]
+    pub fn is_locked(&self) -> bool {
+        self.lock.load(Ordering::Relaxed)
+    }
+
+    #[inline]
+    pub fn unlock(guard: SpinMutexGuard<'_, T>) {
+        drop(guard);
+    }
+
+    #[inline]
+    pub unsafe fn force_unlock(&self) {
+        self.lock.store(false, Ordering::Release);
     }
 
     #[inline]
-    pub fn try_lock(&self) -> Option<SpinMutexGuard<'_, T>> {
+    pub fn try_lock(&self) -> Option<SpinMutexGuard<T>> {
         if self
             .lock
-            .compare_exchange(false, true, Ordering::Acquire, 
Ordering::Acquire)
+            .compare_exchange(false, true, Ordering::Acquire, 
Ordering::Relaxed)
             .is_ok()
         {
-            Some(SpinMutexGuard { lock: self })
+            Some(SpinMutexGuard {
+                lock: &self.lock,
+                data: unsafe { &mut *self.data.get() },
+            })
         } else {
             None
         }
@@ -87,7 +110,7 @@ impl<T: ?Sized> SpinMutex<T> {
 
     #[inline]
     pub fn get_mut(&mut self) -> &mut T {
-        unsafe { &mut *self.value.get() }
+        unsafe { &mut *self.data.get() }
     }
 }
 
@@ -118,19 +141,19 @@ impl<T: ?Sized> Deref for SpinMutexGuard<'_, T> {
     type Target = T;
 
     fn deref(&self) -> &T {
-        unsafe { &*self.lock.value.get() }
+        self.data
     }
 }
 
 impl<T: ?Sized> DerefMut for SpinMutexGuard<'_, T> {
     fn deref_mut(&mut self) -> &mut T {
-        unsafe { &mut *self.lock.value.get() }
+        self.data
     }
 }
 
 impl<'a, T: ?Sized> Drop for SpinMutexGuard<'a, T> {
     fn drop(&mut self) {
-        self.lock.lock.store(false, Ordering::Release)
+        self.lock.store(false, Ordering::Release)
     }
 }
 
@@ -153,6 +176,6 @@ impl RawMutex for SpinMutex<()> {
 
     #[inline]
     unsafe fn unlock(&self) {
-        drop(SpinMutexGuard { lock: self });
+        self.force_unlock();
     }
 }
diff --git a/sgx_trts/src/sync/rwlock.rs b/sgx_trts/src/sync/rwlock.rs
index 26328751..1dbded74 100644
--- a/sgx_trts/src/sync/rwlock.rs
+++ b/sgx_trts/src/sync/rwlock.rs
@@ -27,7 +27,7 @@ use sgx_types::marker::ContiguousMemory;
 
 pub struct SpinRwLock<T: ?Sized> {
     lock: AtomicUsize,
-    value: UnsafeCell<T>,
+    data: UnsafeCell<T>,
 }
 
 const READER: usize = 1 << 1;
@@ -38,8 +38,9 @@ unsafe impl<T: ContiguousMemory> ContiguousMemory for 
SpinRwLock<T> {}
 unsafe impl<T: ?Sized + Send> Send for SpinRwLock<T> {}
 unsafe impl<T: ?Sized + Send + Sync> Sync for SpinRwLock<T> {}
 
-pub struct SpinRwLockReadGuard<'a, T: ?Sized + 'a> {
-    lock: &'a SpinRwLock<T>,
+pub struct SpinRwLockReadGuard<'a, T: 'a + ?Sized> {
+    lock: &'a AtomicUsize,
+    data: &'a T,
 }
 
 impl<T: ?Sized> !Send for SpinRwLockReadGuard<'_, T> {}
@@ -47,7 +48,8 @@ impl<T: ?Sized> !Send for SpinRwLockReadGuard<'_, T> {}
 unsafe impl<T: ?Sized + Sync> Sync for SpinRwLockReadGuard<'_, T> {}
 
 pub struct SpinRwLockWriteGuard<'a, T: ?Sized + 'a> {
-    lock: &'a SpinRwLock<T>,
+    inner: &'a SpinRwLock<T>,
+    data: &'a mut T,
 }
 
 impl<T: ?Sized> !Send for SpinRwLockWriteGuard<'_, T> {}
@@ -56,17 +58,17 @@ unsafe impl<T: ?Sized + Sync> Sync for 
SpinRwLockWriteGuard<'_, T> {}
 
 impl<T> SpinRwLock<T> {
     #[inline]
-    pub const fn new(user_data: T) -> SpinRwLock<T> {
+    pub const fn new(data: T) -> SpinRwLock<T> {
         SpinRwLock {
             lock: AtomicUsize::new(0),
-            value: UnsafeCell::new(user_data),
+            data: UnsafeCell::new(data),
         }
     }
 
     #[inline]
     pub fn into_inner(self) -> T {
-        let SpinRwLock { value, .. } = self;
-        value.into_inner()
+        let SpinRwLock { data, .. } = self;
+        data.into_inner()
     }
 }
 
@@ -81,6 +83,16 @@ impl<T: ?Sized> SpinRwLock<T> {
         }
     }
 
+    #[inline]
+    pub fn write(&self) -> SpinRwLockWriteGuard<T> {
+        loop {
+            match self.try_write_internal(false) {
+                Some(guard) => return guard,
+                None => spin_loop(),
+            }
+        }
+    }
+
     #[inline]
     pub fn try_read(&self) -> Option<SpinRwLockReadGuard<T>> {
         let value = self.lock.fetch_add(READER, Ordering::Acquire);
@@ -88,28 +100,34 @@ impl<T: ?Sized> SpinRwLock<T> {
             self.lock.fetch_sub(READER, Ordering::Release);
             None
         } else {
-            Some(SpinRwLockReadGuard { lock: self })
+            Some(SpinRwLockReadGuard {
+                lock: &self.lock,
+                data: unsafe { &*self.data.get() },
+            })
         }
     }
 
     #[inline]
-    pub fn write(&self) -> SpinRwLockWriteGuard<T> {
-        loop {
-            match self.try_write() {
-                Some(guard) => return guard,
-                None => spin_loop(),
-            }
-        }
+    pub fn try_write(&self) -> Option<SpinRwLockWriteGuard<T>> {
+        self.try_write_internal(true)
     }
 
-    #[inline]
-    pub fn try_write(&self) -> Option<SpinRwLockWriteGuard<T>> {
-        if self
-            .lock
-            .compare_exchange(0, WRITER, Ordering::Acquire, Ordering::Relaxed)
-            .is_ok()
+    #[inline(always)]
+    fn try_write_internal(&self, strong: bool) -> 
Option<SpinRwLockWriteGuard<T>> {
+        if compare_exchange(
+            &self.lock,
+            0,
+            WRITER,
+            Ordering::Acquire,
+            Ordering::Relaxed,
+            strong,
+        )
+        .is_ok()
         {
-            Some(SpinRwLockWriteGuard { lock: self })
+            Some(SpinRwLockWriteGuard {
+                inner: self,
+                data: unsafe { &mut *self.data.get() },
+            })
         } else {
             None
         }
@@ -117,7 +135,7 @@ impl<T: ?Sized> SpinRwLock<T> {
 
     #[inline]
     pub fn get_mut(&mut self) -> &mut T {
-        unsafe { &mut *self.value.get() }
+        unsafe { &mut *self.data.get() }
     }
 
     #[inline]
@@ -155,7 +173,7 @@ impl<T: fmt::Debug> fmt::Debug for SpinRwLockReadGuard<'_, 
T> {
 impl<T: fmt::Debug> fmt::Debug for SpinRwLockWriteGuard<'_, T> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         f.debug_struct("SpinRwLockWriteGuard")
-            .field("lock", &self.lock)
+            .field("lock", &self.inner)
             .finish()
     }
 }
@@ -172,61 +190,80 @@ impl<T> From<T> for SpinRwLock<T> {
     }
 }
 
-impl<T: ?Sized> Deref for SpinRwLockReadGuard<'_, T> {
+impl<'rwlock, T: ?Sized> Deref for SpinRwLockReadGuard<'rwlock, T> {
     type Target = T;
 
     fn deref(&self) -> &T {
-        unsafe { &*self.lock.value.get() }
+        self.data
     }
 }
 
-impl<T: ?Sized> Deref for SpinRwLockWriteGuard<'_, T> {
+impl<'rwlock, T: ?Sized> Deref for SpinRwLockWriteGuard<'rwlock, T> {
     type Target = T;
 
     fn deref(&self) -> &T {
-        unsafe { &*self.lock.value.get() }
+        self.data
     }
 }
 
-impl<T: ?Sized> DerefMut for SpinRwLockWriteGuard<'_, T> {
+impl<'rwlock, T: ?Sized> DerefMut for SpinRwLockWriteGuard<'rwlock, T> {
     fn deref_mut(&mut self) -> &mut T {
-        unsafe { &mut *self.lock.value.get() }
+        self.data
     }
 }
 
 impl<T: ?Sized> Drop for SpinRwLockReadGuard<'_, T> {
     fn drop(&mut self) {
-        debug_assert!(self.lock.lock.load(Ordering::Relaxed) & !WRITER > 0);
-        self.lock.lock.fetch_sub(READER, Ordering::Release);
+        debug_assert!(self.lock.load(Ordering::Relaxed) & !WRITER > 0);
+        self.lock.fetch_sub(READER, Ordering::Release);
     }
 }
 
 impl<T: ?Sized> Drop for SpinRwLockWriteGuard<'_, T> {
     fn drop(&mut self) {
-        debug_assert_eq!(self.lock.lock.load(Ordering::Relaxed) & WRITER, 
WRITER);
-        self.lock.lock.fetch_and(!WRITER, Ordering::Release);
+        debug_assert_eq!(self.inner.lock.load(Ordering::Relaxed) & WRITER, 
WRITER);
+        self.inner.lock.fetch_and(!WRITER, Ordering::Release);
+    }
+}
+
+#[inline(always)]
+fn compare_exchange(
+    atomic: &AtomicUsize,
+    current: usize,
+    new: usize,
+    success: Ordering,
+    failure: Ordering,
+    strong: bool,
+) -> Result<usize, usize> {
+    if strong {
+        atomic.compare_exchange(current, new, success, failure)
+    } else {
+        atomic.compare_exchange_weak(current, new, success, failure)
     }
 }
 
 impl RawRwLock for SpinRwLock<()> {
-    #[inline]
+    #[inline(always)]
     fn read(&self) {
         mem::forget(self.read());
     }
 
-    #[inline]
+    #[inline(always)]
     fn try_read(&self) -> bool {
         self.try_read().map(mem::forget).is_some()
     }
 
     #[inline]
     unsafe fn read_unlock(&self) {
-        drop(SpinRwLockReadGuard { lock: self });
+        drop(SpinRwLockReadGuard {
+            lock: &self.lock,
+            data: &(),
+        });
     }
 
-    #[inline]
+    #[inline(always)]
     fn write(&self) {
-        mem::forget(self.write());
+        core::mem::forget(self.write());
     }
 
     #[inline]
@@ -235,6 +272,9 @@ impl RawRwLock for SpinRwLock<()> {
     }
 
     unsafe fn write_unlock(&self) {
-        drop(SpinRwLockWriteGuard { lock: self });
+        drop(SpinRwLockWriteGuard {
+            inner: self,
+            data: &mut (),
+        });
     }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to