This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new ce38812 feat: Support HashJoin operator (#194)
ce38812 is described below
commit ce3881245ac938092e3fc38285713b8f88abc375
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Mar 19 22:11:24 2024 -0700
feat: Support HashJoin operator (#194)
* feat: Support HashJoin
* Add comment
* Clean up test
* Fix join filter
* Fix clippy
* Use consistent function with sort merge join
* Add note about left semi and left anti joins
* For review
* Merging
* Move tests
* Add a function to parse join parameters
---
core/src/execution/datafusion/planner.rs | 373 +++++++++++++++++----
core/src/execution/proto/operator.proto | 8 +
.../apache/comet/CometSparkSessionExtensions.scala | 23 +-
.../org/apache/comet/serde/QueryPlanSerde.scala | 46 ++-
.../org/apache/spark/sql/comet/operators.scala | 38 +++
.../org/apache/comet/exec/CometJoinSuite.scala | 86 +++++
6 files changed, 508 insertions(+), 66 deletions(-)
diff --git a/core/src/execution/datafusion/planner.rs
b/core/src/execution/datafusion/planner.rs
index eda8692..c8869c5 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -17,7 +17,7 @@
//! Converts Spark physical plan to DataFusion physical plan
-use std::{str::FromStr, sync::Arc};
+use std::{collections::HashMap, str::FromStr, sync::Arc};
use arrow_schema::{DataType, Field, Schema, TimeUnit};
use datafusion::{
@@ -37,14 +37,17 @@ use datafusion::{
physical_plan::{
aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy},
filter::FilterExec,
- joins::SortMergeJoinExec,
+ joins::{utils::JoinFilter, HashJoinExec, PartitionMode,
SortMergeJoinExec},
limit::LocalLimitExec,
projection::ProjectionExec,
sorts::sort::SortExec,
ExecutionPlan, Partitioning,
},
};
-use datafusion_common::{JoinType as DFJoinType, ScalarValue};
+use datafusion_common::{
+ tree_node::{TreeNode, TreeNodeRewriter, VisitRecursion},
+ JoinType as DFJoinType, ScalarValue,
+};
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};
@@ -89,6 +92,14 @@ type PhyAggResult = Result<Vec<Arc<dyn AggregateExpr>>,
ExecutionError>;
type PhyExprResult = Result<Vec<(Arc<dyn PhysicalExpr>, String)>,
ExecutionError>;
type PartitionPhyExprResult = Result<Vec<Arc<dyn PhysicalExpr>>,
ExecutionError>;
+struct JoinParameters {
+ pub left: Arc<dyn ExecutionPlan>,
+ pub right: Arc<dyn ExecutionPlan>,
+ pub join_on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
+ pub join_filter: Option<JoinFilter>,
+ pub join_type: DFJoinType,
+}
+
pub const TEST_EXEC_CONTEXT_ID: i64 = -1;
/// The query planner for converting Spark query plans to DataFusion query
plans.
@@ -870,50 +881,22 @@ impl PhysicalPlanner {
))
}
OpStruct::SortMergeJoin(join) => {
- assert!(children.len() == 2);
- let (mut left_scans, left) = self.create_plan(&children[0],
inputs)?;
- let (mut right_scans, right) = self.create_plan(&children[1],
inputs)?;
-
- left_scans.append(&mut right_scans);
-
- let left_join_exprs = join
- .left_join_keys
- .iter()
- .map(|expr| self.create_expr(expr, left.schema()))
- .collect::<Result<Vec<_>, _>>()?;
- let right_join_exprs = join
- .right_join_keys
- .iter()
- .map(|expr| self.create_expr(expr, right.schema()))
- .collect::<Result<Vec<_>, _>>()?;
-
- let join_on = left_join_exprs
- .into_iter()
- .zip(right_join_exprs)
- .collect::<Vec<_>>();
-
- let join_type = match join.join_type.try_into() {
- Ok(JoinType::Inner) => DFJoinType::Inner,
- Ok(JoinType::LeftOuter) => DFJoinType::Left,
- Ok(JoinType::RightOuter) => DFJoinType::Right,
- Ok(JoinType::FullOuter) => DFJoinType::Full,
- Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi,
- Ok(JoinType::RightSemi) => DFJoinType::RightSemi,
- Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti,
- Ok(JoinType::RightAnti) => DFJoinType::RightAnti,
- Err(_) => {
- return Err(ExecutionError::GeneralError(format!(
- "Unsupported join type: {:?}",
- join.join_type
- )));
- }
- };
+ let (join_params, scans) = self.parse_join_parameters(
+ inputs,
+ children,
+ &join.left_join_keys,
+ &join.right_join_keys,
+ join.join_type,
+ &None,
+ )?;
let sort_options = join
.sort_options
.iter()
.map(|sort_option| {
- let sort_expr = self.create_sort_expr(sort_option,
left.schema()).unwrap();
+ let sort_expr = self
+ .create_sort_expr(sort_option,
join_params.left.schema())
+ .unwrap();
SortOptions {
descending: sort_expr.options.descending,
nulls_first: sort_expr.options.nulls_first,
@@ -921,38 +904,175 @@ impl PhysicalPlanner {
})
.collect();
- // DataFusion `SortMergeJoinExec` operator keeps the input
batch internally. We need
- // to copy the input batch to avoid the data corruption from
reusing the input
- // batch.
- let left = if can_reuse_input_batch(&left) {
- Arc::new(CopyExec::new(left))
- } else {
- left
- };
-
- let right = if can_reuse_input_batch(&right) {
- Arc::new(CopyExec::new(right))
- } else {
- right
- };
-
let join = Arc::new(SortMergeJoinExec::try_new(
- left,
- right,
- join_on,
- None,
- join_type,
+ join_params.left,
+ join_params.right,
+ join_params.join_on,
+ join_params.join_filter,
+ join_params.join_type,
sort_options,
// null doesn't equal to null in Spark join key. If the
join key is
// `EqualNullSafe`, Spark will rewrite it during planning.
false,
)?);
- Ok((left_scans, join))
+ Ok((scans, join))
+ }
+ OpStruct::HashJoin(join) => {
+ let (join_params, scans) = self.parse_join_parameters(
+ inputs,
+ children,
+ &join.left_join_keys,
+ &join.right_join_keys,
+ join.join_type,
+ &join.condition,
+ )?;
+ let join = Arc::new(HashJoinExec::try_new(
+ join_params.left,
+ join_params.right,
+ join_params.join_on,
+ join_params.join_filter,
+ &join_params.join_type,
+ PartitionMode::Partitioned,
+ // null doesn't equal to null in Spark join key. If the
join key is
+ // `EqualNullSafe`, Spark will rewrite it during planning.
+ false,
+ )?);
+ Ok((scans, join))
}
}
}
+ fn parse_join_parameters(
+ &self,
+ inputs: &mut Vec<Arc<GlobalRef>>,
+ children: &[Operator],
+ left_join_keys: &[Expr],
+ right_join_keys: &[Expr],
+ join_type: i32,
+ condition: &Option<Expr>,
+ ) -> Result<(JoinParameters, Vec<ScanExec>), ExecutionError> {
+ assert!(children.len() == 2);
+ let (mut left_scans, left) = self.create_plan(&children[0], inputs)?;
+ let (mut right_scans, right) = self.create_plan(&children[1], inputs)?;
+
+ left_scans.append(&mut right_scans);
+
+ let left_join_exprs: Vec<_> = left_join_keys
+ .iter()
+ .map(|expr| self.create_expr(expr, left.schema()))
+ .collect::<Result<Vec<_>, _>>()?;
+ let right_join_exprs: Vec<_> = right_join_keys
+ .iter()
+ .map(|expr| self.create_expr(expr, right.schema()))
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let join_on = left_join_exprs
+ .into_iter()
+ .zip(right_join_exprs)
+ .collect::<Vec<_>>();
+
+ let join_type = match join_type.try_into() {
+ Ok(JoinType::Inner) => DFJoinType::Inner,
+ Ok(JoinType::LeftOuter) => DFJoinType::Left,
+ Ok(JoinType::RightOuter) => DFJoinType::Right,
+ Ok(JoinType::FullOuter) => DFJoinType::Full,
+ Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi,
+ Ok(JoinType::RightSemi) => DFJoinType::RightSemi,
+ Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti,
+ Ok(JoinType::RightAnti) => DFJoinType::RightAnti,
+ Err(_) => {
+ return Err(ExecutionError::GeneralError(format!(
+ "Unsupported join type: {:?}",
+ join_type
+ )));
+ }
+ };
+
+ // Handle join filter as DataFusion `JoinFilter` struct
+ let join_filter = if let Some(expr) = condition {
+ let left_schema = left.schema();
+ let right_schema = right.schema();
+ let left_fields = left_schema.fields();
+ let right_fields = right_schema.fields();
+ let all_fields: Vec<_> = left_fields
+ .into_iter()
+ .chain(right_fields)
+ .cloned()
+ .collect();
+ let full_schema = Arc::new(Schema::new(all_fields));
+
+ let physical_expr = self.create_expr(expr, full_schema)?;
+ let (left_field_indices, right_field_indices) =
+ expr_to_columns(&physical_expr, left_fields.len(),
right_fields.len())?;
+ let column_indices = JoinFilter::build_column_indices(
+ left_field_indices.clone(),
+ right_field_indices.clone(),
+ );
+
+ let filter_fields: Vec<Field> = left_field_indices
+ .clone()
+ .into_iter()
+ .map(|i| left.schema().field(i).clone())
+ .chain(
+ right_field_indices
+ .clone()
+ .into_iter()
+ .map(|i| right.schema().field(i).clone()),
+ )
+ .collect_vec();
+
+ let filter_schema = Schema::new_with_metadata(filter_fields,
HashMap::new());
+
+ // Rewrite the physical expression to use the new column indices.
+ // DataFusion's join filter is bound to intermediate schema which
contains
+ // only the fields used in the filter expression. But the Spark's
join filter
+ // expression is bound to the full schema. We need to rewrite the
physical
+ // expression to use the new column indices.
+ let rewritten_physical_expr = rewrite_physical_expr(
+ physical_expr,
+ left_schema.fields.len(),
+ right_schema.fields.len(),
+ &left_field_indices,
+ &right_field_indices,
+ )?;
+
+ Some(JoinFilter::new(
+ rewritten_physical_expr,
+ column_indices,
+ filter_schema,
+ ))
+ } else {
+ None
+ };
+
+ // DataFusion Join operators keep the input batch internally. We need
+ // to copy the input batch to avoid the data corruption from reusing
the input
+ // batch.
+ let left = if can_reuse_input_batch(&left) {
+ Arc::new(CopyExec::new(left))
+ } else {
+ left
+ };
+
+ let right = if can_reuse_input_batch(&right) {
+ Arc::new(CopyExec::new(right))
+ } else {
+ right
+ };
+
+ Ok((
+ JoinParameters {
+ left,
+ right,
+ join_on,
+ join_type,
+ join_filter,
+ },
+ left_scans,
+ ))
+ }
+
/// Create a DataFusion physical aggregate expression from Spark physical
aggregate expression
fn create_agg_expr(
&self,
@@ -1143,6 +1263,133 @@ fn can_reuse_input_batch(op: &Arc<dyn ExecutionPlan>)
-> bool {
|| op.as_any().downcast_ref::<FilterExec>().is_some()
}
+/// Collects the indices of the columns in the input schema that are used in
the expression
+/// and returns them as a pair of vectors, one for the left side and one for
the right side.
+fn expr_to_columns(
+ expr: &Arc<dyn PhysicalExpr>,
+ left_field_len: usize,
+ right_field_len: usize,
+) -> Result<(Vec<usize>, Vec<usize>), ExecutionError> {
+ let mut left_field_indices: Vec<usize> = vec![];
+ let mut right_field_indices: Vec<usize> = vec![];
+
+ expr.apply(&mut |expr| {
+ Ok({
+ if let Some(column) = expr.as_any().downcast_ref::<Column>() {
+ if column.index() > left_field_len + right_field_len {
+ return Err(DataFusionError::Internal(format!(
+ "Column index {} out of range",
+ column.index()
+ )));
+ } else if column.index() < left_field_len {
+ left_field_indices.push(column.index());
+ } else {
+ right_field_indices.push(column.index() - left_field_len);
+ }
+ }
+ VisitRecursion::Continue
+ })
+ })?;
+
+ left_field_indices.sort();
+ right_field_indices.sort();
+
+ Ok((left_field_indices, right_field_indices))
+}
+
+/// A physical join filter rewritter which rewrites the column indices in the
expression
+/// to use the new column indices. See `rewrite_physical_expr`.
+struct JoinFilterRewriter<'a> {
+ left_field_len: usize,
+ right_field_len: usize,
+ left_field_indices: &'a [usize],
+ right_field_indices: &'a [usize],
+}
+
+impl JoinFilterRewriter<'_> {
+ fn new<'a>(
+ left_field_len: usize,
+ right_field_len: usize,
+ left_field_indices: &'a [usize],
+ right_field_indices: &'a [usize],
+ ) -> JoinFilterRewriter<'a> {
+ JoinFilterRewriter {
+ left_field_len,
+ right_field_len,
+ left_field_indices,
+ right_field_indices,
+ }
+ }
+}
+
+impl TreeNodeRewriter for JoinFilterRewriter<'_> {
+ type N = Arc<dyn PhysicalExpr>;
+
+ fn mutate(&mut self, node: Self::N) -> datafusion_common::Result<Self::N> {
+ let new_expr: Arc<dyn PhysicalExpr> =
+ if let Some(column) = node.as_any().downcast_ref::<Column>() {
+ if column.index() < self.left_field_len {
+ // left side
+ let new_index = self
+ .left_field_indices
+ .iter()
+ .position(|&x| x == column.index())
+ .ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "Column index {} not found in left field
indices",
+ column.index()
+ ))
+ })?;
+ Arc::new(Column::new(column.name(), new_index))
+ } else if column.index() < self.left_field_len +
self.right_field_len {
+ // right side
+ let new_index = self
+ .right_field_indices
+ .iter()
+ .position(|&x| x + self.left_field_len ==
column.index())
+ .ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "Column index {} not found in right field
indices",
+ column.index()
+ ))
+ })?;
+ Arc::new(Column::new(
+ column.name(),
+ new_index + self.left_field_indices.len(),
+ ))
+ } else {
+ return Err(DataFusionError::Internal(format!(
+ "Column index {} out of range",
+ column.index()
+ )));
+ }
+ } else {
+ node.clone()
+ };
+ Ok(new_expr)
+ }
+}
+
+/// Rewrites the physical expression to use the new column indices.
+/// This is necessary when the physical expression is used in a join filter,
as the column
+/// indices are different from the original schema.
+fn rewrite_physical_expr(
+ expr: Arc<dyn PhysicalExpr>,
+ left_field_len: usize,
+ right_field_len: usize,
+ left_field_indices: &[usize],
+ right_field_indices: &[usize],
+) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
+ let mut rewriter = JoinFilterRewriter::new(
+ left_field_len,
+ right_field_len,
+ left_field_indices,
+ right_field_indices,
+ );
+
+ Ok(expr.rewrite(&mut rewriter)?)
+}
+
#[cfg(test)]
mod tests {
use std::{sync::Arc, task::Poll};
diff --git a/core/src/execution/proto/operator.proto
b/core/src/execution/proto/operator.proto
index 0b7888d..6080c56 100644
--- a/core/src/execution/proto/operator.proto
+++ b/core/src/execution/proto/operator.proto
@@ -41,6 +41,7 @@ message Operator {
ShuffleWriter shuffle_writer = 106;
Expand expand = 107;
SortMergeJoin sort_merge_join = 108;
+ HashJoin hash_join = 109;
}
}
@@ -89,6 +90,13 @@ message Expand {
int32 num_expr_per_project = 3;
}
+message HashJoin {
+ repeated spark.spark_expression.Expr left_join_keys = 1;
+ repeated spark.spark_expression.Expr right_join_keys = 2;
+ JoinType join_type = 3;
+ optional spark.spark_expression.Expr condition = 4;
+}
+
message SortMergeJoin {
repeated spark.spark_expression.Expr left_join_keys = 1;
repeated spark.spark_expression.Expr right_join_keys = 2;
diff --git
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 6199f47..1380ee9 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -38,7 +38,7 @@ import
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec,
SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -335,6 +335,27 @@ class CometSparkSessionExtensions
op
}
+ case op: ShuffledHashJoinExec
+ if isCometOperatorEnabled(conf, "hash_join") &&
+ op.children.forall(isCometNative(_)) =>
+ val newOp = transform1(op)
+ newOp match {
+ case Some(nativeOp) =>
+ CometHashJoinExec(
+ nativeOp,
+ op,
+ op.leftKeys,
+ op.rightKeys,
+ op.joinType,
+ op.condition,
+ op.buildSide,
+ op.left,
+ op.right,
+ SerializedPlan(None))
+ case None =>
+ op
+ }
+
case op: SortMergeJoinExec
if isCometOperatorEnabled(conf, "sort_merge_join") &&
op.children.forall(isCometNative(_)) =>
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 6c04fe3..bf2510b 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, Final, First, Last, Max, Min,
Partial, Sum}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
-import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
+import org.apache.spark.sql.catalyst.optimizer.{BuildRight,
NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
@@ -36,7 +36,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec,
SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -1915,6 +1915,48 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
}
}
+ case join: ShuffledHashJoinExec if isCometOperatorEnabled(op.conf,
"hash_join") =>
+ if (join.buildSide == BuildRight) {
+ // DataFusion HashJoin assumes build side is always left.
+ // TODO: support BuildRight
+ return None
+ }
+
+ val condition = join.condition.map { cond =>
+ val condProto = exprToProto(cond, join.left.output ++
join.right.output)
+ if (condProto.isEmpty) {
+ return None
+ }
+ condProto.get
+ }
+
+ val joinType = join.joinType match {
+ case Inner => JoinType.Inner
+ case LeftOuter => JoinType.LeftOuter
+ case RightOuter => JoinType.RightOuter
+ case FullOuter => JoinType.FullOuter
+ case LeftSemi => JoinType.LeftSemi
+ case LeftAnti => JoinType.LeftAnti
+ case _ => return None // Spark doesn't support other join types
+ }
+
+ val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
+ val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))
+
+ if (leftKeys.forall(_.isDefined) &&
+ rightKeys.forall(_.isDefined) &&
+ childOp.nonEmpty) {
+ val joinBuilder = OperatorOuterClass.HashJoin
+ .newBuilder()
+ .setJoinType(joinType)
+ .addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
+ .addAllRightJoinKeys(rightKeys.map(_.get).asJava)
+ condition.foreach(joinBuilder.setCondition)
+ Some(result.setHashJoin(joinBuilder).build())
+ } else {
+ None
+ }
+
case join: SortMergeJoinExec if isCometOperatorEnabled(op.conf,
"sort_merge_join") =>
// `requiredOrders` and `getKeyOrdering` are copied from Spark's
SortMergeJoinExec.
def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 501ffd5..fb300a3 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -31,6 +31,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet,
Expression, NamedExpression, SortOrder}
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
AggregateMode}
+import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator,
CometShuffleExchangeExec}
@@ -587,6 +588,43 @@ case class CometHashAggregateExec(
Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode,
child)
}
+case class CometHashJoinExec(
+ override val nativeOp: Operator,
+ override val originalPlan: SparkPlan,
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ condition: Option[Expression],
+ buildSide: BuildSide,
+ override val left: SparkPlan,
+ override val right: SparkPlan,
+ override val serializedPlanOpt: SerializedPlan)
+ extends CometBinaryExec {
+ override def withNewChildrenInternal(newLeft: SparkPlan, newRight:
SparkPlan): SparkPlan =
+ this.copy(left = newLeft, right = newRight)
+
+ override def stringArgs: Iterator[Any] =
+ Iterator(leftKeys, rightKeys, joinType, condition, left, right)
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case other: CometHashJoinExec =>
+ this.leftKeys == other.leftKeys &&
+ this.rightKeys == other.rightKeys &&
+ this.condition == other.condition &&
+ this.buildSide == other.buildSide &&
+ this.left == other.left &&
+ this.right == other.right &&
+ this.serializedPlanOpt == other.serializedPlanOpt
+ case _ =>
+ false
+ }
+ }
+
+ override def hashCode(): Int =
+ Objects.hashCode(leftKeys, rightKeys, condition, left, right)
+}
+
case class CometSortMergeJoinExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
index 73ce0e1..a64ec87 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
@@ -38,6 +38,92 @@ class CometJoinSuite extends CometTestBase {
}
}
+ test("HashJoin without join filter") {
+ withSQLConf(
+ SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
+ withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
+ // Inner join: build left
+ val df1 =
+ sql(
+ "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON
tbl_a._2 = tbl_b._1")
+ checkSparkAnswerAndOperator(df1)
+
+ // Right join: build left
+ val df2 =
+ sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN
tbl_b ON tbl_a._2 = tbl_b._1")
+ checkSparkAnswerAndOperator(df2)
+
+ // Full join: build left
+ val df3 =
+ sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN
tbl_b ON tbl_a._2 = tbl_b._1")
+ checkSparkAnswerAndOperator(df3)
+
+ // TODO: Spark 3.4 returns SortMergeJoin for this query even with
SHUFFLE_HASH hint.
+ // Left join with build left and right join with build right in hash
join is only supported
+ // in Spark 3.5 or above. See SPARK-36612.
+ //
+ // Left join: build left
+ // sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN
tbl_b ON tbl_a._2 = tbl_b._1")
+
+ // TODO: DataFusion HashJoin doesn't support build right yet.
+ // Inner join: build right
+ // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a JOIN tbl_b ON
tbl_a._2 = tbl_b._1")
+ //
+ // Left join: build right
+ // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT JOIN
tbl_b ON tbl_a._2 = tbl_b._1")
+ //
+ // Right join: build right
+ // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a RIGHT JOIN
tbl_b ON tbl_a._2 = tbl_b._1")
+ //
+ // Full join: build right
+ // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN
tbl_b ON tbl_a._2 = tbl_b._1")
+ //
+ // val left = sql("SELECT * FROM tbl_a")
+ // val right = sql("SELECT * FROM tbl_b")
+ //
+ // Left semi and anti joins are only supported with build right in
Spark.
+ // left.join(right, left("_2") === right("_1"), "leftsemi")
+ // left.join(right, left("_2") === right("_1"), "leftanti")
+ }
+ }
+ }
+ }
+
+ test("HashJoin with join filter") {
+ withSQLConf(
+ SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
+ withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
+ // Inner join: build left
+ val df1 =
+ sql(
+ "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b " +
+ "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
+ checkSparkAnswerAndOperator(df1)
+
+ // Right join: build left
+ val df2 =
+ sql(
+ "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b
" +
+ "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
+ checkSparkAnswerAndOperator(df2)
+
+ // Full join: build left
+ val df3 =
+ sql(
+ "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b
" +
+ "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
+ checkSparkAnswerAndOperator(df3)
+ }
+ }
+ }
+ }
+
// TODO: Add a test for SortMergeJoin with join filter after new DataFusion
release
test("SortMergeJoin without join filter") {
withSQLConf(