This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new ce38812  feat: Support HashJoin operator (#194)
ce38812 is described below

commit ce3881245ac938092e3fc38285713b8f88abc375
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Mar 19 22:11:24 2024 -0700

    feat: Support HashJoin operator (#194)
    
    * feat: Support HashJoin
    
    * Add comment
    
    * Clean up test
    
    * Fix join filter
    
    * Fix clippy
    
    * Use consistent function with sort merge join
    
    * Add note about left semi and left anti joins
    
    * For review
    
    * Merging
    
    * Move tests
    
    * Add a function to parse join parameters
---
 core/src/execution/datafusion/planner.rs           | 373 +++++++++++++++++----
 core/src/execution/proto/operator.proto            |   8 +
 .../apache/comet/CometSparkSessionExtensions.scala |  23 +-
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  46 ++-
 .../org/apache/spark/sql/comet/operators.scala     |  38 +++
 .../org/apache/comet/exec/CometJoinSuite.scala     |  86 +++++
 6 files changed, 508 insertions(+), 66 deletions(-)

diff --git a/core/src/execution/datafusion/planner.rs 
b/core/src/execution/datafusion/planner.rs
index eda8692..c8869c5 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -17,7 +17,7 @@
 
 //! Converts Spark physical plan to DataFusion physical plan
 
-use std::{str::FromStr, sync::Arc};
+use std::{collections::HashMap, str::FromStr, sync::Arc};
 
 use arrow_schema::{DataType, Field, Schema, TimeUnit};
 use datafusion::{
@@ -37,14 +37,17 @@ use datafusion::{
     physical_plan::{
         aggregates::{AggregateMode as DFAggregateMode, PhysicalGroupBy},
         filter::FilterExec,
-        joins::SortMergeJoinExec,
+        joins::{utils::JoinFilter, HashJoinExec, PartitionMode, 
SortMergeJoinExec},
         limit::LocalLimitExec,
         projection::ProjectionExec,
         sorts::sort::SortExec,
         ExecutionPlan, Partitioning,
     },
 };
-use datafusion_common::{JoinType as DFJoinType, ScalarValue};
+use datafusion_common::{
+    tree_node::{TreeNode, TreeNodeRewriter, VisitRecursion},
+    JoinType as DFJoinType, ScalarValue,
+};
 use itertools::Itertools;
 use jni::objects::GlobalRef;
 use num::{BigInt, ToPrimitive};
@@ -89,6 +92,14 @@ type PhyAggResult = Result<Vec<Arc<dyn AggregateExpr>>, 
ExecutionError>;
 type PhyExprResult = Result<Vec<(Arc<dyn PhysicalExpr>, String)>, 
ExecutionError>;
 type PartitionPhyExprResult = Result<Vec<Arc<dyn PhysicalExpr>>, 
ExecutionError>;
 
+struct JoinParameters {
+    pub left: Arc<dyn ExecutionPlan>,
+    pub right: Arc<dyn ExecutionPlan>,
+    pub join_on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
+    pub join_filter: Option<JoinFilter>,
+    pub join_type: DFJoinType,
+}
+
 pub const TEST_EXEC_CONTEXT_ID: i64 = -1;
 
 /// The query planner for converting Spark query plans to DataFusion query 
plans.
@@ -870,50 +881,22 @@ impl PhysicalPlanner {
                 ))
             }
             OpStruct::SortMergeJoin(join) => {
-                assert!(children.len() == 2);
-                let (mut left_scans, left) = self.create_plan(&children[0], 
inputs)?;
-                let (mut right_scans, right) = self.create_plan(&children[1], 
inputs)?;
-
-                left_scans.append(&mut right_scans);
-
-                let left_join_exprs = join
-                    .left_join_keys
-                    .iter()
-                    .map(|expr| self.create_expr(expr, left.schema()))
-                    .collect::<Result<Vec<_>, _>>()?;
-                let right_join_exprs = join
-                    .right_join_keys
-                    .iter()
-                    .map(|expr| self.create_expr(expr, right.schema()))
-                    .collect::<Result<Vec<_>, _>>()?;
-
-                let join_on = left_join_exprs
-                    .into_iter()
-                    .zip(right_join_exprs)
-                    .collect::<Vec<_>>();
-
-                let join_type = match join.join_type.try_into() {
-                    Ok(JoinType::Inner) => DFJoinType::Inner,
-                    Ok(JoinType::LeftOuter) => DFJoinType::Left,
-                    Ok(JoinType::RightOuter) => DFJoinType::Right,
-                    Ok(JoinType::FullOuter) => DFJoinType::Full,
-                    Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi,
-                    Ok(JoinType::RightSemi) => DFJoinType::RightSemi,
-                    Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti,
-                    Ok(JoinType::RightAnti) => DFJoinType::RightAnti,
-                    Err(_) => {
-                        return Err(ExecutionError::GeneralError(format!(
-                            "Unsupported join type: {:?}",
-                            join.join_type
-                        )));
-                    }
-                };
+                let (join_params, scans) = self.parse_join_parameters(
+                    inputs,
+                    children,
+                    &join.left_join_keys,
+                    &join.right_join_keys,
+                    join.join_type,
+                    &None,
+                )?;
 
                 let sort_options = join
                     .sort_options
                     .iter()
                     .map(|sort_option| {
-                        let sort_expr = self.create_sort_expr(sort_option, 
left.schema()).unwrap();
+                        let sort_expr = self
+                            .create_sort_expr(sort_option, 
join_params.left.schema())
+                            .unwrap();
                         SortOptions {
                             descending: sort_expr.options.descending,
                             nulls_first: sort_expr.options.nulls_first,
@@ -921,38 +904,175 @@ impl PhysicalPlanner {
                     })
                     .collect();
 
-                // DataFusion `SortMergeJoinExec` operator keeps the input 
batch internally. We need
-                // to copy the input batch to avoid the data corruption from 
reusing the input
-                // batch.
-                let left = if can_reuse_input_batch(&left) {
-                    Arc::new(CopyExec::new(left))
-                } else {
-                    left
-                };
-
-                let right = if can_reuse_input_batch(&right) {
-                    Arc::new(CopyExec::new(right))
-                } else {
-                    right
-                };
-
                 let join = Arc::new(SortMergeJoinExec::try_new(
-                    left,
-                    right,
-                    join_on,
-                    None,
-                    join_type,
+                    join_params.left,
+                    join_params.right,
+                    join_params.join_on,
+                    join_params.join_filter,
+                    join_params.join_type,
                     sort_options,
                     // null doesn't equal to null in Spark join key. If the 
join key is
                     // `EqualNullSafe`, Spark will rewrite it during planning.
                     false,
                 )?);
 
-                Ok((left_scans, join))
+                Ok((scans, join))
+            }
+            OpStruct::HashJoin(join) => {
+                let (join_params, scans) = self.parse_join_parameters(
+                    inputs,
+                    children,
+                    &join.left_join_keys,
+                    &join.right_join_keys,
+                    join.join_type,
+                    &join.condition,
+                )?;
+                let join = Arc::new(HashJoinExec::try_new(
+                    join_params.left,
+                    join_params.right,
+                    join_params.join_on,
+                    join_params.join_filter,
+                    &join_params.join_type,
+                    PartitionMode::Partitioned,
+                    // null doesn't equal to null in Spark join key. If the 
join key is
+                    // `EqualNullSafe`, Spark will rewrite it during planning.
+                    false,
+                )?);
+                Ok((scans, join))
             }
         }
     }
 
+    fn parse_join_parameters(
+        &self,
+        inputs: &mut Vec<Arc<GlobalRef>>,
+        children: &[Operator],
+        left_join_keys: &[Expr],
+        right_join_keys: &[Expr],
+        join_type: i32,
+        condition: &Option<Expr>,
+    ) -> Result<(JoinParameters, Vec<ScanExec>), ExecutionError> {
+        assert!(children.len() == 2);
+        let (mut left_scans, left) = self.create_plan(&children[0], inputs)?;
+        let (mut right_scans, right) = self.create_plan(&children[1], inputs)?;
+
+        left_scans.append(&mut right_scans);
+
+        let left_join_exprs: Vec<_> = left_join_keys
+            .iter()
+            .map(|expr| self.create_expr(expr, left.schema()))
+            .collect::<Result<Vec<_>, _>>()?;
+        let right_join_exprs: Vec<_> = right_join_keys
+            .iter()
+            .map(|expr| self.create_expr(expr, right.schema()))
+            .collect::<Result<Vec<_>, _>>()?;
+
+        let join_on = left_join_exprs
+            .into_iter()
+            .zip(right_join_exprs)
+            .collect::<Vec<_>>();
+
+        let join_type = match join_type.try_into() {
+            Ok(JoinType::Inner) => DFJoinType::Inner,
+            Ok(JoinType::LeftOuter) => DFJoinType::Left,
+            Ok(JoinType::RightOuter) => DFJoinType::Right,
+            Ok(JoinType::FullOuter) => DFJoinType::Full,
+            Ok(JoinType::LeftSemi) => DFJoinType::LeftSemi,
+            Ok(JoinType::RightSemi) => DFJoinType::RightSemi,
+            Ok(JoinType::LeftAnti) => DFJoinType::LeftAnti,
+            Ok(JoinType::RightAnti) => DFJoinType::RightAnti,
+            Err(_) => {
+                return Err(ExecutionError::GeneralError(format!(
+                    "Unsupported join type: {:?}",
+                    join_type
+                )));
+            }
+        };
+
+        // Handle join filter as DataFusion `JoinFilter` struct
+        let join_filter = if let Some(expr) = condition {
+            let left_schema = left.schema();
+            let right_schema = right.schema();
+            let left_fields = left_schema.fields();
+            let right_fields = right_schema.fields();
+            let all_fields: Vec<_> = left_fields
+                .into_iter()
+                .chain(right_fields)
+                .cloned()
+                .collect();
+            let full_schema = Arc::new(Schema::new(all_fields));
+
+            let physical_expr = self.create_expr(expr, full_schema)?;
+            let (left_field_indices, right_field_indices) =
+                expr_to_columns(&physical_expr, left_fields.len(), 
right_fields.len())?;
+            let column_indices = JoinFilter::build_column_indices(
+                left_field_indices.clone(),
+                right_field_indices.clone(),
+            );
+
+            let filter_fields: Vec<Field> = left_field_indices
+                .clone()
+                .into_iter()
+                .map(|i| left.schema().field(i).clone())
+                .chain(
+                    right_field_indices
+                        .clone()
+                        .into_iter()
+                        .map(|i| right.schema().field(i).clone()),
+                )
+                .collect_vec();
+
+            let filter_schema = Schema::new_with_metadata(filter_fields, 
HashMap::new());
+
+            // Rewrite the physical expression to use the new column indices.
+            // DataFusion's join filter is bound to intermediate schema which 
contains
+            // only the fields used in the filter expression. But the Spark's 
join filter
+            // expression is bound to the full schema. We need to rewrite the 
physical
+            // expression to use the new column indices.
+            let rewritten_physical_expr = rewrite_physical_expr(
+                physical_expr,
+                left_schema.fields.len(),
+                right_schema.fields.len(),
+                &left_field_indices,
+                &right_field_indices,
+            )?;
+
+            Some(JoinFilter::new(
+                rewritten_physical_expr,
+                column_indices,
+                filter_schema,
+            ))
+        } else {
+            None
+        };
+
+        // DataFusion Join operators keep the input batch internally. We need
+        // to copy the input batch to avoid the data corruption from reusing 
the input
+        // batch.
+        let left = if can_reuse_input_batch(&left) {
+            Arc::new(CopyExec::new(left))
+        } else {
+            left
+        };
+
+        let right = if can_reuse_input_batch(&right) {
+            Arc::new(CopyExec::new(right))
+        } else {
+            right
+        };
+
+        Ok((
+            JoinParameters {
+                left,
+                right,
+                join_on,
+                join_type,
+                join_filter,
+            },
+            left_scans,
+        ))
+    }
+
     /// Create a DataFusion physical aggregate expression from Spark physical 
aggregate expression
     fn create_agg_expr(
         &self,
@@ -1143,6 +1263,133 @@ fn can_reuse_input_batch(op: &Arc<dyn ExecutionPlan>) 
-> bool {
         || op.as_any().downcast_ref::<FilterExec>().is_some()
 }
 
+/// Collects the indices of the columns in the input schema that are used in 
the expression
+/// and returns them as a pair of vectors, one for the left side and one for 
the right side.
+fn expr_to_columns(
+    expr: &Arc<dyn PhysicalExpr>,
+    left_field_len: usize,
+    right_field_len: usize,
+) -> Result<(Vec<usize>, Vec<usize>), ExecutionError> {
+    let mut left_field_indices: Vec<usize> = vec![];
+    let mut right_field_indices: Vec<usize> = vec![];
+
+    expr.apply(&mut |expr| {
+        Ok({
+            if let Some(column) = expr.as_any().downcast_ref::<Column>() {
+                if column.index() > left_field_len + right_field_len {
+                    return Err(DataFusionError::Internal(format!(
+                        "Column index {} out of range",
+                        column.index()
+                    )));
+                } else if column.index() < left_field_len {
+                    left_field_indices.push(column.index());
+                } else {
+                    right_field_indices.push(column.index() - left_field_len);
+                }
+            }
+            VisitRecursion::Continue
+        })
+    })?;
+
+    left_field_indices.sort();
+    right_field_indices.sort();
+
+    Ok((left_field_indices, right_field_indices))
+}
+
+/// A physical join filter rewritter which rewrites the column indices in the 
expression
+/// to use the new column indices. See `rewrite_physical_expr`.
+struct JoinFilterRewriter<'a> {
+    left_field_len: usize,
+    right_field_len: usize,
+    left_field_indices: &'a [usize],
+    right_field_indices: &'a [usize],
+}
+
+impl JoinFilterRewriter<'_> {
+    fn new<'a>(
+        left_field_len: usize,
+        right_field_len: usize,
+        left_field_indices: &'a [usize],
+        right_field_indices: &'a [usize],
+    ) -> JoinFilterRewriter<'a> {
+        JoinFilterRewriter {
+            left_field_len,
+            right_field_len,
+            left_field_indices,
+            right_field_indices,
+        }
+    }
+}
+
+impl TreeNodeRewriter for JoinFilterRewriter<'_> {
+    type N = Arc<dyn PhysicalExpr>;
+
+    fn mutate(&mut self, node: Self::N) -> datafusion_common::Result<Self::N> {
+        let new_expr: Arc<dyn PhysicalExpr> =
+            if let Some(column) = node.as_any().downcast_ref::<Column>() {
+                if column.index() < self.left_field_len {
+                    // left side
+                    let new_index = self
+                        .left_field_indices
+                        .iter()
+                        .position(|&x| x == column.index())
+                        .ok_or_else(|| {
+                            DataFusionError::Internal(format!(
+                                "Column index {} not found in left field 
indices",
+                                column.index()
+                            ))
+                        })?;
+                    Arc::new(Column::new(column.name(), new_index))
+                } else if column.index() < self.left_field_len + 
self.right_field_len {
+                    // right side
+                    let new_index = self
+                        .right_field_indices
+                        .iter()
+                        .position(|&x| x + self.left_field_len == 
column.index())
+                        .ok_or_else(|| {
+                            DataFusionError::Internal(format!(
+                                "Column index {} not found in right field 
indices",
+                                column.index()
+                            ))
+                        })?;
+                    Arc::new(Column::new(
+                        column.name(),
+                        new_index + self.left_field_indices.len(),
+                    ))
+                } else {
+                    return Err(DataFusionError::Internal(format!(
+                        "Column index {} out of range",
+                        column.index()
+                    )));
+                }
+            } else {
+                node.clone()
+            };
+        Ok(new_expr)
+    }
+}
+
+/// Rewrites the physical expression to use the new column indices.
+/// This is necessary when the physical expression is used in a join filter, 
as the column
+/// indices are different from the original schema.
+fn rewrite_physical_expr(
+    expr: Arc<dyn PhysicalExpr>,
+    left_field_len: usize,
+    right_field_len: usize,
+    left_field_indices: &[usize],
+    right_field_indices: &[usize],
+) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
+    let mut rewriter = JoinFilterRewriter::new(
+        left_field_len,
+        right_field_len,
+        left_field_indices,
+        right_field_indices,
+    );
+
+    Ok(expr.rewrite(&mut rewriter)?)
+}
+
 #[cfg(test)]
 mod tests {
     use std::{sync::Arc, task::Poll};
diff --git a/core/src/execution/proto/operator.proto 
b/core/src/execution/proto/operator.proto
index 0b7888d..6080c56 100644
--- a/core/src/execution/proto/operator.proto
+++ b/core/src/execution/proto/operator.proto
@@ -41,6 +41,7 @@ message Operator {
     ShuffleWriter shuffle_writer = 106;
     Expand expand = 107;
     SortMergeJoin sort_merge_join = 108;
+    HashJoin hash_join = 109;
   }
 }
 
