This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new f4468ea9a Remove panics from `common_subexpr_eliminate` (#3346)
f4468ea9a is described below

commit f4468ea9aa7e8af285cd9d02da4131f9e5a94c9f
Author: Andy Grove <[email protected]>
AuthorDate: Fri Sep 2 17:14:03 2022 -0600

    Remove panics from `common_subexpr_eliminate` (#3346)
    
    * add helper function for pop_expr
    
    * remove unwraps
---
 .../optimizer/src/common_subexpr_eliminate.rs      | 77 +++++++++++++++-------
 1 file changed, 53 insertions(+), 24 deletions(-)

diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs 
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 1bbd09007..fcefd8fa7 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -19,7 +19,7 @@
 
 use crate::{OptimizerConfig, OptimizerRule};
 use arrow::datatypes::DataType;
-use datafusion_common::{DFField, DFSchema, Result};
+use datafusion_common::{DFField, DFSchema, DataFusionError, Result};
 use datafusion_expr::{
     col,
     expr::GroupingSet,
@@ -107,7 +107,7 @@ fn optimize(
             )?;
 
             Ok(LogicalPlan::Projection(Projection::try_new_with_schema(
-                new_expr.pop().unwrap(),
+                pop_expr(&mut new_expr)?,
                 Arc::new(new_input),
                 schema.clone(),
                 alias.clone(),
@@ -139,10 +139,16 @@ fn optimize(
                 optimizer_config,
             )?;
 
-            Ok(LogicalPlan::Filter(Filter {
-                predicate: new_expr.pop().unwrap().pop().unwrap(),
-                input: Arc::new(new_input),
-            }))
+            if let Some(predicate) = pop_expr(&mut new_expr)?.pop() {
+                Ok(LogicalPlan::Filter(Filter {
+                    predicate,
+                    input: Arc::new(new_input),
+                }))
+            } else {
+                Err(DataFusionError::Internal(
+                    "Failed to pop predicate expr".to_string(),
+                ))
+            }
         }
         LogicalPlan::Window(Window {
             input,
@@ -161,7 +167,7 @@ fn optimize(
 
             Ok(LogicalPlan::Window(Window {
                 input: Arc::new(new_input),
-                window_expr: new_expr.pop().unwrap(),
+                window_expr: pop_expr(&mut new_expr)?,
                 schema: schema.clone(),
             }))
         }
@@ -182,8 +188,8 @@ fn optimize(
                 optimizer_config,
             )?;
             // note the reversed pop order.
-            let new_aggr_expr = new_expr.pop().unwrap();
-            let new_group_expr = new_expr.pop().unwrap();
+            let new_aggr_expr = pop_expr(&mut new_expr)?;
+            let new_group_expr = pop_expr(&mut new_expr)?;
 
             Ok(LogicalPlan::Aggregate(Aggregate::try_new(
                 Arc::new(new_input),
@@ -204,7 +210,7 @@ fn optimize(
             )?;
 
             Ok(LogicalPlan::Sort(Sort {
-                expr: new_expr.pop().unwrap(),
+                expr: pop_expr(&mut new_expr)?,
                 input: Arc::new(new_input),
             }))
         }
@@ -241,6 +247,12 @@ fn optimize(
     }
 }
 
+fn pop_expr(new_expr: &mut Vec<Vec<Expr>>) -> Result<Vec<Expr>> {
+    new_expr
+        .pop()
+        .ok_or_else(|| DataFusionError::Internal("Failed to pop 
expression".to_string()))
+}
+
 fn to_arrays(
     expr: &[Expr],
     input: &LogicalPlan,
@@ -268,12 +280,20 @@ fn build_project_plan(
     let mut fields_set = HashSet::new();
 
     for id in affected_id {
-        let (expr, _, data_type) = expr_set.get(&id).unwrap();
-        // todo: check `nullable`
-        let field = DFField::new(None, &id, data_type.clone(), true);
-        fields_set.insert(field.name().to_owned());
-        fields.push(field);
-        project_exprs.push(expr.clone().alias(&id));
+        match expr_set.get(&id) {
+            Some((expr, _, data_type)) => {
+                // todo: check `nullable`
+                let field = DFField::new(None, &id, data_type.clone(), true);
+                fields_set.insert(field.name().to_owned());
+                fields.push(field);
+                project_exprs.push(expr.clone().alias(&id));
+            }
+            _ => {
+                return Err(DataFusionError::Internal(
+                    "expr_set invalid state".to_string(),
+                ))
+            }
+        }
     }
 
     for field in input.schema().fields() {
@@ -651,13 +671,19 @@ impl ExprRewriter for CommonSubexprRewriter<'_> {
             self.curr_index += 1;
             return Ok(RewriteRecursion::Skip);
         }
-        let (_, counter, _) = self.expr_set.get(curr_id).unwrap();
-        if *counter > 1 {
-            self.affected_id.insert(curr_id.clone());
-            Ok(RewriteRecursion::Mutate)
-        } else {
-            self.curr_index += 1;
-            Ok(RewriteRecursion::Skip)
+        match self.expr_set.get(curr_id) {
+            Some((_, counter, _)) => {
+                if *counter > 1 {
+                    self.affected_id.insert(curr_id.clone());
+                    Ok(RewriteRecursion::Mutate)
+                } else {
+                    self.curr_index += 1;
+                    Ok(RewriteRecursion::Skip)
+                }
+            }
+            _ => Err(DataFusionError::Internal(
+                "expr_set invalid state".to_string(),
+            )),
         }
     }
 
@@ -670,9 +696,12 @@ impl ExprRewriter for CommonSubexprRewriter<'_> {
         let (series_number, id) = &self.id_array[self.curr_index];
         self.curr_index += 1;
         // Skip sub-node of a replaced tree, or without identifier, or is not 
repeated expr.
+        let expr_set_item = self.expr_set.get(id).ok_or_else(|| {
+            DataFusionError::Internal("expr_set invalid state".to_string())
+        })?;
         if *series_number < self.max_series_number
             || id.is_empty()
-            || self.expr_set.get(id).unwrap().1 <= 1
+            || expr_set_item.1 <= 1
         {
             return Ok(expr);
         }

Reply via email to