pepijnve commented on code in PR #19757:
URL: https://github.com/apache/datafusion/pull/19757#discussion_r2683365301


##########
datafusion/physical-optimizer/src/ensure_coop.rs:
##########
@@ -67,23 +67,44 @@ impl PhysicalOptimizerRule for EnsureCooperative {
         plan: Arc<dyn ExecutionPlan>,
         _config: &ConfigOptions,
     ) -> Result<Arc<dyn ExecutionPlan>> {
-        plan.transform_up(|plan| {
-            let is_leaf = plan.children().is_empty();
-            let is_exchange = plan.properties().evaluation_type == 
EvaluationType::Eager;
-            if (is_leaf || is_exchange)
-                && plan.properties().scheduling_type != 
SchedulingType::Cooperative
-            {
-                // Wrap non-cooperative leaves or eager evaluation roots in a 
cooperative exec to
-                // ensure the plans they participate in are properly 
cooperative.
-                Ok(Transformed::new(
-                    Arc::new(CooperativeExec::new(Arc::clone(&plan))),
-                    true,
-                    TreeNodeRecursion::Continue,
-                ))
-            } else {
+        use std::cell::Cell;
+
+        // Track depth: 0 means not under any CooperativeExec
+        // Using Cell to allow interior mutability from multiple closures
+        let coop_depth = Cell::new(0usize);
+
+        plan.transform_down_up(
+            // Down phase: Track when entering CooperativeExec subtrees
+            |plan| {
+                if plan.as_any().downcast_ref::<CooperativeExec>().is_some() {
+                    coop_depth.set(coop_depth.get() + 1);
+                }
                 Ok(Transformed::no(plan))
-            }
-        })
+            },
+            // Up phase: Wrap nodes with CooperativeExec if needed, then 
restore depth
+            |plan| {
+                let is_cooperative =
+                    plan.properties().scheduling_type == 
SchedulingType::Cooperative;
+                let is_leaf = plan.children().is_empty();
+                let is_exchange =
+                    plan.properties().evaluation_type == EvaluationType::Eager;
+
+                // Wrap if:
+                // 1. Node is a leaf or exchange point
+                // 2. Node is not already cooperative
+                // 3. Not under any CooperativeExec (depth == 0)
+                if (is_leaf || is_exchange) && !is_cooperative && 
coop_depth.get() == 0 {

Review Comment:
   Here's some very contrived test code (in details section below) that 
illustrates this. The code will output
   
   ```
   aggr Lazy NonCooperative
     filter Lazy NonCooperative
       exch Eager Cooperative
         filter Lazy NonCooperative
           CooperativeExec
             exch Eager NonCooperative
               filter Lazy NonCooperative
                 scan Lazy NonCooperative
   ```
   
   Notice that there's a `coop` missing around the final scan.
   
   The code used to produce this (with the incorrect double coop). The double 
coop is not intentional, but the two layers of coop are.
   
   ```
   aggr Lazy NonCooperative
     filter Lazy NonCooperative
       exch Eager Cooperative
         filter Lazy NonCooperative
           CooperativeExec
             CooperativeExec
               exch Eager NonCooperative
                 filter Lazy NonCooperative
                   CooperativeExec
                     scan Lazy NonCooperative
   ```
   
   <details>
   
   ```
   #[tokio::test]
   async fn test_exchange() {
       let scan = Arc::new(DummyExec::new("scan".to_string(), None, 
SchedulingType::NonCooperative, EvaluationType::Lazy));
       let filter = Arc::new(DummyExec::new("filter".to_string(), Some(scan), 
SchedulingType::NonCooperative, EvaluationType::Lazy));
       let exchange = Arc::new(DummyExec::new("exch".to_string(), Some(filter), 
SchedulingType::NonCooperative, EvaluationType::Eager));
       let coop = Arc::new(CooperativeExec::new(exchange));
       let filter = Arc::new(DummyExec::new("filter".to_string(), Some(coop), 
SchedulingType::NonCooperative, EvaluationType::Lazy));
       let exchange = Arc::new(DummyExec::new("exch".to_string(), Some(filter), 
SchedulingType::Cooperative, EvaluationType::Eager));
       let filter = Arc::new(DummyExec::new("filter".to_string(), 
Some(exchange), SchedulingType::NonCooperative, EvaluationType::Lazy));
       let aggregate = Arc::new(DummyExec::new("aggr".to_string(), 
Some(filter), SchedulingType::NonCooperative, EvaluationType::Lazy));
   
       let config = ConfigOptions::new();
       let optimized = EnsureCooperative::new()
           .optimize(aggregate as Arc<dyn ExecutionPlan>, &config)
           .unwrap();
   
       let display = displayable(optimized.as_ref()).indent(true).to_string();
   
       println!("{}", display);
   }
   
   #[derive(Debug)]
   struct DummyExec {
       name: String,
       input: Option<Arc<dyn ExecutionPlan>>,
       scheduling_type: SchedulingType,
       evaluation_type: EvaluationType,
       properties: PlanProperties,
   }
   
   impl DummyExec {
       fn new(
           name: String,
           input: Option<Arc<dyn ExecutionPlan>>,
           scheduling_type: SchedulingType,
           evaluation_type: EvaluationType,
       ) -> Self {
           DummyExec {
               name,
               input,
               scheduling_type,
               evaluation_type,
               properties: PlanProperties::new(
                   EquivalenceProperties::new(Arc::new(Schema::empty())),
                   Partitioning::UnknownPartitioning(1),
                   EmissionType::Incremental,
                   Boundedness::Bounded
               
).with_scheduling_type(scheduling_type).with_evaluation_type(evaluation_type),
           }
       }
   }
   
   impl DisplayAs for DummyExec {
       fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> 
std::fmt::Result {
           write!(f, "{} {:?} {:?}", self.name, self.evaluation_type, 
self.scheduling_type)
       }
   }
   
   impl ExecutionPlan for DummyExec {
       fn name(&self) -> &str {
           self.name.as_str()
       }
   
       fn as_any(&self) -> &dyn Any {
           self
       }
   
       fn properties(&self) -> &PlanProperties {
           &self.properties
       }
   
       fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
           match &self.input {
               None => vec![],
               Some(i) => vec![i],
           }
       }
   
       fn with_new_children(
           self: Arc<Self>,
           children: Vec<Arc<dyn ExecutionPlan>>,
       ) -> Result<Arc<dyn ExecutionPlan>> {
           Ok(Arc::new(DummyExec::new(
               self.name.clone(),
               match children.len() {
                   0 => None,
                   _ => Some(children[0].clone()),
               },
               self.scheduling_type,
               self.evaluation_type,
           )))
       }
   
       fn execute(
           &self,
           _partition: usize,
           _context: Arc<TaskContext>,
       ) -> Result<SendableRecordBatchStream> {
           todo!()
       }
   }
   ```
   
   </details>



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to