This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 2047d7fe85 feat: Implement LeftMark join to fix subquery correctness
issue (#13134)
2047d7fe85 is described below
commit 2047d7fe8577d7c8fb079da6c0be1b544672ade2
Author: Emil Ejbyfeldt <[email protected]>
AuthorDate: Thu Oct 31 19:52:33 2024 +0100
feat: Implement LeftMark join to fix subquery correctness issue (#13134)
* Implement LeftMark join
In https://github.com/apache/datafusion/pull/12945 the emulation of an
mark join has a bug when there is duplicate values in the subquery. This
would be fixable by adding a distinct before the join. But this patch
instead implements a LeftMark join with the desired semantics and uses
that. The LeftMark join will return a row for each in the left input
with an additional column "mark" that is true if there was a match in
the right input and false otherwise.
Note: This patch does not implement the full null semantics for the mark
join described in
http://btw2017.informatik.uni-stuttgart.de/slidesandpapers/F1-10-37/paper_web.pdf
which which will be needed if we and `ANY` subqueries. The version is
this patch the mark column will only be true for had a match and false
when no match was found, never `null`.
* Use mark join in decorrelate subqueries
This fixes a correctness issue in the current approach.
* Add physical plan sqllogictest
* fmt
* Fix join type in doc comment
* Minor clean ups
* Add more documentation to LeftMark join
* Remove qualification
* fix doc
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/common/src/functional_dependencies.rs | 2 +-
datafusion/common/src/join_type.rs | 21 +++
datafusion/core/src/dataframe/mod.rs | 6 +-
.../src/physical_optimizer/enforce_distribution.rs | 11 +-
.../core/src/physical_optimizer/join_selection.rs | 5 +
.../core/src/physical_optimizer/sort_pushdown.rs | 8 +-
datafusion/core/tests/fuzz_cases/join_fuzz.rs | 24 +++
datafusion/expr/src/logical_plan/builder.rs | 24 +++
datafusion/expr/src/logical_plan/plan.rs | 8 +-
datafusion/optimizer/src/analyzer/subquery.rs | 5 +-
.../src/decorrelate_predicate_subquery.rs | 52 ++-----
.../optimizer/src/optimize_projections/mod.rs | 6 +-
datafusion/optimizer/src/push_down_filter.rs | 15 +-
datafusion/optimizer/src/push_down_limit.rs | 2 +-
datafusion/physical-expr/src/equivalence/class.rs | 2 +-
datafusion/physical-plan/src/joins/hash_join.rs | 100 +++++++++++++
.../physical-plan/src/joins/nested_loop_join.rs | 32 ++++
.../physical-plan/src/joins/sort_merge_join.rs | 163 ++++++++++++++++-----
.../physical-plan/src/joins/symmetric_hash_join.rs | 30 +++-
datafusion/physical-plan/src/joins/utils.rs | 114 +++++++-------
.../proto-common/proto/datafusion_common.proto | 4 +-
datafusion/proto-common/src/from_proto/mod.rs | 1 +
datafusion/proto-common/src/generated/pbjson.rs | 6 +
datafusion/proto-common/src/generated/prost.rs | 6 +
datafusion/proto-common/src/to_proto/mod.rs | 1 +
.../proto/src/generated/datafusion_proto_common.rs | 6 +
datafusion/proto/src/logical_plan/from_proto.rs | 1 +
datafusion/proto/src/logical_plan/to_proto.rs | 1 +
datafusion/sql/src/unparser/plan.rs | 9 +-
datafusion/sqllogictest/test_files/subquery.slt | 98 ++++++++-----
datafusion/substrait/src/logical_plan/consumer.rs | 1 +
datafusion/substrait/src/logical_plan/producer.rs | 5 +-
.../tests/cases/roundtrip_logical_plan.rs | 18 +--
33 files changed, 592 insertions(+), 195 deletions(-)
diff --git a/datafusion/common/src/functional_dependencies.rs
b/datafusion/common/src/functional_dependencies.rs
index ed9a68c195..31eafc7443 100644
--- a/datafusion/common/src/functional_dependencies.rs
+++ b/datafusion/common/src/functional_dependencies.rs
@@ -334,7 +334,7 @@ impl FunctionalDependencies {
left_func_dependencies.extend(right_func_dependencies);
left_func_dependencies
}
- JoinType::LeftSemi | JoinType::LeftAnti => {
+ JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
// These joins preserve functional dependencies of the left
side:
left_func_dependencies
}
diff --git a/datafusion/common/src/join_type.rs
b/datafusion/common/src/join_type.rs
index d502e7836d..e98f34199b 100644
--- a/datafusion/common/src/join_type.rs
+++ b/datafusion/common/src/join_type.rs
@@ -44,6 +44,20 @@ pub enum JoinType {
LeftAnti,
/// Right Anti Join
RightAnti,
+ /// Left Mark join
+ ///
+ /// Returns one record for each record from the left input. The output
contains an additional
+ /// column "mark" which is true if there is at least one match in the
right input where the
+ /// join condition evaluates to true. Otherwise, the mark column is false.
For more details see
+ /// [1]. This join type is used to decorrelate EXISTS subqueries used
inside disjunctive
+ /// predicates.
+ ///
+ /// Note: This we currently do not implement the full null semantics for
the mark join described
+ /// in [1] which will be needed if we and ANY subqueries. In our version
the mark column will
+ /// only be true for had a match and false when no match was found, never
null.
+ ///
+ /// [1]:
http://btw2017.informatik.uni-stuttgart.de/slidesandpapers/F1-10-37/paper_web.pdf
+ LeftMark,
}
impl JoinType {
@@ -63,6 +77,7 @@ impl Display for JoinType {
JoinType::RightSemi => "RightSemi",
JoinType::LeftAnti => "LeftAnti",
JoinType::RightAnti => "RightAnti",
+ JoinType::LeftMark => "LeftMark",
};
write!(f, "{join_type}")
}
@@ -82,6 +97,7 @@ impl FromStr for JoinType {
"RIGHTSEMI" => Ok(JoinType::RightSemi),
"LEFTANTI" => Ok(JoinType::LeftAnti),
"RIGHTANTI" => Ok(JoinType::RightAnti),
+ "LEFTMARK" => Ok(JoinType::LeftMark),
_ => _not_impl_err!("The join type {s} does not exist or is not
implemented"),
}
}
@@ -101,6 +117,7 @@ impl Display for JoinSide {
match self {
JoinSide::Left => write!(f, "left"),
JoinSide::Right => write!(f, "right"),
+ JoinSide::None => write!(f, "none"),
}
}
}
@@ -113,6 +130,9 @@ pub enum JoinSide {
Left,
/// Right side of the join
Right,
+ /// Neither side of the join, used for Mark joins where the mark column
does not belong to
+ /// either side of the join
+ None,
}
impl JoinSide {
@@ -121,6 +141,7 @@ impl JoinSide {
match self {
JoinSide::Left => JoinSide::Right,
JoinSide::Right => JoinSide::Left,
+ JoinSide::None => JoinSide::None,
}
}
}
diff --git a/datafusion/core/src/dataframe/mod.rs
b/datafusion/core/src/dataframe/mod.rs
index e5d352a63c..2c71cb80d7 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -3864,6 +3864,7 @@ mod tests {
JoinType::RightSemi,
JoinType::LeftAnti,
JoinType::RightAnti,
+ JoinType::LeftMark,
];
let default_partition_count = SessionConfig::new().target_partitions();
@@ -3881,7 +3882,10 @@ mod tests {
let join_schema = physical_plan.schema();
match join_type {
- JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => {
+ JoinType::Left
+ | JoinType::LeftSemi
+ | JoinType::LeftAnti
+ | JoinType::LeftMark => {
let left_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(Column::new_with_schema("c1", &join_schema)?),
Arc::new(Column::new_with_schema("c2", &join_schema)?),
diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs
b/datafusion/core/src/physical_optimizer/enforce_distribution.rs
index aa4bcb6837..ff8f16f4ee 100644
--- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs
+++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs
@@ -328,7 +328,8 @@ fn adjust_input_keys_ordering(
JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti
- | JoinType::Full => vec![],
+ | JoinType::Full
+ | JoinType::LeftMark => vec![],
};
}
PartitionMode::Auto => {
@@ -1959,6 +1960,7 @@ pub(crate) mod tests {
JoinType::Full,
JoinType::LeftSemi,
JoinType::LeftAnti,
+ JoinType::LeftMark,
JoinType::RightSemi,
JoinType::RightAnti,
];
@@ -1981,7 +1983,8 @@ pub(crate) mod tests {
| JoinType::Right
| JoinType::Full
| JoinType::LeftSemi
- | JoinType::LeftAnti => {
+ | JoinType::LeftAnti
+ | JoinType::LeftMark => {
// Join on (a == c)
let top_join_on = vec![(
Arc::new(Column::new_with_schema("a",
&join.schema()).unwrap())
@@ -1999,7 +2002,7 @@ pub(crate) mod tests {
let expected = match join_type {
// Should include 3 RepartitionExecs
- JoinType::Inner | JoinType::Left | JoinType::LeftSemi
| JoinType::LeftAnti => vec![
+ JoinType::Inner | JoinType::Left | JoinType::LeftSemi
| JoinType::LeftAnti | JoinType::LeftMark => vec![
top_join_plan.as_str(),
join_plan.as_str(),
"RepartitionExec: partitioning=Hash([a@0], 10),
input_partitions=10",
@@ -2098,7 +2101,7 @@ pub(crate) mod tests {
assert_optimized!(expected, top_join.clone(), true);
assert_optimized!(expected, top_join, false);
}
- JoinType::LeftSemi | JoinType::LeftAnti => {}
+ JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark
=> {}
}
}
diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs
b/datafusion/core/src/physical_optimizer/join_selection.rs
index 1c63df1f02..2bf706f33d 100644
--- a/datafusion/core/src/physical_optimizer/join_selection.rs
+++ b/datafusion/core/src/physical_optimizer/join_selection.rs
@@ -132,6 +132,9 @@ fn swap_join_type(join_type: JoinType) -> JoinType {
JoinType::RightSemi => JoinType::LeftSemi,
JoinType::LeftAnti => JoinType::RightAnti,
JoinType::RightAnti => JoinType::LeftAnti,
+ JoinType::LeftMark => {
+ unreachable!("LeftMark join type does not support swapping")
+ }
}
}
@@ -573,6 +576,7 @@ fn hash_join_convert_symmetric_subrule(
hash_join.right().equivalence_properties(),
hash_join.right().schema(),
),
+ JoinSide::None => return false,
};
let name = schema.field(*index).name();
@@ -588,6 +592,7 @@ fn hash_join_convert_symmetric_subrule(
match side {
JoinSide::Left =>
hash_join.left().output_ordering(),
JoinSide::Right =>
hash_join.right().output_ordering(),
+ JoinSide::None => unreachable!(),
}
.map(|p| p.to_vec())
})
diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs
b/datafusion/core/src/physical_optimizer/sort_pushdown.rs
index c7677d725b..fdbda1fe52 100644
--- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs
+++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs
@@ -384,6 +384,7 @@ fn try_pushdown_requirements_to_join(
return Ok(None);
}
}
+ JoinSide::None => return Ok(None),
};
let join_type = smj.join_type();
let probe_side = SortMergeJoinExec::probe_side(&join_type);
@@ -410,6 +411,7 @@ fn try_pushdown_requirements_to_join(
JoinSide::Right => {
required_input_ordering[1] = new_req;
}
+ JoinSide::None => unreachable!(),
}
required_input_ordering
}))
@@ -421,7 +423,11 @@ fn expr_source_side(
left_columns_len: usize,
) -> Option<JoinSide> {
match join_type {
- JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full =>
{
+ JoinType::Inner
+ | JoinType::Left
+ | JoinType::Right
+ | JoinType::Full
+ | JoinType::LeftMark => {
let all_column_sides = required_exprs
.iter()
.filter_map(|r| {
diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
index 5c03bc3a91..d7a3460e49 100644
--- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
@@ -234,6 +234,30 @@ async fn test_anti_join_1k_filtered() {
.await
}
+#[tokio::test]
+async fn test_left_mark_join_1k() {
+ JoinFuzzTestCase::new(
+ make_staggered_batches(1000),
+ make_staggered_batches(1000),
+ JoinType::LeftMark,
+ None,
+ )
+ .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
+ .await
+}
+
+#[tokio::test]
+async fn test_left_mark_join_1k_filtered() {
+ JoinFuzzTestCase::new(
+ make_staggered_batches(1000),
+ make_staggered_batches(1000),
+ JoinType::LeftMark,
+ Some(Box::new(col_lt_col_filter)),
+ )
+ .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
+ .await
+}
+
type JoinFilterBuilder = Box<dyn Fn(Arc<Schema>, Arc<Schema>) -> JoinFilter>;
struct JoinFuzzTestCase {
diff --git a/datafusion/expr/src/logical_plan/builder.rs
b/datafusion/expr/src/logical_plan/builder.rs
index e50ffb59d2..b7839c4873 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -20,6 +20,7 @@
use std::any::Any;
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
+use std::iter::once;
use std::sync::Arc;
use crate::dml::CopyTo;
@@ -1326,6 +1327,25 @@ pub fn change_redundant_column(fields: &Fields) ->
Vec<Field> {
})
.collect()
}
+
+fn mark_field(schema: &DFSchema) -> (Option<TableReference>, Arc<Field>) {
+ let mut table_references = schema
+ .iter()
+ .filter_map(|(qualifier, _)| qualifier)
+ .collect::<Vec<_>>();
+ table_references.dedup();
+ let table_reference = if table_references.len() == 1 {
+ table_references.pop().cloned()
+ } else {
+ None
+ };
+
+ (
+ table_reference,
+ Arc::new(Field::new("mark", DataType::Boolean, false)),
+ )
+}
+
/// Creates a schema for a join operation.
/// The fields from the left side are first
pub fn build_join_schema(
@@ -1392,6 +1412,10 @@ pub fn build_join_schema(
.map(|(q, f)| (q.cloned(), Arc::clone(f)))
.collect()
}
+ JoinType::LeftMark => left_fields
+ .map(|(q, f)| (q.cloned(), Arc::clone(f)))
+ .chain(once(mark_field(right)))
+ .collect(),
JoinType::RightSemi | JoinType::RightAnti => {
// Only use the right side for the schema
right_fields
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index a301c48659..8ba2a44842 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -532,7 +532,9 @@ impl LogicalPlan {
left.head_output_expr()
}
}
- JoinType::LeftSemi | JoinType::LeftAnti =>
left.head_output_expr(),
+ JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark
=> {
+ left.head_output_expr()
+ }
JoinType::RightSemi | JoinType::RightAnti =>
right.head_output_expr(),
},
LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) =>
{
@@ -1290,7 +1292,9 @@ impl LogicalPlan {
_ => None,
}
}
- JoinType::LeftSemi | JoinType::LeftAnti => left.max_rows(),
+ JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark
=> {
+ left.max_rows()
+ }
JoinType::RightSemi | JoinType::RightAnti => right.max_rows(),
},
LogicalPlan::Repartition(Repartition { input, .. }) =>
input.max_rows(),
diff --git a/datafusion/optimizer/src/analyzer/subquery.rs
b/datafusion/optimizer/src/analyzer/subquery.rs
index 0ffc954388..fa04835f09 100644
--- a/datafusion/optimizer/src/analyzer/subquery.rs
+++ b/datafusion/optimizer/src/analyzer/subquery.rs
@@ -181,7 +181,10 @@ fn check_inner_plan(inner_plan: &LogicalPlan,
can_contain_outer_ref: bool) -> Re
})?;
Ok(())
}
- JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => {
+ JoinType::Left
+ | JoinType::LeftSemi
+ | JoinType::LeftAnti
+ | JoinType::LeftMark => {
check_inner_plan(left, can_contain_outer_ref)?;
check_inner_plan(right, false)
}
diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs
b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs
index cc1687cffe..7fdad5ba4b 100644
--- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs
+++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs
@@ -17,7 +17,6 @@
//! [`DecorrelatePredicateSubquery`] converts `IN`/`EXISTS` subquery
predicates to `SEMI`/`ANTI` joins
use std::collections::BTreeSet;
-use std::iter;
use std::ops::Deref;
use std::sync::Arc;
@@ -34,11 +33,10 @@ use
datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
use datafusion_expr::logical_plan::{JoinType, Subquery};
use datafusion_expr::utils::{conjunction, split_conjunction_owned};
use datafusion_expr::{
- exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr,
Expr, Filter,
+ exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr, Expr,
Filter,
LogicalPlan, LogicalPlanBuilder, Operator,
};
-use itertools::chain;
use log::debug;
/// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left
semi/anti joins
@@ -138,17 +136,14 @@ fn rewrite_inner_subqueries(
Expr::Exists(Exists {
subquery: Subquery { subquery, .. },
negated,
- }) => {
- match existence_join(&cur_input, Arc::clone(&subquery), None,
negated, alias)?
- {
- Some((plan, exists_expr)) => {
- cur_input = plan;
- Ok(Transformed::yes(exists_expr))
- }
- None if negated => Ok(Transformed::no(not_exists(subquery))),
- None => Ok(Transformed::no(exists(subquery))),
+ }) => match mark_join(&cur_input, Arc::clone(&subquery), None,
negated, alias)? {
+ Some((plan, exists_expr)) => {
+ cur_input = plan;
+ Ok(Transformed::yes(exists_expr))
}
- }
+ None if negated => Ok(Transformed::no(not_exists(subquery))),
+ None => Ok(Transformed::no(exists(subquery))),
+ },
Expr::InSubquery(InSubquery {
expr,
subquery: Subquery { subquery, .. },
@@ -159,7 +154,7 @@ fn rewrite_inner_subqueries(
.map_or(plan_err!("single expression required."),
|output_expr| {
Ok(Expr::eq(*expr.clone(), output_expr))
})?;
- match existence_join(
+ match mark_join(
&cur_input,
Arc::clone(&subquery),
Some(in_predicate),
@@ -283,10 +278,6 @@ fn build_join_top(
build_join(left, subquery, in_predicate_opt, join_type, subquery_alias)
}
-/// Existence join is emulated by adding a non-nullable column to the subquery
and using a left join
-/// and checking if the column is null or not. If native support is added for
Existence/Mark then
-/// we should use that instead.
-///
/// This is used to handle the case when the subquery is embedded in a more
complex boolean
/// expression like and OR. For example
///
@@ -296,37 +287,26 @@ fn build_join_top(
///
/// ```text
/// Projection: t1.id
-/// Filter: t1.id < 0 OR __correlated_sq_1.__exists IS NOT NULL
-/// Left Join: Filter: t1.id = __correlated_sq_1.id
+/// Filter: t1.id < 0 OR __correlated_sq_1.mark
+/// LeftMark Join: Filter: t1.id = __correlated_sq_1.id
/// TableScan: t1
/// SubqueryAlias: __correlated_sq_1
-/// Projection: t2.id, true as __exists
+/// Projection: t2.id
/// TableScan: t2
-fn existence_join(
+fn mark_join(
left: &LogicalPlan,
subquery: Arc<LogicalPlan>,
in_predicate_opt: Option<Expr>,
negated: bool,
alias_generator: &Arc<AliasGenerator>,
) -> Result<Option<(LogicalPlan, Expr)>> {
- // Add non nullable column to emulate existence join
- let always_true_expr = lit(true).alias("__exists");
- let cols = chain(
- subquery.schema().columns().into_iter().map(Expr::Column),
- iter::once(always_true_expr),
- );
- let subquery = LogicalPlanBuilder::from(subquery).project(cols)?.build()?;
let alias = alias_generator.next("__correlated_sq");
- let exists_col = Expr::Column(Column::new(Some(alias.clone()),
"__exists"));
- let exists_expr = if negated {
- exists_col.is_null()
- } else {
- exists_col.is_not_null()
- };
+ let exists_col = Expr::Column(Column::new(Some(alias.clone()), "mark"));
+ let exists_expr = if negated { !exists_col } else { exists_col };
Ok(
- build_join(left, &subquery, in_predicate_opt, JoinType::Left, alias)?
+ build_join(left, &subquery, in_predicate_opt, JoinType::LeftMark,
alias)?
.map(|plan| (plan, exists_expr)),
)
}
diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs
b/datafusion/optimizer/src/optimize_projections/mod.rs
index 42eff7100f..94c04d6328 100644
--- a/datafusion/optimizer/src/optimize_projections/mod.rs
+++ b/datafusion/optimizer/src/optimize_projections/mod.rs
@@ -677,7 +677,11 @@ fn split_join_requirements(
) -> (RequiredIndicies, RequiredIndicies) {
match join_type {
// In these cases requirements are split between left/right children:
- JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full =>
{
+ JoinType::Inner
+ | JoinType::Left
+ | JoinType::Right
+ | JoinType::Full
+ | JoinType::LeftMark => {
// Decrease right side indices by `left_len` so that they point to
valid
// positions within the right child:
indices.split_off(left_len)
diff --git a/datafusion/optimizer/src/push_down_filter.rs
b/datafusion/optimizer/src/push_down_filter.rs
index 3b3693cc94..1d0fc207cb 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -163,7 +163,7 @@ pub(crate) fn lr_is_preserved(join_type: JoinType) ->
(bool, bool) {
JoinType::Full => (false, false),
// No columns from the right side of the join can be referenced in
output
// predicates for semi/anti joins, so whether we specify t/f doesn't
matter.
- JoinType::LeftSemi | JoinType::LeftAnti => (true, false),
+ JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true,
false),
// No columns from the left side of the join can be referenced in
output
// predicates for semi/anti joins, so whether we specify t/f doesn't
matter.
JoinType::RightSemi | JoinType::RightAnti => (false, true),
@@ -188,6 +188,7 @@ pub(crate) fn on_lr_is_preserved(join_type: JoinType) ->
(bool, bool) {
JoinType::LeftSemi | JoinType::RightSemi => (true, true),
JoinType::LeftAnti => (false, true),
JoinType::RightAnti => (true, false),
+ JoinType::LeftMark => (false, true),
}
}
@@ -732,11 +733,13 @@ fn infer_join_predicates_from_on_filters(
on_filters,
inferred_predicates,
),
- JoinType::Left | JoinType::LeftSemi =>
infer_join_predicates_impl::<true, false>(
- join_col_keys,
- on_filters,
- inferred_predicates,
- ),
+ JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => {
+ infer_join_predicates_impl::<true, false>(
+ join_col_keys,
+ on_filters,
+ inferred_predicates,
+ )
+ }
JoinType::Right | JoinType::RightSemi => {
infer_join_predicates_impl::<false, true>(
join_col_keys,
diff --git a/datafusion/optimizer/src/push_down_limit.rs
b/datafusion/optimizer/src/push_down_limit.rs
index ec7a0a1364..8a3aa4bb84 100644
--- a/datafusion/optimizer/src/push_down_limit.rs
+++ b/datafusion/optimizer/src/push_down_limit.rs
@@ -248,7 +248,7 @@ fn push_down_join(mut join: Join, limit: usize) ->
Transformed<Join> {
let (left_limit, right_limit) = if is_no_join_condition(&join) {
match join.join_type {
Left | Right | Full | Inner => (Some(limit), Some(limit)),
- LeftAnti | LeftSemi => (Some(limit), None),
+ LeftAnti | LeftSemi | LeftMark => (Some(limit), None),
RightAnti | RightSemi => (None, Some(limit)),
}
} else {
diff --git a/datafusion/physical-expr/src/equivalence/class.rs
b/datafusion/physical-expr/src/equivalence/class.rs
index c1851ddb22..7305bc1b0a 100644
--- a/datafusion/physical-expr/src/equivalence/class.rs
+++ b/datafusion/physical-expr/src/equivalence/class.rs
@@ -632,7 +632,7 @@ impl EquivalenceGroup {
}
result
}
- JoinType::LeftSemi | JoinType::LeftAnti => self.clone(),
+ JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark =>
self.clone(),
JoinType::RightSemi | JoinType::RightAnti =>
right_equivalences.clone(),
}
}
diff --git a/datafusion/physical-plan/src/joins/hash_join.rs
b/datafusion/physical-plan/src/joins/hash_join.rs
index 2d11e03814..c56c179c17 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -524,6 +524,7 @@ impl HashJoinExec {
| JoinType::Full
| JoinType::LeftAnti
| JoinType::LeftSemi
+ | JoinType::LeftMark
));
let mode = if pipeline_breaking {
@@ -3091,6 +3092,94 @@ mod tests {
Ok(())
}
+ #[apply(batch_sizes)]
+ #[tokio::test]
+ async fn join_left_mark(batch_size: usize) -> Result<()> {
+ let task_ctx = prepare_task_ctx(batch_size);
+ let left = build_table(
+ ("a1", &vec![1, 2, 3]),
+ ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
+ ("c1", &vec![7, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![4, 5, 6]),
+ ("c2", &vec![70, 80, 90]),
+ );
+ let on = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+ )];
+
+ let (columns, batches) = join_collect(
+ Arc::clone(&left),
+ Arc::clone(&right),
+ on.clone(),
+ &JoinType::LeftMark,
+ false,
+ task_ctx,
+ )
+ .await?;
+ assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
+
+ let expected = [
+ "+----+----+----+-------+",
+ "| a1 | b1 | c1 | mark |",
+ "+----+----+----+-------+",
+ "| 1 | 4 | 7 | true |",
+ "| 2 | 5 | 8 | true |",
+ "| 3 | 7 | 9 | false |",
+ "+----+----+----+-------+",
+ ];
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+
+ #[apply(batch_sizes)]
+ #[tokio::test]
+ async fn partitioned_join_left_mark(batch_size: usize) -> Result<()> {
+ let task_ctx = prepare_task_ctx(batch_size);
+ let left = build_table(
+ ("a1", &vec![1, 2, 3]),
+ ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
+ ("c1", &vec![7, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30, 40]),
+ ("b1", &vec![4, 4, 5, 6]),
+ ("c2", &vec![60, 70, 80, 90]),
+ );
+ let on = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+ )];
+
+ let (columns, batches) = partitioned_join_collect(
+ Arc::clone(&left),
+ Arc::clone(&right),
+ on.clone(),
+ &JoinType::LeftMark,
+ false,
+ task_ctx,
+ )
+ .await?;
+ assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
+
+ let expected = [
+ "+----+----+----+-------+",
+ "| a1 | b1 | c1 | mark |",
+ "+----+----+----+-------+",
+ "| 1 | 4 | 7 | true |",
+ "| 2 | 5 | 8 | true |",
+ "| 3 | 7 | 9 | false |",
+ "+----+----+----+-------+",
+ ];
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+
#[test]
fn join_with_hash_collision() -> Result<()> {
let mut hashmap_left = RawTable::with_capacity(2);
@@ -3476,6 +3565,15 @@ mod tests {
"| 30 | 6 | 90 |",
"+----+----+----+",
];
+ let expected_left_mark = vec![
+ "+----+----+----+-------+",
+ "| a1 | b1 | c1 | mark |",
+ "+----+----+----+-------+",
+ "| 1 | 4 | 7 | true |",
+ "| 2 | 5 | 8 | true |",
+ "| 3 | 7 | 9 | false |",
+ "+----+----+----+-------+",
+ ];
let test_cases = vec![
(JoinType::Inner, expected_inner),
@@ -3486,6 +3584,7 @@ mod tests {
(JoinType::LeftAnti, expected_left_anti),
(JoinType::RightSemi, expected_right_semi),
(JoinType::RightAnti, expected_right_anti),
+ (JoinType::LeftMark, expected_left_mark),
];
for (join_type, expected) in test_cases {
@@ -3768,6 +3867,7 @@ mod tests {
JoinType::LeftAnti,
JoinType::RightSemi,
JoinType::RightAnti,
+ JoinType::LeftMark,
];
for join_type in join_types {
diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs
b/datafusion/physical-plan/src/joins/nested_loop_join.rs
index 358ff02473..957230f513 100644
--- a/datafusion/physical-plan/src/joins/nested_loop_join.rs
+++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs
@@ -1244,6 +1244,37 @@ pub(crate) mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn join_left_mark_with_filter() -> Result<()> {
+ let task_ctx = Arc::new(TaskContext::default());
+ let left = build_left_table();
+ let right = build_right_table();
+
+ let filter = prepare_join_filter();
+ let (columns, batches) = multi_partitioned_join_collect(
+ left,
+ right,
+ &JoinType::LeftMark,
+ Some(filter),
+ task_ctx,
+ )
+ .await?;
+ assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]);
+ let expected = [
+ "+----+----+-----+-------+",
+ "| a1 | b1 | c1 | mark |",
+ "+----+----+-----+-------+",
+ "| 11 | 8 | 110 | false |",
+ "| 5 | 5 | 50 | true |",
+ "| 9 | 8 | 90 | false |",
+ "+----+----+-----+-------+",
+ ];
+
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+
#[tokio::test]
async fn test_overallocation() -> Result<()> {
let left = build_table(
@@ -1269,6 +1300,7 @@ pub(crate) mod tests {
JoinType::Full,
JoinType::LeftSemi,
JoinType::LeftAnti,
+ JoinType::LeftMark,
JoinType::RightSemi,
JoinType::RightAnti,
];
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index b299b495c5..20fafcc347 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -35,7 +35,9 @@ use std::sync::Arc;
use std::task::{Context, Poll};
use arrow::array::*;
-use arrow::compute::{self, concat_batches, filter_record_batch, take,
SortOptions};
+use arrow::compute::{
+ self, concat_batches, filter_record_batch, is_not_null, take, SortOptions,
+};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use arrow::error::ArrowError;
use arrow::ipc::reader::FileReader;
@@ -178,7 +180,8 @@ impl SortMergeJoinExec {
| JoinType::Left
| JoinType::Full
| JoinType::LeftAnti
- | JoinType::LeftSemi => JoinSide::Left,
+ | JoinType::LeftSemi
+ | JoinType::LeftMark => JoinSide::Left,
}
}
@@ -186,7 +189,10 @@ impl SortMergeJoinExec {
fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
match join_type {
JoinType::Inner => vec![true, false],
- JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti =>
vec![true, false],
+ JoinType::Left
+ | JoinType::LeftSemi
+ | JoinType::LeftAnti
+ | JoinType::LeftMark => vec![true, false],
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
vec![false, true]
}
@@ -784,6 +790,29 @@ fn get_corrected_filter_mask(
corrected_mask.extend(vec![Some(false); null_matched]);
Some(corrected_mask.finish())
}
+ JoinType::LeftMark => {
+ for i in 0..row_indices_length {
+ let last_index =
+ last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
+ if filter_mask.value(i) && !seen_true {
+ seen_true = true;
+ corrected_mask.append_value(true);
+ } else if seen_true || !filter_mask.value(i) && !last_index {
+ corrected_mask.append_null(); // to be ignored and not set
to output
+ } else {
+ corrected_mask.append_value(false); // to be converted to
null joined row
+ }
+
+ if last_index {
+ seen_true = false;
+ }
+ }
+
+ // Generate null joined rows for records which have no matching
join key
+ let null_matched = expected_size - corrected_mask.len();
+ corrected_mask.extend(vec![Some(false); null_matched]);
+ Some(corrected_mask.finish())
+ }
JoinType::LeftSemi => {
for i in 0..row_indices_length {
let last_index =
@@ -860,6 +889,7 @@ impl Stream for SMJStream {
self.join_type,
JoinType::Left
| JoinType::LeftSemi
+ | JoinType::LeftMark
| JoinType::Right
| JoinType::LeftAnti
)
@@ -943,6 +973,7 @@ impl Stream for SMJStream {
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
+ | JoinType::LeftMark
)
{
continue;
@@ -964,6 +995,7 @@ impl Stream for SMJStream {
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
+ | JoinType::LeftMark
)
{
let out = self.filter_joined_batch()?;
@@ -1264,6 +1296,8 @@ impl SMJStream {
let mut join_streamed = false;
// Whether to join buffered rows
let mut join_buffered = false;
+ // For Mark join we store a dummy id to indicate the the row has a
match
+ let mut mark_row_as_match = false;
// determine whether we need to join streamed/buffered rows
match self.current_ordering {
@@ -1275,12 +1309,14 @@ impl SMJStream {
| JoinType::RightSemi
| JoinType::Full
| JoinType::LeftAnti
+ | JoinType::LeftMark
) {
join_streamed = !self.streamed_joined;
}
}
Ordering::Equal => {
- if matches!(self.join_type, JoinType::LeftSemi) {
+ if matches!(self.join_type, JoinType::LeftSemi |
JoinType::LeftMark) {
+ mark_row_as_match = matches!(self.join_type,
JoinType::LeftMark);
// if the join filter is specified then its needed to
output the streamed index
// only if it has not been emitted before
// the `join_filter_matched_idxs` keeps track on if
streamed index has a successful
@@ -1357,9 +1393,11 @@ impl SMJStream {
} else {
Some(self.buffered_data.scanning_batch_idx)
};
+ // For Mark join we store a dummy id to indicate the the row has a
match
+ let scanning_idx = mark_row_as_match.then_some(0);
self.streamed_batch
- .append_output_pair(scanning_batch_idx, None);
+ .append_output_pair(scanning_batch_idx, scanning_idx);
self.output_size += 1;
self.buffered_data.scanning_finish();
self.streamed_joined = true;
@@ -1461,24 +1499,25 @@ impl SMJStream {
// The row indices of joined buffered batch
let buffered_indices: UInt64Array =
chunk.buffered_indices.finish();
- let mut buffered_columns =
- if matches!(self.join_type, JoinType::LeftSemi |
JoinType::LeftAnti) {
- vec![]
- } else if let Some(buffered_idx) = chunk.buffered_batch_idx {
- get_buffered_columns(
- &self.buffered_data,
- buffered_idx,
- &buffered_indices,
- )?
- } else {
- // If buffered batch none, meaning it is null joined batch.
- // We need to create null arrays for buffered columns to
join with streamed rows.
- self.buffered_schema
- .fields()
- .iter()
- .map(|f| new_null_array(f.data_type(),
buffered_indices.len()))
- .collect::<Vec<_>>()
- };
+ let mut buffered_columns = if matches!(self.join_type,
JoinType::LeftMark) {
+ vec![Arc::new(is_not_null(&buffered_indices)?) as ArrayRef]
+ } else if matches!(self.join_type, JoinType::LeftSemi |
JoinType::LeftAnti) {
+ vec![]
+ } else if let Some(buffered_idx) = chunk.buffered_batch_idx {
+ get_buffered_columns(
+ &self.buffered_data,
+ buffered_idx,
+ &buffered_indices,
+ )?
+ } else {
+ // If buffered batch none, meaning it is null joined batch.
+ // We need to create null arrays for buffered columns to join
with streamed rows.
+ create_unmatched_columns(
+ self.join_type,
+ &self.buffered_schema,
+ buffered_indices.len(),
+ )
+ };
let streamed_columns_length = streamed_columns.len();
@@ -1489,7 +1528,7 @@ impl SMJStream {
get_filter_column(&self.filter, &buffered_columns,
&streamed_columns)
} else if matches!(
self.join_type,
- JoinType::LeftSemi | JoinType::LeftAnti
+ JoinType::LeftSemi | JoinType::LeftAnti |
JoinType::LeftMark
) {
// unwrap is safe here as we check is_some on top of if
statement
let buffered_columns = get_buffered_columns(
@@ -1517,7 +1556,6 @@ impl SMJStream {
};
let output_batch = RecordBatch::try_new(Arc::clone(&self.schema),
columns)?;
-
// Apply join filter if any
if !filter_columns.is_empty() {
if let Some(f) = &self.filter {
@@ -1553,6 +1591,7 @@ impl SMJStream {
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
+ | JoinType::LeftMark
) {
self.output_record_batches
.batches
@@ -1691,6 +1730,7 @@ impl SMJStream {
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
+ | JoinType::LeftMark
))
{
self.output_record_batches.batches.clear();
@@ -1721,16 +1761,18 @@ impl SMJStream {
let buffered_columns_length = self.buffered_schema.fields.len();
let streamed_columns_length = self.streamed_schema.fields.len();
- if matches!(self.join_type, JoinType::Left | JoinType::Right) {
+ if matches!(
+ self.join_type,
+ JoinType::Left | JoinType::LeftMark | JoinType::Right
+ ) {
let null_mask = compute::not(corrected_mask)?;
let null_joined_batch = filter_record_batch(&record_batch,
&null_mask)?;
- let mut buffered_columns = self
- .buffered_schema
- .fields()
- .iter()
- .map(|f| new_null_array(f.data_type(),
null_joined_batch.num_rows()))
- .collect::<Vec<_>>();
+ let mut buffered_columns = create_unmatched_columns(
+ self.join_type,
+ &self.buffered_schema,
+ null_joined_batch.num_rows(),
+ );
let columns = if matches!(self.join_type, JoinType::Right) {
let streamed_columns = null_joined_batch
@@ -1777,6 +1819,22 @@ impl SMJStream {
}
}
+fn create_unmatched_columns(
+ join_type: JoinType,
+ schema: &SchemaRef,
+ size: usize,
+) -> Vec<ArrayRef> {
+ if matches!(join_type, JoinType::LeftMark) {
+ vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef]
+ } else {
+ schema
+ .fields()
+ .iter()
+ .map(|f| new_null_array(f.data_type(), size))
+ .collect::<Vec<_>>()
+ }
+}
+
/// Gets the arrays which join filters are applied on.
fn get_filter_column(
join_filter: &Option<JoinFilter>,
@@ -2716,6 +2774,39 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn join_left_mark() -> Result<()> {
+ let left = build_table(
+ ("a1", &vec![1, 2, 2, 3]),
+ ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right
+ ("c1", &vec![7, 8, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30, 40]),
+ ("b1", &vec![4, 4, 5, 6]), // 5 is double on the right
+ ("c2", &vec![60, 70, 80, 90]),
+ );
+ let on = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+ )];
+
+ let (_, batches) = join_collect(left, right, on, LeftMark).await?;
+ let expected = [
+ "+----+----+----+-------+",
+ "| a1 | b1 | c1 | mark |",
+ "+----+----+----+-------+",
+ "| 1 | 4 | 7 | true |",
+ "| 2 | 5 | 8 | true |",
+ "| 2 | 5 | 8 | true |",
+ "| 3 | 7 | 9 | false |",
+ "+----+----+----+-------+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_eq!(expected, &batches);
+ Ok(())
+ }
+
#[tokio::test]
async fn join_with_duplicated_column_names() -> Result<()> {
let left = build_table(
@@ -3047,7 +3138,7 @@ mod tests {
)];
let sort_options = vec![SortOptions::default(); on.len()];
- let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti];
+ let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti,
LeftMark];
// Disable DiskManager to prevent spilling
let runtime = RuntimeEnvBuilder::new()
@@ -3125,7 +3216,7 @@ mod tests {
)];
let sort_options = vec![SortOptions::default(); on.len()];
- let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti];
+ let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti,
LeftMark];
// Disable DiskManager to prevent spilling
let runtime = RuntimeEnvBuilder::new()
@@ -3181,7 +3272,7 @@ mod tests {
)];
let sort_options = vec![SortOptions::default(); on.len()];
- let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti];
+ let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti,
LeftMark];
// Enable DiskManager to allow spilling
let runtime = RuntimeEnvBuilder::new()
@@ -3282,7 +3373,7 @@ mod tests {
)];
let sort_options = vec![SortOptions::default(); on.len()];
- let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti];
+ let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti,
LeftMark];
// Enable DiskManager to allow spilling
let runtime = RuntimeEnvBuilder::new()
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index eb6a30d17e..3e0cd48da2 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -62,6 +62,7 @@ use arrow::array::{
use arrow::compute::concat_batches;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
+use arrow_buffer::ArrowNativeType;
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::utils::bisect;
use datafusion_common::{internal_err, plan_err, JoinSide, JoinType, Result};
@@ -670,7 +671,11 @@ fn need_to_produce_result_in_final(build_side: JoinSide,
join_type: JoinType) ->
if build_side == JoinSide::Left {
matches!(
join_type,
- JoinType::Left | JoinType::LeftAnti | JoinType::Full |
JoinType::LeftSemi
+ JoinType::Left
+ | JoinType::LeftAnti
+ | JoinType::Full
+ | JoinType::LeftSemi
+ | JoinType::LeftMark
)
} else {
matches!(
@@ -709,6 +714,20 @@ where
{
// Store the result in a tuple
let result = match (build_side, join_type) {
+ (JoinSide::Left, JoinType::LeftMark) => {
+ let build_indices = (0..prune_length)
+ .map(L::Native::from_usize)
+ .collect::<PrimitiveArray<L>>();
+ let probe_indices = (0..prune_length)
+ .map(|idx| {
+ // For mark join we output a dummy index 0 to indicate the
row had a match
+ visited_rows
+ .contains(&(idx + deleted_offset))
+ .then_some(R::Native::from_usize(0).unwrap())
+ })
+ .collect();
+ (build_indices, probe_indices)
+ }
// In the case of `Left` or `Right` join, or `Full` join, get the anti
indices
(JoinSide::Left, JoinType::Left | JoinType::LeftAnti)
| (JoinSide::Right, JoinType::Right | JoinType::RightAnti)
@@ -872,6 +891,7 @@ pub(crate) fn join_with_probe_batch(
JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftSemi
+ | JoinType::LeftMark
| JoinType::RightSemi
) {
Ok(None)
@@ -1707,6 +1727,7 @@ mod tests {
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
+ JoinType::LeftMark,
JoinType::RightAnti,
JoinType::Full
)]
@@ -1791,6 +1812,7 @@ mod tests {
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
+ JoinType::LeftMark,
JoinType::RightAnti,
JoinType::Full
)]
@@ -1855,6 +1877,7 @@ mod tests {
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
+ JoinType::LeftMark,
JoinType::RightAnti,
JoinType::Full
)]
@@ -1906,6 +1929,7 @@ mod tests {
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
+ JoinType::LeftMark,
JoinType::RightAnti,
JoinType::Full
)]
@@ -1933,6 +1957,7 @@ mod tests {
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
+ JoinType::LeftMark,
JoinType::RightAnti,
JoinType::Full
)]
@@ -2298,6 +2323,7 @@ mod tests {
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
+ JoinType::LeftMark,
JoinType::RightAnti,
JoinType::Full
)]
@@ -2380,6 +2406,7 @@ mod tests {
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
+ JoinType::LeftMark,
JoinType::RightAnti,
JoinType::Full
)]
@@ -2454,6 +2481,7 @@ mod tests {
JoinType::RightSemi,
JoinType::LeftSemi,
JoinType::LeftAnti,
+ JoinType::LeftMark,
JoinType::RightAnti,
JoinType::Full
)]
diff --git a/datafusion/physical-plan/src/joins/utils.rs
b/datafusion/physical-plan/src/joins/utils.rs
index 090cf9aa62..e7c191f983 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -20,6 +20,7 @@
use std::collections::HashSet;
use std::fmt::{self, Debug};
use std::future::Future;
+use std::iter::once;
use std::ops::{IndexMut, Range};
use std::sync::Arc;
use std::task::{Context, Poll};
@@ -619,6 +620,7 @@ fn output_join_field(old_field: &Field, join_type:
&JoinType, is_left: bool) ->
JoinType::RightSemi => false, // doesn't introduce nulls
JoinType::LeftAnti => false, // doesn't introduce nulls (or can it??)
JoinType::RightAnti => false, // doesn't introduce nulls (or can it??)
+ JoinType::LeftMark => false,
};
if force_nullable {
@@ -635,44 +637,10 @@ pub fn build_join_schema(
right: &Schema,
join_type: &JoinType,
) -> (Schema, Vec<ColumnIndex>) {
- let (fields, column_indices): (SchemaBuilder, Vec<ColumnIndex>) = match
join_type {
- JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right =>
{
- let left_fields = left
- .fields()
- .iter()
- .map(|f| output_join_field(f, join_type, true))
- .enumerate()
- .map(|(index, f)| {
- (
- f,
- ColumnIndex {
- index,
- side: JoinSide::Left,
- },
- )
- });
- let right_fields = right
- .fields()
- .iter()
- .map(|f| output_join_field(f, join_type, false))
- .enumerate()
- .map(|(index, f)| {
- (
- f,
- ColumnIndex {
- index,
- side: JoinSide::Right,
- },
- )
- });
-
- // left then right
- left_fields.chain(right_fields).unzip()
- }
- JoinType::LeftSemi | JoinType::LeftAnti => left
- .fields()
+ let left_fields = || {
+ left.fields()
.iter()
- .cloned()
+ .map(|f| output_join_field(f, join_type, true))
.enumerate()
.map(|(index, f)| {
(
@@ -683,11 +651,13 @@ pub fn build_join_schema(
},
)
})
- .unzip(),
- JoinType::RightSemi | JoinType::RightAnti => right
+ };
+
+ let right_fields = || {
+ right
.fields()
.iter()
- .cloned()
+ .map(|f| output_join_field(f, join_type, false))
.enumerate()
.map(|(index, f)| {
(
@@ -698,7 +668,25 @@ pub fn build_join_schema(
},
)
})
- .unzip(),
+ };
+
+ let (fields, column_indices): (SchemaBuilder, Vec<ColumnIndex>) = match
join_type {
+ JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right =>
{
+ // left then right
+ left_fields().chain(right_fields()).unzip()
+ }
+ JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(),
+ JoinType::LeftMark => {
+ let right_field = once((
+ Field::new("mark", arrow_schema::DataType::Boolean, false),
+ ColumnIndex {
+ index: 0,
+ side: JoinSide::None,
+ },
+ ));
+ left_fields().chain(right_field).unzip()
+ }
+ JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(),
};
let metadata = left
@@ -902,6 +890,16 @@ fn estimate_join_cardinality(
column_statistics: outer_stats.column_statistics,
})
}
+
+ JoinType::LeftMark => {
+ let num_rows = *left_stats.num_rows.get_value()?;
+ let mut column_statistics = left_stats.column_statistics;
+ column_statistics.push(ColumnStatistics::new_unknown());
+ Some(PartialJoinStatistics {
+ num_rows,
+ column_statistics,
+ })
+ }
}
}
@@ -1153,7 +1151,11 @@ impl<T: 'static> OnceFut<T> {
pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool {
matches!(
join_type,
- JoinType::Left | JoinType::LeftAnti | JoinType::LeftSemi |
JoinType::Full
+ JoinType::Left
+ | JoinType::LeftAnti
+ | JoinType::LeftSemi
+ | JoinType::LeftMark
+ | JoinType::Full
)
}
@@ -1171,6 +1173,13 @@ pub(crate) fn get_final_indices_from_bit_map(
join_type: JoinType,
) -> (UInt64Array, UInt32Array) {
let left_size = left_bit_map.len();
+ if join_type == JoinType::LeftMark {
+ let left_indices = (0..left_size as u64).collect::<UInt64Array>();
+ let right_indices = (0..left_size)
+ .map(|idx| left_bit_map.get_bit(idx).then_some(0))
+ .collect::<UInt32Array>();
+ return (left_indices, right_indices);
+ }
let left_indices = if join_type == JoinType::LeftSemi {
(0..left_size)
.filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as
u64))
@@ -1254,7 +1263,10 @@ pub(crate) fn build_batch_from_indices(
let mut columns: Vec<Arc<dyn Array>> =
Vec::with_capacity(schema.fields().len());
for column_index in column_indices {
- let array = if column_index.side == build_side {
+ let array = if column_index.side == JoinSide::None {
+ // LeftMark join, the mark column is a true if the indices is not
null, otherwise it will be false
+ Arc::new(compute::is_not_null(probe_indices)?)
+ } else if column_index.side == build_side {
let array = build_input_buffer.column(column_index.index);
if array.is_empty() || build_indices.null_count() ==
build_indices.len() {
// Outer join would generate a null index when finding no
match at our side.
@@ -1323,7 +1335,7 @@ pub(crate) fn adjust_indices_by_join_type(
// the left_indices will not be used later for the `right anti`
join
Ok((left_indices, right_indices))
}
- JoinType::LeftSemi | JoinType::LeftAnti => {
+ JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
// matched or unmatched left row will be produced in the end of
loop
// When visit the right batch, we can output the matched left row
and don't need to wait the end of loop
Ok((
@@ -1646,7 +1658,7 @@ pub(crate) fn symmetric_join_output_partitioning(
let left_partitioning = left.output_partitioning();
let right_partitioning = right.output_partitioning();
match join_type {
- JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => {
+ JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti |
JoinType::LeftMark => {
left_partitioning.clone()
}
JoinType::RightSemi | JoinType::RightAnti =>
right_partitioning.clone(),
@@ -1671,11 +1683,13 @@ pub(crate) fn asymmetric_join_output_partitioning(
left.schema().fields().len(),
),
JoinType::RightSemi | JoinType::RightAnti =>
right.output_partitioning().clone(),
- JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti |
JoinType::Full => {
- Partitioning::UnknownPartitioning(
- right.output_partitioning().partition_count(),
- )
- }
+ JoinType::Left
+ | JoinType::LeftSemi
+ | JoinType::LeftAnti
+ | JoinType::Full
+ | JoinType::LeftMark => Partitioning::UnknownPartitioning(
+ right.output_partitioning().partition_count(),
+ ),
}
}
diff --git a/datafusion/proto-common/proto/datafusion_common.proto
b/datafusion/proto-common/proto/datafusion_common.proto
index 7f8bce6b20..65cd33d523 100644
--- a/datafusion/proto-common/proto/datafusion_common.proto
+++ b/datafusion/proto-common/proto/datafusion_common.proto
@@ -84,6 +84,7 @@ enum JoinType {
LEFTANTI = 5;
RIGHTSEMI = 6;
RIGHTANTI = 7;
+ LEFTMARK = 8;
}
enum JoinConstraint {
@@ -541,9 +542,10 @@ message ParquetOptions {
string created_by = 16;
}
-enum JoinSide{
+enum JoinSide {
LEFT_SIDE = 0;
RIGHT_SIDE = 1;
+ NONE = 2;
}
message Precision{
diff --git a/datafusion/proto-common/src/from_proto/mod.rs
b/datafusion/proto-common/src/from_proto/mod.rs
index d848f795c6..a554e4ed28 100644
--- a/datafusion/proto-common/src/from_proto/mod.rs
+++ b/datafusion/proto-common/src/from_proto/mod.rs
@@ -778,6 +778,7 @@ impl From<protobuf::JoinSide> for JoinSide {
match t {
protobuf::JoinSide::LeftSide => JoinSide::Left,
protobuf::JoinSide::RightSide => JoinSide::Right,
+ protobuf::JoinSide::None => JoinSide::None,
}
}
}
diff --git a/datafusion/proto-common/src/generated/pbjson.rs
b/datafusion/proto-common/src/generated/pbjson.rs
index e8b46fbf70..e8235ef7b9 100644
--- a/datafusion/proto-common/src/generated/pbjson.rs
+++ b/datafusion/proto-common/src/generated/pbjson.rs
@@ -3761,6 +3761,7 @@ impl serde::Serialize for JoinSide {
let variant = match self {
Self::LeftSide => "LEFT_SIDE",
Self::RightSide => "RIGHT_SIDE",
+ Self::None => "NONE",
};
serializer.serialize_str(variant)
}
@@ -3774,6 +3775,7 @@ impl<'de> serde::Deserialize<'de> for JoinSide {
const FIELDS: &[&str] = &[
"LEFT_SIDE",
"RIGHT_SIDE",
+ "NONE",
];
struct GeneratedVisitor;
@@ -3816,6 +3818,7 @@ impl<'de> serde::Deserialize<'de> for JoinSide {
match value {
"LEFT_SIDE" => Ok(JoinSide::LeftSide),
"RIGHT_SIDE" => Ok(JoinSide::RightSide),
+ "NONE" => Ok(JoinSide::None),
_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
}
}
@@ -3838,6 +3841,7 @@ impl serde::Serialize for JoinType {
Self::Leftanti => "LEFTANTI",
Self::Rightsemi => "RIGHTSEMI",
Self::Rightanti => "RIGHTANTI",
+ Self::Leftmark => "LEFTMARK",
};
serializer.serialize_str(variant)
}
@@ -3857,6 +3861,7 @@ impl<'de> serde::Deserialize<'de> for JoinType {
"LEFTANTI",
"RIGHTSEMI",
"RIGHTANTI",
+ "LEFTMARK",
];
struct GeneratedVisitor;
@@ -3905,6 +3910,7 @@ impl<'de> serde::Deserialize<'de> for JoinType {
"LEFTANTI" => Ok(JoinType::Leftanti),
"RIGHTSEMI" => Ok(JoinType::Rightsemi),
"RIGHTANTI" => Ok(JoinType::Rightanti),
+ "LEFTMARK" => Ok(JoinType::Leftmark),
_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
}
}
diff --git a/datafusion/proto-common/src/generated/prost.rs
b/datafusion/proto-common/src/generated/prost.rs
index 939a4b3c2c..68e7f74c7f 100644
--- a/datafusion/proto-common/src/generated/prost.rs
+++ b/datafusion/proto-common/src/generated/prost.rs
@@ -883,6 +883,7 @@ pub enum JoinType {
Leftanti = 5,
Rightsemi = 6,
Rightanti = 7,
+ Leftmark = 8,
}
impl JoinType {
/// String value of the enum field names used in the ProtoBuf definition.
@@ -899,6 +900,7 @@ impl JoinType {
Self::Leftanti => "LEFTANTI",
Self::Rightsemi => "RIGHTSEMI",
Self::Rightanti => "RIGHTANTI",
+ Self::Leftmark => "LEFTMARK",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
@@ -912,6 +914,7 @@ impl JoinType {
"LEFTANTI" => Some(Self::Leftanti),
"RIGHTSEMI" => Some(Self::Rightsemi),
"RIGHTANTI" => Some(Self::Rightanti),
+ "LEFTMARK" => Some(Self::Leftmark),
_ => None,
}
}
@@ -1069,6 +1072,7 @@ impl CompressionTypeVariant {
pub enum JoinSide {
LeftSide = 0,
RightSide = 1,
+ None = 2,
}
impl JoinSide {
/// String value of the enum field names used in the ProtoBuf definition.
@@ -1079,6 +1083,7 @@ impl JoinSide {
match self {
Self::LeftSide => "LEFT_SIDE",
Self::RightSide => "RIGHT_SIDE",
+ Self::None => "NONE",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
@@ -1086,6 +1091,7 @@ impl JoinSide {
match value {
"LEFT_SIDE" => Some(Self::LeftSide),
"RIGHT_SIDE" => Some(Self::RightSide),
+ "NONE" => Some(Self::None),
_ => None,
}
}
diff --git a/datafusion/proto-common/src/to_proto/mod.rs
b/datafusion/proto-common/src/to_proto/mod.rs
index f9b8973e2d..02a642a4af 100644
--- a/datafusion/proto-common/src/to_proto/mod.rs
+++ b/datafusion/proto-common/src/to_proto/mod.rs
@@ -759,6 +759,7 @@ impl From<JoinSide> for protobuf::JoinSide {
match t {
JoinSide::Left => protobuf::JoinSide::LeftSide,
JoinSide::Right => protobuf::JoinSide::RightSide,
+ JoinSide::None => protobuf::JoinSide::None,
}
}
}
diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs
b/datafusion/proto/src/generated/datafusion_proto_common.rs
index 939a4b3c2c..68e7f74c7f 100644
--- a/datafusion/proto/src/generated/datafusion_proto_common.rs
+++ b/datafusion/proto/src/generated/datafusion_proto_common.rs
@@ -883,6 +883,7 @@ pub enum JoinType {
Leftanti = 5,
Rightsemi = 6,
Rightanti = 7,
+ Leftmark = 8,
}
impl JoinType {
/// String value of the enum field names used in the ProtoBuf definition.
@@ -899,6 +900,7 @@ impl JoinType {
Self::Leftanti => "LEFTANTI",
Self::Rightsemi => "RIGHTSEMI",
Self::Rightanti => "RIGHTANTI",
+ Self::Leftmark => "LEFTMARK",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
@@ -912,6 +914,7 @@ impl JoinType {
"LEFTANTI" => Some(Self::Leftanti),
"RIGHTSEMI" => Some(Self::Rightsemi),
"RIGHTANTI" => Some(Self::Rightanti),
+ "LEFTMARK" => Some(Self::Leftmark),
_ => None,
}
}
@@ -1069,6 +1072,7 @@ impl CompressionTypeVariant {
pub enum JoinSide {
LeftSide = 0,
RightSide = 1,
+ None = 2,
}
impl JoinSide {
/// String value of the enum field names used in the ProtoBuf definition.
@@ -1079,6 +1083,7 @@ impl JoinSide {
match self {
Self::LeftSide => "LEFT_SIDE",
Self::RightSide => "RIGHT_SIDE",
+ Self::None => "NONE",
}
}
/// Creates an enum from field names used in the ProtoBuf definition.
@@ -1086,6 +1091,7 @@ impl JoinSide {
match value {
"LEFT_SIDE" => Some(Self::LeftSide),
"RIGHT_SIDE" => Some(Self::RightSide),
+ "NONE" => Some(Self::None),
_ => None,
}
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 27bda7dd5a..f25fb0bf25 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -213,6 +213,7 @@ impl From<protobuf::JoinType> for JoinType {
protobuf::JoinType::Rightsemi => JoinType::RightSemi,
protobuf::JoinType::Leftanti => JoinType::LeftAnti,
protobuf::JoinType::Rightanti => JoinType::RightAnti,
+ protobuf::JoinType::Leftmark => JoinType::LeftMark,
}
}
}
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 5a6f3a32c6..8af7b19d90 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -685,6 +685,7 @@ impl From<JoinType> for protobuf::JoinType {
JoinType::RightSemi => protobuf::JoinType::Rightsemi,
JoinType::LeftAnti => protobuf::JoinType::Leftanti,
JoinType::RightAnti => protobuf::JoinType::Rightanti,
+ JoinType::LeftMark => protobuf::JoinType::Leftmark,
}
}
}
diff --git a/datafusion/sql/src/unparser/plan.rs
b/datafusion/sql/src/unparser/plan.rs
index 2c38a1d36c..6348aba490 100644
--- a/datafusion/sql/src/unparser/plan.rs
+++ b/datafusion/sql/src/unparser/plan.rs
@@ -552,7 +552,7 @@ impl Unparser<'_> {
relation,
global: false,
join_operator: self
- .join_operator_to_sql(join.join_type, join_constraint),
+ .join_operator_to_sql(join.join_type,
join_constraint)?,
};
let mut from = select.pop_from().unwrap();
from.push_join(ast_join);
@@ -855,8 +855,8 @@ impl Unparser<'_> {
&self,
join_type: JoinType,
constraint: ast::JoinConstraint,
- ) -> ast::JoinOperator {
- match join_type {
+ ) -> Result<ast::JoinOperator> {
+ Ok(match join_type {
JoinType::Inner => ast::JoinOperator::Inner(constraint),
JoinType::Left => ast::JoinOperator::LeftOuter(constraint),
JoinType::Right => ast::JoinOperator::RightOuter(constraint),
@@ -865,7 +865,8 @@ impl Unparser<'_> {
JoinType::LeftSemi => ast::JoinOperator::LeftSemi(constraint),
JoinType::RightAnti => ast::JoinOperator::RightAnti(constraint),
JoinType::RightSemi => ast::JoinOperator::RightSemi(constraint),
- }
+ JoinType::LeftMark => unimplemented!("Unparsing of Left Mark join
type"),
+ })
}
/// Convert the components of a USING clause to the USING AST. Returns
diff --git a/datafusion/sqllogictest/test_files/subquery.slt
b/datafusion/sqllogictest/test_files/subquery.slt
index 36de19f1c3..027b5ca8dc 100644
--- a/datafusion/sqllogictest/test_files/subquery.slt
+++ b/datafusion/sqllogictest/test_files/subquery.slt
@@ -1056,13 +1056,11 @@ where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id
from t2 where t1.t1_int > 0)
----
logical_plan
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
-02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL
-03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists
-04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int >
Int32(0)
-05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
-06)--------SubqueryAlias: __correlated_sq_1
-07)----------Projection: t2.t2_id, Boolean(true) AS __exists
-08)------------TableScan: t2 projection=[t2_id]
+02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.mark
+03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int >
Int32(0)
+04)------TableScan: t1 projection=[t1_id, t1_name, t1_int]
+05)------SubqueryAlias: __correlated_sq_1
+06)--------TableScan: t2 projection=[t2_id]
query ITI rowsort
select t1.t1_id,
@@ -1085,13 +1083,12 @@ where t1.t1_id = 11 or t1.t1_id + 12 not in (select
t2.t2_id + 1 from t2 where t
----
logical_plan
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
-02)--Filter: t1.t1_id = Int32(11) OR __correlated_sq_1.__exists IS NULL
-03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists
-04)------Left Join: CAST(t1.t1_id AS Int64) + Int64(12) =
__correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0)
-05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
-06)--------SubqueryAlias: __correlated_sq_1
-07)----------Projection: CAST(t2.t2_id AS Int64) + Int64(1), Boolean(true) AS
__exists
-08)------------TableScan: t2 projection=[t2_id]
+02)--Filter: t1.t1_id = Int32(11) OR NOT __correlated_sq_1.mark
+03)----LeftMark Join: CAST(t1.t1_id AS Int64) + Int64(12) =
__correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0)
+04)------TableScan: t1 projection=[t1_id, t1_name, t1_int]
+05)------SubqueryAlias: __correlated_sq_1
+06)--------Projection: CAST(t2.t2_id AS Int64) + Int64(1)
+07)----------TableScan: t2 projection=[t2_id]
query ITI rowsort
select t1.t1_id,
@@ -1113,13 +1110,11 @@ where t1.t1_id > 40 or exists (select * from t2 where
t1.t1_id = t2.t2_id)
----
logical_plan
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
-02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NOT NULL
-03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists
-04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id
-05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
-06)--------SubqueryAlias: __correlated_sq_1
-07)----------Projection: t2.t2_id, Boolean(true) AS __exists
-08)------------TableScan: t2 projection=[t2_id]
+02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.mark
+03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id
+04)------TableScan: t1 projection=[t1_id, t1_name, t1_int]
+05)------SubqueryAlias: __correlated_sq_1
+06)--------TableScan: t2 projection=[t2_id]
query ITI rowsort
select t1.t1_id,
@@ -1132,6 +1127,9 @@ where t1.t1_id > 40 or exists (select * from t2 where
t1.t1_id = t2.t2_id)
22 b 2
44 d 4
+statement ok
+set datafusion.explain.logical_plan_only = false;
+
# not_exists_subquery_to_join_with_correlated_outer_filter_disjunction
query TT
explain select t1.t1_id,
@@ -1142,13 +1140,27 @@ where t1.t1_id > 40 or not exists (select * from t2
where t1.t1_id = t2.t2_id)
----
logical_plan
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
-02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.__exists IS NULL
-03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_1.__exists
-04)------Left Join: t1.t1_id = __correlated_sq_1.t2_id
-05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
-06)--------SubqueryAlias: __correlated_sq_1
-07)----------Projection: t2.t2_id, Boolean(true) AS __exists
-08)------------TableScan: t2 projection=[t2_id]
+02)--Filter: t1.t1_id > Int32(40) OR NOT __correlated_sq_1.mark
+03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id
+04)------TableScan: t1 projection=[t1_id, t1_name, t1_int]
+05)------SubqueryAlias: __correlated_sq_1
+06)--------TableScan: t2 projection=[t2_id]
+physical_plan
+01)CoalesceBatchesExec: target_batch_size=2
+02)--FilterExec: t1_id@0 > 40 OR NOT mark@3, projection=[t1_id@0, t1_name@1,
t1_int@2]
+03)----CoalesceBatchesExec: target_batch_size=2
+04)------HashJoinExec: mode=Partitioned, join_type=LeftMark, on=[(t1_id@0,
t2_id@0)]
+05)--------CoalesceBatchesExec: target_batch_size=2
+06)----------RepartitionExec: partitioning=Hash([t1_id@0], 4),
input_partitions=4
+07)------------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
+08)--------------MemoryExec: partitions=1, partition_sizes=[1]
+09)--------CoalesceBatchesExec: target_batch_size=2
+10)----------RepartitionExec: partitioning=Hash([t2_id@0], 4),
input_partitions=4
+11)------------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
+12)--------------MemoryExec: partitions=1, partition_sizes=[1]
+
+statement ok
+set datafusion.explain.logical_plan_only = true;
query ITI rowsort
select t1.t1_id,
@@ -1170,16 +1182,14 @@ where t1.t1_id in (select t3.t3_id from t3) and
(t1.t1_id > 40 or t1.t1_id in (s
----
logical_plan
01)Projection: t1.t1_id, t1.t1_name, t1.t1_int
-02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.__exists IS NOT NULL
-03)----Projection: t1.t1_id, t1.t1_name, t1.t1_int, __correlated_sq_2.__exists
-04)------Left Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int >
Int32(0)
-05)--------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id
-06)----------TableScan: t1 projection=[t1_id, t1_name, t1_int]
-07)----------SubqueryAlias: __correlated_sq_1
-08)------------TableScan: t3 projection=[t3_id]
-09)--------SubqueryAlias: __correlated_sq_2
-10)----------Projection: t2.t2_id, Boolean(true) AS __exists
-11)------------TableScan: t2 projection=[t2_id]
+02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.mark
+03)----LeftMark Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int >
Int32(0)
+04)------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id
+05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int]
+06)--------SubqueryAlias: __correlated_sq_1
+07)----------TableScan: t3 projection=[t3_id]
+08)------SubqueryAlias: __correlated_sq_2
+09)--------TableScan: t2 projection=[t2_id]
query ITI rowsort
select t1.t1_id,
@@ -1192,6 +1202,18 @@ where t1.t1_id in (select t3.t3_id from t3) and
(t1.t1_id > 40 or t1.t1_id in (s
22 b 2
44 d 4
+# Handle duplicate values in exists query
+query ITI rowsort
+select t1.t1_id,
+ t1.t1_name,
+ t1.t1_int
+from t1
+where t1.t1_id > 40 or exists (select * from t2 cross join t3 where t1.t1_id =
t2.t2_id)
+----
+11 a 1
+22 b 2
+44 d 4
+
# Nested subqueries
query ITI rowsort
select t1.t1_id,
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 2aaf8ec0aa..289aa7b7f4 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -1226,6 +1226,7 @@ fn from_substrait_jointype(join_type: i32) ->
Result<JoinType> {
join_rel::JoinType::Outer => Ok(JoinType::Full),
join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti),
join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi),
+ join_rel::JoinType::LeftMark => Ok(JoinType::LeftMark),
_ => plan_err!("unsupported join type {substrait_join_type:?}"),
}
} else {
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 408885f706..c73029f130 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -725,7 +725,10 @@ fn to_substrait_jointype(join_type: JoinType) ->
join_rel::JoinType {
JoinType::Full => join_rel::JoinType::Outer,
JoinType::LeftAnti => join_rel::JoinType::LeftAnti,
JoinType::LeftSemi => join_rel::JoinType::LeftSemi,
- JoinType::RightAnti | JoinType::RightSemi => unimplemented!(),
+ JoinType::LeftMark => join_rel::JoinType::LeftMark,
+ JoinType::RightAnti | JoinType::RightSemi => {
+ unimplemented!()
+ }
}
}
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 04530dd34d..8fbdefe285 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -453,15 +453,15 @@ async fn roundtrip_inlist_5() -> Result<()> {
// on roundtrip there is an additional projection during TableScan which
includes all column of the table,
// using assert_expected_plan here as a workaround
assert_expected_plan(
- "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a
FROM data2 WHERE f IN ('b', 'c', 'd')))",
- "Projection: data.a, data.f\
- \n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f =
Utf8(\"c\") OR Boolean(true) IS NOT NULL\
- \n Projection: data.a, data.f, Boolean(true)\
- \n Left Join: data.a = data2.a\
- \n TableScan: data projection=[a, f]\
- \n Projection: data2.a, Boolean(true)\
- \n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR
data2.f = Utf8(\"d\")\
- \n TableScan: data2 projection=[a, f],
partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f =
Utf8(\"d\")]",
+ "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT
data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))",
+
+ "Projection: data.a, data.f\
+ \n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f =
Utf8(\"c\") OR data2.mark\
+ \n LeftMark Join: data.a = data2.a\
+ \n TableScan: data projection=[a, f]\
+ \n Projection: data2.a\
+ \n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR
data2.f = Utf8(\"d\")\
+ \n TableScan: data2 projection=[a, f],
partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f =
Utf8(\"d\")]",
true).await
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]