wiedld commented on code in PR #7379:
URL: https://github.com/apache/arrow-datafusion/pull/7379#discussion_r1312424734


##########
datafusion/core/src/physical_plan/sorts/stream.rs:
##########
@@ -207,3 +233,237 @@ impl<T: FieldArray> PartitionedStream for 
FieldCursorStream<T> {
         }))
     }
 }
+
+/// A wrapper around [`CursorStream<C>`] that implements [`PartitionedStream`]
+/// and provides polling of a subset of the streams.
+pub struct OffsetCursorStream<C: Cursor> {
+    streams: Arc<Mutex<BatchTrackingStream<C>>>,
+    offset: usize,
+    limit: usize,
+}
+
+impl<C: Cursor> OffsetCursorStream<C> {
+    pub fn new(
+        streams: Arc<Mutex<BatchTrackingStream<C>>>,
+        offset: usize,
+        limit: usize,
+    ) -> Self {
+        Self {
+            streams,
+            offset,
+            limit,
+        }
+    }
+}
+
+impl<C: Cursor> PartitionedStream for OffsetCursorStream<C> {
+    type Output = Result<(C, Uuid, BatchOffset)>;
+
+    fn partitions(&self) -> usize {
+        self.limit - self.offset
+    }
+
+    fn poll_next(
+        &mut self,
+        cx: &mut Context<'_>,
+        stream_idx: usize,
+    ) -> Poll<Option<Self::Output>> {
+        let stream_abs_idx = stream_idx + self.offset;
+        if stream_abs_idx >= self.limit {
+            return Poll::Ready(Some(Err(DataFusionError::Internal(format!(
+                "Invalid stream index {} for offset {} and limit {}",
+                stream_idx, self.offset, self.limit
+            )))));
+        }
+        Poll::Ready(ready!(self.streams.lock().poll_next(cx, stream_abs_idx)))
+    }
+}
+
+impl<C: Cursor> std::fmt::Debug for OffsetCursorStream<C> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("OffsetCursorStream").finish()
+    }
+}
+
+pub struct BatchTrackingStream<C: Cursor> {
+    /// Write once, read many [`RecordBatch`]s
+    batches: HashMap<Uuid, Arc<RecordBatch>, RandomState>,
+    /// Input streams yielding [`Cursor`]s and [`RecordBatch`]es
+    streams: BatchCursorStream<C>,
+    /// Accounts for memory used by buffered batches
+    reservation: MemoryReservation,
+}
+
+impl<C: Cursor> BatchTrackingStream<C> {
+    pub fn new(streams: BatchCursorStream<C>, reservation: MemoryReservation) 
-> Self {
+        Self {
+            batches: HashMap::with_hasher(RandomState::new()),
+            streams,
+            reservation,
+        }
+    }
+
+    pub fn get_batches(&self, batch_ids: &[Uuid]) -> Vec<Arc<RecordBatch>> {
+        batch_ids.iter().map(|id| self.batches[id].clone()).collect()
+    }
+
+    pub fn remove_batches(&mut self, batch_ids: &[Uuid]) {
+        for id in batch_ids {
+            self.batches.remove(id);
+        }
+    }
+
+    fn poll_next(
+        &mut self,
+        cx: &mut Context<'_>,
+        stream_idx: usize,
+    ) -> Poll<Option<Result<(C, Uuid, BatchOffset)>>> {
+        Poll::Ready(ready!(self.streams.poll_next(cx, stream_idx)).map(|r| {
+            r.and_then(|(cursor, batch)| {
+                self.reservation.try_grow(batch.get_array_memory_size())?;
+                let batch_id = Uuid::new_v4();
+                self.batches.insert(batch_id, Arc::new(batch));
+                Ok((cursor, batch_id, BatchOffset(0_usize)))
+            })
+        }))
+    }
+}
+
+/// A newtype wrapper around a set of fused [`MergeStream`]
+/// that implements debug, and skips over empty inner poll results
+struct FusedMergeStreams<C>(Vec<Fuse<MergeStream<C>>>);
+
+impl<C: Cursor> FusedMergeStreams<C> {
+    fn poll_next(
+        &mut self,
+        cx: &mut Context<'_>,
+        stream_idx: usize,
+    ) -> Poll<Option<Result<(Vec<(C, Uuid, BatchOffset)>, Vec<SortOrder>)>>>
+    {
+        loop {
+            match ready!(self.0[stream_idx].poll_next_unpin(cx)) {
+                Some(Ok((_, sort_order))) if sort_order.len() == 0 => continue,
+                r => return Poll::Ready(r),
+            }
+        }
+    }
+}
+
+pub struct YieldedCursorStream<C: Cursor> {
+    // inner polled batch cursors, per stream_idx, which are partially yielded
+    cursors: Vec<Option<VecDeque<(C, Uuid, BatchOffset)>>>,
+    /// Streams being polled
+    streams: FusedMergeStreams<C>,
+}
+
+impl<C: Cursor + std::marker::Send> YieldedCursorStream<C> {
+    pub fn new(streams: Vec<MergeStream<C>>) -> Self {
+        let stream_cnt = streams.len();
+        Self {
+            cursors: (0..stream_cnt).map(|_| None).collect(),
+            streams: FusedMergeStreams(streams.into_iter().map(|s| 
s.fuse()).collect()),
+        }
+    }
+
+    fn incr_next_batch(
+        &mut self,
+        stream_idx: usize,
+    ) -> Option<(C, Uuid, BatchOffset)> {
+        self.cursors[stream_idx]
+            .as_mut()
+            .map(|queue| queue.pop_front())
+            .flatten()
+    }
+
+    // TODO: in order to handle sort_order, we need to either:
+    // parse further
+    // or concat the cursors
+    fn try_parse_batches(
+        &mut self,
+        stream_idx: usize,
+        cursors: Vec<(C, Uuid, BatchOffset)>,
+        sort_order: Vec<SortOrder>,
+    ) -> Result<()> {
+        let mut cursors_per_batch: HashMap<(Uuid, BatchOffset), C, 
RandomState> =
+            HashMap::with_capacity_and_hasher(cursors.len(), 
RandomState::new());

Review Comment:
   It did. Down to 3.8Gc on the same test case. TY!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to