This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new cd503c3 ARROW-9516: [Rust][DataFusion] refactor of column names
cd503c3 is described below
commit cd503c3f583dab4b94c9934d525664e5897ff06d
Author: Jorge C. Leitao <[email protected]>
AuthorDate: Fri Jul 24 18:18:50 2020 -0600
ARROW-9516: [Rust][DataFusion] refactor of column names
This PR is respective to ARROW-9516.
It:
1. simplifies how we construct and handle columns: all columns are now
name-based, not index-based. This simplifies our code base significantly.
2. makes all column naming happen on the logical plan, not physical plan
3. gives more expressive column names to the end schema of a logical plan,
particularly to aggregated expressions (e.g. `SUM(a), SUM(b)` instead of `SUM,
SUM`.
This is currently a proof of value: all tests pass and stuff, but there is
no decision made on whether we should proceed with these changes. More details
available at [ARROW-9516](https://issues.apache.org/jira/browse/ARROW-9516).
Closes #7796 from jorgecarleitao/column_names
Authored-by: Jorge C. Leitao <[email protected]>
Signed-off-by: Andy Grove <[email protected]>
---
rust/datafusion/src/execution/context.rs | 164 +++++++++---
.../src/execution/physical_plan/expressions.rs | 187 +++++---------
.../src/execution/physical_plan/hash_aggregate.rs | 62 +++--
.../execution/physical_plan/math_expressions.rs | 10 +-
rust/datafusion/src/execution/physical_plan/mod.rs | 17 +-
.../src/execution/physical_plan/projection.rs | 25 +-
.../src/execution/physical_plan/selection.rs | 12 +-
.../datafusion/src/execution/physical_plan/sort.rs | 6 +-
rust/datafusion/src/execution/physical_plan/udf.rs | 7 -
rust/datafusion/src/execution/table_impl.rs | 16 +-
rust/datafusion/src/logicalplan.rs | 122 +++++++--
rust/datafusion/src/optimizer/mod.rs | 1 -
.../src/optimizer/projection_push_down.rs | 285 ++++++---------------
rust/datafusion/src/optimizer/resolve_columns.rs | 159 ------------
rust/datafusion/src/optimizer/type_coercion.rs | 22 +-
rust/datafusion/src/optimizer/utils.rs | 115 ++-------
rust/datafusion/src/sql/planner.rs | 133 +++++-----
rust/datafusion/tests/sql.rs | 2 +-
18 files changed, 544 insertions(+), 801 deletions(-)
diff --git a/rust/datafusion/src/execution/context.rs
b/rust/datafusion/src/execution/context.rs
index 2500163..cec0f15 100644
--- a/rust/datafusion/src/execution/context.rs
+++ b/rust/datafusion/src/execution/context.rs
@@ -36,8 +36,7 @@ use crate::execution::physical_plan::common;
use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions};
use crate::execution::physical_plan::datasource::DatasourceExec;
use crate::execution::physical_plan::expressions::{
- Alias, Avg, BinaryExpr, CastExpr, Column, Count, Literal, Max, Min,
PhysicalSortExpr,
- Sum,
+ Avg, BinaryExpr, CastExpr, Column, Count, Literal, Max, Min,
PhysicalSortExpr, Sum,
};
use crate::execution::physical_plan::hash_aggregate::HashAggregateExec;
use crate::execution::physical_plan::limit::LimitExec;
@@ -51,10 +50,11 @@ use crate::execution::physical_plan::sort::{SortExec,
SortOptions};
use crate::execution::physical_plan::udf::{ScalarFunction, ScalarFunctionExpr};
use crate::execution::physical_plan::{AggregateExpr, ExecutionPlan,
PhysicalExpr};
use crate::execution::table_impl::TableImpl;
-use crate::logicalplan::*;
+use crate::logicalplan::{
+ Expr, FunctionMeta, FunctionType, LogicalPlan, LogicalPlanBuilder,
+};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::projection_push_down::ProjectionPushDown;
-use crate::optimizer::resolve_columns::ResolveColumnsRule;
use crate::optimizer::type_coercion::TypeCoercionRule;
use crate::sql::parser::{DFASTNode, DFParser, FileType};
use crate::sql::planner::{SchemaProvider, SqlToRel};
@@ -67,6 +67,15 @@ pub struct ExecutionContext {
scalar_functions: HashMap<String, Box<ScalarFunction>>,
}
+fn tuple_err<T, R>(value: (Result<T>, Result<R>)) -> Result<(T, R)> {
+ match value {
+ (Ok(e), Ok(e1)) => Ok((e, e1)),
+ (Err(e), Ok(_)) => Err(e),
+ (Ok(_), Err(e1)) => Err(e1),
+ (Err(e), Err(_)) => Err(e),
+ }
+}
+
impl ExecutionContext {
/// Create a new execution context for in-memory queries
pub fn new() -> Self {
@@ -275,7 +284,6 @@ impl ExecutionContext {
/// Optimize the logical plan by applying optimizer rules
pub fn optimize(&self, plan: &LogicalPlan) -> Result<LogicalPlan> {
let rules: Vec<Box<dyn OptimizerRule>> = vec![
- Box::new(ResolveColumnsRule::new()),
Box::new(ProjectionPushDown::new()),
Box::new(TypeCoercionRule::new(&self.scalar_functions)),
];
@@ -356,7 +364,12 @@ impl ExecutionContext {
let input_schema = input.as_ref().schema().clone();
let runtime_expr = expr
.iter()
- .map(|e| self.create_physical_expr(e, &input_schema))
+ .map(|e| {
+ tuple_err((
+ self.create_physical_expr(e, &input_schema),
+ e.name(&input_schema),
+ ))
+ })
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(ProjectionExec::try_new(runtime_expr, input)?))
}
@@ -370,17 +383,30 @@ impl ExecutionContext {
let input = self.create_physical_plan(input, batch_size)?;
let input_schema = input.as_ref().schema().clone();
- let group_expr = group_expr
+ let groups = group_expr
.iter()
- .map(|e| self.create_physical_expr(e, &input_schema))
+ .map(|e| {
+ tuple_err((
+ self.create_physical_expr(e, &input_schema),
+ e.name(&input_schema),
+ ))
+ })
.collect::<Result<Vec<_>>>()?;
- let aggr_expr = aggr_expr
+ let aggregates = aggr_expr
.iter()
- .map(|e| self.create_aggregate_expr(e, &input_schema))
+ .map(|e| {
+ tuple_err((
+ self.create_aggregate_expr(e, &input_schema),
+ e.name(&input_schema),
+ ))
+ })
.collect::<Result<Vec<_>>>()?;
- let initial_aggr =
- HashAggregateExec::try_new(group_expr, aggr_expr, input)?;
+ let initial_aggr = HashAggregateExec::try_new(
+ groups.clone(),
+ aggregates.clone(),
+ input,
+ )?;
let schema = initial_aggr.schema();
let partitions = initial_aggr.partitions()?;
@@ -389,13 +415,27 @@ impl ExecutionContext {
return Ok(Arc::new(initial_aggr));
}
- let (final_group, final_aggr) = initial_aggr.make_final_expr();
-
let merge = Arc::new(MergeExec::new(schema.clone(),
partitions));
+ // construct the expressions for the final aggregation
+ let (final_group, final_aggr) = initial_aggr.make_final_expr(
+ groups.iter().map(|x| x.1.clone()).collect(),
+ aggregates.iter().map(|x| x.1.clone()).collect(),
+ );
+
+ // construct a second aggregation, keeping the final column
name equal to the first aggregation
+ // and the expressions corresponding to the respective
aggregate
Ok(Arc::new(HashAggregateExec::try_new(
- final_group,
- final_aggr,
+ final_group
+ .iter()
+ .enumerate()
+ .map(|(i, expr)| (expr.clone(), groups[i].1.clone()))
+ .collect(),
+ final_aggr
+ .iter()
+ .enumerate()
+ .map(|(i, expr)| (expr.clone(),
aggregates[i].1.clone()))
+ .collect(),
merge,
)?))
}
@@ -455,12 +495,11 @@ impl ExecutionContext {
input_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
match e {
- Expr::Alias(expr, name) => {
- let expr = self.create_physical_expr(expr, input_schema)?;
- Ok(Arc::new(Alias::new(expr, &name)))
- }
- Expr::Column(i) => {
- Ok(Arc::new(Column::new(*i, &input_schema.field(*i).name())))
+ Expr::Alias(expr, ..) => Ok(self.create_physical_expr(expr,
input_schema)?),
+ Expr::Column(name) => {
+ // check that name exists
+ input_schema.field_with_name(&name)?;
+ Ok(Arc::new(Column::new(name)))
}
Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))),
Expr::BinaryExpr { left, op, right } =>
Ok(Arc::new(BinaryExpr::new(
@@ -484,7 +523,6 @@ impl ExecutionContext {
physical_args.push(self.create_physical_expr(e,
input_schema)?);
}
Ok(Arc::new(ScalarFunctionExpr::new(
- name,
Box::new(f.fun.clone()),
physical_args,
return_type,
@@ -650,6 +688,7 @@ mod tests {
use super::*;
use crate::datasource::MemTable;
use crate::execution::physical_plan::udf::ScalarUdf;
+ use crate::logicalplan::{aggregate_expr, col, scalar_function};
use crate::test;
use arrow::array::{ArrayRef, Int32Array};
use arrow::compute::add;
@@ -666,10 +705,12 @@ mod tests {
// there should be one batch per partition
assert_eq!(results.len(), partition_count);
- // each batch should contain 2 columns and 10 rows
+ // each batch should contain 2 columns and 10 rows with correct field
names
for batch in &results {
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 10);
+
+ assert_eq!(field_names(batch), vec!["c1", "c2"]);
}
Ok(())
@@ -706,7 +747,7 @@ mod tests {
let table = ctx.table("test")?;
let logical_plan = LogicalPlanBuilder::from(&table.to_logical_plan())
- .project(vec![Expr::UnresolvedColumn("c2".to_string())])?
+ .project(vec![col("c2")])?
.build()?;
let optimized_plan = ctx.optimize(&logical_plan)?;
@@ -725,7 +766,7 @@ mod tests {
_ => assert!(false, "expect optimized_plan to be projection"),
}
- let expected = "Projection: #0\
+ let expected = "Projection: #c2\
\n TableScan: test projection=Some([1])";
assert_eq!(format!("{:?}", optimized_plan), expected);
@@ -747,19 +788,19 @@ mod tests {
let tmp_dir = TempDir::new("execute")?;
let ctx = create_ctx(&tmp_dir, 1)?;
- let schema = Arc::new(Schema::new(vec![Field::new(
- "state",
- DataType::Utf8,
- false,
- )]));
+ let schema = ctx.datasources.get("test").unwrap().schema();
+ assert_eq!(schema.field_with_name("c1")?.is_nullable(), false);
let plan = LogicalPlanBuilder::scan("default", "test",
schema.as_ref(), None)?
- .project(vec![col("state")])?
+ .project(vec![col("c1")])?
.build()?;
let plan = ctx.optimize(&plan)?;
let physical_plan = ctx.create_physical_plan(&Arc::new(plan), 1024)?;
- assert_eq!(physical_plan.schema().field(0).is_nullable(), false);
+ assert_eq!(
+ physical_plan.schema().field_with_name("c1")?.is_nullable(),
+ false
+ );
Ok(())
}
@@ -783,7 +824,7 @@ mod tests {
projection: None,
projected_schema: Box::new(schema.clone()),
})
- .project(vec![Expr::UnresolvedColumn("b".to_string())])?
+ .project(vec![col("b")])?
.build()?;
assert_fields_eq(&plan, vec!["b"]);
@@ -804,7 +845,7 @@ mod tests {
_ => assert!(false, "expect optimized_plan to be projection"),
}
- let expected = "Projection: #0\
+ let expected = "Projection: #b\
\n InMemoryScan: projection=Some([1])";
assert_eq!(format!("{:?}", optimized_plan), expected);
@@ -844,6 +885,9 @@ mod tests {
assert_eq!(results.len(), 1);
let batch = &results[0];
+
+ assert_eq!(field_names(batch), vec!["SUM(c1)", "SUM(c2)"]);
+
let expected: Vec<&str> = vec!["60,220"];
let mut rows = test::format_batch(&batch);
rows.sort();
@@ -858,6 +902,9 @@ mod tests {
assert_eq!(results.len(), 1);
let batch = &results[0];
+
+ assert_eq!(field_names(batch), vec!["AVG(c1)", "AVG(c2)"]);
+
let expected: Vec<&str> = vec!["1.5,5.5"];
let mut rows = test::format_batch(&batch);
rows.sort();
@@ -872,6 +919,9 @@ mod tests {
assert_eq!(results.len(), 1);
let batch = &results[0];
+
+ assert_eq!(field_names(batch), vec!["MAX(c1)", "MAX(c2)"]);
+
let expected: Vec<&str> = vec!["3,10"];
let mut rows = test::format_batch(&batch);
rows.sort();
@@ -886,6 +936,9 @@ mod tests {
assert_eq!(results.len(), 1);
let batch = &results[0];
+
+ assert_eq!(field_names(batch), vec!["MIN(c1)", "MIN(c2)"]);
+
let expected: Vec<&str> = vec!["0,1"];
let mut rows = test::format_batch(&batch);
rows.sort();
@@ -900,6 +953,9 @@ mod tests {
assert_eq!(results.len(), 1);
let batch = &results[0];
+
+ assert_eq!(field_names(batch), vec!["c1", "SUM(c2)"]);
+
let expected: Vec<&str> = vec!["0,55", "1,55", "2,55", "3,55"];
let mut rows = test::format_batch(&batch);
rows.sort();
@@ -914,6 +970,9 @@ mod tests {
assert_eq!(results.len(), 1);
let batch = &results[0];
+
+ assert_eq!(field_names(batch), vec!["c1", "AVG(c2)"]);
+
let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"];
let mut rows = test::format_batch(&batch);
rows.sort();
@@ -928,6 +987,9 @@ mod tests {
assert_eq!(results.len(), 1);
let batch = &results[0];
+
+ assert_eq!(field_names(batch), vec!["c1", "MAX(c2)"]);
+
let expected: Vec<&str> = vec!["0,10", "1,10", "2,10", "3,10"];
let mut rows = test::format_batch(&batch);
rows.sort();
@@ -942,6 +1004,9 @@ mod tests {
assert_eq!(results.len(), 1);
let batch = &results[0];
+
+ assert_eq!(field_names(batch), vec!["c1", "MIN(c2)"]);
+
let expected: Vec<&str> = vec!["0,1", "1,1", "2,1", "3,1"];
let mut rows = test::format_batch(&batch);
rows.sort();
@@ -956,6 +1021,9 @@ mod tests {
assert_eq!(results.len(), 1);
let batch = &results[0];
+
+ assert_eq!(field_names(batch), vec!["COUNT(c1)", "COUNT(c2)"]);
+
let expected: Vec<&str> = vec!["10,10"];
let mut rows = test::format_batch(&batch);
rows.sort();
@@ -969,6 +1037,9 @@ mod tests {
assert_eq!(results.len(), 1);
let batch = &results[0];
+
+ assert_eq!(field_names(batch), vec!["COUNT(c1)", "COUNT(c2)"]);
+
let expected: Vec<&str> = vec!["40,40"];
let mut rows = test::format_batch(&batch);
rows.sort();
@@ -982,6 +1053,9 @@ mod tests {
assert_eq!(results.len(), 1);
let batch = &results[0];
+
+ assert_eq!(field_names(batch), vec!["c1", "COUNT(c2)"]);
+
let expected = vec!["0,10", "1,10", "2,10", "3,10"];
let mut rows = test::format_batch(&batch);
rows.sort();
@@ -995,16 +1069,16 @@ mod tests {
let ctx = create_ctx(&tmp_dir, 1)?;
let schema = Arc::new(Schema::new(vec![
- Field::new("state", DataType::Utf8, false),
- Field::new("salary", DataType::UInt32, false),
+ Field::new("c1", DataType::Utf8, false),
+ Field::new("c2", DataType::UInt32, false),
]));
let plan = LogicalPlanBuilder::scan("default", "test",
schema.as_ref(), None)?
.aggregate(
- vec![col("state")],
- vec![aggregate_expr("SUM", col("salary"), DataType::UInt32)],
+ vec![col("c1")],
+ vec![aggregate_expr("SUM", col("c2"), DataType::UInt32)],
)?
- .project(vec![col("state"), col_index(1).alias("total_salary")])?
+ .project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])?
.build()?;
let plan = ctx.optimize(&plan)?;
@@ -1131,6 +1205,7 @@ mod tests {
let batch = &result[0];
assert_eq!(3, batch.num_columns());
assert_eq!(4, batch.num_rows());
+ assert_eq!(field_names(batch), vec!["a", "b", "my_add(a,b)"]);
let a = batch
.column(0)
@@ -1166,6 +1241,15 @@ mod tests {
ctx.collect(physical_plan.as_ref())
}
+ fn field_names(result: &RecordBatch) -> Vec<String> {
+ result
+ .schema()
+ .fields()
+ .iter()
+ .map(|x| x.name().clone())
+ .collect::<Vec<String>>()
+ }
+
/// Execute SQL and return results
fn execute(sql: &str, partition_count: usize) -> Result<Vec<RecordBatch>> {
let tmp_dir = TempDir::new("execute")?;
diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs
b/rust/datafusion/src/execution/physical_plan/expressions.rs
index e3e1ea8..194f356 100644
--- a/rust/datafusion/src/execution/physical_plan/expressions.rs
+++ b/rust/datafusion/src/execution/physical_plan/expressions.rs
@@ -47,81 +47,43 @@ use arrow::compute::kernels::sort::{SortColumn,
SortOptions};
use arrow::datatypes::{DataType, Schema, TimeUnit};
use arrow::record_batch::RecordBatch;
-/// Represents an aliased expression
-pub struct Alias {
- expr: Arc<dyn PhysicalExpr>,
- alias: String,
-}
-
-impl Alias {
- /// Create a new aliased expression
- pub fn new(expr: Arc<dyn PhysicalExpr>, alias: &str) -> Self {
- Self {
- expr: expr.clone(),
- alias: alias.to_owned(),
- }
- }
-}
-
-impl PhysicalExpr for Alias {
- fn name(&self) -> String {
- self.alias.clone()
- }
-
- fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
- self.expr.data_type(input_schema)
- }
-
- fn nullable(&self, input_schema: &Schema) -> Result<bool> {
- self.expr.nullable(input_schema)
- }
-
- fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- self.expr.evaluate(batch)
- }
-}
-
/// Represents the column at a given index in a RecordBatch
pub struct Column {
- index: usize,
name: String,
}
impl Column {
/// Create a new column expression
- pub fn new(index: usize, name: &str) -> Self {
+ pub fn new(name: &str) -> Self {
Self {
- index,
name: name.to_owned(),
}
}
}
impl PhysicalExpr for Column {
- /// Get the name to use in a schema to represent the result of this
expression
- fn name(&self) -> String {
- self.name.clone()
- }
-
/// Get the data type of this expression, given the schema of the input
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
- Ok(input_schema.field(self.index).data_type().clone())
+ Ok(input_schema
+ .field_with_name(&self.name)?
+ .data_type()
+ .clone())
}
/// Decide whehter this expression is nullable, given the schema of the
input
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
- Ok(input_schema.field(self.index).is_nullable())
+ Ok(input_schema.field_with_name(&self.name)?.is_nullable())
}
/// Evaluate the expression
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- Ok(batch.column(self.index).clone())
+ Ok(batch.column(batch.schema().index_of(&self.name)?).clone())
}
}
/// Create a column expression
-pub fn col(i: usize, schema: &Schema) -> Arc<dyn PhysicalExpr> {
- Arc::new(Column::new(i, &schema.field(i).name()))
+pub fn col(name: &str) -> Arc<dyn PhysicalExpr> {
+ Arc::new(Column::new(name))
}
/// SUM aggregate expression
@@ -137,10 +99,6 @@ impl Sum {
}
impl AggregateExpr for Sum {
- fn name(&self) -> String {
- "SUM".to_string()
- }
-
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
match self.expr.data_type(input_schema)? {
DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64 => {
@@ -166,8 +124,8 @@ impl AggregateExpr for Sum {
Rc::new(RefCell::new(SumAccumulator { sum: None }))
}
- fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
- Arc::new(Sum::new(Arc::new(Column::new(column_index, &self.name()))))
+ fn create_reducer(&self, column_name: &str) -> Arc<dyn AggregateExpr> {
+ Arc::new(Sum::new(Arc::new(Column::new(column_name))))
}
}
@@ -333,10 +291,6 @@ impl Avg {
}
impl AggregateExpr for Avg {
- fn name(&self) -> String {
- "AVG".to_string()
- }
-
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
match self.expr.data_type(input_schema)? {
DataType::Int8
@@ -367,8 +321,8 @@ impl AggregateExpr for Avg {
}))
}
- fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
- Arc::new(Avg::new(Arc::new(Column::new(column_index, &self.name()))))
+ fn create_reducer(&self, column_name: &str) -> Arc<dyn AggregateExpr> {
+ Arc::new(Avg::new(Arc::new(Column::new(column_name))))
}
}
@@ -451,10 +405,6 @@ impl Max {
}
impl AggregateExpr for Max {
- fn name(&self) -> String {
- "MAX".to_string()
- }
-
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
match self.expr.data_type(input_schema)? {
DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64 => {
@@ -480,8 +430,8 @@ impl AggregateExpr for Max {
Rc::new(RefCell::new(MaxAccumulator { max: None }))
}
- fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
- Arc::new(Max::new(Arc::new(Column::new(column_index, &self.name()))))
+ fn create_reducer(&self, column_name: &str) -> Arc<dyn AggregateExpr> {
+ Arc::new(Max::new(Arc::new(Column::new(column_name))))
}
}
@@ -650,10 +600,6 @@ impl Min {
}
impl AggregateExpr for Min {
- fn name(&self) -> String {
- "MIN".to_string()
- }
-
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
match self.expr.data_type(input_schema)? {
DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64 => {
@@ -679,8 +625,8 @@ impl AggregateExpr for Min {
Rc::new(RefCell::new(MinAccumulator { min: None }))
}
- fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
- Arc::new(Min::new(Arc::new(Column::new(column_index, &self.name()))))
+ fn create_reducer(&self, column_name: &str) -> Arc<dyn AggregateExpr> {
+ Arc::new(Min::new(Arc::new(Column::new(column_name))))
}
}
@@ -850,10 +796,6 @@ impl Count {
}
impl AggregateExpr for Count {
- fn name(&self) -> String {
- "COUNT".to_string()
- }
-
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(DataType::UInt64)
}
@@ -866,8 +808,8 @@ impl AggregateExpr for Count {
Rc::new(RefCell::new(CountAccumulator { count: 0 }))
}
- fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
- Arc::new(Sum::new(Arc::new(Column::new(column_index, &self.name()))))
+ fn create_reducer(&self, column_name: &str) -> Arc<dyn AggregateExpr> {
+ Arc::new(Sum::new(Arc::new(Column::new(column_name))))
}
}
@@ -1024,10 +966,6 @@ impl BinaryExpr {
}
impl PhysicalExpr for BinaryExpr {
- fn name(&self) -> String {
- format!("{:?}", self.op)
- }
-
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
self.left.data_type(input_schema)
}
@@ -1112,10 +1050,6 @@ impl NotExpr {
}
impl PhysicalExpr for NotExpr {
- fn name(&self) -> String {
- "NOT".to_string()
- }
-
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
return Ok(DataType::Boolean);
}
@@ -1193,10 +1127,6 @@ impl CastExpr {
}
impl PhysicalExpr for CastExpr {
- fn name(&self) -> String {
- "CAST".to_string()
- }
-
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(self.cast_type.clone())
}
@@ -1236,10 +1166,6 @@ macro_rules! build_literal_array {
}
impl PhysicalExpr for Literal {
- fn name(&self) -> String {
- "lit".to_string()
- }
-
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(self.value.get_datatype())
}
@@ -1336,7 +1262,7 @@ mod tests {
)?;
// expression: "a < b"
- let lt = binary(col(0, &schema), Operator::Lt, col(1, &schema));
+ let lt = binary(col("a"), Operator::Lt, col("b"));
let result = lt.evaluate(&batch)?;
assert_eq!(result.len(), 5);
@@ -1367,9 +1293,9 @@ mod tests {
// expression: "a < b OR a == b"
let expr = binary(
- binary(col(0, &schema), Operator::Lt, col(1, &schema)),
+ binary(col("a"), Operator::Lt, col("b")),
Operator::Or,
- binary(col(0, &schema), Operator::Eq, col(1, &schema)),
+ binary(col("a"), Operator::Eq, col("b")),
);
let result = expr.evaluate(&batch)?;
assert_eq!(result.len(), 5);
@@ -1414,7 +1340,7 @@ mod tests {
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![Arc::new(a)])?;
- let cast = CastExpr::try_new(col(0, &schema), &schema,
DataType::UInt32)?;
+ let cast = CastExpr::try_new(col("a"), &schema, DataType::UInt32)?;
let result = cast.evaluate(&batch)?;
assert_eq!(result.len(), 5);
@@ -1433,7 +1359,7 @@ mod tests {
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![Arc::new(a)])?;
- let cast = CastExpr::try_new(col(0, &schema), &schema,
DataType::Utf8)?;
+ let cast = CastExpr::try_new(col("a"), &schema, DataType::Utf8)?;
let result = cast.evaluate(&batch)?;
assert_eq!(result.len(), 5);
@@ -1453,7 +1379,7 @@ mod tests {
let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![Arc::new(a)])?;
let cast = CastExpr::try_new(
- col(0, &schema),
+ col("a"),
&schema,
DataType::Timestamp(TimeUnit::Nanosecond, None),
)?;
@@ -1472,7 +1398,7 @@ mod tests {
#[test]
fn invalid_cast() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]);
- match CastExpr::try_new(col(0, &schema), &schema, DataType::Int32) {
+ match CastExpr::try_new(col("a"), &schema, DataType::Int32) {
Err(ExecutionError::General(ref str)) => {
assert_eq!(str, "Invalid CAST from Utf8 to Int32");
Ok(())
@@ -1485,12 +1411,16 @@ mod tests {
fn sum_contract() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32,
false)]);
- let sum = sum(col(0, &schema));
- assert_eq!("SUM".to_string(), sum.name());
+ let sum = sum(col("a"));
assert_eq!(DataType::Int64, sum.data_type(&schema)?);
- let combiner = sum.create_reducer(0);
- assert_eq!("SUM".to_string(), combiner.name());
+ // after the aggr expression is applied, the schema changes to:
+ let schema = Schema::new(vec![
+ schema.field(0).clone(),
+ Field::new("SUM(a)", sum.data_type(&schema)?, false),
+ ]);
+
+ let combiner = sum.create_reducer("SUM(a)");
assert_eq!(DataType::Int64, combiner.data_type(&schema)?);
Ok(())
@@ -1500,12 +1430,16 @@ mod tests {
fn max_contract() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32,
false)]);
- let max = max(col(0, &schema));
- assert_eq!("MAX".to_string(), max.name());
+ let max = max(col("a"));
assert_eq!(DataType::Int64, max.data_type(&schema)?);
- let combiner = max.create_reducer(0);
- assert_eq!("MAX".to_string(), combiner.name());
+ // after the aggr expression is applied, the schema changes to:
+ let schema = Schema::new(vec![
+ schema.field(0).clone(),
+ Field::new("Max(a)", max.data_type(&schema)?, false),
+ ]);
+
+ let combiner = max.create_reducer("Max(a)");
assert_eq!(DataType::Int64, combiner.data_type(&schema)?);
Ok(())
@@ -1515,12 +1449,15 @@ mod tests {
fn min_contract() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32,
false)]);
- let min = min(col(0, &schema));
- assert_eq!("MIN".to_string(), min.name());
+ let min = min(col("a"));
assert_eq!(DataType::Int64, min.data_type(&schema)?);
- let combiner = min.create_reducer(0);
- assert_eq!("MIN".to_string(), combiner.name());
+ // after the aggr expression is applied, the schema changes to:
+ let schema = Schema::new(vec![
+ schema.field(0).clone(),
+ Field::new("MIN(a)", min.data_type(&schema)?, false),
+ ]);
+ let combiner = min.create_reducer("MIN(a)");
assert_eq!(DataType::Int64, combiner.data_type(&schema)?);
Ok(())
@@ -1529,12 +1466,16 @@ mod tests {
fn avg_contract() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32,
false)]);
- let avg = avg(col(0, &schema));
- assert_eq!("AVG".to_string(), avg.name());
+ let avg = avg(col("a"));
assert_eq!(DataType::Float64, avg.data_type(&schema)?);
- let combiner = avg.create_reducer(0);
- assert_eq!("AVG".to_string(), combiner.name());
+ // after the aggr expression is applied, the schema changes to:
+ let schema = Schema::new(vec![
+ schema.field(0).clone(),
+ Field::new("SUM(a)", avg.data_type(&schema)?, false),
+ ]);
+
+ let combiner = avg.create_reducer("SUM(a)");
assert_eq!(DataType::Float64, combiner.data_type(&schema)?);
Ok(())
@@ -1865,7 +1806,7 @@ mod tests {
}
fn do_sum(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
- let sum = sum(col(0, &batch.schema()));
+ let sum = sum(col("a"));
let accum = sum.create_accumulator();
let input = sum.evaluate_input(batch)?;
let mut accum = accum.borrow_mut();
@@ -1876,7 +1817,7 @@ mod tests {
}
fn do_max(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
- let max = max(col(0, &batch.schema()));
+ let max = max(col("a"));
let accum = max.create_accumulator();
let input = max.evaluate_input(batch)?;
let mut accum = accum.borrow_mut();
@@ -1887,7 +1828,7 @@ mod tests {
}
fn do_min(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
- let min = min(col(0, &batch.schema()));
+ let min = min(col("a"));
let accum = min.create_accumulator();
let input = min.evaluate_input(batch)?;
let mut accum = accum.borrow_mut();
@@ -1898,7 +1839,7 @@ mod tests {
}
fn do_count(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
- let count = count(col(0, &batch.schema()));
+ let count = count(col("a"));
let accum = count.create_accumulator();
let input = count.evaluate_input(batch)?;
let mut accum = accum.borrow_mut();
@@ -1909,7 +1850,7 @@ mod tests {
}
fn do_avg(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
- let avg = avg(col(0, &batch.schema()));
+ let avg = avg(col("a"));
let accum = avg.create_accumulator();
let input = avg.evaluate_input(batch)?;
let mut accum = accum.borrow_mut();
@@ -2009,7 +1950,7 @@ mod tests {
op: Operator,
expected: PrimitiveArray<T>,
) -> Result<()> {
- let arithmetic_op = binary(col(0, schema.as_ref()), op, col(1,
schema.as_ref()));
+ let arithmetic_op = binary(col("a"), op, col("b"));
let batch = RecordBatch::try_new(schema, data)?;
let result = arithmetic_op.evaluate(&batch)?;
@@ -2039,7 +1980,7 @@ mod tests {
let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![Arc::new(a)])?;
// expression: "!a"
- let lt = not(col(0, &schema));
+ let lt = not(col("a"));
let result = lt.evaluate(&batch)?;
assert_eq!(result.len(), 2);
diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs
b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs
index 441681a..19836fd 100644
--- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs
+++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs
@@ -39,7 +39,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::{RecordBatch, RecordBatchReader};
-use crate::execution::physical_plan::expressions::Column;
+use crate::execution::physical_plan::expressions::col;
use crate::logicalplan::ScalarValue;
use fnv::FnvHashMap;
@@ -54,26 +54,24 @@ pub struct HashAggregateExec {
impl HashAggregateExec {
/// Create a new hash aggregate execution plan
pub fn try_new(
- group_expr: Vec<Arc<dyn PhysicalExpr>>,
- aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+ group_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+ aggr_expr: Vec<(Arc<dyn AggregateExpr>, String)>,
input: Arc<dyn ExecutionPlan>,
) -> Result<Self> {
let input_schema = input.schema();
let mut fields = Vec::with_capacity(group_expr.len() +
aggr_expr.len());
- for expr in &group_expr {
- let name = expr.name();
- fields.push(Field::new(&name, expr.data_type(&input_schema)?,
true))
+ for (expr, name) in &group_expr {
+ fields.push(Field::new(name, expr.data_type(&input_schema)?, true))
}
- for expr in &aggr_expr {
- let name = expr.name();
+ for (expr, name) in &aggr_expr {
fields.push(Field::new(&name, expr.data_type(&input_schema)?,
true))
}
let schema = Arc::new(Schema::new(fields));
Ok(HashAggregateExec {
- group_expr,
- aggr_expr,
+ group_expr: group_expr.iter().map(|x| x.0.clone()).collect(),
+ aggr_expr: aggr_expr.iter().map(|x| x.0.clone()).collect(),
input,
schema,
})
@@ -83,19 +81,15 @@ impl HashAggregateExec {
/// expressions
pub fn make_final_expr(
&self,
+ group_names: Vec<String>,
+ agg_names: Vec<String>,
) -> (Vec<Arc<dyn PhysicalExpr>>, Vec<Arc<dyn AggregateExpr>>) {
let final_group: Vec<Arc<dyn PhysicalExpr>> =
(0..self.group_expr.len())
- .map(|i| {
- Arc::new(Column::new(i, &self.group_expr[i].name()))
- as Arc<dyn PhysicalExpr>
- })
+ .map(|i| col(&group_names[i]) as Arc<dyn PhysicalExpr>)
.collect();
let final_aggr: Vec<Arc<dyn AggregateExpr>> = (0..self.aggr_expr.len())
- .map(|i| {
- let aggr = self.aggr_expr[i].create_reducer(i +
self.group_expr.len());
- aggr as Arc<dyn AggregateExpr>
- })
+ .map(|i| self.aggr_expr[i].create_reducer(&agg_names[i]))
.collect();
(final_group, final_aggr)
@@ -772,24 +766,42 @@ mod tests {
let csv =
CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema),
None, 1024)?;
- let group_expr: Vec<Arc<dyn PhysicalExpr>> = vec![col(1,
schema.as_ref())];
+ let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
+ vec![(col("c2"), "c2".to_string())];
- let aggr_expr: Vec<Arc<dyn AggregateExpr>> = vec![sum(col(3,
schema.as_ref()))];
+ let aggregates: Vec<(Arc<dyn AggregateExpr>, String)> =
+ vec![(sum(col("c4")), "SUM(c4)".to_string())];
let partition_aggregate = HashAggregateExec::try_new(
- group_expr.clone(),
- aggr_expr.clone(),
+ groups.clone(),
+ aggregates.clone(),
Arc::new(csv),
)?;
let schema = partition_aggregate.schema();
let partitions = partition_aggregate.partitions()?;
- let (final_group, final_aggr) = partition_aggregate.make_final_expr();
+
+ // construct the expressions for the final aggregation
+ let (final_group, final_aggr) = partition_aggregate.make_final_expr(
+ groups.iter().map(|x| x.1.clone()).collect(),
+ aggregates.iter().map(|x| x.1.clone()).collect(),
+ );
let merge = Arc::new(MergeExec::new(schema.clone(), partitions));
- let merged_aggregate =
- HashAggregateExec::try_new(final_group, final_aggr, merge)?;
+ let merged_aggregate = HashAggregateExec::try_new(
+ final_group
+ .iter()
+ .enumerate()
+ .map(|(i, expr)| (expr.clone(), groups[i].1.clone()))
+ .collect(),
+ final_aggr
+ .iter()
+ .enumerate()
+ .map(|(i, expr)| (expr.clone(), aggregates[i].1.clone()))
+ .collect(),
+ merge,
+ )?;
let result = test::execute(&merged_aggregate)?;
assert_eq!(result.len(), 1);
diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs
b/rust/datafusion/src/execution/physical_plan/math_expressions.rs
index c995405..aa578b6 100644
--- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs
+++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs
@@ -81,18 +81,18 @@ pub fn register_math_functions(ctx: &mut ExecutionContext) {
mod tests {
use super::*;
use crate::error::Result;
- use crate::logicalplan::{sqrt, Expr, LogicalPlanBuilder};
+ use crate::logicalplan::{col, sqrt, LogicalPlanBuilder};
use arrow::datatypes::Schema;
#[test]
fn cast_i8_input() -> Result<()> {
let schema = Schema::new(vec![Field::new("c0", DataType::Int8, true)]);
let plan = LogicalPlanBuilder::scan("", "", &schema, None)?
- .project(vec![sqrt(Expr::UnresolvedColumn("c0".to_owned()))])?
+ .project(vec![sqrt(col("c0"))])?
.build()?;
let ctx = ExecutionContext::new();
let plan = ctx.optimize(&plan)?;
- let expected = "Projection: sqrt(CAST(#0 AS Float64))\
+ let expected = "Projection: sqrt(CAST(#c0 AS Float64))\
\n TableScan: projection=Some([0])";
assert_eq!(format!("{:?}", plan), expected);
Ok(())
@@ -102,11 +102,11 @@ mod tests {
fn no_cast_f64_input() -> Result<()> {
let schema = Schema::new(vec![Field::new("c0", DataType::Float64,
true)]);
let plan = LogicalPlanBuilder::scan("", "", &schema, None)?
- .project(vec![sqrt(Expr::UnresolvedColumn("c0".to_owned()))])?
+ .project(vec![sqrt(col("c0"))])?
.build()?;
let ctx = ExecutionContext::new();
let plan = ctx.optimize(&plan)?;
- let expected = "Projection: sqrt(#0)\
+ let expected = "Projection: sqrt(#c0)\
\n TableScan: projection=Some([0])";
assert_eq!(format!("{:?}", plan), expected);
Ok(())
diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs
b/rust/datafusion/src/execution/physical_plan/mod.rs
index a3b32eb..2e19178 100644
--- a/rust/datafusion/src/execution/physical_plan/mod.rs
+++ b/rust/datafusion/src/execution/physical_plan/mod.rs
@@ -24,7 +24,7 @@ use std::sync::{Arc, Mutex};
use crate::error::Result;
use crate::logicalplan::ScalarValue;
use arrow::array::ArrayRef;
-use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+use arrow::datatypes::{DataType, Schema, SchemaRef};
use arrow::record_batch::{RecordBatch, RecordBatchReader};
/// Partition-aware execution plan for a relation
@@ -42,29 +42,18 @@ pub trait Partition: Send + Sync {
}
/// Expression that can be evaluated against a RecordBatch
+/// A Physical expression knows its type, nullability and how to evaluate
itself.
pub trait PhysicalExpr: Send + Sync {
- /// Get the name to use in a schema to represent the result of this
expression
- fn name(&self) -> String;
/// Get the data type of this expression, given the schema of the input
fn data_type(&self, input_schema: &Schema) -> Result<DataType>;
/// Decide whehter this expression is nullable, given the schema of the
input
fn nullable(&self, input_schema: &Schema) -> Result<bool>;
/// Evaluate an expression against a RecordBatch
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
- /// Generate schema Field type for this expression
- fn to_schema_field(&self, input_schema: &Schema) -> Result<Field> {
- Ok(Field::new(
- &self.name(),
- self.data_type(input_schema)?,
- self.nullable(input_schema)?,
- ))
- }
}
/// Aggregate expression that can be evaluated against a RecordBatch
pub trait AggregateExpr: Send + Sync {
- /// Get the name to use in a schema to represent the result of this
expression
- fn name(&self) -> String;
/// Get the data type of this expression, given the schema of the input
fn data_type(&self, input_schema: &Schema) -> Result<DataType>;
/// Evaluate the expression being aggregated
@@ -74,7 +63,7 @@ pub trait AggregateExpr: Send + Sync {
/// Create an aggregate expression for combining the results of
accumulators from partitions.
/// For example, to combine the results of a parallel SUM we just need to
do another SUM, but
/// to combine the results of parallel COUNT we would also use SUM.
- fn create_reducer(&self, column_index: usize) -> Arc<dyn AggregateExpr>;
+ fn create_reducer(&self, column_name: &str) -> Arc<dyn AggregateExpr>;
}
/// Aggregate accumulator
diff --git a/rust/datafusion/src/execution/physical_plan/projection.rs
b/rust/datafusion/src/execution/physical_plan/projection.rs
index 7f39ded..2c2bcb0 100644
--- a/rust/datafusion/src/execution/physical_plan/projection.rs
+++ b/rust/datafusion/src/execution/physical_plan/projection.rs
@@ -24,7 +24,7 @@ use std::sync::{Arc, Mutex};
use crate::error::{ExecutionError, Result};
use crate::execution::physical_plan::{ExecutionPlan, Partition, PhysicalExpr};
-use arrow::datatypes::{Schema, SchemaRef};
+use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::{RecordBatch, RecordBatchReader};
@@ -41,20 +41,26 @@ pub struct ProjectionExec {
impl ProjectionExec {
/// Create a projection on an input
pub fn try_new(
- expr: Vec<Arc<dyn PhysicalExpr>>,
+ expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
input: Arc<dyn ExecutionPlan>,
) -> Result<Self> {
let input_schema = input.schema();
let fields: Result<Vec<_>> = expr
.iter()
- .map(|e| e.to_schema_field(&input_schema))
+ .map(|(e, name)| {
+ Ok(Field::new(
+ name,
+ e.data_type(&input_schema)?,
+ e.nullable(&input_schema)?,
+ ))
+ })
.collect();
let schema = Arc::new(Schema::new(fields?));
Ok(Self {
- expr: expr.clone(),
+ expr: expr.iter().map(|x| x.0.clone()).collect(),
schema,
input: input.clone(),
})
@@ -141,7 +147,7 @@ mod tests {
use super::*;
use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions};
- use crate::execution::physical_plan::expressions::Column;
+ use crate::execution::physical_plan::expressions::col;
use crate::test;
#[test]
@@ -154,12 +160,9 @@ mod tests {
let csv =
CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema),
None, 1024)?;
- let projection = ProjectionExec::try_new(
- vec![Arc::new(Column::new(0, &schema.as_ref().field(0).name()))],
- Arc::new(csv),
- )?;
-
- assert_eq!("c1", projection.schema.field(0).name().as_str());
+ // pick column c1 and name it column c1 in the output schema
+ let projection =
+ ProjectionExec::try_new(vec![(col("c1"), "c1".to_string())],
Arc::new(csv))?;
let mut partition_count = 0;
let mut row_count = 0;
diff --git a/rust/datafusion/src/execution/physical_plan/selection.rs
b/rust/datafusion/src/execution/physical_plan/selection.rs
index b8efe42..4f021ea 100644
--- a/rust/datafusion/src/execution/physical_plan/selection.rs
+++ b/rust/datafusion/src/execution/physical_plan/selection.rs
@@ -166,17 +166,9 @@ mod tests {
CsvExec::try_new(&path, CsvReadOptions::new().schema(&schema),
None, 1024)?;
let predicate: Arc<dyn PhysicalExpr> = binary(
- binary(
- col(1, schema.as_ref()),
- Operator::Gt,
- lit(ScalarValue::UInt32(1)),
- ),
+ binary(col("c2"), Operator::Gt, lit(ScalarValue::UInt32(1))),
Operator::And,
- binary(
- col(1, schema.as_ref()),
- Operator::Lt,
- lit(ScalarValue::UInt32(4)),
- ),
+ binary(col("c2"), Operator::Lt, lit(ScalarValue::UInt32(4))),
);
let selection: Arc<dyn ExecutionPlan> =
diff --git a/rust/datafusion/src/execution/physical_plan/sort.rs
b/rust/datafusion/src/execution/physical_plan/sort.rs
index 7017e8a..c8b8dec 100644
--- a/rust/datafusion/src/execution/physical_plan/sort.rs
+++ b/rust/datafusion/src/execution/physical_plan/sort.rs
@@ -172,17 +172,17 @@ mod tests {
vec![
// c1 string column
PhysicalSortExpr {
- expr: col(0, schema.as_ref()),
+ expr: col("c1"),
options: SortOptions::default(),
},
// c2 uin32 column
PhysicalSortExpr {
- expr: col(1, schema.as_ref()),
+ expr: col("c2"),
options: SortOptions::default(),
},
// c7 uin8 column
PhysicalSortExpr {
- expr: col(6, schema.as_ref()),
+ expr: col("c7"),
options: SortOptions::default(),
},
],
diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs
b/rust/datafusion/src/execution/physical_plan/udf.rs
index df970dd..a4480bc 100644
--- a/rust/datafusion/src/execution/physical_plan/udf.rs
+++ b/rust/datafusion/src/execution/physical_plan/udf.rs
@@ -61,7 +61,6 @@ impl ScalarFunction {
/// Scalar UDF Physical Expression
pub struct ScalarFunctionExpr {
- name: String,
fun: Box<ScalarUdf>,
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
@@ -70,13 +69,11 @@ pub struct ScalarFunctionExpr {
impl ScalarFunctionExpr {
/// Create a new Scalar function
pub fn new(
- name: &str,
fun: Box<ScalarUdf>,
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: &DataType,
) -> Self {
Self {
- name: name.to_owned(),
fun,
args,
return_type: return_type.clone(),
@@ -85,10 +82,6 @@ impl ScalarFunctionExpr {
}
impl PhysicalExpr for ScalarFunctionExpr {
- fn name(&self) -> String {
- self.name.clone()
- }
-
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(self.return_type.clone())
}
diff --git a/rust/datafusion/src/execution/table_impl.rs
b/rust/datafusion/src/execution/table_impl.rs
index 8f798bd..7494ba5 100644
--- a/rust/datafusion/src/execution/table_impl.rs
+++ b/rust/datafusion/src/execution/table_impl.rs
@@ -23,8 +23,7 @@ use crate::arrow::datatypes::DataType;
use crate::arrow::record_batch::RecordBatch;
use crate::error::{ExecutionError, Result};
use crate::execution::context::ExecutionContext;
-use crate::logicalplan::LogicalPlanBuilder;
-use crate::logicalplan::{Expr, LogicalPlan};
+use crate::logicalplan::{col, Expr, LogicalPlan, LogicalPlanBuilder};
use crate::table::*;
use arrow::datatypes::Schema;
@@ -48,8 +47,9 @@ impl Table for TableImpl {
.map(|name| {
self.plan
.schema()
+ // take the index to ensure that the column exists in the
schema
.index_of(name.to_owned())
- .and_then(|i| Ok(Expr::Column(i)))
+ .and_then(|_| Ok(col(name)))
.map_err(|e| e.into())
})
.collect::<Result<Vec<_>>>()?;
@@ -90,7 +90,8 @@ impl Table for TableImpl {
/// Return an expression representing a column within this table
fn col(&self, name: &str) -> Result<Expr> {
- Ok(Expr::Column(self.plan.schema().index_of(name)?))
+ self.plan.schema().index_of(name)?; // check that the column exists
+ Ok(col(name))
}
/// Create an expression to represent the min() aggregate function
@@ -141,7 +142,12 @@ impl TableImpl {
/// Determine the data type for a given expression
fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
match expr {
- Expr::Column(i) =>
Ok(self.plan.schema().field(*i).data_type().clone()),
+ Expr::Column(name) => Ok(self
+ .plan
+ .schema()
+ .field_with_name(name)?
+ .data_type()
+ .clone()),
_ => Err(ExecutionError::General(format!(
"Could not determine data type for expr {:?}",
expr
diff --git a/rust/datafusion/src/logicalplan.rs
b/rust/datafusion/src/logicalplan.rs
index 032bfb9..b372597 100644
--- a/rust/datafusion/src/logicalplan.rs
+++ b/rust/datafusion/src/logicalplan.rs
@@ -179,15 +179,92 @@ impl ScalarValue {
}
}
+/// Returns a readable name of an expression based on the input schema.
+/// This function recursively transverses the expression for names such as
"CAST(a > 2)".
+fn create_name(e: &Expr, input_schema: &Schema) -> Result<String> {
+ match e {
+ Expr::Alias(_, name) => Ok(name.clone()),
+ Expr::Column(name) => Ok(name.clone()),
+ Expr::Literal(value) => Ok(format!("{:?}", value)),
+ Expr::BinaryExpr { left, op, right } => {
+ let left = create_name(left, input_schema)?;
+ let right = create_name(right, input_schema)?;
+ Ok(format!("{} {:?} {}", left, op, right))
+ }
+ Expr::Cast { expr, data_type } => {
+ let expr = create_name(expr, input_schema)?;
+ Ok(format!("CAST({} as {:?})", expr, data_type))
+ }
+ Expr::ScalarFunction { name, args, .. } => {
+ let mut names = Vec::with_capacity(args.len());
+ for e in args {
+ names.push(create_name(e, input_schema)?);
+ }
+ Ok(format!("{}({})", name, names.join(",")))
+ }
+ Expr::AggregateFunction { name, args, .. } => {
+ let mut names = Vec::with_capacity(args.len());
+ for e in args {
+ names.push(create_name(e, input_schema)?);
+ }
+ Ok(format!("{}({})", name, names.join(",")))
+ }
+ other => Err(ExecutionError::NotImplemented(format!(
+ "Physical plan does not support logical expression {:?}",
+ other
+ ))),
+ }
+}
+
+/// Returns the datatype of the expression given the input schema
+// note: the physical plan derived from an expression must match the datatype
on this function.
+fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result<Field> {
+ let data_type = match e {
+ Expr::Alias(expr, ..) => expr.get_type(input_schema),
+ Expr::Column(name) =>
Ok(input_schema.field_with_name(name)?.data_type().clone()),
+ Expr::Literal(ref lit) => Ok(lit.get_datatype()),
+ Expr::ScalarFunction {
+ ref return_type, ..
+ } => Ok(return_type.clone()),
+ Expr::AggregateFunction {
+ ref return_type, ..
+ } => Ok(return_type.clone()),
+ Expr::Cast { ref data_type, .. } => Ok(data_type.clone()),
+ Expr::BinaryExpr {
+ ref left,
+ ref right,
+ ..
+ } => {
+ let left_type = left.get_type(input_schema)?;
+ let right_type = right.get_type(input_schema)?;
+ Ok(utils::get_supertype(&left_type, &right_type).unwrap())
+ }
+ _ => Err(ExecutionError::NotImplemented(format!(
+ "Cannot determine schema type for expression {:?}",
+ e
+ ))),
+ };
+
+ match data_type {
+ Ok(d) => Ok(Field::new(&e.name(input_schema)?, d, true)),
+ Err(e) => Err(e),
+ }
+}
+
+/// Create field meta-data from an expression, for use in a result set schema
+fn exprlist_to_fields(expr: &[Expr], input_schema: &Schema) ->
Result<Vec<Field>> {
+ expr.iter()
+ .map(|e| expr_to_field(e, input_schema))
+ .collect()
+}
+
/// Relation expression
#[derive(Clone, PartialEq)]
pub enum Expr {
/// An aliased expression
Alias(Box<Expr>, String),
- /// index into a value within the row or complex value
- Column(usize),
- /// Reference to column by name
- UnresolvedColumn(String),
+ /// column of a table scan
+ Column(String),
/// literal value
Literal(ScalarValue),
/// binary expression e.g. "age > 21"
@@ -248,10 +325,7 @@ impl Expr {
pub fn get_type(&self, schema: &Schema) -> Result<DataType> {
match self {
Expr::Alias(expr, _) => expr.get_type(schema),
- Expr::Column(n) => Ok(schema.field(*n).data_type().clone()),
- Expr::UnresolvedColumn(name) => {
- Ok(schema.field_with_name(&name)?.data_type().clone())
- }
+ Expr::Column(name) =>
Ok(schema.field_with_name(name)?.data_type().clone()),
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Cast { data_type, .. } => Ok(data_type.clone()),
Expr::ScalarFunction { return_type, .. } =>
Ok(return_type.clone()),
@@ -283,6 +357,13 @@ impl Expr {
}
}
+ /// Return the name of this expression
+ ///
+ /// This represents how a column with this expression is named when no
alias is chosen
+ pub fn name(&self, input_schema: &Schema) -> Result<String> {
+ create_name(self, input_schema)
+ }
+
/// Perform a type cast on the expression value.
///
/// Will `Err` if the type cast cannot be performed.
@@ -368,14 +449,9 @@ impl Expr {
}
}
-/// Create a column expression based on a column index
-pub fn col_index(index: usize) -> Expr {
- Expr::Column(index)
-}
-
/// Create a column expression based on a column name
pub fn col(name: &str) -> Expr {
- Expr::UnresolvedColumn(name.to_owned())
+ Expr::Column(name.to_owned())
}
/// Whether it can be represented as a literal expression
@@ -475,8 +551,7 @@ impl fmt::Debug for Expr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias),
- Expr::Column(i) => write!(f, "#{}", i),
- Expr::UnresolvedColumn(name) => write!(f, "#{}", name),
+ Expr::Column(name) => write!(f, "#{}", name),
Expr::Literal(v) => write!(f, "{:?}", v),
Expr::Cast { expr, data_type } => {
write!(f, "CAST({:?} AS {:?})", expr, data_type)
@@ -925,7 +1000,7 @@ impl LogicalPlanBuilder {
(0..expr.len()).for_each(|i| match &expr[i] {
Expr::Wildcard => {
(0..input_schema.fields().len())
- .for_each(|i| expr_vec.push(col_index(i).clone()));
+ .for_each(|i|
expr_vec.push(col(input_schema.field(i).name())));
}
_ => expr_vec.push(expr[i].clone()),
});
@@ -934,10 +1009,8 @@ impl LogicalPlanBuilder {
expr.clone()
};
- let schema = Schema::new(utils::exprlist_to_fields(
- &projected_expr,
- input_schema.as_ref(),
- )?);
+ let schema =
+ Schema::new(exprlist_to_fields(&projected_expr,
input_schema.as_ref())?);
Ok(Self::from(&LogicalPlan::Projection {
expr: projected_expr,
@@ -974,11 +1047,10 @@ impl LogicalPlanBuilder {
/// Apply an aggregate
pub fn aggregate(&self, group_expr: Vec<Expr>, aggr_expr: Vec<Expr>) ->
Result<Self> {
- let mut all_fields: Vec<Expr> = group_expr.clone();
- aggr_expr.iter().for_each(|x| all_fields.push(x.clone()));
+ let mut all_expr: Vec<Expr> = group_expr.clone();
+ aggr_expr.iter().for_each(|x| all_expr.push(x.clone()));
- let aggr_schema =
- Schema::new(utils::exprlist_to_fields(&all_fields,
self.plan.schema())?);
+ let aggr_schema = Schema::new(exprlist_to_fields(&all_expr,
self.plan.schema())?);
Ok(Self::from(&LogicalPlan::Aggregate {
input: Box::new(self.plan.clone()),
diff --git a/rust/datafusion/src/optimizer/mod.rs
b/rust/datafusion/src/optimizer/mod.rs
index 1ac97b1..e60c7db 100644
--- a/rust/datafusion/src/optimizer/mod.rs
+++ b/rust/datafusion/src/optimizer/mod.rs
@@ -20,6 +20,5 @@
pub mod optimizer;
pub mod projection_push_down;
-pub mod resolve_columns;
pub mod type_coercion;
pub mod utils;
diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs
b/rust/datafusion/src/optimizer/projection_push_down.rs
index d1bba6e..e99a996 100644
--- a/rust/datafusion/src/optimizer/projection_push_down.rs
+++ b/rust/datafusion/src/optimizer/projection_push_down.rs
@@ -20,11 +20,11 @@
use crate::error::{ExecutionError, Result};
use crate::logicalplan::LogicalPlan;
-use crate::logicalplan::{Expr, LogicalPlanBuilder};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use arrow::datatypes::{Field, Schema};
-use std::collections::{HashMap, HashSet};
+use arrow::error::Result as ArrowResult;
+use std::collections::HashSet;
/// Projection Push Down optimizer rule ensures that only referenced columns
are
/// loaded into memory
@@ -32,9 +32,9 @@ pub struct ProjectionPushDown {}
impl OptimizerRule for ProjectionPushDown {
fn optimize(&mut self, plan: &LogicalPlan) -> Result<LogicalPlan> {
- let mut accum: HashSet<usize> = HashSet::new();
- let mut mapping: HashMap<usize, usize> = HashMap::new();
- self.optimize_plan(plan, &mut accum, &mut mapping, false)
+ // set of all columns refered from a scan.
+ let mut accum: HashSet<String> = HashSet::new();
+ self.optimize_plan(plan, &mut accum, false)
}
}
@@ -47,72 +47,70 @@ impl ProjectionPushDown {
fn optimize_plan(
&self,
plan: &LogicalPlan,
- accum: &mut HashSet<usize>,
- mapping: &mut HashMap<usize, usize>,
+ accum: &mut HashSet<String>,
has_projection: bool,
) -> Result<LogicalPlan> {
match plan {
- LogicalPlan::Projection { expr, input, .. } => {
+ LogicalPlan::Projection {
+ expr,
+ input,
+ schema,
+ } => {
// collect all columns referenced by projection expressions
- utils::exprlist_to_column_indices(&expr, accum)?;
+ utils::exprlist_to_column_names(&expr, accum)?;
- LogicalPlanBuilder::from(
- &self.optimize_plan(&input, accum, mapping, true)?,
- )
- .project(self.rewrite_expr_list(expr, mapping)?)?
- .build()
+ Ok(LogicalPlan::Projection {
+ expr: expr.clone(),
+ input: Box::new(self.optimize_plan(&input, accum, true)?),
+ schema: schema.clone(),
+ })
}
LogicalPlan::Selection { expr, input } => {
// collect all columns referenced by filter expression
- utils::expr_to_column_indices(expr, accum)?;
+ utils::expr_to_column_names(expr, accum)?;
- LogicalPlanBuilder::from(&self.optimize_plan(
- &input,
- accum,
- mapping,
- has_projection,
- )?)
- .filter(self.rewrite_expr(expr, mapping)?)?
- .build()
+ Ok(LogicalPlan::Selection {
+ expr: expr.clone(),
+ input: Box::new(self.optimize_plan(&input, accum,
has_projection)?),
+ })
}
LogicalPlan::Aggregate {
input,
group_expr,
aggr_expr,
- ..
+ schema,
} => {
// collect all columns referenced by grouping and aggregate
expressions
- utils::exprlist_to_column_indices(&group_expr, accum)?;
- utils::exprlist_to_column_indices(&aggr_expr, accum)?;
+ utils::exprlist_to_column_names(&group_expr, accum)?;
+ utils::exprlist_to_column_names(&aggr_expr, accum)?;
- LogicalPlanBuilder::from(&self.optimize_plan(
- &input,
- accum,
- mapping,
- has_projection,
- )?)
- .aggregate(
- self.rewrite_expr_list(group_expr, mapping)?,
- self.rewrite_expr_list(aggr_expr, mapping)?,
- )?
- .build()
+ Ok(LogicalPlan::Aggregate {
+ input: Box::new(self.optimize_plan(&input, accum,
has_projection)?),
+ group_expr: group_expr.clone(),
+ aggr_expr: aggr_expr.clone(),
+ schema: schema.clone(),
+ })
}
- LogicalPlan::Sort { expr, input, .. } => {
+ LogicalPlan::Sort {
+ expr,
+ input,
+ schema,
+ } => {
// collect all columns referenced by sort expressions
- utils::exprlist_to_column_indices(&expr, accum)?;
+ utils::exprlist_to_column_names(&expr, accum)?;
- LogicalPlanBuilder::from(&self.optimize_plan(
- &input,
- accum,
- mapping,
- has_projection,
- )?)
- .sort(self.rewrite_expr_list(expr, mapping)?)?
- .build()
+ Ok(LogicalPlan::Sort {
+ expr: expr.clone(),
+ input: Box::new(self.optimize_plan(&input, accum,
has_projection)?),
+ schema: schema.clone(),
+ })
}
- LogicalPlan::EmptyRelation { schema } =>
Ok(LogicalPlan::EmptyRelation {
+ LogicalPlan::Limit { n, input, schema } => Ok(LogicalPlan::Limit {
+ n: n.clone(),
+ input: Box::new(self.optimize_plan(&input, accum,
has_projection)?),
schema: schema.clone(),
}),
+ LogicalPlan::EmptyRelation { .. } => Ok(plan.clone()),
LogicalPlan::TableScan {
schema_name,
table_name,
@@ -124,7 +122,6 @@ impl ProjectionPushDown {
&table_schema,
projection,
accum,
- mapping,
has_projection,
)?;
@@ -133,8 +130,8 @@ impl ProjectionPushDown {
schema_name: schema_name.to_string(),
table_name: table_name.to_string(),
table_schema: table_schema.clone(),
- projected_schema: Box::new(projected_schema),
projection: Some(projection),
+ projected_schema: Box::new(projected_schema),
})
}
LogicalPlan::InMemoryScan {
@@ -143,13 +140,8 @@ impl ProjectionPushDown {
projection,
..
} => {
- let (projection, projected_schema) = get_projected_schema(
- &schema,
- projection,
- accum,
- mapping,
- has_projection,
- )?;
+ let (projection, projected_schema) =
+ get_projected_schema(&schema, projection, accum,
has_projection)?;
Ok(LogicalPlan::InMemoryScan {
data: data.clone(),
schema: schema.clone(),
@@ -165,13 +157,8 @@ impl ProjectionPushDown {
projection,
..
} => {
- let (projection, projected_schema) = get_projected_schema(
- &schema,
- projection,
- accum,
- mapping,
- has_projection,
- )?;
+ let (projection, projected_schema) =
+ get_projected_schema(&schema, projection, accum,
has_projection)?;
Ok(LogicalPlan::CsvScan {
path: path.to_owned(),
@@ -188,13 +175,8 @@ impl ProjectionPushDown {
projection,
..
} => {
- let (projection, projected_schema) = get_projected_schema(
- &schema,
- projection,
- accum,
- mapping,
- has_projection,
- )?;
+ let (projection, projected_schema) =
+ get_projected_schema(&schema, projection, accum,
has_projection)?;
Ok(LogicalPlan::ParquetScan {
path: path.to_owned(),
@@ -203,103 +185,7 @@ impl ProjectionPushDown {
projected_schema: Box::new(projected_schema),
})
}
- LogicalPlan::Limit { n, input, .. } => LogicalPlanBuilder::from(
- &self.optimize_plan(&input, accum, mapping, has_projection)?,
- )
- .limit(*n)?
- .build(),
- LogicalPlan::CreateExternalTable {
- schema,
- name,
- location,
- file_type,
- has_header,
- } => Ok(LogicalPlan::CreateExternalTable {
- schema: schema.clone(),
- name: name.to_string(),
- location: location.to_string(),
- file_type: file_type.clone(),
- has_header: *has_header,
- }),
- }
- }
-
- fn rewrite_expr_list(
- &self,
- expr: &[Expr],
- mapping: &HashMap<usize, usize>,
- ) -> Result<Vec<Expr>> {
- Ok(expr
- .iter()
- .map(|e| self.rewrite_expr(e, mapping))
- .collect::<Result<Vec<Expr>>>()?)
- }
-
- fn rewrite_expr(&self, expr: &Expr, mapping: &HashMap<usize, usize>) ->
Result<Expr> {
- match expr {
- Expr::Alias(expr, name) => Ok(Expr::Alias(
- Box::new(self.rewrite_expr(expr, mapping)?),
- name.clone(),
- )),
- Expr::Column(i) => Ok(Expr::Column(self.new_index(mapping, i)?)),
- Expr::UnresolvedColumn(_) => Err(ExecutionError::ExecutionError(
- "Columns need to be resolved before projection push down rule
can run"
- .to_owned(),
- )),
- Expr::Literal(_) => Ok(expr.clone()),
- Expr::Not(e) => Ok(Expr::Not(Box::new(self.rewrite_expr(e,
mapping)?))),
- Expr::IsNull(e) => Ok(Expr::IsNull(Box::new(self.rewrite_expr(e,
mapping)?))),
- Expr::IsNotNull(e) => {
- Ok(Expr::IsNotNull(Box::new(self.rewrite_expr(e, mapping)?)))
- }
- Expr::BinaryExpr { left, op, right } => Ok(Expr::BinaryExpr {
- left: Box::new(self.rewrite_expr(left, mapping)?),
- op: op.clone(),
- right: Box::new(self.rewrite_expr(right, mapping)?),
- }),
- Expr::Cast { expr, data_type } => Ok(Expr::Cast {
- expr: Box::new(self.rewrite_expr(expr, mapping)?),
- data_type: data_type.clone(),
- }),
- Expr::Sort {
- expr,
- asc,
- nulls_first,
- } => Ok(Expr::Sort {
- expr: Box::new(self.rewrite_expr(expr, mapping)?),
- asc: *asc,
- nulls_first: *nulls_first,
- }),
- Expr::AggregateFunction {
- name,
- args,
- return_type,
- } => Ok(Expr::AggregateFunction {
- name: name.to_string(),
- args: self.rewrite_expr_list(args, mapping)?,
- return_type: return_type.clone(),
- }),
- Expr::ScalarFunction {
- name,
- args,
- return_type,
- } => Ok(Expr::ScalarFunction {
- name: name.to_string(),
- args: self.rewrite_expr_list(args, mapping)?,
- return_type: return_type.clone(),
- }),
- Expr::Wildcard => Err(ExecutionError::General(
- "Wildcard expressions are not valid in a logical query
plan".to_owned(),
- )),
- }
- }
-
- fn new_index(&self, mapping: &HashMap<usize, usize>, i: &usize) ->
Result<usize> {
- match mapping.get(i) {
- Some(j) => Ok(*j),
- _ => Err(ExecutionError::InternalError(
- "Internal error computing new column index".to_string(),
- )),
+ LogicalPlan::CreateExternalTable { .. } => Ok(plan.clone()),
}
}
}
@@ -307,8 +193,7 @@ impl ProjectionPushDown {
fn get_projected_schema(
table_schema: &Schema,
projection: &Option<Vec<usize>>,
- accum: &HashSet<usize>,
- mapping: &mut HashMap<usize, usize>,
+ accum: &HashSet<String>,
has_projection: bool,
) -> Result<(Vec<usize>, Schema)> {
if projection.is_some() {
@@ -318,8 +203,15 @@ fn get_projected_schema(
}
// once we reach the table scan, we can use the accumulated set of column
- // indexes as the projection in the table scan
- let mut projection = accum.iter().map(|i| *i).collect::<Vec<usize>>();
+ // names to construct the set of column indexes in the scan
+ //
+ // we discard non-existing columns because some column names are not part
of the schema,
+ // e.g. when the column derives from an aggregation
+ let mut projection: Vec<usize> = accum
+ .iter()
+ .map(|name| table_schema.index_of(name))
+ .filter_map(ArrowResult::ok)
+ .collect();
if projection.is_empty() {
if has_projection {
@@ -346,21 +238,6 @@ fn get_projected_schema(
projected_fields.push(table_schema.fields()[*i].clone());
}
- // now that the table scan is returning a different schema we need to
- // create a mapping from the original column index to the
- // new column index so that we can rewrite expressions as
- // we walk back up the tree
-
- if mapping.len() != 0 {
- return Err(ExecutionError::InternalError("illegal state".to_string()));
- }
-
- for i in 0..table_schema.fields().len() {
- if let Some(n) = projection.iter().position(|v| *v == i) {
- mapping.insert(i, n);
- }
- }
-
Ok((projection, Schema::new(projected_fields)))
}
@@ -368,8 +245,8 @@ fn get_projected_schema(
mod tests {
use super::*;
- use crate::logicalplan::lit;
- use crate::logicalplan::Expr::*;
+ use crate::logicalplan::{col, lit};
+ use crate::logicalplan::{Expr, LogicalPlanBuilder};
use crate::test::*;
use arrow::datatypes::DataType;
@@ -378,10 +255,10 @@ mod tests {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
- .aggregate(vec![], vec![max(Column(1))])?
+ .aggregate(vec![], vec![max(col("b"))])?
.build()?;
- let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#0)]]\
+ let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\
\n TableScan: test projection=Some([1])";
assert_optimized_plan_eq(&plan, expected);
@@ -394,10 +271,10 @@ mod tests {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
- .aggregate(vec![Column(2)], vec![max(Column(1))])?
+ .aggregate(vec![col("c")], vec![max(col("b"))])?
.build()?;
- let expected = "Aggregate: groupBy=[[#1]], aggr=[[MAX(#0)]]\
+ let expected = "Aggregate: groupBy=[[#c]], aggr=[[MAX(#b)]]\
\n TableScan: test projection=Some([1, 2])";
assert_optimized_plan_eq(&plan, expected);
@@ -410,12 +287,12 @@ mod tests {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(&table_scan)
- .filter(Column(2))?
- .aggregate(vec![], vec![max(Column(1))])?
+ .filter(col("c"))?
+ .aggregate(vec![], vec![max(col("b"))])?
.build()?;
- let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#0)]]\
- \n Selection: #1\
+ let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\
+ \n Selection: #c\
\n TableScan: test projection=Some([1, 2])";
assert_optimized_plan_eq(&plan, expected);
@@ -428,13 +305,13 @@ mod tests {
let table_scan = test_table_scan()?;
let projection = LogicalPlanBuilder::from(&table_scan)
- .project(vec![Cast {
- expr: Box::new(Column(2)),
+ .project(vec![Expr::Cast {
+ expr: Box::new(col("c")),
data_type: DataType::Float64,
}])?
.build()?;
- let expected = "Projection: CAST(#0 AS Float64)\
+ let expected = "Projection: CAST(#c AS Float64)\
\n TableScan: test projection=Some([2])";
assert_optimized_plan_eq(&projection, expected);
@@ -449,12 +326,12 @@ mod tests {
assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
let plan = LogicalPlanBuilder::from(&table_scan)
- .project(vec![Column(0), Column(1)])?
+ .project(vec![col("a"), col("b")])?
.build()?;
assert_fields_eq(&plan, vec!["a", "b"]);
- let expected = "Projection: #0, #1\
+ let expected = "Projection: #a, #b\
\n TableScan: test projection=Some([0, 1])";
assert_optimized_plan_eq(&plan, expected);
@@ -469,14 +346,14 @@ mod tests {
assert_fields_eq(&table_scan, vec!["a", "b", "c"]);
let plan = LogicalPlanBuilder::from(&table_scan)
- .project(vec![Column(2), Column(0)])?
+ .project(vec![col("c"), col("a")])?
.limit(5)?
.build()?;
assert_fields_eq(&plan, vec!["c", "a"]);
let expected = "Limit: 5\
- \n Projection: #1, #0\
+ \n Projection: #c, #a\
\n TableScan: test projection=Some([0, 2])";
assert_optimized_plan_eq(&plan, expected);
diff --git a/rust/datafusion/src/optimizer/resolve_columns.rs
b/rust/datafusion/src/optimizer/resolve_columns.rs
deleted file mode 100644
index 61b2e81..0000000
--- a/rust/datafusion/src/optimizer/resolve_columns.rs
+++ /dev/null
@@ -1,159 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Optimizer rule to replace UnresolvedColumns with Columns
-
-use crate::error::Result;
-use crate::logicalplan::LogicalPlan;
-use crate::logicalplan::{Expr, LogicalPlanBuilder};
-use crate::optimizer::optimizer::OptimizerRule;
-use arrow::datatypes::Schema;
-
-/// Replace UnresolvedColumns with Columns
-pub struct ResolveColumnsRule {}
-
-impl ResolveColumnsRule {
- #[allow(missing_docs)]
- pub fn new() -> Self {
- Self {}
- }
-}
-
-impl OptimizerRule for ResolveColumnsRule {
- fn optimize(&mut self, plan: &LogicalPlan) -> Result<LogicalPlan> {
- match plan {
- LogicalPlan::Projection { input, expr, .. } => {
- Ok(LogicalPlanBuilder::from(&self.optimize(input.as_ref())?)
- .project(rewrite_expr_list(expr, &input.schema())?)?
- .build()?)
- }
- LogicalPlan::Selection { expr, input } =>
Ok(LogicalPlanBuilder::from(input)
- .filter(rewrite_expr(expr, &input.schema())?)?
- .build()?),
- LogicalPlan::Aggregate {
- input,
- group_expr,
- aggr_expr,
- ..
- } => Ok(LogicalPlanBuilder::from(input)
- .aggregate(
- rewrite_expr_list(group_expr, &input.schema())?,
- rewrite_expr_list(aggr_expr, &input.schema())?,
- )?
- .build()?),
- LogicalPlan::Sort { input, expr, .. } => {
- Ok(LogicalPlanBuilder::from(&self.optimize(input)?)
- .sort(rewrite_expr_list(expr, &input.schema())?)?
- .build()?)
- }
- _ => Ok(plan.clone()),
- }
- }
-}
-
-fn rewrite_expr_list(expr: &[Expr], schema: &Schema) -> Result<Vec<Expr>> {
- Ok(expr
- .iter()
- .map(|e| rewrite_expr(e, schema))
- .collect::<Result<Vec<_>>>()?)
-}
-
-fn rewrite_expr(expr: &Expr, schema: &Schema) -> Result<Expr> {
- match expr {
- Expr::Alias(expr, alias) => Ok(rewrite_expr(&expr,
schema)?.alias(&alias)),
- Expr::UnresolvedColumn(name) =>
Ok(Expr::Column(schema.index_of(&name)?)),
- Expr::BinaryExpr { left, op, right } => Ok(Expr::BinaryExpr {
- left: Box::new(rewrite_expr(&left, schema)?),
- op: op.clone(),
- right: Box::new(rewrite_expr(&right, schema)?),
- }),
- Expr::Not(expr) => Ok(Expr::Not(Box::new(rewrite_expr(&expr,
schema)?))),
- Expr::IsNotNull(expr) => {
- Ok(Expr::IsNotNull(Box::new(rewrite_expr(&expr, schema)?)))
- }
- Expr::IsNull(expr) => Ok(Expr::IsNull(Box::new(rewrite_expr(&expr,
schema)?))),
- Expr::Cast { expr, data_type } => Ok(Expr::Cast {
- expr: Box::new(rewrite_expr(&expr, schema)?),
- data_type: data_type.clone(),
- }),
- Expr::Sort {
- expr,
- asc,
- nulls_first,
- } => Ok(Expr::Sort {
- expr: Box::new(rewrite_expr(&expr, schema)?),
- asc: asc.clone(),
- nulls_first: nulls_first.clone(),
- }),
- Expr::ScalarFunction {
- name,
- args,
- return_type,
- } => Ok(Expr::ScalarFunction {
- name: name.clone(),
- args: rewrite_expr_list(args, schema)?,
- return_type: return_type.clone(),
- }),
- Expr::AggregateFunction {
- name,
- args,
- return_type,
- } => Ok(Expr::AggregateFunction {
- name: name.clone(),
- args: rewrite_expr_list(args, schema)?,
- return_type: return_type.clone(),
- }),
- _ => Ok(expr.clone()),
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::logicalplan::col;
- use crate::test::*;
-
- #[test]
- fn aggregate_no_group_by() -> Result<()> {
- let table_scan = test_table_scan()?;
-
- let plan = LogicalPlanBuilder::from(&table_scan)
- .aggregate(vec![col("a")], vec![max(col("b"))])?
- .build()?;
-
- // plan has unresolve columns
- let expected = "Aggregate: groupBy=[[#a]], aggr=[[MAX(#b)]]\n
TableScan: test projection=None";
- assert_eq!(format!("{:?}", plan), expected);
-
- // optimized plan has resolved columns
- let expected = "Aggregate: groupBy=[[#0]], aggr=[[MAX(#1)]]\n
TableScan: test projection=None";
- assert_optimized_plan_eq(&plan, expected);
-
- Ok(())
- }
-
- fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
- let optimized_plan = optimize(plan).expect("failed to optimize plan");
- let formatted_plan = format!("{:?}", optimized_plan);
- assert_eq!(formatted_plan, expected);
- }
-
- fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
- let mut rule = ResolveColumnsRule::new();
- rule.optimize(plan)
- }
-}
diff --git a/rust/datafusion/src/optimizer/type_coercion.rs
b/rust/datafusion/src/optimizer/type_coercion.rs
index 7639b5b..a03a92c 100644
--- a/rust/datafusion/src/optimizer/type_coercion.rs
+++ b/rust/datafusion/src/optimizer/type_coercion.rs
@@ -132,7 +132,6 @@ impl<'a> TypeCoercionRule<'a> {
alias.to_owned(),
)),
Expr::Literal(_) => Ok(expr.clone()),
- Expr::UnresolvedColumn(_) => Ok(expr.clone()),
Expr::Not(_) => Ok(expr.clone()),
Expr::Sort { .. } => Ok(expr.clone()),
Expr::Wildcard { .. } => Err(ExecutionError::General(
@@ -183,7 +182,6 @@ mod tests {
use super::*;
use crate::execution::context::ExecutionContext;
use crate::execution::physical_plan::csv::CsvReadOptions;
- use crate::logicalplan::Expr::*;
use crate::logicalplan::{col, Operator};
use crate::test::arrow_testdata_path;
use arrow::datatypes::{DataType, Field, Schema};
@@ -214,12 +212,12 @@ mod tests {
binary_cast_test(
DataType::Int32,
DataType::Int64,
- "CAST(#0 AS Int64) Plus #1",
+ "CAST(#c0 AS Int64) Plus #c1",
);
binary_cast_test(
DataType::Int64,
DataType::Int32,
- "#0 Plus CAST(#1 AS Int64)",
+ "#c0 Plus CAST(#c1 AS Int64)",
);
}
@@ -228,12 +226,12 @@ mod tests {
binary_cast_test(
DataType::Float32,
DataType::Float64,
- "CAST(#0 AS Float64) Plus #1",
+ "CAST(#c0 AS Float64) Plus #c1",
);
binary_cast_test(
DataType::Float64,
DataType::Float32,
- "#0 Plus CAST(#1 AS Float64)",
+ "#c0 Plus CAST(#c1 AS Float64)",
);
}
@@ -242,12 +240,12 @@ mod tests {
binary_cast_test(
DataType::Int32,
DataType::Float32,
- "CAST(#0 AS Float32) Plus #1",
+ "CAST(#c0 AS Float32) Plus #c1",
);
binary_cast_test(
DataType::Float32,
DataType::Int32,
- "#0 Plus CAST(#1 AS Float32)",
+ "#c0 Plus CAST(#c1 AS Float32)",
);
}
@@ -256,12 +254,12 @@ mod tests {
binary_cast_test(
DataType::UInt32,
DataType::Int64,
- "CAST(#0 AS Int64) Plus #1",
+ "CAST(#c0 AS Int64) Plus #c1",
);
binary_cast_test(
DataType::Int64,
DataType::UInt32,
- "#0 Plus CAST(#1 AS Int64)",
+ "#c0 Plus CAST(#c1 AS Int64)",
);
}
@@ -272,9 +270,9 @@ mod tests {
]);
let expr = Expr::BinaryExpr {
- left: Box::new(Column(0)),
+ left: Box::new(col("c0")),
op: Operator::Plus,
- right: Box::new(Column(1)),
+ right: Box::new(col("c1")),
};
let ctx = ExecutionContext::new();
diff --git a/rust/datafusion/src/optimizer/utils.rs
b/rust/datafusion/src/optimizer/utils.rs
index cdbad9e..5c59803 100644
--- a/rust/datafusion/src/optimizer/utils.rs
+++ b/rust/datafusion/src/optimizer/utils.rs
@@ -19,117 +19,54 @@
use std::collections::HashSet;
-use arrow::datatypes::{DataType, Field, Schema};
+use arrow::datatypes::DataType;
use crate::error::{ExecutionError, Result};
use crate::logicalplan::Expr;
/// Recursively walk a list of expression trees, collecting the unique set of
column
-/// indexes referenced in the expression
-pub fn exprlist_to_column_indices(
+/// names referenced in the expression
+pub fn exprlist_to_column_names(
expr: &[Expr],
- accum: &mut HashSet<usize>,
+ accum: &mut HashSet<String>,
) -> Result<()> {
for e in expr {
- expr_to_column_indices(e, accum)?;
+ expr_to_column_names(e, accum)?;
}
Ok(())
}
-/// Recursively walk an expression tree, collecting the unique set of column
indexes
+/// Recursively walk an expression tree, collecting the unique set of column
names
/// referenced in the expression
-pub fn expr_to_column_indices(expr: &Expr, accum: &mut HashSet<usize>) ->
Result<()> {
+pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet<String>) ->
Result<()> {
match expr {
- Expr::Alias(expr, _) => expr_to_column_indices(expr, accum),
- Expr::Column(i) => {
- accum.insert(*i);
+ Expr::Alias(expr, _) => expr_to_column_names(expr, accum),
+ Expr::Column(name) => {
+ accum.insert(name.clone());
Ok(())
}
- Expr::UnresolvedColumn(_) => Err(ExecutionError::ExecutionError(
- "Columns need to be resolved before column indexes resolution rule
can run"
- .to_owned(),
- )),
Expr::Literal(_) => {
// not needed
Ok(())
}
- Expr::Not(e) => expr_to_column_indices(e, accum),
- Expr::IsNull(e) => expr_to_column_indices(e, accum),
- Expr::IsNotNull(e) => expr_to_column_indices(e, accum),
+ Expr::Not(e) => expr_to_column_names(e, accum),
+ Expr::IsNull(e) => expr_to_column_names(e, accum),
+ Expr::IsNotNull(e) => expr_to_column_names(e, accum),
Expr::BinaryExpr { left, right, .. } => {
- expr_to_column_indices(left, accum)?;
- expr_to_column_indices(right, accum)?;
+ expr_to_column_names(left, accum)?;
+ expr_to_column_names(right, accum)?;
Ok(())
}
- Expr::Cast { expr, .. } => expr_to_column_indices(expr, accum),
- Expr::Sort { expr, .. } => expr_to_column_indices(expr, accum),
- Expr::AggregateFunction { args, .. } =>
exprlist_to_column_indices(args, accum),
- Expr::ScalarFunction { args, .. } => exprlist_to_column_indices(args,
accum),
+ Expr::Cast { expr, .. } => expr_to_column_names(expr, accum),
+ Expr::Sort { expr, .. } => expr_to_column_names(expr, accum),
+ Expr::AggregateFunction { args, .. } => exprlist_to_column_names(args,
accum),
+ Expr::ScalarFunction { args, .. } => exprlist_to_column_names(args,
accum),
Expr::Wildcard => Err(ExecutionError::General(
"Wildcard expressions are not valid in a logical query
plan".to_owned(),
)),
}
}
-/// Create field meta-data from an expression, for use in a result set schema
-pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result<Field> {
- match e {
- Expr::Alias(expr, name) => {
- Ok(Field::new(name, expr.get_type(input_schema)?, true))
- }
- Expr::UnresolvedColumn(name) =>
Ok(input_schema.field_with_name(&name)?.clone()),
- Expr::Column(i) => {
- let input_schema_field_count = input_schema.fields().len();
- if *i < input_schema_field_count {
- Ok(input_schema.fields()[*i].clone())
- } else {
- Err(ExecutionError::General(format!(
- "Column index {} out of bounds for input schema with {}
field(s)",
- *i, input_schema_field_count
- )))
- }
- }
- Expr::Literal(ref lit) => Ok(Field::new("lit", lit.get_datatype(),
true)),
- Expr::ScalarFunction {
- ref name,
- ref return_type,
- ..
- } => Ok(Field::new(&name, return_type.clone(), true)),
- Expr::AggregateFunction {
- ref name,
- ref return_type,
- ..
- } => Ok(Field::new(&name, return_type.clone(), true)),
- Expr::Cast { ref data_type, .. } => {
- Ok(Field::new("cast", data_type.clone(), true))
- }
- Expr::BinaryExpr {
- ref left,
- ref right,
- ..
- } => {
- let left_type = left.get_type(input_schema)?;
- let right_type = right.get_type(input_schema)?;
- Ok(Field::new(
- "binary_expr",
- get_supertype(&left_type, &right_type).unwrap(),
- true,
- ))
- }
- _ => Err(ExecutionError::NotImplemented(format!(
- "Cannot determine schema type for expression {:?}",
- e
- ))),
- }
-}
-
-/// Create field meta-data from an expression, for use in a result set schema
-pub fn exprlist_to_fields(expr: &[Expr], input_schema: &Schema) ->
Result<Vec<Field>> {
- expr.iter()
- .map(|e| expr_to_field(e, input_schema))
- .collect()
-}
-
/// Given two datatypes, determine the supertype that both types can safely be
cast to
pub fn get_supertype(l: &DataType, r: &DataType) -> Result<DataType> {
match _get_supertype(l, r) {
@@ -248,29 +185,29 @@ fn _get_supertype(l: &DataType, r: &DataType) ->
Option<DataType> {
#[cfg(test)]
mod tests {
use super::*;
- use crate::logicalplan::Expr;
+ use crate::logicalplan::col;
use arrow::datatypes::DataType;
use std::collections::HashSet;
#[test]
fn test_collect_expr() -> Result<()> {
- let mut accum: HashSet<usize> = HashSet::new();
- expr_to_column_indices(
+ let mut accum: HashSet<String> = HashSet::new();
+ expr_to_column_names(
&Expr::Cast {
- expr: Box::new(Expr::Column(3)),
+ expr: Box::new(col("a")),
data_type: DataType::Float64,
},
&mut accum,
)?;
- expr_to_column_indices(
+ expr_to_column_names(
&Expr::Cast {
- expr: Box::new(Expr::Column(3)),
+ expr: Box::new(col("a")),
data_type: DataType::Float64,
},
&mut accum,
)?;
assert_eq!(1, accum.len());
- assert!(accum.contains(&3));
+ assert!(accum.contains("a"));
Ok(())
}
}
diff --git a/rust/datafusion/src/sql/planner.rs
b/rust/datafusion/src/sql/planner.rs
index 8c74c5a..21cf870 100644
--- a/rust/datafusion/src/sql/planner.rs
+++ b/rust/datafusion/src/sql/planner.rs
@@ -173,34 +173,24 @@ impl<S: SchemaProvider> SqlToRel<S> {
.aggregate(group_expr, aggr_expr)?
.build()?;
- // wrap in projection to preserve final order of fields
- let mut projected_fields = Vec::with_capacity(group_by_count +
aggr_count);
- let mut group_expr_index = 0;
- let mut aggr_expr_index = 0;
- for i in 0..projection_expr.len() {
- if is_aggregate_expr(&projection_expr[i]) {
- projected_fields.push(group_by_count + aggr_expr_index);
- aggr_expr_index += 1;
- } else {
- projected_fields.push(group_expr_index);
- group_expr_index += 1;
- }
- }
-
- // determine if projection is needed or not
- // NOTE this would be better done later in a query optimizer rule
- let mut projection_needed = false;
- for i in 0..projected_fields.len() {
- if projected_fields[i] != i {
- projection_needed = true;
- break;
- }
- }
-
- if projection_needed {
+ // optionally wrap in projection to preserve final order of fields
+ let expected_columns: Vec<String> = projection_expr
+ .iter()
+ .map(|e| e.name(input.schema()))
+ .collect::<Result<Vec<_>>>()?;
+ let columns: Vec<String> = plan
+ .schema()
+ .fields()
+ .iter()
+ .map(|f| f.name().clone())
+ .collect::<Vec<_>>();
+ if expected_columns != columns {
self.project(
&plan,
- projected_fields.iter().map(|i| Expr::Column(*i)).collect(),
+ expected_columns
+ .iter()
+ .map(|c| Expr::Column(c.clone()))
+ .collect(),
)
} else {
Ok(plan)
@@ -273,16 +263,14 @@ impl<S: SchemaProvider> SqlToRel<S> {
alias.to_owned(),
)),
- ASTNode::SQLIdentifier(ref id) => {
- match schema.fields().iter().position(|c| c.name().eq(id)) {
- Some(index) => Ok(Expr::Column(index)),
- None => Err(ExecutionError::ExecutionError(format!(
- "Invalid identifier '{}' for schema {}",
- id,
- schema.to_string()
- ))),
- }
- }
+ ASTNode::SQLIdentifier(ref id) => match schema.field_with_name(id)
{
+ Ok(field) => Ok(Expr::Column(field.name().clone())),
+ Err(_) => Err(ExecutionError::ExecutionError(format!(
+ "Invalid identifier '{}' for schema {}",
+ id,
+ schema.to_string()
+ ))),
+ },
ASTNode::SQLWildcard => Ok(Expr::Wildcard),
@@ -483,8 +471,8 @@ mod tests {
fn select_simple_selection() {
let sql = "SELECT id, first_name, last_name \
FROM person WHERE state = 'CO'";
- let expected = "Projection: #0, #1, #2\
- \n Selection: #4 Eq Utf8(\"CO\")\
+ let expected = "Projection: #id, #first_name, #last_name\
+ \n Selection: #state Eq Utf8(\"CO\")\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -493,8 +481,8 @@ mod tests {
fn select_neg_selection() {
let sql = "SELECT id, first_name, last_name \
FROM person WHERE NOT state";
- let expected = "Projection: #0, #1, #2\
- \n Selection: NOT #4\
+ let expected = "Projection: #id, #first_name, #last_name\
+ \n Selection: NOT #state\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -503,8 +491,8 @@ mod tests {
fn select_compound_selection() {
let sql = "SELECT id, first_name, last_name \
FROM person WHERE state = 'CO' AND age >= 21 AND age <= 65";
- let expected = "Projection: #0, #1, #2\
- \n Selection: #4 Eq Utf8(\"CO\") And #3 GtEq Int64(21) And #3
LtEq Int64(65)\
+ let expected = "Projection: #id, #first_name, #last_name\
+ \n Selection: #state Eq Utf8(\"CO\") And #age GtEq Int64(21) And
#age LtEq Int64(65)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -513,8 +501,8 @@ mod tests {
fn test_timestamp_selection() {
let sql = "SELECT state FROM person WHERE birth_date < CAST
(158412331400600000 as timestamp)";
- let expected = "Projection: #4\
- \n Selection: #6 Lt CAST(Int64(158412331400600000) AS
Timestamp(Nanosecond, None))\
+ let expected = "Projection: #state\
+ \n Selection: #birth_date Lt CAST(Int64(158412331400600000) AS
Timestamp(Nanosecond, None))\
\n TableScan: person projection=None";
quick_test(sql, expected);
@@ -530,13 +518,13 @@ mod tests {
AND age >= 21 \
AND age < 65 \
AND age <= 65";
- let expected = "Projection: #3, #1, #2\
- \n Selection: #3 Eq Int64(21) \
- And #3 NotEq Int64(21) \
- And #3 Gt Int64(21) \
- And #3 GtEq Int64(21) \
- And #3 Lt Int64(65) \
- And #3 LtEq Int64(65)\
+ let expected = "Projection: #age, #first_name, #last_name\
+ \n Selection: #age Eq Int64(21) \
+ And #age NotEq Int64(21) \
+ And #age Gt Int64(21) \
+ And #age GtEq Int64(21) \
+ And #age Lt Int64(65) \
+ And #age LtEq Int64(65)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -545,7 +533,7 @@ mod tests {
fn select_simple_aggregate() {
quick_test(
"SELECT MIN(age) FROM person",
- "Aggregate: groupBy=[[]], aggr=[[MIN(#3)]]\
+ "Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\
\n TableScan: person projection=None",
);
}
@@ -554,7 +542,7 @@ mod tests {
fn test_sum_aggregate() {
quick_test(
"SELECT SUM(age) from person",
- "Aggregate: groupBy=[[]], aggr=[[SUM(#3)]]\
+ "Aggregate: groupBy=[[]], aggr=[[SUM(#age)]]\
\n TableScan: person projection=None",
);
}
@@ -563,7 +551,7 @@ mod tests {
fn select_simple_aggregate_with_groupby() {
quick_test(
"SELECT state, MIN(age), MAX(age) FROM person GROUP BY state",
- "Aggregate: groupBy=[[#4]], aggr=[[MIN(#3), MAX(#3)]]\
+ "Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\
\n TableScan: person projection=None",
);
}
@@ -572,7 +560,7 @@ mod tests {
fn test_wildcard() {
quick_test(
"SELECT * from person",
- "Projection: #0, #1, #2, #3, #4, #5, #6\
+ "Projection: #id, #first_name, #last_name, #age, #state, #salary,
#birth_date\
\n TableScan: person projection=None",
);
}
@@ -588,7 +576,7 @@ mod tests {
#[test]
fn select_count_column() {
let sql = "SELECT COUNT(id) FROM person";
- let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#0)]]\
+ let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#id)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -596,7 +584,7 @@ mod tests {
#[test]
fn select_scalar_func() {
let sql = "SELECT sqrt(age) FROM person";
- let expected = "Projection: sqrt(CAST(#3 AS Float64))\
+ let expected = "Projection: sqrt(CAST(#age AS Float64))\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -604,7 +592,7 @@ mod tests {
#[test]
fn select_aliased_scalar_func() {
let sql = "SELECT sqrt(age) AS square_people FROM person";
- let expected = "Projection: sqrt(CAST(#3 AS Float64)) AS square_people\
+ let expected = "Projection: sqrt(CAST(#age AS Float64)) AS
square_people\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -612,8 +600,8 @@ mod tests {
#[test]
fn select_order_by() {
let sql = "SELECT id FROM person ORDER BY id";
- let expected = "Sort: #0 ASC NULLS FIRST\
- \n Projection: #0\
+ let expected = "Sort: #id ASC NULLS FIRST\
+ \n Projection: #id\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -621,8 +609,8 @@ mod tests {
#[test]
fn select_order_by_desc() {
let sql = "SELECT id FROM person ORDER BY id DESC";
- let expected = "Sort: #0 DESC NULLS FIRST\
- \n Projection: #0\
+ let expected = "Sort: #id DESC NULLS FIRST\
+ \n Projection: #id\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -631,15 +619,15 @@ mod tests {
fn select_order_by_nulls_last() {
quick_test(
"SELECT id FROM person ORDER BY id DESC NULLS LAST",
- "Sort: #0 DESC NULLS LAST\
- \n Projection: #0\
+ "Sort: #id DESC NULLS LAST\
+ \n Projection: #id\
\n TableScan: person projection=None",
);
quick_test(
"SELECT id FROM person ORDER BY id NULLS LAST",
- "Sort: #0 ASC NULLS LAST\
- \n Projection: #0\
+ "Sort: #id ASC NULLS LAST\
+ \n Projection: #id\
\n TableScan: person projection=None",
);
}
@@ -647,13 +635,24 @@ mod tests {
#[test]
fn select_group_by() {
let sql = "SELECT state FROM person GROUP BY state";
- let expected = "Aggregate: groupBy=[[#4]], aggr=[[]]\
+ let expected = "Aggregate: groupBy=[[#state]], aggr=[[]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
#[test]
+ fn select_group_by_needs_projection() {
+ let sql = "SELECT COUNT(state), state FROM person GROUP BY state";
+ let expected = "\
+ Projection: #COUNT(state), #state\
+ \n Aggregate: groupBy=[[#state]], aggr=[[COUNT(#state)]]\
+ \n TableScan: person projection=None";
+
+ quick_test(sql, expected);
+ }
+
+ #[test]
fn select_7480_1() {
let sql = "SELECT c1, MIN(c12) FROM aggregate_test_100 GROUP BY c1,
c13";
let err = logical_plan(sql).expect_err("query should have failed");
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index 70635e9..30a55f7 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -279,7 +279,7 @@ fn csv_query_group_by_int_count() -> Result<()> {
fn csv_query_group_by_string_min_max() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx)?;
- let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY
c1";
+ let sql = "SELECT c1, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY
c1";
let mut actual = execute(&mut ctx, sql);
actual.sort();
let expected =