This is an automated email from the ASF dual-hosted git repository.
kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new f79d821 fix: Support projection pushdown for RecordBatchReader
provider (fixes #186) (#197)
f79d821 is described below
commit f79d8219fdfad0356cedd56689ff380d4c540ed8
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Fri Oct 10 23:54:15 2025 +0800
fix: Support projection pushdown for RecordBatchReader provider (fixes
#186) (#197)
This fixes #186.
The provider ignored the projection indices passed to
TableProvider::scan(), so the physical plan schema ([a,b]) did not match the
pushed-down logical projection ([b]). This PR implements projection pushdown
for RecordBatchReaderProvider and adds a regression test.
---
python/sedonadb/tests/test_dataframe.py | 18 +++++
rust/sedona/src/record_batch_reader_provider.rs | 97 ++++++++++++++++++++++---
2 files changed, 104 insertions(+), 11 deletions(-)
diff --git a/python/sedonadb/tests/test_dataframe.py
b/python/sedonadb/tests/test_dataframe.py
index b609635..60ac995 100644
--- a/python/sedonadb/tests/test_dataframe.py
+++ b/python/sedonadb/tests/test_dataframe.py
@@ -342,6 +342,24 @@ def test_dataframe_to_parquet(con):
)
+def test_record_batch_reader_projection(con):
+ def batches():
+ for _ in range(3):
+ yield pa.record_batch({"a": ["a", "b", "c"], "b": [1, 2, 3]})
+
+ reader = pa.RecordBatchReader.from_batches(next(batches()).schema,
batches())
+ df = con.create_data_frame(reader)
+ df.to_view("temp_rbr_proj", overwrite=True)
+ try:
+ # Query the view with projection (only select column b)
+ proj_df = con.sql("SELECT b FROM temp_rbr_proj")
+ tbl = proj_df.to_arrow_table()
+ assert tbl.column_names == ["b"]
+ assert tbl.to_pydict()["b"] == [1, 2, 3] * 3
+ finally:
+ con.drop_view("temp_rbr_proj")
+
+
def test_show(con, capsys):
con.sql("SELECT 1 as one").show()
expected = """
diff --git a/rust/sedona/src/record_batch_reader_provider.rs
b/rust/sedona/src/record_batch_reader_provider.rs
index 1832e93..e197f89 100644
--- a/rust/sedona/src/record_batch_reader_provider.rs
+++ b/rust/sedona/src/record_batch_reader_provider.rs
@@ -84,13 +84,16 @@ impl TableProvider for RecordBatchReaderProvider {
async fn scan(
&self,
_state: &dyn Session,
- _projection: Option<&Vec<usize>>,
+ projection: Option<&Vec<usize>>,
_filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut reader_guard = self.reader.lock();
if let Some(reader) = reader_guard.take() {
- Ok(Arc::new(RecordBatchReaderExec::new(reader, limit)))
+ let projection = projection.cloned();
+ Ok(Arc::new(RecordBatchReaderExec::try_new(
+ reader, limit, projection,
+ )?))
} else {
sedona_internal_err!("Can't scan RecordBatchReader provider more
than once")
}
@@ -158,11 +161,25 @@ struct RecordBatchReaderExec {
schema: SchemaRef,
properties: PlanProperties,
limit: Option<usize>,
+ projection: Option<Vec<usize>>,
}
impl RecordBatchReaderExec {
- fn new(reader: Box<dyn RecordBatchReader + Send>, limit: Option<usize>) ->
Self {
- let schema = reader.schema();
+ fn try_new(
+ reader: Box<dyn RecordBatchReader + Send>,
+ limit: Option<usize>,
+ projection: Option<Vec<usize>>,
+ ) -> Result<Self> {
+ let full_schema = reader.schema();
+ let schema: SchemaRef = if let Some(indices) = projection.as_ref() {
+ SchemaRef::new(
+ full_schema
+ .project(indices)
+ .map_err(DataFusionError::from)?,
+ )
+ } else {
+ full_schema.clone()
+ };
let properties = PlanProperties::new(
EquivalenceProperties::new(schema.clone()),
Partitioning::UnknownPartitioning(1),
@@ -170,12 +187,13 @@ impl RecordBatchReaderExec {
Boundedness::Bounded,
);
- Self {
+ Ok(Self {
reader: Mutex::new(Some(reader)),
schema,
properties,
limit,
- }
+ projection,
+ })
}
}
@@ -186,6 +204,7 @@ impl Debug for RecordBatchReaderExec {
.field("schema", &self.schema)
.field("properties", &self.properties)
.field("limit", &self.limit)
+ .field("projection", &self.projection)
.finish()
}
}
@@ -240,7 +259,17 @@ impl ExecutionPlan for RecordBatchReaderExec {
match self.limit {
Some(limit) => {
// Create a row-limited iterator that properly handles row
counting
- let iter = RowLimitedIterator::new(reader, limit);
+ let projection = self.projection.clone();
+ let iter = RowLimitedIterator::new(reader, limit).map(move
|res| match res {
+ Ok(batch) => {
+ if let Some(indices) = projection.as_ref() {
+ batch.project(indices).map_err(|e| e.into())
+ } else {
+ Ok(batch)
+ }
+ }
+ Err(e) => Err(e),
+ });
let stream = Box::pin(futures::stream::iter(iter));
let record_batch_stream =
RecordBatchStreamAdapter::new(self.schema.clone(), stream);
@@ -248,9 +277,16 @@ impl ExecutionPlan for RecordBatchReaderExec {
}
None => {
// No limit, just convert the reader directly to a stream
- let iter = reader.map(|item| match item {
- Ok(batch) => Ok(batch),
- Err(e) => Err(DataFusionError::from(e)),
+ let projection = self.projection.clone();
+ let iter = reader.map(move |item| match item {
+ Ok(batch) => {
+ if let Some(indices) = projection.as_ref() {
+ batch.project(indices).map_err(|e| e.into())
+ } else {
+ Ok(batch)
+ }
+ }
+ Err(e) => Err(e.into()),
});
let stream = Box::pin(futures::stream::iter(iter));
let record_batch_stream =
@@ -266,7 +302,7 @@ mod test {
use arrow_array::{RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
- use datafusion::prelude::{DataFrame, SessionContext};
+ use datafusion::prelude::{col, DataFrame, SessionContext};
use rstest::rstest;
use sedona_schema::datatypes::WKB_GEOMETRY;
use sedona_testing::create::create_array_storage;
@@ -383,6 +419,45 @@ mod test {
}
}
+ #[tokio::test]
+ async fn test_projection_pushdown() {
+ let ctx = SessionContext::new();
+
+ // Create a two-column batch
+ let schema = Schema::new(vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Int32, false),
+ ]);
+ let batch = RecordBatch::try_new(
+ Arc::new(schema.clone()),
+ vec![
+ Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3])),
+ Arc::new(arrow_array::Int32Array::from(vec![10, 20, 30])),
+ ],
+ )
+ .unwrap();
+
+ // Wrap in a RecordBatchReaderProvider
+ let reader =
+ RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok),
Arc::new(schema));
+ let provider =
Arc::new(RecordBatchReaderProvider::new(Box::new(reader)));
+
+ // Read table then select only column b (this should push projection
into scan)
+ let df = ctx.read_table(provider).unwrap();
+ let df_b = df.select(vec![col("b")]).unwrap();
+ let results = df_b.collect().await.unwrap();
+ assert_eq!(results.len(), 1);
+ let out_batch = &results[0];
+ assert_eq!(out_batch.num_columns(), 1);
+ assert_eq!(out_batch.schema().field(0).name(), "b");
+ let values = out_batch
+ .column(0)
+ .as_any()
+ .downcast_ref::<arrow_array::Int32Array>()
+ .unwrap();
+ assert_eq!(values.values(), &[10, 20, 30]);
+ }
+
fn read_test_table_with_limit(
ctx: &SessionContext,
batch_sizes: Vec<usize>,