@@ -89,6 +90,13 @@ message Expand {
   int32 num_expr_per_project = 3;
 }
 
+message HashJoin {
+  repeated spark.spark_expression.Expr left_join_keys = 1;
+  repeated spark.spark_expression.Expr right_join_keys = 2;
+  JoinType join_type = 3;
+  optional spark.spark_expression.Expr condition = 4;
+}
+
 message SortMergeJoin {
   repeated spark.spark_expression.Expr left_join_keys = 1;
   repeated spark.spark_expression.Expr right_join_keys = 2;
diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 6199f47..1380ee9 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -38,7 +38,7 @@ import 
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
 import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, 
SortMergeJoinExec}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -335,6 +335,27 @@ class CometSparkSessionExtensions
               op
           }
 
+        case op: ShuffledHashJoinExec
+            if isCometOperatorEnabled(conf, "hash_join") &&
+              op.children.forall(isCometNative(_)) =>
+          val newOp = transform1(op)
+          newOp match {
+            case Some(nativeOp) =>
+              CometHashJoinExec(
+                nativeOp,
+                op,
+                op.leftKeys,
+                op.rightKeys,
+                op.joinType,
+                op.condition,
+                op.buildSide,
+                op.left,
+                op.right,
+                SerializedPlan(None))
+            case None =>
+              op
+          }
+
         case op: SortMergeJoinExec
             if isCometOperatorEnabled(conf, "sort_merge_join") &&
               op.children.forall(isCometNative(_)) =>
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 6c04fe3..bf2510b 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, Final, First, Last, Max, Min, 
Partial, Sum}
 import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
