This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b87f210dbd fix: reduce lock contention in distributor channels (#10026)
b87f210dbd is described below

commit b87f210dbdd90e5f65caefac1eeb053b0f0f612e
Author: Marco Neumann <[email protected]>
AuthorDate: Thu Apr 25 14:31:55 2024 +0200

    fix: reduce lock contention in distributor channels (#10026)
    
    * fix: lock contention in distributor channels
    
    Reduce lock contention in distributor channels via:
    
    - use atomic counters instead of "counter behind mutex" where
      appropriate
    - use less state
    - only lock when needed
    - move "wake" operation out of lock scopes (they are eventual operations
      anyways and many wake operations results in "futex wake" operations --
      i.e. a syscall -- which you should avoid while holding the lock)
    
    * refactor: add more docs and tests for distributor channels
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../src/repartition/distributor_channels.rs        | 358 ++++++++++++++-------
 1 file changed, 245 insertions(+), 113 deletions(-)

diff --git a/datafusion/physical-plan/src/repartition/distributor_channels.rs 
b/datafusion/physical-plan/src/repartition/distributor_channels.rs
index e71b88467b..bad923ce9e 100644
--- a/datafusion/physical-plan/src/repartition/distributor_channels.rs
+++ b/datafusion/physical-plan/src/repartition/distributor_channels.rs
@@ -40,8 +40,12 @@
 use std::{
     collections::VecDeque,
     future::Future,
+    ops::DerefMut,
     pin::Pin,
-    sync::Arc,
+    sync::{
+        atomic::{AtomicUsize, Ordering},
+        Arc,
+    },
     task::{Context, Poll, Waker},
 };
 
@@ -52,20 +56,12 @@ pub fn channels<T>(
     n: usize,
 ) -> (Vec<DistributionSender<T>>, Vec<DistributionReceiver<T>>) {
     let channels = (0..n)
-        .map(|id| {
-            Arc::new(Mutex::new(Channel {
-                data: VecDeque::default(),
-                n_senders: 1,
-                recv_alive: true,
-                recv_wakers: Vec::default(),
-                id,
-            }))
-        })
+        .map(|id| Arc::new(Channel::new_with_one_sender(id)))
         .collect::<Vec<_>>();
-    let gate = Arc::new(Mutex::new(Gate {
-        empty_channels: n,
-        send_wakers: Vec::default(),
-    }));
+    let gate = Arc::new(Gate {
+        empty_channels: AtomicUsize::new(n),
+        send_wakers: Mutex::new(None),
+    });
     let senders = channels
         .iter()
         .map(|channel| DistributionSender {
@@ -143,8 +139,7 @@ impl<T> DistributionSender<T> {
 
 impl<T> Clone for DistributionSender<T> {
     fn clone(&self) -> Self {
-        let mut guard = self.channel.lock();
-        guard.n_senders += 1;
+        self.channel.n_senders.fetch_add(1, Ordering::SeqCst);
 
         Self {
             channel: Arc::clone(&self.channel),
@@ -155,19 +150,46 @@ impl<T> Clone for DistributionSender<T> {
 
 impl<T> Drop for DistributionSender<T> {
     fn drop(&mut self) {
-        let mut guard_channel = self.channel.lock();
-        guard_channel.n_senders -= 1;
+        let n_senders_pre = self.channel.n_senders.fetch_sub(1, 
Ordering::SeqCst);
+        // is the the last copy of the sender side?
+        if n_senders_pre > 1 {
+            return;
+        }
 
-        if guard_channel.n_senders == 0 {
-            // Note: the recv_alive check is so that we don't double-clear the 
status
-            if guard_channel.data.is_empty() && guard_channel.recv_alive {
+        let receivers = {
+            let mut state = self.channel.state.lock();
+
+            // During the shutdown of a empty channel, both the sender and the 
receiver side will be dropped. However we
+            // only want to decrement the "empty channels" counter once.
+            //
+            // We are within a critical section here, so we we can safely 
assume that either the last sender or the
+            // receiver (there's only one) will be dropped first.
+            //
+            // If the last sender is dropped first, `state.data` will still 
exists and the sender side decrements the
+            // signal. The receiver side then MUST check the `n_senders` 
counter during the section and if it is zero,
+            // it inferres that it is dropped afterwards and MUST NOT 
decrement the counter.
+            //
+            // If the receiver end is dropped first, it will inferr -- based 
on `n_senders` -- that there are still
+            // senders and it will decrement the `empty_channels` counter. It 
will also set `data` to `None`. The sender
+            // side will then see that `data` is `None` and can therefore 
inferr that the receiver end was dropped, and
+            // hence it MUST NOT decrement the `empty_channels` counter.
+            if state
+                .data
+                .as_ref()
+                .map(|data| data.is_empty())
+                .unwrap_or_default()
+            {
                 // channel is gone, so we need to clear our signal
-                let mut guard_gate = self.gate.lock();
-                guard_gate.empty_channels -= 1;
+                self.gate.decr_empty_channels();
             }
 
-            // receiver may be waiting for data, but should return `None` now 
since the channel is closed
-            guard_channel.wake_receivers();
+            // make sure that nobody can add wakers anymore
+            state.recv_wakers.take().expect("not closed yet")
+        };
+
+        // wake outside of lock scope
+        for recv in receivers {
+            recv.wake();
         }
     }
 }
@@ -188,33 +210,41 @@ impl<'a, T> Future for SendFuture<'a, T> {
         let this = &mut *self;
         assert!(this.element.is_some(), "polled ready future");
 
-        let mut guard_channel = this.channel.lock();
-
-        // receiver end still alive?
-        if !guard_channel.recv_alive {
-            return Poll::Ready(Err(SendError(
-                this.element.take().expect("just checked"),
-            )));
-        }
+        // lock scope
+        let to_wake = {
+            let mut guard_channel_state = this.channel.state.lock();
+
+            let Some(data) = guard_channel_state.data.as_mut() else {
+                // receiver end dead
+                return Poll::Ready(Err(SendError(
+                    this.element.take().expect("just checked"),
+                )));
+            };
+
+            // does ANY receiver need data?
+            // if so, allow sender to create another
+            if this.gate.empty_channels.load(Ordering::SeqCst) == 0 {
+                let mut guard = this.gate.send_wakers.lock();
+                if let Some(send_wakers) = guard.deref_mut() {
+                    send_wakers.push((cx.waker().clone(), this.channel.id));
+                    return Poll::Pending;
+                }
+            }
 
-        let mut guard_gate = this.gate.lock();
+            let was_empty = data.is_empty();
+            data.push_back(this.element.take().expect("just checked"));
 
-        // does ANY receiver need data?
-        // if so, allow sender to create another
-        if guard_gate.empty_channels == 0 {
-            guard_gate
-                .send_wakers
-                .push((cx.waker().clone(), guard_channel.id));
-            return Poll::Pending;
-        }
+            if was_empty {
+                this.gate.decr_empty_channels();
+                guard_channel_state.take_recv_wakers()
+            } else {
+                Vec::with_capacity(0)
+            }
+        };
 
-        let was_empty = guard_channel.data.is_empty();
-        guard_channel
-            .data
-            .push_back(this.element.take().expect("just checked"));
-        if was_empty {
-            guard_gate.empty_channels -= 1;
-            guard_channel.wake_receivers();
+        // wake outside of lock scope
+        for receiver in to_wake {
+            receiver.wake();
         }
 
         Poll::Ready(Ok(()))
@@ -243,21 +273,18 @@ impl<T> DistributionReceiver<T> {
 
 impl<T> Drop for DistributionReceiver<T> {
     fn drop(&mut self) {
-        let mut guard_channel = self.channel.lock();
-        let mut guard_gate = self.gate.lock();
-        guard_channel.recv_alive = false;
+        let mut guard_channel_state = self.channel.state.lock();
+        let data = guard_channel_state.data.take().expect("not dropped yet");
 
-        // Note: n_senders check is here so we don't double-clear the signal
-        if guard_channel.data.is_empty() && (guard_channel.n_senders > 0) {
+        // See `DistributedSender::drop` for an explanation of the drop order 
and when the "empty channels" counter is
+        // decremented.
+        if data.is_empty() && (self.channel.n_senders.load(Ordering::SeqCst) > 
0) {
             // channel is gone, so we need to clear our signal
-            guard_gate.empty_channels -= 1;
+            self.gate.decr_empty_channels();
         }
 
         // senders may be waiting for gate to open but should error now that 
the channel is closed
-        guard_gate.wake_channel_senders(guard_channel.id);
-
-        // clear potential remaining data from channel
-        guard_channel.data.clear();
+        self.gate.wake_channel_senders(self.channel.id);
     }
 }
 
@@ -275,37 +302,51 @@ impl<'a, T> Future for RecvFuture<'a, T> {
         let this = &mut *self;
         assert!(!this.rdy, "polled ready future");
 
-        let mut guard_channel = this.channel.lock();
+        let mut guard_channel_state = this.channel.state.lock();
+        let channel_state = guard_channel_state.deref_mut();
+        let data = channel_state.data.as_mut().expect("not dropped yet");
 
-        match guard_channel.data.pop_front() {
+        match data.pop_front() {
             Some(element) => {
                 // change "empty" signal for this channel?
-                if guard_channel.data.is_empty() && (guard_channel.n_senders > 
0) {
-                    let mut guard_gate = this.gate.lock();
-
+                if data.is_empty() && channel_state.recv_wakers.is_some() {
                     // update counter
-                    let old_counter = guard_gate.empty_channels;
-                    guard_gate.empty_channels += 1;
+                    let old_counter =
+                        this.gate.empty_channels.fetch_add(1, 
Ordering::SeqCst);
 
                     // open gate?
-                    if old_counter == 0 {
-                        guard_gate.wake_all_senders();
+                    let to_wake = if old_counter == 0 {
+                        let mut guard = this.gate.send_wakers.lock();
+
+                        // check after lock to see if we should still change 
the state
+                        if this.gate.empty_channels.load(Ordering::SeqCst) > 0 
{
+                            guard.take().unwrap_or_default()
+                        } else {
+                            Vec::with_capacity(0)
+                        }
+                    } else {
+                        Vec::with_capacity(0)
+                    };
+
+                    drop(guard_channel_state);
+
+                    // wake outside of lock scope
+                    for (waker, _channel_id) in to_wake {
+                        waker.wake();
                     }
-
-                    drop(guard_gate);
-                    drop(guard_channel);
                 }
 
                 this.rdy = true;
                 Poll::Ready(Some(element))
             }
-            None if guard_channel.n_senders == 0 => {
-                this.rdy = true;
-                Poll::Ready(None)
-            }
             None => {
-                guard_channel.recv_wakers.push(cx.waker().clone());
-                Poll::Pending
+                if let Some(recv_wakers) = channel_state.recv_wakers.as_mut() {
+                    recv_wakers.push(cx.waker().clone());
+                    Poll::Pending
+                } else {
+                    this.rdy = true;
+                    Poll::Ready(None)
+                }
             }
         }
     }
@@ -314,78 +355,122 @@ impl<'a, T> Future for RecvFuture<'a, T> {
 /// Links senders and receivers.
 #[derive(Debug)]
 struct Channel<T> {
-    /// Buffered data.
-    data: VecDeque<T>,
-
     /// Reference counter for the sender side.
-    n_senders: usize,
-
-    /// Reference "counter"/flag for the single receiver.
-    recv_alive: bool,
-
-    /// Wakers for the receiver side.
-    ///
-    /// The receiver will be pending if the [buffer](Self::data) is empty and
-    /// there are senders left (according to the [reference 
counter](Self::n_senders)).
-    recv_wakers: Vec<Waker>,
+    n_senders: AtomicUsize,
 
     /// Channel ID.
     ///
     /// This is used to address [send wakers](Gate::send_wakers).
     id: usize,
+
+    /// Mutable state.
+    state: Mutex<ChannelState<T>>,
 }
 
 impl<T> Channel<T> {
-    fn wake_receivers(&mut self) {
-        for waker in self.recv_wakers.drain(..) {
-            waker.wake();
+    /// Create new channel with one sender (so we don't need to 
[fetch-add](AtomicUsize::fetch_add) directly afterwards).
+    fn new_with_one_sender(id: usize) -> Self {
+        Channel {
+            n_senders: AtomicUsize::new(1),
+            id,
+            state: Mutex::new(ChannelState {
+                data: Some(VecDeque::default()),
+                recv_wakers: Some(Vec::default()),
+            }),
         }
     }
 }
 
+#[derive(Debug)]
+struct ChannelState<T> {
+    /// Buffered data.
+    ///
+    /// This is [`None`] when the receiver is gone.
+    data: Option<VecDeque<T>>,
+
+    /// Wakers for the receiver side.
+    ///
+    /// The receiver will be pending if the [buffer](Self::data) is empty and
+    /// there are senders left (otherwise this is set to [`None`]).
+    recv_wakers: Option<Vec<Waker>>,
+}
+
+impl<T> ChannelState<T> {
+    /// Get all [`recv_wakers`](Self::recv_wakers) and replace with 
identically-sized buffer.
+    ///
+    /// The wakers should be woken AFTER the lock to [this state](Self) was 
dropped.
+    ///
+    /// # Panics
+    /// Assumes that channel is NOT closed yet, i.e. that 
[`recv_wakers`](Self::recv_wakers) is not [`None`].
+    fn take_recv_wakers(&mut self) -> Vec<Waker> {
+        let to_wake = self.recv_wakers.as_mut().expect("not closed");
+        let mut tmp = Vec::with_capacity(to_wake.capacity());
+        std::mem::swap(to_wake, &mut tmp);
+        tmp
+    }
+}
+
 /// Shared channel.
 ///
 /// One or multiple senders and a single receiver will share a channel.
-type SharedChannel<T> = Arc<Mutex<Channel<T>>>;
+type SharedChannel<T> = Arc<Channel<T>>;
 
 /// The "all channels have data" gate.
 #[derive(Debug)]
 struct Gate {
     /// Number of currently empty (and still open) channels.
-    empty_channels: usize,
+    empty_channels: AtomicUsize,
 
     /// Wakers for the sender side, including their channel IDs.
-    send_wakers: Vec<(Waker, usize)>,
+    ///
+    /// This is `None` if the there are non-empty channels.
+    send_wakers: Mutex<Option<Vec<(Waker, usize)>>>,
 }
 
 impl Gate {
-    //// Wake all senders.
+    /// Wake senders for a specific channel.
     ///
-    /// This is helpful to signal that there are some channels empty now and 
hence the gate was opened.
-    fn wake_all_senders(&mut self) {
-        for (waker, _id) in self.send_wakers.drain(..) {
+    /// This is helpful to signal that the receiver side is gone and the 
senders shall now error.
+    fn wake_channel_senders(&self, id: usize) {
+        // lock scope
+        let to_wake = {
+            let mut guard = self.send_wakers.lock();
+
+            if let Some(send_wakers) = guard.deref_mut() {
+                // `drain_filter` is unstable, so implement our own
+                let (wake, keep) =
+                    send_wakers.drain(..).partition(|(_waker, id2)| id == 
*id2);
+
+                *send_wakers = keep;
+
+                wake
+            } else {
+                Vec::with_capacity(0)
+            }
+        };
+
+        // wake outside of lock scope
+        for (waker, _id) in to_wake {
             waker.wake();
         }
     }
 
-    /// Wake senders for a specific channel.
-    ///
-    /// This is helpful to signal that the receiver side is gone and the 
senders shall now error.
-    fn wake_channel_senders(&mut self, id: usize) {
-        // `drain_filter` is unstable, so implement our own
-        let (wake, keep) = self
-            .send_wakers
-            .drain(..)
-            .partition(|(_waker, id2)| id == *id2);
-        self.send_wakers = keep;
-        for (waker, _id) in wake {
-            waker.wake();
+    fn decr_empty_channels(&self) {
+        let old_count = self.empty_channels.fetch_sub(1, Ordering::SeqCst);
+
+        if old_count == 1 {
+            let mut guard = self.send_wakers.lock();
+
+            // double-check state during lock
+            if self.empty_channels.load(Ordering::SeqCst) == 0 && 
guard.is_none() {
+                *guard = Some(Vec::new());
+            }
         }
     }
 }
 
 /// Gate shared by all senders and receivers.
-type SharedGate = Arc<Mutex<Gate>>;
+type SharedGate = Arc<Gate>;
 
 #[cfg(test)]
 mod tests {
@@ -596,6 +681,52 @@ mod tests {
         assert_eq!(counter.strong_count(), 0);
     }
 
+    /// Ensure that polling "pending" futures work even when you poll them too 
often (which happens under some circumstances).
+    #[test]
+    fn test_poll_empty_channel_twice() {
+        let (txs, mut rxs) = channels(1);
+
+        let mut recv_fut = rxs[0].recv();
+        let waker_1a = poll_pending(&mut recv_fut);
+        let waker_1b = poll_pending(&mut recv_fut);
+
+        let mut recv_fut = rxs[0].recv();
+        let waker_2 = poll_pending(&mut recv_fut);
+
+        poll_ready(&mut txs[0].send("a")).unwrap();
+        assert!(waker_1a.woken());
+        assert!(waker_1b.woken());
+        assert!(waker_2.woken());
+        assert_eq!(poll_ready(&mut recv_fut), Some("a"),);
+
+        poll_ready(&mut txs[0].send("b")).unwrap();
+        let mut send_fut = txs[0].send("c");
+        let waker_3 = poll_pending(&mut send_fut);
+        assert_eq!(poll_ready(&mut rxs[0].recv()), Some("b"),);
+        assert!(waker_3.woken());
+        poll_ready(&mut send_fut).unwrap();
+        assert_eq!(poll_ready(&mut rxs[0].recv()), Some("c"));
+
+        let mut recv_fut = rxs[0].recv();
+        let waker_4 = poll_pending(&mut recv_fut);
+
+        let mut recv_fut = rxs[0].recv();
+        let waker_5 = poll_pending(&mut recv_fut);
+
+        poll_ready(&mut txs[0].send("d")).unwrap();
+        let mut send_fut = txs[0].send("e");
+        let waker_6a = poll_pending(&mut send_fut);
+        let waker_6b = poll_pending(&mut send_fut);
+
+        assert!(waker_4.woken());
+        assert!(waker_5.woken());
+        assert_eq!(poll_ready(&mut recv_fut), Some("d"),);
+
+        assert!(waker_6a.woken());
+        assert!(waker_6b.woken());
+        poll_ready(&mut send_fut).unwrap();
+    }
+
     #[test]
     #[should_panic(expected = "polled ready future")]
     fn test_panic_poll_send_future_after_ready_ok() {
@@ -655,6 +786,7 @@ mod tests {
         poll_pending(&mut fut);
     }
 
+    /// Test [`poll_pending`] (i.e. the testing utils, not the actual library 
code).
     #[test]
     fn test_meta_poll_pending_waker() {
         let (tx, mut rx) = futures::channel::oneshot::channel();


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to