This is an automated email from the ASF dual-hosted git repository.
yjshen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 7158b4b Merge dataframe and dataframe imp (#1998)
7158b4b is described below
commit 7158b4b36ca66f4c1e711366445e15ddb0ead8da
Author: Kaushik <[email protected]>
AuthorDate: Mon Mar 14 23:33:47 2022 -0700
Merge dataframe and dataframe imp (#1998)
* merging Dataframe trait and DataFrameImpl into Dataframe struct
Change-Id: Id15520a81c272a7666f17d5f3b707445a4af7798
* reformatting
Change-Id: I6c8cab772c206c80cd5aa5b39296d8009ab4efdb
* new DataFrame type fixes
Change-Id: I01ebf2f52e21aaf9b24e4fb7ac54904dcd55e220
* fix build
Change-Id: Iefb549ecee1648a89e1165f93a8e2fa5e785df56
* re-formatting files
Change-Id: I2c96a7e24f1f74052eee50ea8199f381bdb553df
* revise dataframe.rs
Change-Id: I3aea55073f454d9fd096df0d5d9bf7fe5533bfaa
* fix import statements
Change-Id: I72f7bf086cf8a5726490b9cdf3bd52ff4154f006
* re-formatting files
Change-Id: Ib585edb0b6c7416c48a77168d82989e309dd776b
* bug fix
Change-Id: I1af2eb7609912074b37e38f66c15d0ce75ed291a
Co-authored-by: venkata.chaganti <[email protected]>
---
ballista/rust/client/src/context.rs | 15 +-
datafusion-cli/src/context.rs | 6 +-
datafusion-examples/examples/custom_datasource.rs | 6 +-
datafusion/src/dataframe.rs | 608 ++++++++++++++++++--
datafusion/src/execution/context.rs | 57 +-
datafusion/src/execution/dataframe_impl.rs | 662 ----------------------
datafusion/src/execution/mod.rs | 1 -
datafusion/tests/dataframe_functions.rs | 2 +-
8 files changed, 616 insertions(+), 741 deletions(-)
diff --git a/ballista/rust/client/src/context.rs
b/ballista/rust/client/src/context.rs
index f736247..8175a69 100644
--- a/ballista/rust/client/src/context.rs
+++ b/ballista/rust/client/src/context.rs
@@ -32,7 +32,6 @@ use datafusion::catalog::TableReference;
use datafusion::dataframe::DataFrame;
use datafusion::datasource::TableProvider;
use datafusion::error::{DataFusionError, Result};
-use datafusion::execution::dataframe_impl::DataFrameImpl;
use datafusion::logical_plan::{CreateExternalTable, LogicalPlan, TableScan};
use datafusion::prelude::{
AvroReadOptions, CsvReadOptions, ExecutionConfig, ExecutionContext,
@@ -148,7 +147,7 @@ impl BallistaContext {
&self,
path: &str,
options: AvroReadOptions<'_>,
- ) -> Result<Arc<dyn DataFrame>> {
+ ) -> Result<Arc<DataFrame>> {
// convert to absolute path because the executor likely has a
different working directory
let path = PathBuf::from(path);
let path = fs::canonicalize(&path)?;
@@ -168,7 +167,7 @@ impl BallistaContext {
/// Create a DataFrame representing a Parquet table scan
/// TODO fetch schema from scheduler instead of resolving locally
- pub async fn read_parquet(&self, path: &str) -> Result<Arc<dyn DataFrame>>
{
+ pub async fn read_parquet(&self, path: &str) -> Result<Arc<DataFrame>> {
// convert to absolute path because the executor likely has a
different working directory
let path = PathBuf::from(path);
let path = fs::canonicalize(&path)?;
@@ -192,7 +191,7 @@ impl BallistaContext {
&self,
path: &str,
options: CsvReadOptions<'_>,
- ) -> Result<Arc<dyn DataFrame>> {
+ ) -> Result<Arc<DataFrame>> {
// convert to absolute path because the executor likely has a
different working directory
let path = PathBuf::from(path);
let path = fs::canonicalize(&path)?;
@@ -291,7 +290,7 @@ impl BallistaContext {
///
/// This method is `async` because queries of type `CREATE EXTERNAL TABLE`
/// might require the schema to be inferred.
- pub async fn sql(&self, sql: &str) -> Result<Arc<dyn DataFrame>> {
+ pub async fn sql(&self, sql: &str) -> Result<Arc<DataFrame>> {
let mut ctx = {
let state = self.state.lock();
create_df_ctx_with_ballista_query_planner::<LogicalPlanNode>(
@@ -342,16 +341,16 @@ impl BallistaContext {
.has_header(*has_header),
)
.await?;
- Ok(Arc::new(DataFrameImpl::new(ctx.state, &plan)))
+ Ok(Arc::new(DataFrame::new(ctx.state, &plan)))
}
FileType::Parquet => {
self.register_parquet(name, location).await?;
- Ok(Arc::new(DataFrameImpl::new(ctx.state, &plan)))
+ Ok(Arc::new(DataFrame::new(ctx.state, &plan)))
}
FileType::Avro => {
self.register_avro(name, location,
AvroReadOptions::default())
.await?;
- Ok(Arc::new(DataFrameImpl::new(ctx.state, &plan)))
+ Ok(Arc::new(DataFrame::new(ctx.state, &plan)))
}
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported file type {:?}.",
diff --git a/datafusion-cli/src/context.rs b/datafusion-cli/src/context.rs
index 0b746af..4f29af9 100644
--- a/datafusion-cli/src/context.rs
+++ b/datafusion-cli/src/context.rs
@@ -42,7 +42,7 @@ impl Context {
}
/// execute an SQL statement against the context
- pub async fn sql(&mut self, sql: &str) -> Result<Arc<dyn DataFrame>> {
+ pub async fn sql(&mut self, sql: &str) -> Result<Arc<DataFrame>> {
match self {
Context::Local(datafusion) => datafusion.sql(sql).await,
Context::Remote(ballista) => ballista.sql(sql).await,
@@ -63,7 +63,7 @@ impl BallistaContext {
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
Ok(Self(BallistaContext::remote(host, port, &config)))
}
- pub async fn sql(&mut self, sql: &str) -> Result<Arc<dyn DataFrame>> {
+ pub async fn sql(&mut self, sql: &str) -> Result<Arc<DataFrame>> {
self.0.sql(sql).await
}
}
@@ -78,7 +78,7 @@ impl BallistaContext {
.to_string(),
))
}
- pub async fn sql(&mut self, _sql: &str) -> Result<Arc<dyn DataFrame>> {
+ pub async fn sql(&mut self, _sql: &str) -> Result<Arc<DataFrame>> {
unreachable!()
}
}
diff --git a/datafusion-examples/examples/custom_datasource.rs
b/datafusion-examples/examples/custom_datasource.rs
index aad153a..b3ef04d 100644
--- a/datafusion-examples/examples/custom_datasource.rs
+++ b/datafusion-examples/examples/custom_datasource.rs
@@ -16,12 +16,12 @@
// under the License.
use async_trait::async_trait;
-use datafusion::arrow::array::{Array, UInt64Builder, UInt8Builder};
+use datafusion::arrow::array::{UInt64Builder, UInt8Builder};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::dataframe::DataFrame;
use datafusion::datasource::TableProvider;
use datafusion::error::{DataFusionError, Result};
-use datafusion::execution::dataframe_impl::DataFrameImpl;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::logical_plan::{Expr, LogicalPlanBuilder};
use datafusion::physical_plan::expressions::PhysicalSortExpr;
@@ -66,7 +66,7 @@ async fn search_accounts(
.build()
.unwrap();
- let mut dataframe = DataFrameImpl::new(ctx.state, &logical_plan)
+ let mut dataframe = DataFrame::new(ctx.state, &logical_plan)
.select_columns(&["id", "bank_account"])?;
if let Some(f) = filter {
diff --git a/datafusion/src/dataframe.rs b/datafusion/src/dataframe.rs
index 7748a83..7ea4fb5 100644
--- a/datafusion/src/dataframe.rs
+++ b/datafusion/src/dataframe.rs
@@ -20,7 +20,8 @@
use crate::arrow::record_batch::RecordBatch;
use crate::error::Result;
use crate::logical_plan::{
- DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, Partitioning,
+ col, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan,
LogicalPlanBuilder,
+ Partitioning,
};
use parquet::file::properties::WriterProperties;
use std::sync::Arc;
@@ -28,6 +29,20 @@ use std::sync::Arc;
use crate::physical_plan::SendableRecordBatchStream;
use async_trait::async_trait;
+use crate::arrow::datatypes::Schema;
+use crate::arrow::datatypes::SchemaRef;
+use crate::arrow::util::pretty;
+use crate::datasource::TableProvider;
+use crate::datasource::TableType;
+use crate::execution::context::{ExecutionContext, ExecutionContextState};
+use crate::physical_plan::file_format::{plan_to_csv, plan_to_parquet};
+use crate::physical_plan::{collect, collect_partitioned};
+use crate::physical_plan::{execute_stream, execute_stream_partitioned,
ExecutionPlan};
+use crate::scalar::ScalarValue;
+use crate::sql::utils::find_window_exprs;
+use parking_lot::Mutex;
+use std::any::Any;
+
/// DataFrame represents a logical set of rows with the same named columns.
/// Similar to a [Pandas
DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html)
or
/// [Spark
DataFrame](https://spark.apache.org/docs/latest/sql-programming-guide.html)
@@ -53,8 +68,28 @@ use async_trait::async_trait;
/// # Ok(())
/// # }
/// ```
-#[async_trait]
-pub trait DataFrame: Send + Sync {
+pub struct DataFrame {
+ ctx_state: Arc<Mutex<ExecutionContextState>>,
+ plan: LogicalPlan,
+}
+
+impl DataFrame {
+ /// Create a new Table based on an existing logical plan
+ pub fn new(ctx_state: Arc<Mutex<ExecutionContextState>>, plan:
&LogicalPlan) -> Self {
+ Self {
+ ctx_state,
+ plan: plan.clone(),
+ }
+ }
+
+ /// Create a physical plan
+ pub async fn create_physical_plan(&self) -> Result<Arc<dyn ExecutionPlan>>
{
+ let state = self.ctx_state.lock().clone();
+ let ctx = ExecutionContext::from(Arc::new(Mutex::new(state)));
+ let plan = ctx.optimize(&self.plan)?;
+ ctx.create_physical_plan(&plan).await
+ }
+
/// Filter the DataFrame by column. Returns a new DataFrame only
containing the
/// specified columns.
///
@@ -69,7 +104,14 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn select_columns(&self, columns: &[&str]) -> Result<Arc<dyn DataFrame>>;
+ pub fn select_columns(&self, columns: &[&str]) -> Result<Arc<DataFrame>> {
+ let fields = columns
+ .iter()
+ .map(|name| self.plan.schema().field_with_unqualified_name(name))
+ .collect::<Result<Vec<_>>>()?;
+ let expr: Vec<Expr> = fields.iter().map(|f| col(f.name())).collect();
+ self.select(expr)
+ }
/// Create a projection based on arbitrary expressions.
///
@@ -84,7 +126,20 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn select(&self, expr: Vec<Expr>) -> Result<Arc<dyn DataFrame>>;
+ pub fn select(&self, expr_list: Vec<Expr>) -> Result<Arc<DataFrame>> {
+ let window_func_exprs = find_window_exprs(&expr_list);
+ let plan = if window_func_exprs.is_empty() {
+ self.to_logical_plan()
+ } else {
+ LogicalPlanBuilder::window_plan(self.to_logical_plan(),
window_func_exprs)?
+ };
+ let project_plan =
LogicalPlanBuilder::from(plan).project(expr_list)?.build()?;
+
+ Ok(Arc::new(DataFrame::new(
+ self.ctx_state.clone(),
+ &project_plan,
+ )))
+ }
/// Filter a DataFrame to only include rows that match the specified
filter expression.
///
@@ -99,7 +154,12 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn filter(&self, expr: Expr) -> Result<Arc<dyn DataFrame>>;
+ pub fn filter(&self, predicate: Expr) -> Result<Arc<DataFrame>> {
+ let plan = LogicalPlanBuilder::from(self.to_logical_plan())
+ .filter(predicate)?
+ .build()?;
+ Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan)))
+ }
/// Perform an aggregate query with optional grouping expressions.
///
@@ -119,11 +179,16 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn aggregate(
+ pub fn aggregate(
&self,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
- ) -> Result<Arc<dyn DataFrame>>;
+ ) -> Result<Arc<DataFrame>> {
+ let plan = LogicalPlanBuilder::from(self.to_logical_plan())
+ .aggregate(group_expr, aggr_expr)?
+ .build()?;
+ Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan)))
+ }
/// Limit the number of rows returned from this DataFrame.
///
@@ -138,7 +203,12 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn limit(&self, n: usize) -> Result<Arc<dyn DataFrame>>;
+ pub fn limit(&self, n: usize) -> Result<Arc<DataFrame>> {
+ let plan = LogicalPlanBuilder::from(self.to_logical_plan())
+ .limit(n)?
+ .build()?;
+ Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan)))
+ }
/// Calculate the union two [`DataFrame`]s. The two [`DataFrame`]s must
have exactly the same schema
///
@@ -153,7 +223,12 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn union(&self, dataframe: Arc<dyn DataFrame>) -> Result<Arc<dyn
DataFrame>>;
+ pub fn union(&self, dataframe: Arc<DataFrame>) -> Result<Arc<DataFrame>> {
+ let plan = LogicalPlanBuilder::from(self.to_logical_plan())
+ .union(dataframe.to_logical_plan())?
+ .build()?;
+ Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan)))
+ }
/// Calculate the union distinct two [`DataFrame`]s. The two
[`DataFrame`]s must have exactly the same schema
///
@@ -169,7 +244,14 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn distinct(&self) -> Result<Arc<dyn DataFrame>>;
+ pub fn distinct(&self) -> Result<Arc<DataFrame>> {
+ Ok(Arc::new(DataFrame::new(
+ self.ctx_state.clone(),
+ &LogicalPlanBuilder::from(self.to_logical_plan())
+ .distinct()?
+ .build()?,
+ )))
+ }
/// Sort the DataFrame by the specified sorting expressions. Any
expression can be turned into
/// a sort expression by calling its
[sort](../logical_plan/enum.Expr.html#method.sort) method.
@@ -185,7 +267,12 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn sort(&self, expr: Vec<Expr>) -> Result<Arc<dyn DataFrame>>;
+ pub fn sort(&self, expr: Vec<Expr>) -> Result<Arc<DataFrame>> {
+ let plan = LogicalPlanBuilder::from(self.to_logical_plan())
+ .sort(expr)?
+ .build()?;
+ Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan)))
+ }
/// Join this DataFrame with another DataFrame using the specified columns
as join keys
///
@@ -206,13 +293,22 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn join(
+ pub fn join(
&self,
- right: Arc<dyn DataFrame>,
+ right: Arc<DataFrame>,
join_type: JoinType,
left_cols: &[&str],
right_cols: &[&str],
- ) -> Result<Arc<dyn DataFrame>>;
+ ) -> Result<Arc<DataFrame>> {
+ let plan = LogicalPlanBuilder::from(self.to_logical_plan())
+ .join(
+ &right.to_logical_plan(),
+ join_type,
+ (left_cols.to_vec(), right_cols.to_vec()),
+ )?
+ .build()?;
+ Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan)))
+ }
// TODO: add join_using
@@ -229,13 +325,19 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn repartition(
+ pub fn repartition(
&self,
partitioning_scheme: Partitioning,
- ) -> Result<Arc<dyn DataFrame>>;
+ ) -> Result<Arc<DataFrame>> {
+ let plan = LogicalPlanBuilder::from(self.to_logical_plan())
+ .repartition(partitioning_scheme)?
+ .build()?;
+ Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan)))
+ }
+ /// Convert the logical plan represented by this DataFrame into a physical
plan and
+ /// execute it, collecting all resulting batches into memory
/// Executes this DataFrame and collects all results into a vector of
RecordBatch.
- ///
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
@@ -247,7 +349,11 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- async fn collect(&self) -> Result<Vec<RecordBatch>>;
+ pub async fn collect(&self) -> Result<Vec<RecordBatch>> {
+ let plan = self.create_physical_plan().await?;
+ let runtime = self.ctx_state.lock().runtime_env.clone();
+ Ok(collect(plan, runtime).await?)
+ }
/// Print results.
///
@@ -262,7 +368,10 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- async fn show(&self) -> Result<()>;
+ pub async fn show(&self) -> Result<()> {
+ let results = self.collect().await?;
+ Ok(pretty::print_batches(&results)?)
+ }
/// Print results and limit rows.
///
@@ -277,7 +386,10 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- async fn show_limit(&self, n: usize) -> Result<()>;
+ pub async fn show_limit(&self, num: usize) -> Result<()> {
+ let results = self.limit(num)?.collect().await?;
+ Ok(pretty::print_batches(&results)?)
+ }
/// Executes this DataFrame and returns a stream over a single partition
///
@@ -292,7 +404,11 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- async fn execute_stream(&self) -> Result<SendableRecordBatchStream>;
+ pub async fn execute_stream(&self) -> Result<SendableRecordBatchStream> {
+ let plan = self.create_physical_plan().await?;
+ let runtime = self.ctx_state.lock().runtime_env.clone();
+ execute_stream(plan, runtime).await
+ }
/// Executes this DataFrame and collects all results into a vector of
vector of RecordBatch
/// maintaining the input partitioning.
@@ -308,7 +424,11 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- async fn collect_partitioned(&self) -> Result<Vec<Vec<RecordBatch>>>;
+ pub async fn collect_partitioned(&self) -> Result<Vec<Vec<RecordBatch>>> {
+ let plan = self.create_physical_plan().await?;
+ let runtime = self.ctx_state.lock().runtime_env.clone();
+ Ok(collect_partitioned(plan, runtime).await?)
+ }
/// Executes this DataFrame and returns one stream per partition.
///
@@ -323,7 +443,13 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- async fn execute_stream_partitioned(&self) ->
Result<Vec<SendableRecordBatchStream>>;
+ pub async fn execute_stream_partitioned(
+ &self,
+ ) -> Result<Vec<SendableRecordBatchStream>> {
+ let plan = self.create_physical_plan().await?;
+ let runtime = self.ctx_state.lock().runtime_env.clone();
+ Ok(execute_stream_partitioned(plan, runtime).await?)
+ }
/// Returns the schema describing the output of this DataFrame in terms of
columns returned,
/// where each column has a name, data type, and nullability attribute.
@@ -339,10 +465,14 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn schema(&self) -> &DFSchema;
+ pub fn schema(&self) -> &DFSchema {
+ self.plan.schema()
+ }
/// Return the logical plan represented by this DataFrame.
- fn to_logical_plan(&self) -> LogicalPlan;
+ pub fn to_logical_plan(&self) -> LogicalPlan {
+ self.plan.clone()
+ }
/// Return a DataFrame with the explanation of its plan so far.
///
@@ -359,7 +489,12 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn explain(&self, verbose: bool, analyze: bool) -> Result<Arc<dyn
DataFrame>>;
+ pub fn explain(&self, verbose: bool, analyze: bool) ->
Result<Arc<DataFrame>> {
+ let plan = LogicalPlanBuilder::from(self.to_logical_plan())
+ .explain(verbose, analyze)?
+ .build()?;
+ Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan)))
+ }
/// Return a `FunctionRegistry` used to plan udf's calls
///
@@ -375,7 +510,10 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn registry(&self) -> Arc<dyn FunctionRegistry>;
+ pub fn registry(&self) -> Arc<dyn FunctionRegistry> {
+ let registry = self.ctx_state.lock().clone();
+ Arc::new(registry)
+ }
/// Calculate the intersection of two [`DataFrame`]s. The two
[`DataFrame`]s must have exactly the same schema
///
@@ -390,7 +528,14 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn intersect(&self, dataframe: Arc<dyn DataFrame>) -> Result<Arc<dyn
DataFrame>>;
+ pub fn intersect(&self, dataframe: Arc<DataFrame>) ->
Result<Arc<DataFrame>> {
+ let left_plan = self.to_logical_plan();
+ let right_plan = dataframe.to_logical_plan();
+ Ok(Arc::new(DataFrame::new(
+ self.ctx_state.clone(),
+ &LogicalPlanBuilder::intersect(left_plan, right_plan, true)?,
+ )))
+ }
/// Calculate the exception of two [`DataFrame`]s. The two [`DataFrame`]s
must have exactly the same schema
///
@@ -405,15 +550,412 @@ pub trait DataFrame: Send + Sync {
/// # Ok(())
/// # }
/// ```
- fn except(&self, dataframe: Arc<dyn DataFrame>) -> Result<Arc<dyn
DataFrame>>;
+ pub fn except(&self, dataframe: Arc<DataFrame>) -> Result<Arc<DataFrame>> {
+ let left_plan = self.to_logical_plan();
+ let right_plan = dataframe.to_logical_plan();
+
+ Ok(Arc::new(DataFrame::new(
+ self.ctx_state.clone(),
+ &LogicalPlanBuilder::except(left_plan, right_plan, true)?,
+ )))
+ }
/// Write a `DataFrame` to a CSV file.
- async fn write_csv(&self, path: &str) -> Result<()>;
+ pub async fn write_csv(&self, path: &str) -> Result<()> {
+ let plan = self.create_physical_plan().await?;
+ let state = self.ctx_state.lock().clone();
+ let ctx = ExecutionContext::from(Arc::new(Mutex::new(state)));
+ plan_to_csv(&ctx, plan, path).await
+ }
/// Write a `DataFrame` to a Parquet file.
- async fn write_parquet(
+ pub async fn write_parquet(
&self,
path: &str,
writer_properties: Option<WriterProperties>,
- ) -> Result<()>;
+ ) -> Result<()> {
+ let plan = self.create_physical_plan().await?;
+ let state = self.ctx_state.lock().clone();
+ let ctx = ExecutionContext::from(Arc::new(Mutex::new(state)));
+ plan_to_parquet(&ctx, plan, path, writer_properties).await
+ }
+}
+
+#[async_trait]
+impl TableProvider for DataFrame {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn schema(&self) -> SchemaRef {
+ let schema: Schema = self.plan.schema().as_ref().into();
+ Arc::new(schema)
+ }
+
+ fn table_type(&self) -> TableType {
+ TableType::View
+ }
+
+ async fn scan(
+ &self,
+ projection: &Option<Vec<usize>>,
+ filters: &[Expr],
+ limit: Option<usize>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ let expr = projection
+ .as_ref()
+ // construct projections
+ .map_or_else(
+ || Ok(Arc::new(Self::new(self.ctx_state.clone(), &self.plan))
as Arc<_>),
+ |projection| {
+ let schema =
TableProvider::schema(self).project(projection)?;
+ let names = schema
+ .fields()
+ .iter()
+ .map(|field| field.name().as_str())
+ .collect::<Vec<_>>();
+ self.select_columns(names.as_slice())
+ },
+ )?
+ // add predicates, otherwise use `true` as the predicate
+ .filter(filters.iter().cloned().fold(
+ Expr::Literal(ScalarValue::Boolean(Some(true))),
+ |acc, new| acc.and(new),
+ ))?;
+ // add a limit if given
+ Self::new(
+ self.ctx_state.clone(),
+ &limit
+ .map_or_else(|| Ok(expr.clone()), |n| expr.limit(n))?
+ .to_logical_plan(),
+ )
+ .create_physical_plan()
+ .await
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::vec;
+
+ use super::*;
+ use crate::execution::options::CsvReadOptions;
+ use crate::physical_plan::{window_functions, ColumnarValue};
+ use crate::{assert_batches_sorted_eq,
execution::context::ExecutionContext};
+ use crate::{logical_plan::*, test_util};
+ use arrow::datatypes::DataType;
+ use datafusion_expr::ScalarFunctionImplementation;
+ use datafusion_expr::Volatility;
+
+ #[tokio::test]
+ async fn select_columns() -> Result<()> {
+ // build plan using Table API
+
+ let t = test_table().await?;
+ let t2 = t.select_columns(&["c1", "c2", "c11"])?;
+ let plan = t2.to_logical_plan();
+
+ // build query using SQL
+ let sql_plan = create_plan("SELECT c1, c2, c11 FROM
aggregate_test_100").await?;
+
+ // the two plans should be identical
+ assert_same_plan(&plan, &sql_plan);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn select_expr() -> Result<()> {
+ // build plan using Table API
+ let t = test_table().await?;
+ let t2 = t.select(vec![col("c1"), col("c2"), col("c11")])?;
+ let plan = t2.to_logical_plan();
+
+ // build query using SQL
+ let sql_plan = create_plan("SELECT c1, c2, c11 FROM
aggregate_test_100").await?;
+
+ // the two plans should be identical
+ assert_same_plan(&plan, &sql_plan);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn select_with_window_exprs() -> Result<()> {
+ // build plan using Table API
+ let t = test_table().await?;
+ let first_row = Expr::WindowFunction {
+ fun: window_functions::WindowFunction::BuiltInWindowFunction(
+ window_functions::BuiltInWindowFunction::FirstValue,
+ ),
+ args: vec![col("aggregate_test_100.c1")],
+ partition_by: vec![col("aggregate_test_100.c2")],
+ order_by: vec![],
+ window_frame: None,
+ };
+ let t2 = t.select(vec![col("c1"), first_row])?;
+ let plan = t2.to_logical_plan();
+
+ let sql_plan = create_plan(
+ "select c1, first_value(c1) over (partition by c2) from
aggregate_test_100",
+ )
+ .await?;
+
+ assert_same_plan(&plan, &sql_plan);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn aggregate() -> Result<()> {
+ // build plan using DataFrame API
+ let df = test_table().await?;
+ let group_expr = vec![col("c1")];
+ let aggr_expr = vec![
+ min(col("c12")),
+ max(col("c12")),
+ avg(col("c12")),
+ sum(col("c12")),
+ count(col("c12")),
+ count_distinct(col("c12")),
+ ];
+
+ let df: Vec<RecordBatch> = df.aggregate(group_expr,
aggr_expr)?.collect().await?;
+
+ assert_batches_sorted_eq!(
+ vec![
+
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
+ "| c1 | MIN(aggregate_test_100.c12) |
MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) |
SUM(aggregate_test_100.c12) | COUNT(aggregate_test_100.c12) | COUNT(DISTINCT
aggregate_test_100.c12) |",
+
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
+ "| a | 0.02182578039211991 | 0.9800193410444061
| 0.48754517466109415 | 10.238448667882977 | 21
| 21 |",
+ "| b | 0.04893135681998029 | 0.9185813970744787
| 0.41040709263815384 | 7.797734760124923 | 19
| 19 |",
+ "| c | 0.0494924465469434 | 0.991517828651004
| 0.6600456536439784 | 13.860958726523545 | 21
| 21 |",
+ "| d | 0.061029375346466685 | 0.9748360509016578
| 0.48855379387549824 | 8.793968289758968 | 18
| 18 |",
+ "| e | 0.01479305307777301 | 0.9965400387585364
| 0.48600669271341534 | 10.206140546981722 | 21
| 21 |",
+
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
+ ],
+ &df
+ );
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join() -> Result<()> {
+ let left = test_table().await?.select_columns(&["c1", "c2"])?;
+ let right = test_table_with_name("c2")
+ .await?
+ .select_columns(&["c1", "c3"])?;
+ let left_rows = left.collect().await?;
+ let right_rows = right.collect().await?;
+ let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?;
+ let join_rows = join.collect().await?;
+ assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::<usize>());
+ assert_eq!(100, right_rows.iter().map(|x|
x.num_rows()).sum::<usize>());
+ assert_eq!(2008, join_rows.iter().map(|x|
x.num_rows()).sum::<usize>());
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn limit() -> Result<()> {
+ // build query using Table API
+ let t = test_table().await?;
+ let t2 = t.select_columns(&["c1", "c2", "c11"])?.limit(10)?;
+ let plan = t2.to_logical_plan();
+
+ // build query using SQL
+ let sql_plan =
+ create_plan("SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT
10").await?;
+
+ // the two plans should be identical
+ assert_same_plan(&plan, &sql_plan);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn explain() -> Result<()> {
+ // build query using Table API
+ let df = test_table().await?;
+ let df = df
+ .select_columns(&["c1", "c2", "c11"])?
+ .limit(10)?
+ .explain(false, false)?;
+ let plan = df.to_logical_plan();
+
+ // build query using SQL
+ let sql_plan =
+ create_plan("EXPLAIN SELECT c1, c2, c11 FROM aggregate_test_100
LIMIT 10")
+ .await?;
+
+ // the two plans should be identical
+ assert_same_plan(&plan, &sql_plan);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn registry() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx, "aggregate_test_100").await?;
+
+ // declare the udf
+ let my_fn: ScalarFunctionImplementation =
+ Arc::new(|_: &[ColumnarValue]| unimplemented!("my_fn is not
implemented"));
+
+ // create and register the udf
+ ctx.register_udf(create_udf(
+ "my_fn",
+ vec![DataType::Float64],
+ Arc::new(DataType::Float64),
+ Volatility::Immutable,
+ my_fn,
+ ));
+
+ // build query with a UDF using DataFrame API
+ let df = ctx.table("aggregate_test_100")?;
+
+ let f = df.registry();
+
+ let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?;
+ let plan = df.to_logical_plan();
+
+ // build query using SQL
+ let sql_plan =
+ ctx.create_logical_plan("SELECT my_fn(c12) FROM
aggregate_test_100")?;
+
+ // the two plans should be identical
+ assert_same_plan(&plan, &sql_plan);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn sendable() {
+ let df = test_table().await.unwrap();
+ // dataframes should be sendable between threads/tasks
+ let task = tokio::task::spawn(async move {
+ df.select_columns(&["c1"])
+ .expect("should be usable in a task")
+ });
+ task.await.expect("task completed successfully");
+ }
+
+ #[tokio::test]
+ async fn intersect() -> Result<()> {
+ let df = test_table().await?.select_columns(&["c1", "c3"])?;
+ let plan = df.intersect(df.clone())?;
+ let result = plan.to_logical_plan();
+ let expected = create_plan(
+ "SELECT c1, c3 FROM aggregate_test_100
+ INTERSECT ALL SELECT c1, c3 FROM aggregate_test_100",
+ )
+ .await?;
+ assert_same_plan(&result, &expected);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn except() -> Result<()> {
+ let df = test_table().await?.select_columns(&["c1", "c3"])?;
+ let plan = df.except(df.clone())?;
+ let result = plan.to_logical_plan();
+ let expected = create_plan(
+ "SELECT c1, c3 FROM aggregate_test_100
+ EXCEPT ALL SELECT c1, c3 FROM aggregate_test_100",
+ )
+ .await?;
+ assert_same_plan(&result, &expected);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn register_table() -> Result<()> {
+ let df = test_table().await?.select_columns(&["c1", "c12"])?;
+ let mut ctx = ExecutionContext::new();
+ let df_impl = Arc::new(DataFrame::new(ctx.state.clone(),
&df.to_logical_plan()));
+
+ // register a dataframe as a table
+ ctx.register_table("test_table", df_impl.clone())?;
+
+ // pull the table out
+ let table = ctx.table("test_table")?;
+
+ let group_expr = vec![col("c1")];
+ let aggr_expr = vec![sum(col("c12"))];
+
+ // check that we correctly read from the table
+ let df_results = &df_impl
+ .aggregate(group_expr.clone(), aggr_expr.clone())?
+ .collect()
+ .await?;
+ let table_results = &table.aggregate(group_expr,
aggr_expr)?.collect().await?;
+
+ assert_batches_sorted_eq!(
+ vec![
+ "+----+-----------------------------+",
+ "| c1 | SUM(aggregate_test_100.c12) |",
+ "+----+-----------------------------+",
+ "| a | 10.238448667882977 |",
+ "| b | 7.797734760124923 |",
+ "| c | 13.860958726523545 |",
+ "| d | 8.793968289758968 |",
+ "| e | 10.206140546981722 |",
+ "+----+-----------------------------+",
+ ],
+ df_results
+ );
+
+ // the results are the same as the results from the view, modulo the
leaf table name
+ assert_batches_sorted_eq!(
+ vec![
+ "+----+---------------------+",
+ "| c1 | SUM(test_table.c12) |",
+ "+----+---------------------+",
+ "| a | 10.238448667882977 |",
+ "| b | 7.797734760124923 |",
+ "| c | 13.860958726523545 |",
+ "| d | 8.793968289758968 |",
+ "| e | 10.206140546981722 |",
+ "+----+---------------------+",
+ ],
+ table_results
+ );
+ Ok(())
+ }
+ /// Compare the formatted string representation of two plans for equality
+ fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) {
+ assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2));
+ }
+
+ /// Create a logical plan from a SQL query
+ async fn create_plan(sql: &str) -> Result<LogicalPlan> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx, "aggregate_test_100").await?;
+ ctx.create_logical_plan(sql)
+ }
+
+ async fn test_table_with_name(name: &str) -> Result<Arc<DataFrame>> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx, name).await?;
+ ctx.table(name)
+ }
+
+ async fn test_table() -> Result<Arc<DataFrame>> {
+ test_table_with_name("aggregate_test_100").await
+ }
+
+ async fn register_aggregate_csv(
+ ctx: &mut ExecutionContext,
+ table_name: &str,
+ ) -> Result<()> {
+ let schema = test_util::aggr_test_schema();
+ let testdata = crate::test_util::arrow_test_data();
+ ctx.register_csv(
+ table_name,
+ &format!("{}/csv/aggregate_test_100.csv", testdata),
+ CsvReadOptions::new().schema(schema.as_ref()),
+ )
+ .await?;
+ Ok(())
+ }
}
diff --git a/datafusion/src/execution/context.rs
b/datafusion/src/execution/context.rs
index 5ea9b0b..49644c1 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -53,11 +53,11 @@ use crate::catalog::{
schema::{MemorySchemaProvider, SchemaProvider},
ResolvedTableReference, TableReference,
};
+use crate::dataframe::DataFrame;
use crate::datasource::listing::ListingTableConfig;
use crate::datasource::object_store::{ObjectStore, ObjectStoreRegistry};
use crate::datasource::TableProvider;
use crate::error::{DataFusionError, Result};
-use crate::execution::dataframe_impl::DataFrameImpl;
use crate::logical_plan::{
CreateExternalTable, CreateMemoryTable, DropTable, FunctionRegistry,
LogicalPlan,
LogicalPlanBuilder, UNNAMED_TABLE,
@@ -79,6 +79,7 @@ use crate::execution::runtime_env::{RuntimeConfig,
RuntimeEnv};
use crate::logical_plan::plan::Explain;
use crate::physical_plan::file_format::{plan_to_csv, plan_to_parquet};
use crate::physical_plan::planner::DefaultPhysicalPlanner;
+use crate::physical_plan::udaf::AggregateUDF;
use crate::physical_plan::udf::ScalarUDF;
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::PhysicalPlanner;
@@ -87,7 +88,6 @@ use crate::sql::{
planner::{ContextProvider, SqlToRel},
};
use crate::variable::{VarProvider, VarType};
-use crate::{dataframe::DataFrame, physical_plan::udaf::AggregateUDF};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use parquet::file::properties::WriterProperties;
@@ -206,7 +206,7 @@ impl ExecutionContext {
///
/// This method is `async` because queries of type `CREATE EXTERNAL TABLE`
/// might require the schema to be inferred.
- pub async fn sql(&mut self, sql: &str) -> Result<Arc<dyn DataFrame>> {
+ pub async fn sql(&mut self, sql: &str) -> Result<Arc<DataFrame>> {
let plan = self.create_logical_plan(sql)?;
match plan {
LogicalPlan::CreateExternalTable(CreateExternalTable {
@@ -254,12 +254,12 @@ impl ExecutionContext {
self.register_listing_table(name, location, options,
provided_schema)
.await?;
let plan = LogicalPlanBuilder::empty(false).build()?;
- Ok(Arc::new(DataFrameImpl::new(self.state.clone(), &plan)))
+ Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
}
LogicalPlan::CreateMemoryTable(CreateMemoryTable { name, input })
=> {
let plan = self.optimize(&input)?;
- let physical = Arc::new(DataFrameImpl::new(self.state.clone(),
&plan));
+ let physical = Arc::new(DataFrame::new(self.state.clone(),
&plan));
let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(MemTable::try_new(
@@ -269,7 +269,7 @@ impl ExecutionContext {
self.register_table(name.as_str(), table)?;
let plan = LogicalPlanBuilder::empty(false).build()?;
- Ok(Arc::new(DataFrameImpl::new(self.state.clone(), &plan)))
+ Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
}
LogicalPlan::DropTable(DropTable {
@@ -283,11 +283,11 @@ impl ExecutionContext {
)))
} else {
let plan = LogicalPlanBuilder::empty(false).build()?;
- Ok(Arc::new(DataFrameImpl::new(self.state.clone(), &plan)))
+ Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
}
}
- plan => Ok(Arc::new(DataFrameImpl::new(
+ plan => Ok(Arc::new(DataFrame::new(
self.state.clone(),
&self.optimize(&plan)?,
))),
@@ -358,11 +358,11 @@ impl ExecutionContext {
&mut self,
uri: impl Into<String>,
options: AvroReadOptions<'_>,
- ) -> Result<Arc<dyn DataFrame>> {
+ ) -> Result<Arc<DataFrame>> {
let uri: String = uri.into();
let (object_store, path) = self.object_store(&uri)?;
let target_partitions = self.state.lock().config.target_partitions;
- Ok(Arc::new(DataFrameImpl::new(
+ Ok(Arc::new(DataFrame::new(
self.state.clone(),
&LogicalPlanBuilder::scan_avro(
object_store,
@@ -377,8 +377,8 @@ impl ExecutionContext {
}
/// Creates an empty DataFrame.
- pub fn read_empty(&self) -> Result<Arc<dyn DataFrame>> {
- Ok(Arc::new(DataFrameImpl::new(
+ pub fn read_empty(&self) -> Result<Arc<DataFrame>> {
+ Ok(Arc::new(DataFrame::new(
self.state.clone(),
&LogicalPlanBuilder::empty(true).build()?,
)))
@@ -389,11 +389,11 @@ impl ExecutionContext {
&mut self,
uri: impl Into<String>,
options: CsvReadOptions<'_>,
- ) -> Result<Arc<dyn DataFrame>> {
+ ) -> Result<Arc<DataFrame>> {
let uri: String = uri.into();
let (object_store, path) = self.object_store(&uri)?;
let target_partitions = self.state.lock().config.target_partitions;
- Ok(Arc::new(DataFrameImpl::new(
+ Ok(Arc::new(DataFrame::new(
self.state.clone(),
&LogicalPlanBuilder::scan_csv(
object_store,
@@ -411,7 +411,7 @@ impl ExecutionContext {
pub async fn read_parquet(
&mut self,
uri: impl Into<String>,
- ) -> Result<Arc<dyn DataFrame>> {
+ ) -> Result<Arc<DataFrame>> {
let uri: String = uri.into();
let (object_store, path) = self.object_store(&uri)?;
let target_partitions = self.state.lock().config.target_partitions;
@@ -419,18 +419,15 @@ impl ExecutionContext {
LogicalPlanBuilder::scan_parquet(object_store, path, None,
target_partitions)
.await?
.build()?;
- Ok(Arc::new(DataFrameImpl::new(
- self.state.clone(),
- &logical_plan,
- )))
+ Ok(Arc::new(DataFrame::new(self.state.clone(), &logical_plan)))
}
/// Creates a DataFrame for reading a custom TableProvider.
pub fn read_table(
&mut self,
provider: Arc<dyn TableProvider>,
- ) -> Result<Arc<dyn DataFrame>> {
- Ok(Arc::new(DataFrameImpl::new(
+ ) -> Result<Arc<DataFrame>> {
+ Ok(Arc::new(DataFrame::new(
self.state.clone(),
&LogicalPlanBuilder::scan(UNNAMED_TABLE, provider, None)?.build()?,
)))
@@ -622,7 +619,7 @@ impl ExecutionContext {
pub fn table<'a>(
&self,
table_ref: impl Into<TableReference<'a>>,
- ) -> Result<Arc<dyn DataFrame>> {
+ ) -> Result<Arc<DataFrame>> {
let table_ref = table_ref.into();
let schema = self.state.lock().schema_for_ref(table_ref)?;
match schema.table(table_ref.table()) {
@@ -633,7 +630,7 @@ impl ExecutionContext {
None,
)?
.build()?;
- Ok(Arc::new(DataFrameImpl::new(self.state.clone(), &plan)))
+ Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
}
_ => Err(DataFusionError::Plan(format!(
"No table named '{}'",
@@ -2626,7 +2623,7 @@ mod tests {
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();
- // Note capitalizaton
+ // Note capitalization
let my_avg = create_udaf(
"MY_AVG",
DataType::Float64,
@@ -3188,28 +3185,28 @@ mod tests {
// See https://github.com/apache/arrow-datafusion/issues/1154
#[async_trait]
trait CallReadTrait {
- async fn call_read_csv(&self) -> Arc<dyn DataFrame>;
- async fn call_read_avro(&self) -> Arc<dyn DataFrame>;
- async fn call_read_parquet(&self) -> Arc<dyn DataFrame>;
+ async fn call_read_csv(&self) -> Arc<DataFrame>;
+ async fn call_read_avro(&self) -> Arc<DataFrame>;
+ async fn call_read_parquet(&self) -> Arc<DataFrame>;
}
struct CallRead {}
#[async_trait]
impl CallReadTrait for CallRead {
- async fn call_read_csv(&self) -> Arc<dyn DataFrame> {
+ async fn call_read_csv(&self) -> Arc<DataFrame> {
let mut ctx = ExecutionContext::new();
ctx.read_csv("dummy", CsvReadOptions::new()).await.unwrap()
}
- async fn call_read_avro(&self) -> Arc<dyn DataFrame> {
+ async fn call_read_avro(&self) -> Arc<DataFrame> {
let mut ctx = ExecutionContext::new();
ctx.read_avro("dummy", AvroReadOptions::default())
.await
.unwrap()
}
- async fn call_read_parquet(&self) -> Arc<dyn DataFrame> {
+ async fn call_read_parquet(&self) -> Arc<DataFrame> {
let mut ctx = ExecutionContext::new();
ctx.read_parquet("dummy").await.unwrap()
}
diff --git a/datafusion/src/execution/dataframe_impl.rs
b/datafusion/src/execution/dataframe_impl.rs
deleted file mode 100644
index 2af1cd4..0000000
--- a/datafusion/src/execution/dataframe_impl.rs
+++ /dev/null
@@ -1,662 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Implementation of DataFrame API.
-
-use parking_lot::Mutex;
-use std::any::Any;
-use std::sync::Arc;
-
-use crate::arrow::datatypes::Schema;
-use crate::arrow::datatypes::SchemaRef;
-use crate::arrow::record_batch::RecordBatch;
-use crate::error::Result;
-use crate::execution::context::{ExecutionContext, ExecutionContextState};
-use crate::logical_plan::{
- col, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan,
LogicalPlanBuilder,
- Partitioning,
-};
-use crate::scalar::ScalarValue;
-use crate::{
- dataframe::*,
- physical_plan::{collect, collect_partitioned},
-};
-use parquet::file::properties::WriterProperties;
-
-use crate::arrow::util::pretty;
-use crate::datasource::TableProvider;
-use crate::datasource::TableType;
-use crate::physical_plan::file_format::{plan_to_csv, plan_to_parquet};
-use crate::physical_plan::{
- execute_stream, execute_stream_partitioned, ExecutionPlan,
SendableRecordBatchStream,
-};
-use crate::sql::utils::find_window_exprs;
-use async_trait::async_trait;
-
-/// Implementation of DataFrame API
-pub struct DataFrameImpl {
- ctx_state: Arc<Mutex<ExecutionContextState>>,
- plan: LogicalPlan,
-}
-
-impl DataFrameImpl {
- /// Create a new Table based on an existing logical plan
- pub fn new(ctx_state: Arc<Mutex<ExecutionContextState>>, plan:
&LogicalPlan) -> Self {
- Self {
- ctx_state,
- plan: plan.clone(),
- }
- }
-
- /// Create a physical plan
- async fn create_physical_plan(&self) -> Result<Arc<dyn ExecutionPlan>> {
- let state = self.ctx_state.lock().clone();
- let ctx = ExecutionContext::from(Arc::new(Mutex::new(state)));
- let plan = ctx.optimize(&self.plan)?;
- ctx.create_physical_plan(&plan).await
- }
-}
-
-#[async_trait]
-impl TableProvider for DataFrameImpl {
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn schema(&self) -> SchemaRef {
- let schema: Schema = self.plan.schema().as_ref().into();
- Arc::new(schema)
- }
-
- fn table_type(&self) -> TableType {
- TableType::View
- }
-
- async fn scan(
- &self,
- projection: &Option<Vec<usize>>,
- filters: &[Expr],
- limit: Option<usize>,
- ) -> Result<Arc<dyn ExecutionPlan>> {
- let expr = projection
- .as_ref()
- // construct projections
- .map_or_else(
- || Ok(Arc::new(Self::new(self.ctx_state.clone(), &self.plan))
as Arc<_>),
- |projection| {
- let schema =
TableProvider::schema(self).project(projection)?;
- let names = schema
- .fields()
- .iter()
- .map(|field| field.name().as_str())
- .collect::<Vec<_>>();
- self.select_columns(names.as_slice())
- },
- )?
- // add predicates, otherwise use `true` as the predicate
- .filter(filters.iter().cloned().fold(
- Expr::Literal(ScalarValue::Boolean(Some(true))),
- |acc, new| acc.and(new),
- ))?;
- // add a limit if given
- Self::new(
- self.ctx_state.clone(),
- &limit
- .map_or_else(|| Ok(expr.clone()), |n| expr.limit(n))?
- .to_logical_plan(),
- )
- .create_physical_plan()
- .await
- }
-}
-
-#[async_trait]
-impl DataFrame for DataFrameImpl {
- /// Apply a projection based on a list of column names
- fn select_columns(&self, columns: &[&str]) -> Result<Arc<dyn DataFrame>> {
- let fields = columns
- .iter()
- .map(|name| self.plan.schema().field_with_unqualified_name(name))
- .collect::<Result<Vec<_>>>()?;
- let expr: Vec<Expr> = fields.iter().map(|f| col(f.name())).collect();
- self.select(expr)
- }
-
- /// Create a projection based on arbitrary expressions
- fn select(&self, expr_list: Vec<Expr>) -> Result<Arc<dyn DataFrame>> {
- let window_func_exprs = find_window_exprs(&expr_list);
- let plan = if window_func_exprs.is_empty() {
- self.to_logical_plan()
- } else {
- LogicalPlanBuilder::window_plan(self.to_logical_plan(),
window_func_exprs)?
- };
- let project_plan =
LogicalPlanBuilder::from(plan).project(expr_list)?.build()?;
- Ok(Arc::new(DataFrameImpl::new(
- self.ctx_state.clone(),
- &project_plan,
- )))
- }
-
- /// Create a filter based on a predicate expression
- fn filter(&self, predicate: Expr) -> Result<Arc<dyn DataFrame>> {
- let plan = LogicalPlanBuilder::from(self.to_logical_plan())
- .filter(predicate)?
- .build()?;
- Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
- }
-
- /// Perform an aggregate query
- fn aggregate(
- &self,
- group_expr: Vec<Expr>,
- aggr_expr: Vec<Expr>,
- ) -> Result<Arc<dyn DataFrame>> {
- let plan = LogicalPlanBuilder::from(self.to_logical_plan())
- .aggregate(group_expr, aggr_expr)?
- .build()?;
- Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
- }
-
- /// Limit the number of rows
- fn limit(&self, n: usize) -> Result<Arc<dyn DataFrame>> {
- let plan = LogicalPlanBuilder::from(self.to_logical_plan())
- .limit(n)?
- .build()?;
- Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
- }
-
- /// Sort by specified sorting expressions
- fn sort(&self, expr: Vec<Expr>) -> Result<Arc<dyn DataFrame>> {
- let plan = LogicalPlanBuilder::from(self.to_logical_plan())
- .sort(expr)?
- .build()?;
- Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
- }
-
- /// Join with another DataFrame
- fn join(
- &self,
- right: Arc<dyn DataFrame>,
- join_type: JoinType,
- left_cols: &[&str],
- right_cols: &[&str],
- ) -> Result<Arc<dyn DataFrame>> {
- let plan = LogicalPlanBuilder::from(self.to_logical_plan())
- .join(
- &right.to_logical_plan(),
- join_type,
- (left_cols.to_vec(), right_cols.to_vec()),
- )?
- .build()?;
- Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
- }
-
- fn repartition(
- &self,
- partitioning_scheme: Partitioning,
- ) -> Result<Arc<dyn DataFrame>> {
- let plan = LogicalPlanBuilder::from(self.to_logical_plan())
- .repartition(partitioning_scheme)?
- .build()?;
- Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
- }
-
- /// Convert to logical plan
- fn to_logical_plan(&self) -> LogicalPlan {
- self.plan.clone()
- }
-
- /// Convert the logical plan represented by this DataFrame into a physical
plan and
- /// execute it, collecting all resulting batches into memory
- async fn collect(&self) -> Result<Vec<RecordBatch>> {
- let plan = self.create_physical_plan().await?;
- let runtime = self.ctx_state.lock().runtime_env.clone();
- Ok(collect(plan, runtime).await?)
- }
-
- /// Print results.
- async fn show(&self) -> Result<()> {
- let results = self.collect().await?;
- Ok(pretty::print_batches(&results)?)
- }
-
- /// Print results and limit rows.
- async fn show_limit(&self, num: usize) -> Result<()> {
- let results = self.limit(num)?.collect().await?;
- Ok(pretty::print_batches(&results)?)
- }
-
- /// Convert the logical plan represented by this DataFrame into a physical
plan and
- /// execute it, returning a stream over a single partition
- async fn execute_stream(&self) -> Result<SendableRecordBatchStream> {
- let plan = self.create_physical_plan().await?;
- let runtime = self.ctx_state.lock().runtime_env.clone();
- execute_stream(plan, runtime).await
- }
-
- /// Convert the logical plan represented by this DataFrame into a physical
plan and
- /// execute it, collecting all resulting batches into memory while
maintaining
- /// partitioning
- async fn collect_partitioned(&self) -> Result<Vec<Vec<RecordBatch>>> {
- let plan = self.create_physical_plan().await?;
- let runtime = self.ctx_state.lock().runtime_env.clone();
- Ok(collect_partitioned(plan, runtime).await?)
- }
-
- /// Convert the logical plan represented by this DataFrame into a physical
plan and
- /// execute it, returning a stream for each partition
- async fn execute_stream_partitioned(&self) ->
Result<Vec<SendableRecordBatchStream>> {
- let plan = self.create_physical_plan().await?;
- let runtime = self.ctx_state.lock().runtime_env.clone();
- Ok(execute_stream_partitioned(plan, runtime).await?)
- }
-
- /// Returns the schema from the logical plan
- fn schema(&self) -> &DFSchema {
- self.plan.schema()
- }
-
- fn explain(&self, verbose: bool, analyze: bool) -> Result<Arc<dyn
DataFrame>> {
- let plan = LogicalPlanBuilder::from(self.to_logical_plan())
- .explain(verbose, analyze)?
- .build()?;
- Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
- }
-
- fn registry(&self) -> Arc<dyn FunctionRegistry> {
- let registry = self.ctx_state.lock().clone();
- Arc::new(registry)
- }
-
- fn union(&self, dataframe: Arc<dyn DataFrame>) -> Result<Arc<dyn
DataFrame>> {
- let plan = LogicalPlanBuilder::from(self.to_logical_plan())
- .union(dataframe.to_logical_plan())?
- .build()?;
- Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
- }
-
- fn distinct(&self) -> Result<Arc<dyn DataFrame>> {
- Ok(Arc::new(DataFrameImpl::new(
- self.ctx_state.clone(),
- &LogicalPlanBuilder::from(self.to_logical_plan())
- .distinct()?
- .build()?,
- )))
- }
-
- fn intersect(&self, dataframe: Arc<dyn DataFrame>) -> Result<Arc<dyn
DataFrame>> {
- let left_plan = self.to_logical_plan();
- let right_plan = dataframe.to_logical_plan();
- Ok(Arc::new(DataFrameImpl::new(
- self.ctx_state.clone(),
- &LogicalPlanBuilder::intersect(left_plan, right_plan, true)?,
- )))
- }
-
- fn except(&self, dataframe: Arc<dyn DataFrame>) -> Result<Arc<dyn
DataFrame>> {
- let left_plan = self.to_logical_plan();
- let right_plan = dataframe.to_logical_plan();
- Ok(Arc::new(DataFrameImpl::new(
- self.ctx_state.clone(),
- &LogicalPlanBuilder::except(left_plan, right_plan, true)?,
- )))
- }
-
- async fn write_csv(&self, path: &str) -> Result<()> {
- let plan = self.create_physical_plan().await?;
- let state = self.ctx_state.lock().clone();
- let ctx = ExecutionContext::from(Arc::new(Mutex::new(state)));
- plan_to_csv(&ctx, plan, path).await
- }
-
- async fn write_parquet(
- &self,
- path: &str,
- writer_properties: Option<WriterProperties>,
- ) -> Result<()> {
- let plan = self.create_physical_plan().await?;
- let state = self.ctx_state.lock().clone();
- let ctx = ExecutionContext::from(Arc::new(Mutex::new(state)));
- plan_to_parquet(&ctx, plan, path, writer_properties).await
- }
-}
-
-#[cfg(test)]
-mod tests {
- use std::vec;
-
- use super::*;
- use crate::execution::options::CsvReadOptions;
- use crate::physical_plan::{window_functions, ColumnarValue};
- use crate::{assert_batches_sorted_eq,
execution::context::ExecutionContext};
- use crate::{logical_plan::*, test_util};
- use arrow::datatypes::DataType;
- use datafusion_expr::ScalarFunctionImplementation;
- use datafusion_expr::Volatility;
-
- #[tokio::test]
- async fn select_columns() -> Result<()> {
- // build plan using Table API
- let t = test_table().await?;
- let t2 = t.select_columns(&["c1", "c2", "c11"])?;
- let plan = t2.to_logical_plan();
-
- // build query using SQL
- let sql_plan = create_plan("SELECT c1, c2, c11 FROM
aggregate_test_100").await?;
-
- // the two plans should be identical
- assert_same_plan(&plan, &sql_plan);
-
- Ok(())
- }
-
- #[tokio::test]
- async fn select_expr() -> Result<()> {
- // build plan using Table API
- let t = test_table().await?;
- let t2 = t.select(vec![col("c1"), col("c2"), col("c11")])?;
- let plan = t2.to_logical_plan();
-
- // build query using SQL
- let sql_plan = create_plan("SELECT c1, c2, c11 FROM
aggregate_test_100").await?;
-
- // the two plans should be identical
- assert_same_plan(&plan, &sql_plan);
-
- Ok(())
- }
-
- #[tokio::test]
- async fn select_with_window_exprs() -> Result<()> {
- // build plan using Table API
- let t = test_table().await?;
- let first_row = Expr::WindowFunction {
- fun: window_functions::WindowFunction::BuiltInWindowFunction(
- window_functions::BuiltInWindowFunction::FirstValue,
- ),
- args: vec![col("aggregate_test_100.c1")],
- partition_by: vec![col("aggregate_test_100.c2")],
- order_by: vec![],
- window_frame: None,
- };
- let t2 = t.select(vec![col("c1"), first_row])?;
- let plan = t2.to_logical_plan();
-
- let sql_plan = create_plan(
- "select c1, first_value(c1) over (partition by c2) from
aggregate_test_100",
- )
- .await?;
-
- assert_same_plan(&plan, &sql_plan);
- Ok(())
- }
-
- #[tokio::test]
- async fn aggregate() -> Result<()> {
- // build plan using DataFrame API
- let df = test_table().await?;
- let group_expr = vec![col("c1")];
- let aggr_expr = vec![
- min(col("c12")),
- max(col("c12")),
- avg(col("c12")),
- sum(col("c12")),
- count(col("c12")),
- count_distinct(col("c12")),
- ];
-
- let df: Vec<RecordBatch> = df.aggregate(group_expr,
aggr_expr)?.collect().await?;
-
- assert_batches_sorted_eq!(
- vec![
-
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
- "| c1 | MIN(aggregate_test_100.c12) |
MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) |
SUM(aggregate_test_100.c12) | COUNT(aggregate_test_100.c12) | COUNT(DISTINCT
aggregate_test_100.c12) |",
-
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
- "| a | 0.02182578039211991 | 0.9800193410444061
| 0.48754517466109415 | 10.238448667882977 | 21
| 21 |",
- "| b | 0.04893135681998029 | 0.9185813970744787
| 0.41040709263815384 | 7.797734760124923 | 19
| 19 |",
- "| c | 0.0494924465469434 | 0.991517828651004
| 0.6600456536439784 | 13.860958726523545 | 21
| 21 |",
- "| d | 0.061029375346466685 | 0.9748360509016578
| 0.48855379387549824 | 8.793968289758968 | 18
| 18 |",
- "| e | 0.01479305307777301 | 0.9965400387585364
| 0.48600669271341534 | 10.206140546981722 | 21
| 21 |",
-
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
- ],
- &df
- );
-
- Ok(())
- }
-
- #[tokio::test]
- async fn join() -> Result<()> {
- let left = test_table().await?.select_columns(&["c1", "c2"])?;
- let right = test_table_with_name("c2")
- .await?
- .select_columns(&["c1", "c3"])?;
- let left_rows = left.collect().await?;
- let right_rows = right.collect().await?;
- let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?;
- let join_rows = join.collect().await?;
- assert_eq!(100, left_rows.iter().map(|x| x.num_rows()).sum::<usize>());
- assert_eq!(100, right_rows.iter().map(|x|
x.num_rows()).sum::<usize>());
- assert_eq!(2008, join_rows.iter().map(|x|
x.num_rows()).sum::<usize>());
- Ok(())
- }
-
- #[tokio::test]
- async fn limit() -> Result<()> {
- // build query using Table API
- let t = test_table().await?;
- let t2 = t.select_columns(&["c1", "c2", "c11"])?.limit(10)?;
- let plan = t2.to_logical_plan();
-
- // build query using SQL
- let sql_plan =
- create_plan("SELECT c1, c2, c11 FROM aggregate_test_100 LIMIT
10").await?;
-
- // the two plans should be identical
- assert_same_plan(&plan, &sql_plan);
-
- Ok(())
- }
-
- #[tokio::test]
- async fn explain() -> Result<()> {
- // build query using Table API
- let df = test_table().await?;
- let df = df
- .select_columns(&["c1", "c2", "c11"])?
- .limit(10)?
- .explain(false, false)?;
- let plan = df.to_logical_plan();
-
- // build query using SQL
- let sql_plan =
- create_plan("EXPLAIN SELECT c1, c2, c11 FROM aggregate_test_100
LIMIT 10")
- .await?;
-
- // the two plans should be identical
- assert_same_plan(&plan, &sql_plan);
-
- Ok(())
- }
-
- #[tokio::test]
- async fn registry() -> Result<()> {
- let mut ctx = ExecutionContext::new();
- register_aggregate_csv(&mut ctx, "aggregate_test_100").await?;
-
- // declare the udf
- let my_fn: ScalarFunctionImplementation =
- Arc::new(|_: &[ColumnarValue]| unimplemented!("my_fn is not
implemented"));
-
- // create and register the udf
- ctx.register_udf(create_udf(
- "my_fn",
- vec![DataType::Float64],
- Arc::new(DataType::Float64),
- Volatility::Immutable,
- my_fn,
- ));
-
- // build query with a UDF using DataFrame API
- let df = ctx.table("aggregate_test_100")?;
-
- let f = df.registry();
-
- let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?;
- let plan = df.to_logical_plan();
-
- // build query using SQL
- let sql_plan =
- ctx.create_logical_plan("SELECT my_fn(c12) FROM
aggregate_test_100")?;
-
- // the two plans should be identical
- assert_same_plan(&plan, &sql_plan);
-
- Ok(())
- }
-
- #[tokio::test]
- async fn sendable() {
- let df = test_table().await.unwrap();
- // dataframes should be sendable between threads/tasks
- let task = tokio::task::spawn(async move {
- df.select_columns(&["c1"])
- .expect("should be usable in a task")
- });
- task.await.expect("task completed successfully");
- }
-
- #[tokio::test]
- async fn intersect() -> Result<()> {
- let df = test_table().await?.select_columns(&["c1", "c3"])?;
- let plan = df.intersect(df.clone())?;
- let result = plan.to_logical_plan();
- let expected = create_plan(
- "SELECT c1, c3 FROM aggregate_test_100
- INTERSECT ALL SELECT c1, c3 FROM aggregate_test_100",
- )
- .await?;
- assert_same_plan(&result, &expected);
- Ok(())
- }
-
- #[tokio::test]
- async fn except() -> Result<()> {
- let df = test_table().await?.select_columns(&["c1", "c3"])?;
- let plan = df.except(df.clone())?;
- let result = plan.to_logical_plan();
- let expected = create_plan(
- "SELECT c1, c3 FROM aggregate_test_100
- EXCEPT ALL SELECT c1, c3 FROM aggregate_test_100",
- )
- .await?;
- assert_same_plan(&result, &expected);
- Ok(())
- }
-
- #[tokio::test]
- async fn register_table() -> Result<()> {
- let df = test_table().await?.select_columns(&["c1", "c12"])?;
- let mut ctx = ExecutionContext::new();
- let df_impl =
- Arc::new(DataFrameImpl::new(ctx.state.clone(),
&df.to_logical_plan()));
-
- // register a dataframe as a table
- ctx.register_table("test_table", df_impl.clone())?;
-
- // pull the table out
- let table = ctx.table("test_table")?;
-
- let group_expr = vec![col("c1")];
- let aggr_expr = vec![sum(col("c12"))];
-
- // check that we correctly read from the table
- let df_results = &df_impl
- .aggregate(group_expr.clone(), aggr_expr.clone())?
- .collect()
- .await?;
- let table_results = &table.aggregate(group_expr,
aggr_expr)?.collect().await?;
-
- assert_batches_sorted_eq!(
- vec![
- "+----+-----------------------------+",
- "| c1 | SUM(aggregate_test_100.c12) |",
- "+----+-----------------------------+",
- "| a | 10.238448667882977 |",
- "| b | 7.797734760124923 |",
- "| c | 13.860958726523545 |",
- "| d | 8.793968289758968 |",
- "| e | 10.206140546981722 |",
- "+----+-----------------------------+",
- ],
- df_results
- );
-
- // the results are the same as the results from the view, modulo the
leaf table name
- assert_batches_sorted_eq!(
- vec![
- "+----+---------------------+",
- "| c1 | SUM(test_table.c12) |",
- "+----+---------------------+",
- "| a | 10.238448667882977 |",
- "| b | 7.797734760124923 |",
- "| c | 13.860958726523545 |",
- "| d | 8.793968289758968 |",
- "| e | 10.206140546981722 |",
- "+----+---------------------+",
- ],
- table_results
- );
- Ok(())
- }
- /// Compare the formatted string representation of two plans for equality
- fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) {
- assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2));
- }
-
- /// Create a logical plan from a SQL query
- async fn create_plan(sql: &str) -> Result<LogicalPlan> {
- let mut ctx = ExecutionContext::new();
- register_aggregate_csv(&mut ctx, "aggregate_test_100").await?;
- ctx.create_logical_plan(sql)
- }
-
- async fn test_table_with_name(name: &str) -> Result<Arc<dyn DataFrame +
'static>> {
- let mut ctx = ExecutionContext::new();
- register_aggregate_csv(&mut ctx, name).await?;
- ctx.table(name)
- }
-
- async fn test_table() -> Result<Arc<dyn DataFrame + 'static>> {
- test_table_with_name("aggregate_test_100").await
- }
-
- async fn register_aggregate_csv(
- ctx: &mut ExecutionContext,
- table_name: &str,
- ) -> Result<()> {
- let schema = test_util::aggr_test_schema();
- let testdata = crate::test_util::arrow_test_data();
- ctx.register_csv(
- table_name,
- &format!("{}/csv/aggregate_test_100.csv", testdata),
- CsvReadOptions::new().schema(schema.as_ref()),
- )
- .await?;
- Ok(())
- }
-}
diff --git a/datafusion/src/execution/mod.rs b/datafusion/src/execution/mod.rs
index 427c539..54fd298 100644
--- a/datafusion/src/execution/mod.rs
+++ b/datafusion/src/execution/mod.rs
@@ -18,7 +18,6 @@
//! DataFusion query execution
pub mod context;
-pub mod dataframe_impl;
pub(crate) mod disk_manager;
pub mod memory_manager;
pub mod options;
diff --git a/datafusion/tests/dataframe_functions.rs
b/datafusion/tests/dataframe_functions.rs
index ae521a0..1f55af4 100644
--- a/datafusion/tests/dataframe_functions.rs
+++ b/datafusion/tests/dataframe_functions.rs
@@ -35,7 +35,7 @@ use datafusion::execution::context::ExecutionContext;
use datafusion::assert_batches_eq;
-fn create_test_table() -> Result<Arc<dyn DataFrame>> {
+fn create_test_table() -> Result<Arc<DataFrame>> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Field::new("b", DataType::Int32, false),