This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new e19c669855 Support User Defined Table Function (#8306)
e19c669855 is described below
commit e19c669855baa8b78ff86755803944d2ddf65536
Author: Tan Wei <[email protected]>
AuthorDate: Fri Dec 1 06:56:03 2023 +0800
Support User Defined Table Function (#8306)
* Support User Defined Table Function
Signed-off-by: veeupup <[email protected]>
* fix comments
Signed-off-by: veeupup <[email protected]>
* add udtf test
Signed-off-by: veeupup <[email protected]>
* add file header
* Simply table function example, add some comments
* Simplfy exprs
* make clippy happy
* Update datafusion/core/tests/user_defined/user_defined_table_functions.rs
---------
Signed-off-by: veeupup <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion-examples/examples/simple_udtf.rs | 177 +++++++++++++++++
datafusion/core/src/datasource/function.rs | 56 ++++++
datafusion/core/src/datasource/mod.rs | 1 +
datafusion/core/src/execution/context/mod.rs | 30 ++-
datafusion/core/tests/user_defined/mod.rs | 3 +
.../user_defined/user_defined_table_functions.rs | 219 +++++++++++++++++++++
datafusion/sql/src/planner.rs | 9 +
datafusion/sql/src/relation/mod.rs | 76 +++++--
8 files changed, 550 insertions(+), 21 deletions(-)
diff --git a/datafusion-examples/examples/simple_udtf.rs
b/datafusion-examples/examples/simple_udtf.rs
new file mode 100644
index 0000000000..bce6337652
--- /dev/null
+++ b/datafusion-examples/examples/simple_udtf.rs
@@ -0,0 +1,177 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::csv::reader::Format;
+use arrow::csv::ReaderBuilder;
+use async_trait::async_trait;
+use datafusion::arrow::datatypes::SchemaRef;
+use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::datasource::function::TableFunctionImpl;
+use datafusion::datasource::TableProvider;
+use datafusion::error::Result;
+use datafusion::execution::context::{ExecutionProps, SessionState};
+use datafusion::physical_plan::memory::MemoryExec;
+use datafusion::physical_plan::ExecutionPlan;
+use datafusion::prelude::SessionContext;
+use datafusion_common::{plan_err, DataFusionError, ScalarValue};
+use datafusion_expr::{Expr, TableType};
+use datafusion_optimizer::simplify_expressions::{ExprSimplifier,
SimplifyContext};
+use std::fs::File;
+use std::io::Seek;
+use std::path::Path;
+use std::sync::Arc;
+
+// To define your own table function, you only need to do the following 3
things:
+// 1. Implement your own [`TableProvider`]
+// 2. Implement your own [`TableFunctionImpl`] and return your
[`TableProvider`]
+// 3. Register the function using [`SessionContext::register_udtf`]
+
+/// This example demonstrates how to register a TableFunction
+#[tokio::main]
+async fn main() -> Result<()> {
+ // create local execution context
+ let ctx = SessionContext::new();
+
+ // register the table function that will be called in SQL statements by
`read_csv`
+ ctx.register_udtf("read_csv", Arc::new(LocalCsvTableFunc {}));
+
+ let testdata = datafusion::test_util::arrow_test_data();
+ let csv_file = format!("{testdata}/csv/aggregate_test_100.csv");
+
+ // Pass 2 arguments, read csv with at most 2 rows (simplify logic makes
1+1 --> 2)
+ let df = ctx
+ .sql(format!("SELECT * FROM read_csv('{csv_file}', 1 + 1);").as_str())
+ .await?;
+ df.show().await?;
+
+ // just run, return all rows
+ let df = ctx
+ .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str())
+ .await?;
+ df.show().await?;
+
+ Ok(())
+}
+
+/// Table Function that mimics the [`read_csv`] function in DuckDB.
+///
+/// Usage: `read_csv(filename, [limit])`
+///
+/// [`read_csv`]: https://duckdb.org/docs/data/csv/overview.html
+struct LocalCsvTable {
+ schema: SchemaRef,
+ limit: Option<usize>,
+ batches: Vec<RecordBatch>,
+}
+
+#[async_trait]
+impl TableProvider for LocalCsvTable {
+ fn as_any(&self) -> &dyn std::any::Any {
+ self
+ }
+
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+
+ fn table_type(&self) -> TableType {
+ TableType::Base
+ }
+
+ async fn scan(
+ &self,
+ _state: &SessionState,
+ projection: Option<&Vec<usize>>,
+ _filters: &[Expr],
+ _limit: Option<usize>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ let batches = if let Some(max_return_lines) = self.limit {
+ // get max return rows from self.batches
+ let mut batches = vec![];
+ let mut lines = 0;
+ for batch in &self.batches {
+ let batch_lines = batch.num_rows();
+ if lines + batch_lines > max_return_lines {
+ let batch_lines = max_return_lines - lines;
+ batches.push(batch.slice(0, batch_lines));
+ break;
+ } else {
+ batches.push(batch.clone());
+ lines += batch_lines;
+ }
+ }
+ batches
+ } else {
+ self.batches.clone()
+ };
+ Ok(Arc::new(MemoryExec::try_new(
+ &[batches],
+ TableProvider::schema(self),
+ projection.cloned(),
+ )?))
+ }
+}
+struct LocalCsvTableFunc {}
+
+impl TableFunctionImpl for LocalCsvTableFunc {
+ fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
+ let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) =
exprs.get(0) else {
+ return plan_err!("read_csv requires at least one string argument");
+ };
+
+ let limit = exprs
+ .get(1)
+ .map(|expr| {
+ // try to simpify the expression, so 1+2 becomes 3, for example
+ let execution_props = ExecutionProps::new();
+ let info = SimplifyContext::new(&execution_props);
+ let expr = ExprSimplifier::new(info).simplify(expr.clone())?;
+
+ if let Expr::Literal(ScalarValue::Int64(Some(limit))) = expr {
+ Ok(limit as usize)
+ } else {
+ plan_err!("Limit must be an integer")
+ }
+ })
+ .transpose()?;
+
+ let (schema, batches) = read_csv_batches(path)?;
+
+ let table = LocalCsvTable {
+ schema,
+ limit,
+ batches,
+ };
+ Ok(Arc::new(table))
+ }
+}
+
+fn read_csv_batches(csv_path: impl AsRef<Path>) -> Result<(SchemaRef,
Vec<RecordBatch>)> {
+ let mut file = File::open(csv_path)?;
+ let (schema, _) = Format::default().infer_schema(&mut file, None)?;
+ file.rewind()?;
+
+ let reader = ReaderBuilder::new(Arc::new(schema.clone()))
+ .with_header(true)
+ .build(file)?;
+ let mut batches = vec![];
+ for bacth in reader {
+ batches.push(bacth?);
+ }
+ let schema = Arc::new(schema);
+ Ok((schema, batches))
+}
diff --git a/datafusion/core/src/datasource/function.rs
b/datafusion/core/src/datasource/function.rs
new file mode 100644
index 0000000000..2fd352ee4e
--- /dev/null
+++ b/datafusion/core/src/datasource/function.rs
@@ -0,0 +1,56 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! A table that uses a function to generate data
+
+use super::TableProvider;
+
+use datafusion_common::Result;
+use datafusion_expr::Expr;
+
+use std::sync::Arc;
+
+/// A trait for table function implementations
+pub trait TableFunctionImpl: Sync + Send {
+ /// Create a table provider
+ fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>>;
+}
+
+/// A table that uses a function to generate data
+pub struct TableFunction {
+ /// Name of the table function
+ name: String,
+ /// Function implementation
+ fun: Arc<dyn TableFunctionImpl>,
+}
+
+impl TableFunction {
+ /// Create a new table function
+ pub fn new(name: String, fun: Arc<dyn TableFunctionImpl>) -> Self {
+ Self { name, fun }
+ }
+
+ /// Get the name of the table function
+ pub fn name(&self) -> &str {
+ &self.name
+ }
+
+ /// Get the function implementation and generate a table
+ pub fn create_table_provider(&self, args: &[Expr]) -> Result<Arc<dyn
TableProvider>> {
+ self.fun.call(args)
+ }
+}
diff --git a/datafusion/core/src/datasource/mod.rs
b/datafusion/core/src/datasource/mod.rs
index 45f9bee6a5..2e516cc36a 100644
--- a/datafusion/core/src/datasource/mod.rs
+++ b/datafusion/core/src/datasource/mod.rs
@@ -23,6 +23,7 @@ pub mod avro_to_arrow;
pub mod default_table_source;
pub mod empty;
pub mod file_format;
+pub mod function;
pub mod listing;
pub mod listing_table_factory;
pub mod memory;
diff --git a/datafusion/core/src/execution/context/mod.rs
b/datafusion/core/src/execution/context/mod.rs
index dbebedce3c..58a4f08341 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -26,6 +26,7 @@ mod parquet;
use crate::{
catalog::{CatalogList, MemoryCatalogList},
datasource::{
+ function::{TableFunction, TableFunctionImpl},
listing::{ListingOptions, ListingTable},
provider::TableProviderFactory,
},
@@ -42,7 +43,7 @@ use datafusion_common::{
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
- StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
+ Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
pub use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::var_provider::is_system_variables;
@@ -803,6 +804,14 @@ impl SessionContext {
.add_var_provider(variable_type, provider);
}
+ /// Register a table UDF with this context
+ pub fn register_udtf(&self, name: &str, fun: Arc<dyn TableFunctionImpl>) {
+ self.state.write().table_functions.insert(
+ name.to_owned(),
+ Arc::new(TableFunction::new(name.to_owned(), fun)),
+ );
+ }
+
/// Registers a scalar UDF within this context.
///
/// Note in SQL queries, function names are looked up using
@@ -1241,6 +1250,8 @@ pub struct SessionState {
query_planner: Arc<dyn QueryPlanner + Send + Sync>,
/// Collection of catalogs containing schemas and ultimately TableProviders
catalog_list: Arc<dyn CatalogList>,
+ /// Table Functions
+ table_functions: HashMap<String, Arc<TableFunction>>,
/// Scalar functions that are registered with the context
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
/// Aggregate functions registered in the context
@@ -1339,6 +1350,7 @@ impl SessionState {
physical_optimizers: PhysicalOptimizer::new(),
query_planner: Arc::new(DefaultQueryPlanner {}),
catalog_list,
+ table_functions: HashMap::new(),
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
window_functions: HashMap::new(),
@@ -1877,6 +1889,22 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
.ok_or_else(|| plan_datafusion_err!("table '{name}' not found"))
}
+ fn get_table_function_source(
+ &self,
+ name: &str,
+ args: Vec<Expr>,
+ ) -> Result<Arc<dyn TableSource>> {
+ let tbl_func = self
+ .state
+ .table_functions
+ .get(name)
+ .cloned()
+ .ok_or_else(|| plan_datafusion_err!("table function '{name}' not
found"))?;
+ let provider = tbl_func.create_table_provider(&args)?;
+
+ Ok(provider_as_source(provider))
+ }
+
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.state.scalar_functions().get(name).cloned()
}
diff --git a/datafusion/core/tests/user_defined/mod.rs
b/datafusion/core/tests/user_defined/mod.rs
index 09c7c3d326..6c6d966cc3 100644
--- a/datafusion/core/tests/user_defined/mod.rs
+++ b/datafusion/core/tests/user_defined/mod.rs
@@ -26,3 +26,6 @@ mod user_defined_plan;
/// Tests for User Defined Window Functions
mod user_defined_window_functions;
+
+/// Tests for User Defined Table Functions
+mod user_defined_table_functions;
diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs
b/datafusion/core/tests/user_defined/user_defined_table_functions.rs
new file mode 100644
index 0000000000..b5d10b1c5b
--- /dev/null
+++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs
@@ -0,0 +1,219 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::Int64Array;
+use arrow::csv::reader::Format;
+use arrow::csv::ReaderBuilder;
+use async_trait::async_trait;
+use datafusion::arrow::datatypes::SchemaRef;
+use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::datasource::function::TableFunctionImpl;
+use datafusion::datasource::TableProvider;
+use datafusion::error::Result;
+use datafusion::execution::context::SessionState;
+use datafusion::execution::TaskContext;
+use datafusion::physical_plan::memory::MemoryExec;
+use datafusion::physical_plan::{collect, ExecutionPlan};
+use datafusion::prelude::SessionContext;
+use datafusion_common::{assert_batches_eq, DFSchema, ScalarValue};
+use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType};
+use std::fs::File;
+use std::io::Seek;
+use std::path::Path;
+use std::sync::Arc;
+
+/// test simple udtf with define read_csv with parameters
+#[tokio::test]
+async fn test_simple_read_csv_udtf() -> Result<()> {
+ let ctx = SessionContext::new();
+
+ ctx.register_udtf("read_csv", Arc::new(SimpleCsvTableFunc {}));
+
+ let csv_file = "tests/tpch-csv/nation.csv";
+ // read csv with at most 5 rows
+ let rbs = ctx
+ .sql(format!("SELECT * FROM read_csv('{csv_file}', 5);").as_str())
+ .await?
+ .collect()
+ .await?;
+
+ let excepted = [
+
"+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+",
+ "| n_nationkey | n_name | n_regionkey | n_comment
|",
+
"+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+",
+ "| 1 | ARGENTINA | 1 | al foxes promise slyly
according to the regular accounts. bold requests alon
|",
+ "| 2 | BRAZIL | 1 | y alongside of the pending
deposits. carefully special packages are about the ironic forges. slyly special
|",
+ "| 3 | CANADA | 1 | eas hang ironic, silent
packages. slyly regular packages are furiously over the tithes. fluffily bold
|",
+ "| 4 | EGYPT | 4 | y above the carefully
unusual theodolites. final dugouts are quickly across the furiously regular d
|",
+ "| 5 | ETHIOPIA | 0 | ven packages wake quickly.
regu
|",
+
"+-------------+-----------+-------------+-------------------------------------------------------------------------------------------------------------+",
];
+ assert_batches_eq!(excepted, &rbs);
+
+ // just run, return all rows
+ let rbs = ctx
+ .sql(format!("SELECT * FROM read_csv('{csv_file}');").as_str())
+ .await?
+ .collect()
+ .await?;
+ let excepted = [
+
"+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+",
+ "| n_nationkey | n_name | n_regionkey | n_comment
|",
+
"+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+",
+ "| 1 | ARGENTINA | 1 | al foxes promise slyly
according to the regular accounts. bold requests alon
|",
+ "| 2 | BRAZIL | 1 | y alongside of the pending
deposits. carefully special packages are about the ironic forges. slyly special
|",
+ "| 3 | CANADA | 1 | eas hang ironic, silent
packages. slyly regular packages are furiously over the tithes. fluffily bold
|",
+ "| 4 | EGYPT | 4 | y above the carefully
unusual theodolites. final dugouts are quickly across the furiously regular d
|",
+ "| 5 | ETHIOPIA | 0 | ven packages wake quickly.
regu
|",
+ "| 6 | FRANCE | 3 | refully final requests.
regular, ironi
|",
+ "| 7 | GERMANY | 3 | l platelets. regular
accounts x-ray: unusual, regular acco
|",
+ "| 8 | INDIA | 2 | ss excuses cajole slyly
across the packages. deposits print aroun
|",
+ "| 9 | INDONESIA | 2 | slyly express asymptotes.
regular deposits haggle slyly. carefully ironic hockey players sleep blithely.
carefull |",
+ "| 10 | IRAN | 4 | efully alongside of the
slyly final dependencies.
|",
+
"+-------------+-----------+-------------+--------------------------------------------------------------------------------------------------------------------+"
+ ];
+ assert_batches_eq!(excepted, &rbs);
+
+ Ok(())
+}
+
+struct SimpleCsvTable {
+ schema: SchemaRef,
+ exprs: Vec<Expr>,
+ batches: Vec<RecordBatch>,
+}
+
+#[async_trait]
+impl TableProvider for SimpleCsvTable {
+ fn as_any(&self) -> &dyn std::any::Any {
+ self
+ }
+
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+
+ fn table_type(&self) -> TableType {
+ TableType::Base
+ }
+
+ async fn scan(
+ &self,
+ state: &SessionState,
+ projection: Option<&Vec<usize>>,
+ _filters: &[Expr],
+ _limit: Option<usize>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ let batches = if !self.exprs.is_empty() {
+ let max_return_lines = self.interpreter_expr(state).await?;
+ // get max return rows from self.batches
+ let mut batches = vec![];
+ let mut lines = 0;
+ for batch in &self.batches {
+ let batch_lines = batch.num_rows();
+ if lines + batch_lines > max_return_lines as usize {
+ let batch_lines = max_return_lines as usize - lines;
+ batches.push(batch.slice(0, batch_lines));
+ break;
+ } else {
+ batches.push(batch.clone());
+ lines += batch_lines;
+ }
+ }
+ batches
+ } else {
+ self.batches.clone()
+ };
+ Ok(Arc::new(MemoryExec::try_new(
+ &[batches],
+ TableProvider::schema(self),
+ projection.cloned(),
+ )?))
+ }
+}
+
+impl SimpleCsvTable {
+ async fn interpreter_expr(&self, state: &SessionState) -> Result<i64> {
+ use datafusion::logical_expr::expr_rewriter::normalize_col;
+ use datafusion::logical_expr::utils::columnize_expr;
+ let plan = LogicalPlan::EmptyRelation(EmptyRelation {
+ produce_one_row: true,
+ schema: Arc::new(DFSchema::empty()),
+ });
+ let logical_plan = Projection::try_new(
+ vec![columnize_expr(
+ normalize_col(self.exprs[0].clone(), &plan)?,
+ plan.schema(),
+ )],
+ Arc::new(plan),
+ )
+ .map(LogicalPlan::Projection)?;
+ let rbs = collect(
+ state.create_physical_plan(&logical_plan).await?,
+ Arc::new(TaskContext::from(state)),
+ )
+ .await?;
+ let limit = rbs[0]
+ .column(0)
+ .as_any()
+ .downcast_ref::<Int64Array>()
+ .unwrap()
+ .value(0);
+ Ok(limit)
+ }
+}
+
+struct SimpleCsvTableFunc {}
+
+impl TableFunctionImpl for SimpleCsvTableFunc {
+ fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
+ let mut new_exprs = vec![];
+ let mut filepath = String::new();
+ for expr in exprs {
+ match expr {
+ Expr::Literal(ScalarValue::Utf8(Some(ref path))) => {
+ filepath = path.clone()
+ }
+ expr => new_exprs.push(expr.clone()),
+ }
+ }
+ let (schema, batches) = read_csv_batches(filepath)?;
+ let table = SimpleCsvTable {
+ schema,
+ exprs: new_exprs.clone(),
+ batches,
+ };
+ Ok(Arc::new(table))
+ }
+}
+
+fn read_csv_batches(csv_path: impl AsRef<Path>) -> Result<(SchemaRef,
Vec<RecordBatch>)> {
+ let mut file = File::open(csv_path)?;
+ let (schema, _) = Format::default()
+ .with_header(true)
+ .infer_schema(&mut file, None)?;
+ file.rewind()?;
+
+ let reader = ReaderBuilder::new(Arc::new(schema.clone()))
+ .with_header(true)
+ .build(file)?;
+ let mut batches = vec![];
+ for bacth in reader {
+ batches.push(bacth?);
+ }
+ let schema = Arc::new(schema);
+ Ok((schema, batches))
+}
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 622e5aca79..c5c30e3a22 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -52,6 +52,15 @@ pub trait ContextProvider {
}
/// Getter for a datasource
fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn
TableSource>>;
+ /// Getter for a table function
+ fn get_table_function_source(
+ &self,
+ _name: &str,
+ _args: Vec<Expr>,
+ ) -> Result<Arc<dyn TableSource>> {
+ not_impl_err!("Table Functions are not supported")
+ }
+
/// Getter for a UDF description
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
/// Getter for a UDAF description
diff --git a/datafusion/sql/src/relation/mod.rs
b/datafusion/sql/src/relation/mod.rs
index 180743d19b..6fc7e96012 100644
--- a/datafusion/sql/src/relation/mod.rs
+++ b/datafusion/sql/src/relation/mod.rs
@@ -16,9 +16,11 @@
// under the License.
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
-use datafusion_common::{not_impl_err, DataFusionError, Result};
+use datafusion_common::{
+ not_impl_err, plan_err, DFSchema, DataFusionError, Result, TableReference,
+};
use datafusion_expr::{LogicalPlan, LogicalPlanBuilder};
-use sqlparser::ast::TableFactor;
+use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor};
mod join;
@@ -30,24 +32,58 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
let (plan, alias) = match relation {
- TableFactor::Table { name, alias, .. } => {
- // normalize name and alias
- let table_ref = self.object_name_to_table_reference(name)?;
- let table_name = table_ref.to_string();
- let cte = planner_context.get_cte(&table_name);
- (
- match (
- cte,
-
self.context_provider.get_table_source(table_ref.clone()),
- ) {
- (Some(cte_plan), _) => Ok(cte_plan.clone()),
- (_, Ok(provider)) => {
- LogicalPlanBuilder::scan(table_ref, provider,
None)?.build()
- }
- (None, Err(e)) => Err(e),
- }?,
- alias,
- )
+ TableFactor::Table {
+ name, alias, args, ..
+ } => {
+ if let Some(func_args) = args {
+ let tbl_func_name =
name.0.get(0).unwrap().value.to_string();
+ let args = func_args
+ .into_iter()
+ .flat_map(|arg| {
+ if let
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) = arg
+ {
+ self.sql_expr_to_logical_expr(
+ expr,
+ &DFSchema::empty(),
+ planner_context,
+ )
+ } else {
+ plan_err!("Unsupported function argument type:
{:?}", arg)
+ }
+ })
+ .collect::<Vec<_>>();
+ let provider = self
+ .context_provider
+ .get_table_function_source(&tbl_func_name, args)?;
+ let plan = LogicalPlanBuilder::scan(
+ TableReference::Bare {
+ table: std::borrow::Cow::Borrowed("tmp_table"),
+ },
+ provider,
+ None,
+ )?
+ .build()?;
+ (plan, alias)
+ } else {
+ // normalize name and alias
+ let table_ref = self.object_name_to_table_reference(name)?;
+ let table_name = table_ref.to_string();
+ let cte = planner_context.get_cte(&table_name);
+ (
+ match (
+ cte,
+
self.context_provider.get_table_source(table_ref.clone()),
+ ) {
+ (Some(cte_plan), _) => Ok(cte_plan.clone()),
+ (_, Ok(provider)) => {
+ LogicalPlanBuilder::scan(table_ref, provider,
None)?
+ .build()
+ }
+ (None, Err(e)) => Err(e),
+ }?,
+ alias,
+ )
+ }
}
TableFactor::Derived {
subquery, alias, ..