-import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
+import org.apache.spark.sql.catalyst.optimizer.{BuildRight, 
NormalizeNaNAndZero}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
@@ -36,7 +36,7 @@ import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
-import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, 
SortMergeJoinExec}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -1915,6 +1915,48 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
           }
         }
 
+      case join: ShuffledHashJoinExec if isCometOperatorEnabled(op.conf, 
"hash_join") =>
+        if (join.buildSide == BuildRight) {
+          // DataFusion HashJoin assumes build side is always left.
+          // TODO: support BuildRight
+          return None
+        }
+
+        val condition = join.condition.map { cond =>
+          val condProto = exprToProto(cond, join.left.output ++ 
join.right.output)
+          if (condProto.isEmpty) {
+            return None
+          }
+          condProto.get
+        }
+
+        val joinType = join.joinType match {
+          case Inner => JoinType.Inner
+          case LeftOuter => JoinType.LeftOuter
+          case RightOuter => JoinType.RightOuter
+          case FullOuter => JoinType.FullOuter
+          case LeftSemi => JoinType.LeftSemi
+          case LeftAnti => JoinType.LeftAnti
+          case _ => return None // Spark doesn't support other join types
+        }
+
+        val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
+        val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))
+
+        if (leftKeys.forall(_.isDefined) &&
+          rightKeys.forall(_.isDefined) &&
+          childOp.nonEmpty) {
+          val joinBuilder = OperatorOuterClass.HashJoin
+            .newBuilder()
+            .setJoinType(joinType)
+            .addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
+            .addAllRightJoinKeys(rightKeys.map(_.get).asJava)
+          condition.foreach(joinBuilder.setCondition)
+          Some(result.setHashJoin(joinBuilder).build())
+        } else {
+          None
+        }
+
       case join: SortMergeJoinExec if isCometOperatorEnabled(op.conf, 
"sort_merge_join") =>
         // `requiredOrders` and `getKeyOrdering` are copied from Spark's 
