This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ray.git


The following commit(s) were added to refs/heads/main by this push:
     new 2783613  chore: Make query stage / shuffle code easier to understand 
(#54)
2783613 is described below

commit 27836132c97b3324b4ed5969ac9fd08751fbc8af
Author: Andy Grove <[email protected]>
AuthorDate: Wed Dec 18 23:27:36 2024 -0700

    chore: Make query stage / shuffle code easier to understand (#54)
---
 datafusion_ray/context.py       |  7 +++---
 src/planner.rs                  |  2 +-
 src/query_stage.rs              | 42 +++++++++++++++------------------
 src/shuffle/codec.rs            |  2 +-
 src/shuffle/writer.rs           | 10 ++++----
 testdata/expected-plans/q1.txt  |  2 +-
 testdata/expected-plans/q10.txt |  2 +-
 testdata/expected-plans/q11.txt |  2 +-
 testdata/expected-plans/q12.txt |  2 +-
 testdata/expected-plans/q13.txt |  2 +-
 testdata/expected-plans/q16.txt |  2 +-
 testdata/expected-plans/q18.txt |  2 +-
 testdata/expected-plans/q2.txt  |  2 +-
 testdata/expected-plans/q20.txt |  2 +-
 testdata/expected-plans/q21.txt |  2 +-
 testdata/expected-plans/q22.txt |  2 +-
 testdata/expected-plans/q3.txt  |  2 +-
 testdata/expected-plans/q4.txt  |  2 +-
 testdata/expected-plans/q5.txt  |  2 +-
 testdata/expected-plans/q7.txt  |  2 +-
 testdata/expected-plans/q8.txt  |  2 +-
 testdata/expected-plans/q9.txt  |  2 +-
 tests/test_context.py           | 52 ++++++++++++++++++++---------------------
 23 files changed, 72 insertions(+), 77 deletions(-)

diff --git a/datafusion_ray/context.py b/datafusion_ray/context.py
index 0070220..8d354ff 100644
--- a/datafusion_ray/context.py
+++ b/datafusion_ray/context.py
@@ -50,7 +50,7 @@ def execute_query_stage(
 
     # if the query stage has a single output partition then we need to execute 
for the output
     # partition, otherwise we need to execute in parallel for each input 
partition
-    concurrency = stage.get_input_partition_count()
+    concurrency = stage.get_execution_partition_count()
     output_partitions_count = stage.get_output_partition_count()
     if output_partitions_count == 1:
         # reduce stage
@@ -159,5 +159,6 @@ class DatafusionRayContext:
         )
         _, partitions = ray.get(future)
         # assert len(partitions) == 1, len(partitions)
-        result_set = ray.get(partitions[0])
-        return result_set
+        record_batches = ray.get(partitions[0])
+        # filter out empty batches
+        return [batch for batch in record_batches if batch.num_rows > 0]
diff --git a/src/planner.rs b/src/planner.rs
index 954d8e2..c1e7b41 100644
--- a/src/planner.rs
+++ b/src/planner.rs
@@ -399,7 +399,7 @@ mod test {
             let query_stage = graph.query_stages.get(&id).unwrap();
             output.push_str(&format!(
                 "Query Stage #{id} ({} -> {}):\n{}\n",
-                query_stage.get_input_partition_count(),
+                query_stage.get_execution_partition_count(),
                 query_stage.get_output_partition_count(),
                 displayable(query_stage.plan.as_ref()).indent(false)
             ));
diff --git a/src/query_stage.rs b/src/query_stage.rs
index 05c090b..a5c9a08 100644
--- a/src/query_stage.rs
+++ b/src/query_stage.rs
@@ -16,7 +16,7 @@
 // under the License.
 
 use crate::context::serialize_execution_plan;
-use crate::shuffle::{ShuffleCodec, ShuffleReaderExec};
+use crate::shuffle::{ShuffleCodec, ShuffleReaderExec, ShuffleWriterExec};
 use datafusion::error::Result;
 use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, 
Partitioning};
 use datafusion::prelude::SessionContext;
@@ -60,8 +60,8 @@ impl PyQueryStage {
         self.stage.get_child_stage_ids()
     }
 
-    pub fn get_input_partition_count(&self) -> usize {
-        self.stage.get_input_partition_count()
+    pub fn get_execution_partition_count(&self) -> usize {
+        self.stage.get_execution_partition_count()
     }
 
     pub fn get_output_partition_count(&self) -> usize {
@@ -75,16 +75,6 @@ pub struct QueryStage {
     pub plan: Arc<dyn ExecutionPlan>,
 }
 
-fn _get_output_partition_count(plan: &dyn ExecutionPlan) -> usize {
-    // UnknownPartitioning and HashPartitioning with empty expressions will
-    // both return 1 partition.
-    match plan.properties().output_partitioning() {
-        Partitioning::UnknownPartitioning(_) => 1,
-        Partitioning::Hash(expr, _) if expr.is_empty() => 1,
-        p => p.partition_count(),
-    }
-}
-
 impl QueryStage {
     pub fn new(id: usize, plan: Arc<dyn ExecutionPlan>) -> Self {
         Self { id, plan }
@@ -96,21 +86,27 @@ impl QueryStage {
         ids
     }
 
-    /// Get the input partition count. This is the same as the number of 
concurrent tasks
-    /// when we schedule this query stage for execution
-    pub fn get_input_partition_count(&self) -> usize {
-        if self.plan.children().is_empty() {
-            // leaf node (file scan)
-            self.plan.output_partitioning().partition_count()
+    /// Get the number of partitions that can be executed in parallel
+    pub fn get_execution_partition_count(&self) -> usize {
+        if let Some(shuffle) = 
self.plan.as_any().downcast_ref::<ShuffleWriterExec>() {
+            // use the partitioning of the input to the shuffle write because 
we are
+            // really executing that and then using the shuffle writer to 
repartition
+            // the output
+            shuffle.input_plan.output_partitioning().partition_count()
         } else {
-            self.plan.children()[0]
-                .output_partitioning()
-                .partition_count()
+            // for any other plan, use its output partitioning
+            self.plan.output_partitioning().partition_count()
         }
     }
 
     pub fn get_output_partition_count(&self) -> usize {
-        _get_output_partition_count(self.plan.as_ref())
+        // UnknownPartitioning and HashPartitioning with empty expressions will
+        // both return 1 partition.
+        match self.plan.properties().output_partitioning() {
+            Partitioning::UnknownPartitioning(_) => 1,
+            Partitioning::Hash(expr, _) if expr.is_empty() => 1,
+            p => p.partition_count(),
+        }
     }
 }
 
diff --git a/src/shuffle/codec.rs b/src/shuffle/codec.rs
index 79af0b8..0420428 100644
--- a/src/shuffle/codec.rs
+++ b/src/shuffle/codec.rs
@@ -102,7 +102,7 @@ impl PhysicalExtensionCodec for ShuffleCodec {
             };
             PlanType::ShuffleReader(reader)
         } else if let Some(writer) = 
node.as_any().downcast_ref::<ShuffleWriterExec>() {
-            let plan = 
PhysicalPlanNode::try_from_physical_plan(writer.plan.clone(), self)?;
+            let plan = 
PhysicalPlanNode::try_from_physical_plan(writer.input_plan.clone(), self)?;
             let partitioning =
                 
encode_partitioning_scheme(writer.properties().output_partitioning())?;
             let writer = ShuffleWriterExecNode {
diff --git a/src/shuffle/writer.rs b/src/shuffle/writer.rs
index 069f99d..0e0f984 100644
--- a/src/shuffle/writer.rs
+++ b/src/shuffle/writer.rs
@@ -47,7 +47,7 @@ use std::sync::Arc;
 #[derive(Debug)]
 pub struct ShuffleWriterExec {
     pub stage_id: usize,
-    pub(crate) plan: Arc<dyn ExecutionPlan>,
+    pub(crate) input_plan: Arc<dyn ExecutionPlan>,
     /// Output partitioning
     properties: PlanProperties,
     /// Directory to write shuffle files from
@@ -84,7 +84,7 @@ impl ShuffleWriterExec {
 
         Self {
             stage_id,
-            plan,
+            input_plan: plan,
             properties,
             shuffle_dir: shuffle_dir.to_string(),
             metrics: ExecutionPlanMetricsSet::new(),
@@ -98,11 +98,11 @@ impl ExecutionPlan for ShuffleWriterExec {
     }
 
     fn schema(&self) -> SchemaRef {
-        self.plan.schema()
+        self.input_plan.schema()
     }
 
     fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
-        vec![&self.plan]
+        vec![&self.input_plan]
     }
 
     fn with_new_children(
@@ -122,7 +122,7 @@ impl ExecutionPlan for ShuffleWriterExec {
             self.stage_id
         );
 
-        let mut stream = self.plan.execute(input_partition, context)?;
+        let mut stream = self.input_plan.execute(input_partition, context)?;
         let write_time =
             MetricBuilder::new(&self.metrics).subset_time("write_time", 
input_partition);
         let repart_time =
diff --git a/testdata/expected-plans/q1.txt b/testdata/expected-plans/q1.txt
index 282d5da..6f78394 100644
--- a/testdata/expected-plans/q1.txt
+++ b/testdata/expected-plans/q1.txt
@@ -42,7 +42,7 @@ ShuffleWriterExec(stage_id=1, 
output_partitioning=Hash([Column { name: "l_return
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=0, input_partitioning=Hash([Column { 
name: "l_returnflag", index: 0 }, Column { name: "l_linestatus", index: 1 }], 
2))
 
-Query Stage #2 (2 -> 1):
+Query Stage #2 (1 -> 1):
 SortPreservingMergeExec: [l_returnflag@0 ASC NULLS LAST, l_linestatus@1 ASC 
NULLS LAST]
   ShuffleReaderExec(stage_id=1, input_partitioning=Hash([Column { name: 
"l_returnflag", index: 0 }, Column { name: "l_linestatus", index: 1 }], 2))
 
diff --git a/testdata/expected-plans/q10.txt b/testdata/expected-plans/q10.txt
index 046f69e..3825561 100644
--- a/testdata/expected-plans/q10.txt
+++ b/testdata/expected-plans/q10.txt
@@ -117,7 +117,7 @@ ShuffleWriterExec(stage_id=7, 
output_partitioning=Hash([Column { name: "c_custke
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=6, input_partitioning=Hash([Column { 
name: "c_custkey", index: 0 }, Column { name: "c_name", index: 1 }, Column { 
name: "c_acctbal", index: 2 }, Column { name: "c_phone", index: 3 }, Column { 
name: "n_name", index: 4 }, Column { name: "c_address", index: 5 }, Column { 
name: "c_comment", index: 6 }], 2))
 
-Query Stage #8 (2 -> 1):
+Query Stage #8 (1 -> 1):
 SortPreservingMergeExec: [revenue@2 DESC], fetch=20
   ShuffleReaderExec(stage_id=7, input_partitioning=Hash([Column { name: 
"c_custkey", index: 0 }, Column { name: "c_name", index: 1 }, Column { name: 
"c_acctbal", index: 3 }, Column { name: "c_phone", index: 6 }, Column { name: 
"n_name", index: 4 }, Column { name: "c_address", index: 5 }, Column { name: 
"c_comment", index: 7 }], 2))
 
diff --git a/testdata/expected-plans/q11.txt b/testdata/expected-plans/q11.txt
index 74f74d7..2972d52 100644
--- a/testdata/expected-plans/q11.txt
+++ b/testdata/expected-plans/q11.txt
@@ -167,7 +167,7 @@ ShuffleWriterExec(stage_id=10, 
output_partitioning=Hash([Column { name: "ps_part
           CoalesceBatchesExec: target_batch_size=8192
             ShuffleReaderExec(stage_id=9, input_partitioning=Hash([Column { 
name: "ps_partkey", index: 0 }], 2))
 
-Query Stage #11 (2 -> 1):
+Query Stage #11 (1 -> 1):
 SortPreservingMergeExec: [value@1 DESC]
   ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column { name: 
"ps_partkey", index: 0 }], 2))
 
diff --git a/testdata/expected-plans/q12.txt b/testdata/expected-plans/q12.txt
index c7ae269..4cf0596 100644
--- a/testdata/expected-plans/q12.txt
+++ b/testdata/expected-plans/q12.txt
@@ -65,7 +65,7 @@ ShuffleWriterExec(stage_id=3, 
output_partitioning=Hash([Column { name: "l_shipmo
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=2, input_partitioning=Hash([Column { 
name: "l_shipmode", index: 0 }], 2))
 
-Query Stage #4 (2 -> 1):
+Query Stage #4 (1 -> 1):
 SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST]
   ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column { name: 
"l_shipmode", index: 0 }], 2))
 
diff --git a/testdata/expected-plans/q13.txt b/testdata/expected-plans/q13.txt
index 366db12..da7e93a 100644
--- a/testdata/expected-plans/q13.txt
+++ b/testdata/expected-plans/q13.txt
@@ -70,7 +70,7 @@ ShuffleWriterExec(stage_id=3, 
output_partitioning=Hash([Column { name: "c_count"
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=2, input_partitioning=Hash([Column { 
name: "c_count", index: 0 }], 2))
 
-Query Stage #4 (2 -> 1):
+Query Stage #4 (1 -> 1):
 SortPreservingMergeExec: [custdist@1 DESC, c_count@0 DESC]
   ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column { name: 
"c_count", index: 0 }], 2))
 
diff --git a/testdata/expected-plans/q16.txt b/testdata/expected-plans/q16.txt
index 24ecb18..b26e9a4 100644
--- a/testdata/expected-plans/q16.txt
+++ b/testdata/expected-plans/q16.txt
@@ -107,7 +107,7 @@ ShuffleWriterExec(stage_id=6, 
output_partitioning=Hash([Column { name: "p_brand"
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=5, input_partitioning=Hash([Column { 
name: "p_brand", index: 0 }, Column { name: "p_type", index: 1 }, Column { 
name: "p_size", index: 2 }], 2))
 
-Query Stage #7 (2 -> 1):
+Query Stage #7 (1 -> 1):
 SortPreservingMergeExec: [supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, 
p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST]
   ShuffleReaderExec(stage_id=6, input_partitioning=Hash([Column { name: 
"p_brand", index: 0 }, Column { name: "p_type", index: 1 }, Column { name: 
"p_size", index: 2 }], 2))
 
diff --git a/testdata/expected-plans/q18.txt b/testdata/expected-plans/q18.txt
index 30179d0..a5d28e8 100644
--- a/testdata/expected-plans/q18.txt
+++ b/testdata/expected-plans/q18.txt
@@ -104,7 +104,7 @@ ShuffleWriterExec(stage_id=6, 
output_partitioning=Hash([Column { name: "c_name",
       CoalesceBatchesExec: target_batch_size=8192
         ShuffleReaderExec(stage_id=5, input_partitioning=Hash([Column { name: 
"c_name", index: 0 }, Column { name: "c_custkey", index: 1 }, Column { name: 
"o_orderkey", index: 2 }, Column { name: "o_orderdate", index: 3 }, Column { 
name: "o_totalprice", index: 4 }], 2))
 
-Query Stage #7 (2 -> 1):
+Query Stage #7 (1 -> 1):
 SortPreservingMergeExec: [o_totalprice@4 DESC, o_orderdate@3 ASC NULLS LAST], 
fetch=100
   ShuffleReaderExec(stage_id=6, input_partitioning=Hash([Column { name: 
"c_name", index: 0 }, Column { name: "c_custkey", index: 1 }, Column { name: 
"o_orderkey", index: 2 }, Column { name: "o_orderdate", index: 3 }, Column { 
name: "o_totalprice", index: 4 }], 2))
 
diff --git a/testdata/expected-plans/q2.txt b/testdata/expected-plans/q2.txt
index bc0713c..9778441 100644
--- a/testdata/expected-plans/q2.txt
+++ b/testdata/expected-plans/q2.txt
@@ -252,7 +252,7 @@ ShuffleWriterExec(stage_id=17, 
output_partitioning=Hash([Column { name: "p_partk
           CoalesceBatchesExec: target_batch_size=8192
             ShuffleReaderExec(stage_id=16, input_partitioning=Hash([Column { 
name: "ps_partkey", index: 1 }, Column { name: "min(partsupp.ps_supplycost)", 
index: 0 }], 2))
 
-Query Stage #18 (2 -> 1):
+Query Stage #18 (1 -> 1):
 SortPreservingMergeExec: [s_acctbal@0 DESC, n_name@2 ASC NULLS LAST, s_name@1 
ASC NULLS LAST, p_partkey@3 ASC NULLS LAST], fetch=100
   ShuffleReaderExec(stage_id=17, input_partitioning=Hash([Column { name: 
"p_partkey", index: 3 }], 2))
 
diff --git a/testdata/expected-plans/q20.txt b/testdata/expected-plans/q20.txt
index 13b21c8..e1bc54c 100644
--- a/testdata/expected-plans/q20.txt
+++ b/testdata/expected-plans/q20.txt
@@ -142,7 +142,7 @@ ShuffleWriterExec(stage_id=8, output_partitioning=Hash([], 
2))
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=7, input_partitioning=Hash([Column { 
name: "ps_suppkey", index: 0 }], 2))
 
-Query Stage #9 (2 -> 1):
+Query Stage #9 (1 -> 1):
 SortPreservingMergeExec: [s_name@0 ASC NULLS LAST]
   ShuffleReaderExec(stage_id=8, input_partitioning=Hash([], 2))
 
diff --git a/testdata/expected-plans/q21.txt b/testdata/expected-plans/q21.txt
index b88bccc..8d6798f 100644
--- a/testdata/expected-plans/q21.txt
+++ b/testdata/expected-plans/q21.txt
@@ -172,7 +172,7 @@ ShuffleWriterExec(stage_id=10, 
output_partitioning=Hash([Column { name: "s_name"
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=9, input_partitioning=Hash([Column { 
name: "s_name", index: 0 }], 2))
 
-Query Stage #11 (2 -> 1):
+Query Stage #11 (1 -> 1):
 SortPreservingMergeExec: [numwait@1 DESC, s_name@0 ASC NULLS LAST], fetch=100
   ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column { name: 
"s_name", index: 0 }], 2))
 
diff --git a/testdata/expected-plans/q22.txt b/testdata/expected-plans/q22.txt
index da693fb..7ad4ae1 100644
--- a/testdata/expected-plans/q22.txt
+++ b/testdata/expected-plans/q22.txt
@@ -91,7 +91,7 @@ ShuffleWriterExec(stage_id=4, 
output_partitioning=Hash([Column { name: "cntrycod
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column { 
name: "cntrycode", index: 0 }], 2))
 
-Query Stage #5 (2 -> 1):
+Query Stage #5 (1 -> 1):
 SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST]
   ShuffleReaderExec(stage_id=4, input_partitioning=Hash([Column { name: 
"cntrycode", index: 0 }], 2))
 
diff --git a/testdata/expected-plans/q3.txt b/testdata/expected-plans/q3.txt
index f9039d3..3af2ea0 100644
--- a/testdata/expected-plans/q3.txt
+++ b/testdata/expected-plans/q3.txt
@@ -97,7 +97,7 @@ ShuffleWriterExec(stage_id=5, 
output_partitioning=Hash([Column { name: "l_orderk
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=4, input_partitioning=Hash([Column { 
name: "l_orderkey", index: 0 }, Column { name: "o_orderdate", index: 1 }, 
Column { name: "o_shippriority", index: 2 }], 2))
 
-Query Stage #6 (2 -> 1):
+Query Stage #6 (1 -> 1):
 SortPreservingMergeExec: [revenue@1 DESC, o_orderdate@2 ASC NULLS LAST], 
fetch=10
   ShuffleReaderExec(stage_id=5, input_partitioning=Hash([Column { name: 
"l_orderkey", index: 0 }, Column { name: "o_orderdate", index: 2 }, Column { 
name: "o_shippriority", index: 3 }], 2))
 
diff --git a/testdata/expected-plans/q4.txt b/testdata/expected-plans/q4.txt
index 20460e4..2504483 100644
--- a/testdata/expected-plans/q4.txt
+++ b/testdata/expected-plans/q4.txt
@@ -70,7 +70,7 @@ ShuffleWriterExec(stage_id=3, 
output_partitioning=Hash([Column { name: "o_orderp
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=2, input_partitioning=Hash([Column { 
name: "o_orderpriority", index: 0 }], 2))
 
-Query Stage #4 (2 -> 1):
+Query Stage #4 (1 -> 1):
 SortPreservingMergeExec: [o_orderpriority@0 ASC NULLS LAST]
   ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column { name: 
"o_orderpriority", index: 0 }], 2))
 
diff --git a/testdata/expected-plans/q5.txt b/testdata/expected-plans/q5.txt
index 2bacb27..3e66ddb 100644
--- a/testdata/expected-plans/q5.txt
+++ b/testdata/expected-plans/q5.txt
@@ -167,7 +167,7 @@ ShuffleWriterExec(stage_id=11, 
output_partitioning=Hash([Column { name: "n_name"
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column { 
name: "n_name", index: 0 }], 2))
 
-Query Stage #12 (2 -> 1):
+Query Stage #12 (1 -> 1):
 SortPreservingMergeExec: [revenue@1 DESC]
   ShuffleReaderExec(stage_id=11, input_partitioning=Hash([Column { name: 
"n_name", index: 0 }], 2))
 
diff --git a/testdata/expected-plans/q7.txt b/testdata/expected-plans/q7.txt
index 43bc031..9321b1b 100644
--- a/testdata/expected-plans/q7.txt
+++ b/testdata/expected-plans/q7.txt
@@ -176,7 +176,7 @@ ShuffleWriterExec(stage_id=11, 
output_partitioning=Hash([Column { name: "supp_na
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column { 
name: "supp_nation", index: 0 }, Column { name: "cust_nation", index: 1 }, 
Column { name: "l_year", index: 2 }], 2))
 
-Query Stage #12 (2 -> 1):
+Query Stage #12 (1 -> 1):
 SortPreservingMergeExec: [supp_nation@0 ASC NULLS LAST, cust_nation@1 ASC 
NULLS LAST, l_year@2 ASC NULLS LAST]
   ShuffleReaderExec(stage_id=11, input_partitioning=Hash([Column { name: 
"supp_nation", index: 0 }, Column { name: "cust_nation", index: 1 }, Column { 
name: "l_year", index: 2 }], 2))
 
diff --git a/testdata/expected-plans/q8.txt b/testdata/expected-plans/q8.txt
index e9f5b91..c7ec1ec 100644
--- a/testdata/expected-plans/q8.txt
+++ b/testdata/expected-plans/q8.txt
@@ -230,7 +230,7 @@ ShuffleWriterExec(stage_id=15, 
output_partitioning=Hash([Column { name: "o_year"
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=14, input_partitioning=Hash([Column { 
name: "o_year", index: 0 }], 2))
 
-Query Stage #16 (2 -> 1):
+Query Stage #16 (1 -> 1):
 SortPreservingMergeExec: [o_year@0 ASC NULLS LAST]
   ShuffleReaderExec(stage_id=15, input_partitioning=Hash([Column { name: 
"o_year", index: 0 }], 2))
 
diff --git a/testdata/expected-plans/q9.txt b/testdata/expected-plans/q9.txt
index 2c713b3..fa087f1 100644
--- a/testdata/expected-plans/q9.txt
+++ b/testdata/expected-plans/q9.txt
@@ -166,7 +166,7 @@ ShuffleWriterExec(stage_id=11, 
output_partitioning=Hash([Column { name: "nation"
         CoalesceBatchesExec: target_batch_size=8192
           ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column { 
name: "nation", index: 0 }, Column { name: "o_year", index: 1 }], 2))
 
-Query Stage #12 (2 -> 1):
+Query Stage #12 (1 -> 1):
 SortPreservingMergeExec: [nation@0 ASC NULLS LAST, o_year@1 DESC]
   ShuffleReaderExec(stage_id=11, input_partitioning=Hash([Column { name: 
"nation", index: 0 }, Column { name: "o_year", index: 1 }], 2))
 
diff --git a/tests/test_context.py b/tests/test_context.py
index ecc3324..602f761 100644
--- a/tests/test_context.py
+++ b/tests/test_context.py
@@ -17,42 +17,42 @@
 
 from datafusion_ray.context import DatafusionRayContext
 from datafusion import SessionContext, SessionConfig, RuntimeConfig, col, lit, 
functions as F
+import pytest
 
[email protected]
+def df_ctx():
+    """Fixture to create a DataFusion context."""
+    # used fixed partition count so that tests are deterministic on different 
environments
+    config = SessionConfig().with_target_partitions(4)
+    return SessionContext(config=config)
 
-def test_basic_query_succeed():
-    df_ctx = SessionContext()
-    ctx = DatafusionRayContext(df_ctx)
[email protected]
+def ctx(df_ctx):
+    """Fixture to create a Datafusion Ray context."""
+    return DatafusionRayContext(df_ctx)
+
+def test_basic_query_succeed(df_ctx, ctx):
     df_ctx.register_csv("tips", "examples/tips.csv", has_header=True)
-    # TODO why does this return a single batch and not a list of batches?
     record_batches = ctx.sql("SELECT * FROM tips")
-    assert record_batches[0].num_rows == 244
+    assert len(record_batches) <= 4
+    num_rows = sum(batch.num_rows for batch in record_batches)
+    assert num_rows == 244
 
-def test_aggregate_csv():
-    df_ctx = SessionContext()
-    ctx = DatafusionRayContext(df_ctx)
+def test_aggregate_csv(df_ctx, ctx):
     df_ctx.register_csv("tips", "examples/tips.csv", has_header=True)
     record_batches = ctx.sql("select sex, smoker, avg(tip/total_bill) as 
tip_pct from tips group by sex, smoker")
-    assert isinstance(record_batches, list)
-    # TODO why does this return many empty batches?
-    num_rows = 0
-    for record_batch in record_batches:
-        num_rows += record_batch.num_rows
+    assert len(record_batches) <= 4
+    num_rows = sum(batch.num_rows for batch in record_batches)
     assert num_rows == 4
 
-def test_aggregate_parquet():
-    df_ctx = SessionContext()
-    ctx = DatafusionRayContext(df_ctx)
+def test_aggregate_parquet(df_ctx, ctx):
     df_ctx.register_parquet("tips", "examples/tips.parquet")
     record_batches = ctx.sql("select sex, smoker, avg(tip/total_bill) as 
tip_pct from tips group by sex, smoker")
-    # TODO why does this return many empty batches?
-    num_rows = 0
-    for record_batch in record_batches:
-        num_rows += record_batch.num_rows
+    assert len(record_batches) <= 4
+    num_rows = sum(batch.num_rows for batch in record_batches)
     assert num_rows == 4
 
-def test_aggregate_parquet_dataframe():
-    df_ctx = SessionContext()
-    ray_ctx = DatafusionRayContext(df_ctx)
+def test_aggregate_parquet_dataframe(df_ctx, ctx):
     df = df_ctx.read_parquet(f"examples/tips.parquet")
     df = (
         df.aggregate(
@@ -62,12 +62,10 @@ def test_aggregate_parquet_dataframe():
         .filter(col("day") != lit("Dinner"))
         .aggregate([col("sex"), col("smoker")], 
[F.avg(col("tip_pct")).alias("avg_pct")])
     )
-    ray_results = ray_ctx.plan(df.execution_plan())
+    ray_results = ctx.plan(df.execution_plan())
     df_ctx.create_dataframe([ray_results]).show()
 
 
-def test_no_result_query():
-    df_ctx = SessionContext()
-    ctx = DatafusionRayContext(df_ctx)
+def test_no_result_query(df_ctx, ctx):
     df_ctx.register_csv("tips", "examples/tips.csv", has_header=True)
     ctx.sql("CREATE VIEW tips_view AS SELECT * FROM tips")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to