This is an automated email from the ASF dual-hosted git repository.
wayne pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new abbfdcee64 Handle table reuse in semi and anti join (#6059)
abbfdcee64 is described below
commit abbfdcee6404f999116aa4e3c8a63cbf6236f2d9
Author: Nuttiiya Seekhao <[email protected]>
AuthorDate: Sun Apr 23 08:52:14 2023 -0400
Handle table reuse in semi and anti join (#6059)
cargo fmt
cargo clippy --fix
cleanup
---
datafusion/substrait/src/logical_plan/consumer.rs | 20 +++++++++++++++++---
datafusion/substrait/src/logical_plan/producer.rs | 15 ++++++++++++---
datafusion/substrait/tests/roundtrip_logical_plan.rs | 13 +++++++++++++
3 files changed, 42 insertions(+), 6 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 6e67a5d7c2..2b8ffde422 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -305,11 +305,25 @@ pub async fn from_substrait_rel(
from_substrait_rel(ctx, join.right.as_ref().unwrap(),
extensions).await?,
);
let join_type = from_substrait_jointype(join.r#type)?;
- let schema =
- build_join_schema(left.schema(), right.schema(),
&JoinType::Inner)?;
+ // The join condition expression needs full input schema and not
the output schema from join since we lose columns from
+ // certain join types such as semi and anti joins
+ // - if left and right schemas are different, we combine (join)
the schema to include all fields
+ // - if left and right schemas are the same, we handle the
duplicate fields by using `build_join_schema()`, which discard the unused schema
+ // TODO: Handle duplicate fields error for other join types
(non-semi/anti). The current approach does not work due to Substrait's inability
+ // to encode aliases
+ let join_schema = match left.schema().join(right.schema()) {
+ Ok(schema) => Ok(schema),
+ Err(DataFusionError::SchemaError(
+ datafusion::common::SchemaError::DuplicateQualifiedField {
+ qualifier: _,
+ name: _,
+ },
+ )) => build_join_schema(left.schema(), right.schema(),
&join_type),
+ Err(e) => Err(e),
+ };
let on = from_substrait_rex(
join.expression.as_ref().unwrap(),
- &schema,
+ &join_schema?,
extensions,
)
.await?;
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 230221897f..17d424cea6 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -284,7 +284,16 @@ pub fn to_substrait_rel(
// join schema from left and right to maintain all nececesary
columns from inputs
// note that we cannot simple use join.schema here since we
discard some input columns
// when performing semi and anti joins
- let join_schema = join.left.schema().join(join.right.schema());
+ let join_schema = match
join.left.schema().join(join.right.schema()) {
+ Ok(schema) => Ok(schema),
+ Err(DataFusionError::SchemaError(
+ datafusion::common::SchemaError::DuplicateQualifiedField {
+ qualifier: _,
+ name: _,
+ },
+ )) => Ok(join.schema.as_ref().clone()),
+ Err(e) => Err(e),
+ };
if let Some(e) = join_expression {
Ok(Box::new(Rel {
rel_type: Some(RelType::Join(Box::new(JoinRel {
@@ -1329,11 +1338,11 @@ mod test {
}
fn round_trip_literal(scalar: ScalarValue) -> Result<()> {
- println!("Checking round trip of {:?}", scalar);
+ println!("Checking round trip of {scalar:?}");
let substrait = to_substrait_literal(&scalar)?;
let Expression { rex_type: Some(RexType::Literal(substrait_literal)) }
= substrait else {
- panic!("Expected Literal expression, got {:?}", substrait);
+ panic!("Expected Literal expression, got {substrait:?}");
};
let roundtrip_scalar = from_substrait_literal(&substrait_literal)?;
diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/roundtrip_logical_plan.rs
index 2a79f61eb8..514eb463b1 100644
--- a/datafusion/substrait/tests/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs
@@ -250,6 +250,19 @@ mod tests {
.await
}
+ #[tokio::test]
+ async fn simple_intersect_table_reuse() -> Result<()> {
+ assert_expected_plan(
+ "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT
data.a FROM data);",
+ "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
+ \n LeftSemi Join: data.a = data.a\
+ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\
+ \n TableScan: data projection=[a]\
+ \n TableScan: data projection=[a]",
+ )
+ .await
+ }
+
#[tokio::test]
async fn simple_window_function() -> Result<()> {
roundtrip("SELECT RANK() OVER (PARTITION BY a ORDER BY b), d, SUM(b)
OVER (PARTITION BY a) FROM data;").await