SortMergeJoinExec.
         def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 501ffd5..fb300a3 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -31,6 +31,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, 
Expression, NamedExpression, SortOrder}
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
AggregateMode}
+import org.apache.spark.sql.catalyst.optimizer.BuildSide
 import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, 
CometShuffleExchangeExec}
@@ -587,6 +588,43 @@ case class CometHashAggregateExec(
     Objects.hashCode(groupingExpressions, aggregateExpressions, input, mode, 
child)
 }
 
+case class CometHashJoinExec(
+    override val nativeOp: Operator,
+    override val originalPlan: SparkPlan,
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    joinType: JoinType,
+    condition: Option[Expression],
+    buildSide: BuildSide,
+    override val left: SparkPlan,
+    override val right: SparkPlan,
+    override val serializedPlanOpt: SerializedPlan)
+    extends CometBinaryExec {
+  override def withNewChildrenInternal(newLeft: SparkPlan, newRight: 
SparkPlan): SparkPlan =
+    this.copy(left = newLeft, right = newRight)
+
+  override def stringArgs: Iterator[Any] =
+    Iterator(leftKeys, rightKeys, joinType, condition, left, right)
+
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case other: CometHashJoinExec =>
+        this.leftKeys == other.leftKeys &&
+        this.rightKeys == other.rightKeys &&
+        this.condition == other.condition &&
+        this.buildSide == other.buildSide &&
+        this.left == other.left &&
+        this.right == other.right &&
+        this.serializedPlanOpt == other.serializedPlanOpt
+      case _ =>
+        false
+    }
+  }
+
+  override def hashCode(): Int =
+    Objects.hashCode(leftKeys, rightKeys, condition, left, right)
+}
+
 case class CometSortMergeJoinExec(
     override val nativeOp: Operator,
     override val originalPlan: SparkPlan,
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
index 73ce0e1..a64ec87 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
@@ -38,6 +38,92 @@ class CometJoinSuite extends CometTestBase {
     }
   }
 
