This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 18193e6224 chore: Add SessionState to MockContextProvider just like 
SessionContextProvider (#11940)
18193e6224 is described below

commit 18193e6224603c92ce1ab16136ffcd926ca267b5
Author: Dharan Aditya <dharan.adi...@gmail.com>
AuthorDate: Mon Aug 12 19:56:47 2024 +0530

    chore: Add SessionState to MockContextProvider just like 
SessionContextProvider (#11940)
    
    * refac: mock context provide to match public api
    
    * lower udaf names
    
    * cleanup
    
    * typos
    
    Co-authored-by: Jay Zhan <jayzhan...@gmail.com>
    
    * more typos
    
    Co-authored-by: Jay Zhan <jayzhan...@gmail.com>
    
    * typos
    
    * refactor func name
    
    ---------
    
    Co-authored-by: Jay Zhan <jayzhan...@gmail.com>
---
 datafusion/sql/tests/cases/plan_to_sql.rs | 40 +++++++++++++++---------
 datafusion/sql/tests/common/mod.rs        | 52 +++++++++++++++++--------------
 datafusion/sql/tests/sql_integration.rs   | 52 ++++++++++++++++++-------------
 3 files changed, 83 insertions(+), 61 deletions(-)

diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs 
b/datafusion/sql/tests/cases/plan_to_sql.rs
index 179fc108e6..ed23fada0c 100644
--- a/datafusion/sql/tests/cases/plan_to_sql.rs
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -33,7 +33,7 @@ use datafusion_functions::core::planner::CoreFunctionPlanner;
 use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect};
 use sqlparser::parser::Parser;
 
-use crate::common::MockContextProvider;
+use crate::common::{MockContextProvider, MockSessionState};
 
 #[test]
 fn roundtrip_expr() {
@@ -59,8 +59,8 @@ fn roundtrip_expr() {
     let roundtrip = |table, sql: &str| -> Result<String> {
         let dialect = GenericDialect {};
         let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?;
-
-        let context = MockContextProvider::default().with_udaf(sum_udaf());
+        let state = 
MockSessionState::default().with_aggregate_function(sum_udaf());
+        let context = MockContextProvider { state };
         let schema = context.get_table_source(table)?.schema();
         let df_schema = DFSchema::try_from(schema.as_ref().clone())?;
         let sql_to_rel = SqlToRel::new(&context);
@@ -156,11 +156,11 @@ fn roundtrip_statement() -> Result<()> {
         let statement = Parser::new(&dialect)
             .try_with_sql(query)?
             .parse_statement()?;
-
-        let context = MockContextProvider::default()
-            .with_udaf(sum_udaf())
-            .with_udaf(count_udaf())
+        let state = MockSessionState::default()
+            .with_aggregate_function(sum_udaf())
+            .with_aggregate_function(count_udaf())
             .with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
+        let context = MockContextProvider { state };
         let sql_to_rel = SqlToRel::new(&context);
         let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
 
@@ -189,8 +189,10 @@ fn roundtrip_crossjoin() -> Result<()> {
         .try_with_sql(query)?
         .parse_statement()?;
 
-    let context = MockContextProvider::default()
+    let state = MockSessionState::default()
         .with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
+
+    let context = MockContextProvider { state };
     let sql_to_rel = SqlToRel::new(&context);
     let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
 
@@ -412,10 +414,12 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
             .try_with_sql(query.sql)?
             .parse_statement()?;
 
-        let context = MockContextProvider::default()
-            .with_expr_planner(Arc::new(CoreFunctionPlanner::default()))
-            .with_udaf(max_udaf())
-            .with_udaf(min_udaf());
+        let state = MockSessionState::default()
+            .with_aggregate_function(max_udaf())
+            .with_aggregate_function(min_udaf())
+            .with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
+
+        let context = MockContextProvider { state };
         let sql_to_rel = SqlToRel::new(&context);
         let plan = sql_to_rel
             .sql_statement_to_plan(statement)
@@ -443,7 +447,9 @@ fn test_unnest_logical_plan() -> Result<()> {
         .try_with_sql(query)?
         .parse_statement()?;
 
-    let context = MockContextProvider::default();
+    let context = MockContextProvider {
+        state: MockSessionState::default(),
+    };
     let sql_to_rel = SqlToRel::new(&context);
     let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
 
@@ -516,7 +522,9 @@ fn test_pretty_roundtrip() -> Result<()> {
 
     let df_schema = DFSchema::try_from(schema)?;
 
-    let context = MockContextProvider::default();
+    let context = MockContextProvider {
+        state: MockSessionState::default(),
+    };
     let sql_to_rel = SqlToRel::new(&context);
 
     let unparser = Unparser::default().with_pretty(true);
@@ -589,7 +597,9 @@ fn sql_round_trip(query: &str, expect: &str) {
         .parse_statement()
         .unwrap();
 
-    let context = MockContextProvider::default();
+    let context = MockContextProvider {
+        state: MockSessionState::default(),
+    };
     let sql_to_rel = SqlToRel::new(&context);
     let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
 
diff --git a/datafusion/sql/tests/common/mod.rs 
b/datafusion/sql/tests/common/mod.rs
index 374aa9db67..fe0e5f7283 100644
--- a/datafusion/sql/tests/common/mod.rs
+++ b/datafusion/sql/tests/common/mod.rs
@@ -50,36 +50,40 @@ impl Display for MockCsvType {
 }
 
 #[derive(Default)]
-pub(crate) struct MockContextProvider {
-    options: ConfigOptions,
-    udfs: HashMap<String, Arc<ScalarUDF>>,
-    udafs: HashMap<String, Arc<AggregateUDF>>,
+pub(crate) struct MockSessionState {
+    scalar_functions: HashMap<String, Arc<ScalarUDF>>,
+    aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
     expr_planners: Vec<Arc<dyn ExprPlanner>>,
+    pub config_options: ConfigOptions,
 }
 
-impl MockContextProvider {
-    // Suppressing dead code warning, as this is used in integration test 
crates
-    #[allow(dead_code)]
-    pub(crate) fn options_mut(&mut self) -> &mut ConfigOptions {
-        &mut self.options
+impl MockSessionState {
+    pub fn with_expr_planner(mut self, expr_planner: Arc<dyn ExprPlanner>) -> 
Self {
+        self.expr_planners.push(expr_planner);
+        self
     }
 
-    #[allow(dead_code)]
-    pub(crate) fn with_udf(mut self, udf: ScalarUDF) -> Self {
-        self.udfs.insert(udf.name().to_string(), Arc::new(udf));
+    pub fn with_scalar_function(mut self, scalar_function: Arc<ScalarUDF>) -> 
Self {
+        self.scalar_functions
+            .insert(scalar_function.name().to_string(), scalar_function);
         self
     }
 
-    pub(crate) fn with_udaf(mut self, udaf: Arc<AggregateUDF>) -> Self {
+    pub fn with_aggregate_function(
+        mut self,
+        aggregate_function: Arc<AggregateUDF>,
+    ) -> Self {
         // TODO: change to to_string() if all the function name is converted 
to lowercase
-        self.udafs.insert(udaf.name().to_lowercase(), udaf);
+        self.aggregate_functions.insert(
+            aggregate_function.name().to_string().to_lowercase(),
+            aggregate_function,
+        );
         self
     }
+}
 
-    pub(crate) fn with_expr_planner(mut self, planner: Arc<dyn ExprPlanner>) 
-> Self {
-        self.expr_planners.push(planner);
-        self
-    }
+pub(crate) struct MockContextProvider {
+    pub(crate) state: MockSessionState,
 }
 
 impl ContextProvider for MockContextProvider {
@@ -202,11 +206,11 @@ impl ContextProvider for MockContextProvider {
     }
 
     fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
-        self.udfs.get(name).cloned()
+        self.state.scalar_functions.get(name).cloned()
     }
 
     fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
-        self.udafs.get(name).cloned()
+        self.state.aggregate_functions.get(name).cloned()
     }
 
     fn get_variable_type(&self, _: &[String]) -> Option<DataType> {
@@ -218,7 +222,7 @@ impl ContextProvider for MockContextProvider {
     }
 
     fn options(&self) -> &ConfigOptions {
-        &self.options
+        &self.state.config_options
     }
 
     fn get_file_type(
@@ -237,11 +241,11 @@ impl ContextProvider for MockContextProvider {
     }
 
     fn udf_names(&self) -> Vec<String> {
-        self.udfs.keys().cloned().collect()
+        self.state.scalar_functions.keys().cloned().collect()
     }
 
     fn udaf_names(&self) -> Vec<String> {
-        self.udafs.keys().cloned().collect()
+        self.state.aggregate_functions.keys().cloned().collect()
     }
 
     fn udwf_names(&self) -> Vec<String> {
@@ -249,7 +253,7 @@ impl ContextProvider for MockContextProvider {
     }
 
     fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
-        &self.expr_planners
+        &self.state.expr_planners
     }
 }
 
diff --git a/datafusion/sql/tests/sql_integration.rs 
b/datafusion/sql/tests/sql_integration.rs
index 4d7e608056..5a0317c47c 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -41,6 +41,7 @@ use datafusion_sql::{
     planner::{ParserOptions, SqlToRel},
 };
 
+use crate::common::MockSessionState;
 use datafusion_functions::core::planner::CoreFunctionPlanner;
 use datafusion_functions_aggregate::{
     approx_median::approx_median_udaf, count::count_udaf, min_max::max_udaf,
@@ -1495,8 +1496,9 @@ fn recursive_ctes_disabled() {
         select * from numbers;";
 
     // manually setting up test here so that we can disable recursive ctes
-    let mut context = MockContextProvider::default();
-    context.options_mut().execution.enable_recursive_ctes = false;
+    let mut state = MockSessionState::default();
+    state.config_options.execution.enable_recursive_ctes = false;
+    let context = MockContextProvider { state };
 
     let planner = SqlToRel::new_with_options(&context, 
ParserOptions::default());
     let result = DFParser::parse_sql_with_dialect(sql, &GenericDialect {});
@@ -2727,7 +2729,8 @@ fn logical_plan_with_options(sql: &str, options: 
ParserOptions) -> Result<Logica
 }
 
 fn logical_plan_with_dialect(sql: &str, dialect: &dyn Dialect) -> 
Result<LogicalPlan> {
-    let context = MockContextProvider::default().with_udaf(sum_udaf());
+    let state = 
MockSessionState::default().with_aggregate_function(sum_udaf());
+    let context = MockContextProvider { state };
     let planner = SqlToRel::new(&context);
     let result = DFParser::parse_sql_with_dialect(sql, dialect);
     let mut ast = result?;
@@ -2739,39 +2742,44 @@ fn logical_plan_with_dialect_and_options(
     dialect: &dyn Dialect,
     options: ParserOptions,
 ) -> Result<LogicalPlan> {
-    let context = MockContextProvider::default()
-        .with_udf(unicode::character_length().as_ref().clone())
-        .with_udf(string::concat().as_ref().clone())
-        .with_udf(make_udf(
+    let state = MockSessionState::default()
+        
.with_scalar_function(Arc::new(unicode::character_length().as_ref().clone()))
+        .with_scalar_function(Arc::new(string::concat().as_ref().clone()))
+        .with_scalar_function(Arc::new(make_udf(
             "nullif",
             vec![DataType::Int32, DataType::Int32],
             DataType::Int32,
-        ))
-        .with_udf(make_udf(
+        )))
+        .with_scalar_function(Arc::new(make_udf(
             "round",
             vec![DataType::Float64, DataType::Int64],
             DataType::Float32,
-        ))
-        .with_udf(make_udf(
+        )))
+        .with_scalar_function(Arc::new(make_udf(
             "arrow_cast",
             vec![DataType::Int64, DataType::Utf8],
             DataType::Float64,
-        ))
-        .with_udf(make_udf(
+        )))
+        .with_scalar_function(Arc::new(make_udf(
             "date_trunc",
             vec![DataType::Utf8, DataType::Timestamp(Nanosecond, None)],
             DataType::Int32,
-        ))
-        .with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64))
-        .with_udaf(sum_udaf())
-        .with_udaf(approx_median_udaf())
-        .with_udaf(count_udaf())
-        .with_udaf(avg_udaf())
-        .with_udaf(min_udaf())
-        .with_udaf(max_udaf())
-        .with_udaf(grouping_udaf())
+        )))
+        .with_scalar_function(Arc::new(make_udf(
+            "sqrt",
+            vec![DataType::Int64],
+            DataType::Int64,
+        )))
+        .with_aggregate_function(sum_udaf())
+        .with_aggregate_function(approx_median_udaf())
+        .with_aggregate_function(count_udaf())
+        .with_aggregate_function(avg_udaf())
+        .with_aggregate_function(min_udaf())
+        .with_aggregate_function(max_udaf())
+        .with_aggregate_function(grouping_udaf())
         .with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
 
+    let context = MockContextProvider { state };
     let planner = SqlToRel::new_with_options(&context, options);
     let result = DFParser::parse_sql_with_dialect(sql, dialect);
     let mut ast = result?;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to