This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new bfb46a6  Making scan limit of RecordBatchReaderExec handle row limit 
efficiently (#36)
bfb46a6 is described below

commit bfb46a6d4da9491bd85b897217fe9386ce68c549
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Mon Sep 8 15:42:24 2025 +0800

    Making scan limit of RecordBatchReaderExec handle row limit efficiently 
(#36)
    
    * Making limit of RecordBatchReaderExec handle row limit efficiently
    
    * Add test cases for scan limit in RecordBatchReaderExec
---
 rust/sedona/src/record_batch_reader_provider.rs | 226 +++++++++++++++++++++---
 1 file changed, 197 insertions(+), 29 deletions(-)

diff --git a/rust/sedona/src/record_batch_reader_provider.rs 
b/rust/sedona/src/record_batch_reader_provider.rs
index 1d90a8e..1832e93 100644
--- a/rust/sedona/src/record_batch_reader_provider.rs
+++ b/rust/sedona/src/record_batch_reader_provider.rs
@@ -14,7 +14,6 @@
 // KIND, either express or implied.  See the License for the
 // specific language governing permissions and limitations
 // under the License.
-use std::sync::RwLock;
 use std::{any::Any, fmt::Debug, sync::Arc};
 
 use arrow_array::RecordBatchReader;
@@ -33,6 +32,7 @@ use datafusion::{
     prelude::Expr,
 };
 use datafusion_common::DataFusionError;
+use parking_lot::Mutex;
 use sedona_common::sedona_internal_err;
 
 /// A [TableProvider] wrapping a [RecordBatchReader]
@@ -42,7 +42,7 @@ use sedona_common::sedona_internal_err;
 /// such that extension types are preserved in DataFusion internals (i.e.,
 /// it is intended for scanning external tables as SedonaDB).
 pub struct RecordBatchReaderProvider {
-    reader: RwLock<Option<Box<dyn RecordBatchReader + Send>>>,
+    reader: Mutex<Option<Box<dyn RecordBatchReader + Send>>>,
     schema: SchemaRef,
 }
 
@@ -52,7 +52,7 @@ impl RecordBatchReaderProvider {
     pub fn new(reader: Box<dyn RecordBatchReader + Send>) -> Self {
         let schema = reader.schema();
         Self {
-            reader: RwLock::new(Some(reader)),
+            reader: Mutex::new(Some(reader)),
             schema,
         }
     }
@@ -88,10 +88,8 @@ impl TableProvider for RecordBatchReaderProvider {
         _filters: &[Expr],
         limit: Option<usize>,
     ) -> Result<Arc<dyn ExecutionPlan>> {
-        let mut writable_reader = self.reader.try_write().map_err(|_| {
-            DataFusionError::Internal("Failed to acquire lock on 
RecordBatchReader".to_string())
-        })?;
-        if let Some(reader) = writable_reader.take() {
+        let mut reader_guard = self.reader.lock();
+        if let Some(reader) = reader_guard.take() {
             Ok(Arc::new(RecordBatchReaderExec::new(reader, limit)))
         } else {
             sedona_internal_err!("Can't scan RecordBatchReader provider more 
than once")
@@ -99,15 +97,69 @@ impl TableProvider for RecordBatchReaderProvider {
     }
 }
 
+/// An iterator that limits the number of rows from a RecordBatchReader
+struct RowLimitedIterator {
+    reader: Option<Box<dyn RecordBatchReader + Send>>,
+    limit: usize,
+    rows_consumed: usize,
+}
+
+impl RowLimitedIterator {
+    fn new(reader: Box<dyn RecordBatchReader + Send>, limit: usize) -> Self {
+        Self {
+            reader: Some(reader),
+            limit,
+            rows_consumed: 0,
+        }
+    }
+}
+
+impl Iterator for RowLimitedIterator {
+    type Item = Result<arrow_array::RecordBatch>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        // Check if we have already consumed enough rows
+        if self.rows_consumed >= self.limit {
+            self.reader = None;
+            return None;
+        }
+
+        let reader = self.reader.as_mut()?;
+        match reader.next() {
+            Some(Ok(batch)) => {
+                let batch_rows = batch.num_rows();
+
+                if self.rows_consumed + batch_rows <= self.limit {
+                    // Batch fits within limit, consume it entirely
+                    self.rows_consumed += batch_rows;
+                    Some(Ok(batch))
+                } else {
+                    // Batch would exceed limit, need to truncate it
+                    let rows_to_take = self.limit - self.rows_consumed;
+                    self.rows_consumed = self.limit;
+                    self.reader = None;
+                    Some(Ok(batch.slice(0, rows_to_take)))
+                }
+            }
+            Some(Err(e)) => {
+                self.reader = None;
+                Some(Err(DataFusionError::from(e)))
+            }
+            None => {
+                self.reader = None;
+                None
+            }
+        }
+    }
+}
+
 struct RecordBatchReaderExec {
-    reader: RwLock<Option<Box<dyn RecordBatchReader + Send>>>,
+    reader: Mutex<Option<Box<dyn RecordBatchReader + Send>>>,
     schema: SchemaRef,
     properties: PlanProperties,
     limit: Option<usize>,
 }
 
-unsafe impl Sync for RecordBatchReaderExec {}
-
 impl RecordBatchReaderExec {
     fn new(reader: Box<dyn RecordBatchReader + Send>, limit: Option<usize>) -> 
Self {
         let schema = reader.schema();
@@ -119,7 +171,7 @@ impl RecordBatchReaderExec {
         );
 
         Self {
-            reader: RwLock::new(Some(reader)),
+            reader: Mutex::new(Some(reader)),
             schema,
             properties,
             limit,
@@ -177,29 +229,35 @@ impl ExecutionPlan for RecordBatchReaderExec {
         _partition: usize,
         _context: Arc<TaskContext>,
     ) -> Result<SendableRecordBatchStream> {
-        let mut writable_reader = self.reader.try_write().map_err(|_| {
-            DataFusionError::Internal("Failed to acquire lock on 
RecordBatchReader".to_string())
-        })?;
+        let mut reader_guard = self.reader.lock();
 
-        let reader = if let Some(reader) = writable_reader.take() {
+        let reader = if let Some(reader) = reader_guard.take() {
             reader
         } else {
             return sedona_internal_err!("Can't scan RecordBatchReader provider 
more than once");
         };
 
-        let limit = self.limit;
-
-        // Create a stream from the RecordBatchReader iterator
-        let iter = reader
-            .map(|item| match item {
-                Ok(batch) => Ok(batch),
-                Err(e) => Err(DataFusionError::from(e)),
-            })
-            .take(limit.unwrap_or(usize::MAX));
-
-        let stream = Box::pin(futures::stream::iter(iter));
-        let record_batch_stream = 
RecordBatchStreamAdapter::new(self.schema.clone(), stream);
-        Ok(Box::pin(record_batch_stream))
+        match self.limit {
+            Some(limit) => {
+                // Create a row-limited iterator that properly handles row 
counting
+                let iter = RowLimitedIterator::new(reader, limit);
+                let stream = Box::pin(futures::stream::iter(iter));
+                let record_batch_stream =
+                    RecordBatchStreamAdapter::new(self.schema.clone(), stream);
+                Ok(Box::pin(record_batch_stream))
+            }
+            None => {
+                // No limit, just convert the reader directly to a stream
+                let iter = reader.map(|item| match item {
+                    Ok(batch) => Ok(batch),
+                    Err(e) => Err(DataFusionError::from(e)),
+                });
+                let stream = Box::pin(futures::stream::iter(iter));
+                let record_batch_stream =
+                    RecordBatchStreamAdapter::new(self.schema.clone(), stream);
+                Ok(Box::pin(record_batch_stream))
+            }
+        }
     }
 }
 
@@ -208,12 +266,40 @@ mod test {
 
     use arrow_array::{RecordBatch, RecordBatchIterator};
     use arrow_schema::{DataType, Field, Schema};
-    use datafusion::prelude::SessionContext;
+    use datafusion::prelude::{DataFrame, SessionContext};
+    use rstest::rstest;
     use sedona_schema::datatypes::WKB_GEOMETRY;
     use sedona_testing::create::create_array_storage;
 
     use super::*;
 
+    fn create_test_batch(size: usize, start_id: i32) -> RecordBatch {
+        let schema = Schema::new(vec![Field::new("id", DataType::Int32, 
false)]);
+        let ids: Vec<i32> = (start_id..start_id + size as i32).collect();
+        RecordBatch::try_new(
+            Arc::new(schema),
+            vec![Arc::new(arrow_array::Int32Array::from(ids))],
+        )
+        .unwrap()
+    }
+
+    fn create_test_reader(batch_sizes: Vec<usize>) -> Box<dyn 
RecordBatchReader + Send> {
+        let mut start_id = 0i32;
+        let batches: Vec<RecordBatch> = batch_sizes
+            .into_iter()
+            .map(|size| {
+                let batch = create_test_batch(size, start_id);
+                start_id += size as i32;
+                batch
+            })
+            .collect();
+        let schema = batches[0].schema();
+        Box::new(RecordBatchIterator::new(
+            batches.into_iter().map(Ok),
+            schema,
+        ))
+    }
+
     #[tokio::test]
     async fn provider() {
         let ctx = SessionContext::new();
@@ -244,4 +330,86 @@ mod test {
         let results = df.collect().await.unwrap();
         assert_eq!(results, vec![batch])
     }
+
+    #[rstest]
+    #[case(vec![10, 20, 30], None, 60)] // No limit
+    #[case(vec![10, 20, 30], Some(5), 5)] // Limit within first batch
+    #[case(vec![10, 20, 30], Some(10), 10)] // Limit exactly at first batch 
boundary
+    #[case(vec![10, 20, 30], Some(15), 15)] // Limit within second batch
+    #[case(vec![10, 20, 30], Some(30), 30)] // Limit at second batch boundary
+    #[case(vec![10, 20, 30], Some(45), 45)] // Limit within third batch
+    #[case(vec![10, 20, 30], Some(60), 60)] // Limit at total rows
+    #[case(vec![10, 20, 30], Some(100), 60)] // Limit exceeds total rows
+    #[case(vec![0, 5, 0, 3], Some(6), 6)] // Empty batches mixed in, limit 
within data
+    #[case(vec![0, 5, 0, 3], Some(8), 8)] // Empty batches mixed in, limit 
equals total
+    #[case(vec![0, 5, 0, 3], None, 8)] // Empty batches mixed in, no limit
+    #[tokio::test]
+    async fn test_scan_with_row_limit(
+        #[case] batch_sizes: Vec<usize>,
+        #[case] limit: Option<usize>,
+        #[case] expected_rows: usize,
+    ) {
+        let ctx = SessionContext::new();
+
+        // Verify that the RecordBatchReaderExec node in the execution plan 
should contain the correct limit
+        let physical_plan = read_test_table_with_limit(&ctx, 
batch_sizes.clone(), limit)
+            .unwrap()
+            .create_physical_plan()
+            .await
+            .unwrap();
+        let reader_exec = find_record_batch_reader_exec(physical_plan.as_ref())
+            .expect("The plan should contain RecordBatchReaderExec");
+        assert_eq!(reader_exec.limit, limit);
+
+        let df = read_test_table_with_limit(&ctx, batch_sizes, limit).unwrap();
+        let results = df.collect().await.unwrap();
+        let total_rows: usize = results.iter().map(|batch| 
batch.num_rows()).sum();
+        assert_eq!(total_rows, expected_rows);
+
+        // Verify row values are correct (sequential IDs starting from 0)
+        if expected_rows > 0 {
+            let mut expected_id = 0i32;
+            for batch in results.iter() {
+                let id_array = batch
+                    .column(0)
+                    .as_any()
+                    .downcast_ref::<arrow_array::Int32Array>()
+                    .unwrap();
+                for i in 0..id_array.len() {
+                    assert_eq!(id_array.value(i), expected_id);
+                    expected_id += 1;
+                }
+            }
+        }
+    }
+
+    fn read_test_table_with_limit(
+        ctx: &SessionContext,
+        batch_sizes: Vec<usize>,
+        limit: Option<usize>,
+    ) -> Result<DataFrame> {
+        let reader = create_test_reader(batch_sizes);
+        let provider = Arc::new(RecordBatchReaderProvider::new(reader));
+        let df = ctx.read_table(provider)?;
+        if let Some(limit) = limit {
+            df.limit(0, Some(limit))
+        } else {
+            Ok(df)
+        }
+    }
+
+    // Navigate through the plan structure to find our RecordBatchReaderExec
+    fn find_record_batch_reader_exec(plan: &dyn ExecutionPlan) -> 
Option<&RecordBatchReaderExec> {
+        if let Some(reader_exec) = 
plan.as_any().downcast_ref::<RecordBatchReaderExec>() {
+            return Some(reader_exec);
+        }
+
+        // Recursively search children
+        for child in plan.children() {
+            if let Some(reader_exec) = 
find_record_batch_reader_exec(child.as_ref()) {
+                return Some(reader_exec);
+            }
+        }
+        None
+    }
 }

Reply via email to