+  test("HashJoin without join filter") {
+    withSQLConf(
+      SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
+      SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
+        withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
+          // Inner join: build left
+          val df1 =
+            sql(
+              "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b ON 
tbl_a._2 = tbl_b._1")
+          checkSparkAnswerAndOperator(df1)
+
+          // Right join: build left
+          val df2 =
+            sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN 
tbl_b ON tbl_a._2 = tbl_b._1")
+          checkSparkAnswerAndOperator(df2)
+
+          // Full join: build left
+          val df3 =
+            sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN 
tbl_b ON tbl_a._2 = tbl_b._1")
+          checkSparkAnswerAndOperator(df3)
+
+          // TODO: Spark 3.4 returns SortMergeJoin for this query even with 
SHUFFLE_HASH hint.
+          // Left join with build left and right join with build right in hash 
join is only supported
+          // in Spark 3.5 or above. See SPARK-36612.
+          //
+          // Left join: build left
+          // sql("SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a LEFT JOIN 
tbl_b ON tbl_a._2 = tbl_b._1")
+
+          // TODO: DataFusion HashJoin doesn't support build right yet.
+          // Inner join: build right
+          // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a JOIN tbl_b ON 
tbl_a._2 = tbl_b._1")
+          //
+          // Left join: build right
+          // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a LEFT JOIN 
tbl_b ON tbl_a._2 = tbl_b._1")
+          //
+          // Right join: build right
+          // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a RIGHT JOIN 
tbl_b ON tbl_a._2 = tbl_b._1")
+          //
+          // Full join: build right
+          // sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN 
tbl_b ON tbl_a._2 = tbl_b._1")
+          //
+          // val left = sql("SELECT * FROM tbl_a")
+          // val right = sql("SELECT * FROM tbl_b")
+          //
+          // Left semi and anti joins are only supported with build right in 
Spark.
+          // left.join(right, left("_2") === right("_1"), "leftsemi")
+          // left.join(right, left("_2") === right("_1"), "leftanti")
+        }
+      }
+    }
+  }
+
+  test("HashJoin with join filter") {
+    withSQLConf(
+      SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
+      SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
+        withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
+          // Inner join: build left
+          val df1 =
+            sql(
+              "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a JOIN tbl_b " +
+                "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
+          checkSparkAnswerAndOperator(df1)
+
+          // Right join: build left
+          val df2 =
+            sql(
+              "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b 
" +
+                "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
+          checkSparkAnswerAndOperator(df2)
+
+          // Full join: build left
+          val df3 =
+            sql(
+              "SELECT /*+ SHUFFLE_HASH(tbl_a) */ * FROM tbl_a FULL JOIN tbl_b 
" +
+                "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2")
+          checkSparkAnswerAndOperator(df3)
+        }
+      }
+    }
+  }
+
   // TODO: Add a test for SortMergeJoin with join filter after new DataFusion 
release
   test("SortMergeJoin without join filter") {
     withSQLConf(

Reply via email to