This is an automated email from the ASF dual-hosted git repository.

arnavbalyan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new da6cb16ac3 [GLUTEN-9719][VL] Unify output definitions for 
Broadcast/Hash/Columnar shuffle joins
da6cb16ac3 is described below

commit da6cb16ac3e1f4d0a37032a65f52818de24cc93d
Author: Arnav Balyan <[email protected]>
AuthorDate: Tue May 27 21:08:50 2025 +0530

    [GLUTEN-9719][VL] Unify output definitions for Broadcast/Hash/Columnar 
shuffle joins
---
 .../BroadcastNestedLoopJoinExecTransformer.scala   | 22 ++----------
 .../gluten/execution/JoinExecTransformer.scala     | 40 +++-------------------
 .../org/apache/gluten/execution/JoinUtils.scala    | 13 ++++---
 3 files changed, 15 insertions(+), 60 deletions(-)

diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala
index b9e124b608..d1d7bb308f 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala
@@ -25,7 +25,7 @@ import org.apache.gluten.utils.SubstraitUtil
 
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
BuildSide}
-import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, 
InnerLike, JoinType, LeftExistence, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, 
InnerLike, JoinType, LeftOuter, RightOuter}
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan}
 import org.apache.spark.sql.execution.joins.BaseJoinExec
@@ -79,24 +79,8 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
     
BackendsApiManager.getMetricsApiInstance.genNestedLoopJoinTransformerMetricsUpdater(metrics)
   }
 
-  override def output: Seq[Attribute] = {
-    joinType match {
-      case _: InnerLike =>
-        left.output ++ right.output
-      case LeftOuter =>
-        left.output ++ right.output.map(_.withNullability(true))
-      case RightOuter =>
-        left.output.map(_.withNullability(true)) ++ right.output
-      case j: ExistenceJoin =>
-        left.output :+ j.exists
-      case LeftExistence(_) =>
-        left.output
-      case FullOuter =>
-        left.output.map(_.withNullability(true)) ++ 
right.output.map(_.withNullability(true))
-      case x =>
-        throw new IllegalArgumentException(s"${getClass.getSimpleName} not 
take $x as the JoinType")
-    }
-  }
+  override def output: Seq[Attribute] =
+    JoinUtils.getDirectJoinOutputSeq(joinType, left.output, right.output, 
getClass.getSimpleName)
 
   override def outputPartitioning: Partitioning = buildSide match {
     case BuildLeft =>
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
index 8bac2752e3..374812bb59 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
@@ -71,24 +71,8 @@ trait ColumnarShuffledJoin extends BaseJoinExec {
       throw new IllegalArgumentException(s"ShuffledJoin should not take $x as 
the JoinType")
   }
 
-  override def output: Seq[Attribute] = {
-    joinType match {
-      case _: InnerLike =>
-        left.output ++ right.output
-      case LeftOuter =>
-        left.output ++ right.output.map(_.withNullability(true))
-      case RightOuter =>
-        left.output.map(_.withNullability(true)) ++ right.output
-      case FullOuter =>
-        (left.output ++ right.output).map(_.withNullability(true))
-      case j: ExistenceJoin =>
-        left.output :+ j.exists
-      case LeftExistence(_) =>
-        left.output
-      case x =>
-        throw new IllegalArgumentException(s"${getClass.getSimpleName} not 
take $x as the JoinType")
-    }
-  }
+  override def output: Seq[Attribute] =
+    JoinUtils.getDirectJoinOutputSeq(joinType, left.output, right.output, 
getClass.getSimpleName)
 }
 
 /** Performs a hash join of two child relations by first shuffling the data 
using the join keys. */
@@ -370,24 +354,8 @@ abstract class BroadcastHashJoinExecTransformerBase(
     isNullAwareAntiJoin: Boolean)
   extends HashJoinLikeExecTransformer {
 
-  override def output: Seq[Attribute] = {
-    joinType match {
-      case _: InnerLike =>
-        left.output ++ right.output
-      case LeftOuter =>
-        left.output ++ right.output.map(_.withNullability(true))
-      case RightOuter =>
-        left.output.map(_.withNullability(true)) ++ right.output
-      case FullOuter =>
-        (left.output ++ right.output).map(_.withNullability(true))
-      case j: ExistenceJoin =>
-        left.output :+ j.exists
-      case LeftExistence(_) =>
-        left.output
-      case x =>
-        throw new IllegalArgumentException(s"${getClass.getSimpleName} not 
take $x as the JoinType")
-    }
-  }
+  override def output: Seq[Attribute] =
+    JoinUtils.getDirectJoinOutputSeq(joinType, left.output, right.output, 
getClass.getSimpleName)
 
   override def requiredChildDistribution: Seq[Distribution] = {
     val mode = HashedRelationBroadcastMode(buildKeyExprs, isNullAwareAntiJoin)
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
index 5304aa1c46..e6bc3484bc 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
@@ -133,7 +133,8 @@ object JoinUtils {
   private def getDirectJoinOutput(
       joinType: JoinType,
       leftOutput: Seq[Attribute],
-      rightOutput: Seq[Attribute]): (Seq[Attribute], Seq[Attribute]) = {
+      rightOutput: Seq[Attribute],
+      callerClassName: String = null): (Seq[Attribute], Seq[Attribute]) = {
     joinType match {
       case _: InnerLike =>
         (leftOutput, rightOutput)
@@ -149,15 +150,17 @@ object JoinUtils {
         // LeftSemi | LeftAnti | ExistenceJoin.
         (leftOutput, Nil)
       case x =>
-        throw new IllegalArgumentException(s"${getClass.getSimpleName} not 
take $x as the JoinType")
+        val joinClass = 
Option(callerClassName).getOrElse(this.getClass.getSimpleName)
+        throw new IllegalArgumentException(s"$joinClass not take $x as the 
JoinType")
     }
   }
 
-  private def getDirectJoinOutputSeq(
+  def getDirectJoinOutputSeq(
       joinType: JoinType,
       leftOutput: Seq[Attribute],
-      rightOutput: Seq[Attribute]): Seq[Attribute] = {
-    val (left, right) = getDirectJoinOutput(joinType, leftOutput, rightOutput)
+      rightOutput: Seq[Attribute],
+      joinClassName: String = null): Seq[Attribute] = {
+    val (left, right) = getDirectJoinOutput(joinType, leftOutput, rightOutput, 
joinClassName)
     left ++ right
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to