This is an automated email from the ASF dual-hosted git repository.
arnavbalyan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new da6cb16ac3 [GLUTEN-9719][VL] Unify output definitions for
Broadcast/Hash/Columnar shuffle joins
da6cb16ac3 is described below
commit da6cb16ac3e1f4d0a37032a65f52818de24cc93d
Author: Arnav Balyan <[email protected]>
AuthorDate: Tue May 27 21:08:50 2025 +0530
[GLUTEN-9719][VL] Unify output definitions for Broadcast/Hash/Columnar
shuffle joins
---
.../BroadcastNestedLoopJoinExecTransformer.scala | 22 ++----------
.../gluten/execution/JoinExecTransformer.scala | 40 +++-------------------
.../org/apache/gluten/execution/JoinUtils.scala | 13 ++++---
3 files changed, 15 insertions(+), 60 deletions(-)
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala
index b9e124b608..d1d7bb308f 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala
@@ -25,7 +25,7 @@ import org.apache.gluten.utils.SubstraitUtil
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide}
-import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter,
InnerLike, JoinType, LeftExistence, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter,
InnerLike, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan}
import org.apache.spark.sql.execution.joins.BaseJoinExec
@@ -79,24 +79,8 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
BackendsApiManager.getMetricsApiInstance.genNestedLoopJoinTransformerMetricsUpdater(metrics)
}
- override def output: Seq[Attribute] = {
- joinType match {
- case _: InnerLike =>
- left.output ++ right.output
- case LeftOuter =>
- left.output ++ right.output.map(_.withNullability(true))
- case RightOuter =>
- left.output.map(_.withNullability(true)) ++ right.output
- case j: ExistenceJoin =>
- left.output :+ j.exists
- case LeftExistence(_) =>
- left.output
- case FullOuter =>
- left.output.map(_.withNullability(true)) ++
right.output.map(_.withNullability(true))
- case x =>
- throw new IllegalArgumentException(s"${getClass.getSimpleName} not
take $x as the JoinType")
- }
- }
+ override def output: Seq[Attribute] =
+ JoinUtils.getDirectJoinOutputSeq(joinType, left.output, right.output,
getClass.getSimpleName)
override def outputPartitioning: Partitioning = buildSide match {
case BuildLeft =>
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
index 8bac2752e3..374812bb59 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala
@@ -71,24 +71,8 @@ trait ColumnarShuffledJoin extends BaseJoinExec {
throw new IllegalArgumentException(s"ShuffledJoin should not take $x as
the JoinType")
}
- override def output: Seq[Attribute] = {
- joinType match {
- case _: InnerLike =>
- left.output ++ right.output
- case LeftOuter =>
- left.output ++ right.output.map(_.withNullability(true))
- case RightOuter =>
- left.output.map(_.withNullability(true)) ++ right.output
- case FullOuter =>
- (left.output ++ right.output).map(_.withNullability(true))
- case j: ExistenceJoin =>
- left.output :+ j.exists
- case LeftExistence(_) =>
- left.output
- case x =>
- throw new IllegalArgumentException(s"${getClass.getSimpleName} not
take $x as the JoinType")
- }
- }
+ override def output: Seq[Attribute] =
+ JoinUtils.getDirectJoinOutputSeq(joinType, left.output, right.output,
getClass.getSimpleName)
}
/** Performs a hash join of two child relations by first shuffling the data
using the join keys. */
@@ -370,24 +354,8 @@ abstract class BroadcastHashJoinExecTransformerBase(
isNullAwareAntiJoin: Boolean)
extends HashJoinLikeExecTransformer {
- override def output: Seq[Attribute] = {
- joinType match {
- case _: InnerLike =>
- left.output ++ right.output
- case LeftOuter =>
- left.output ++ right.output.map(_.withNullability(true))
- case RightOuter =>
- left.output.map(_.withNullability(true)) ++ right.output
- case FullOuter =>
- (left.output ++ right.output).map(_.withNullability(true))
- case j: ExistenceJoin =>
- left.output :+ j.exists
- case LeftExistence(_) =>
- left.output
- case x =>
- throw new IllegalArgumentException(s"${getClass.getSimpleName} not
take $x as the JoinType")
- }
- }
+ override def output: Seq[Attribute] =
+ JoinUtils.getDirectJoinOutputSeq(joinType, left.output, right.output,
getClass.getSimpleName)
override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode(buildKeyExprs, isNullAwareAntiJoin)
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
index 5304aa1c46..e6bc3484bc 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/execution/JoinUtils.scala
@@ -133,7 +133,8 @@ object JoinUtils {
private def getDirectJoinOutput(
joinType: JoinType,
leftOutput: Seq[Attribute],
- rightOutput: Seq[Attribute]): (Seq[Attribute], Seq[Attribute]) = {
+ rightOutput: Seq[Attribute],
+ callerClassName: String = null): (Seq[Attribute], Seq[Attribute]) = {
joinType match {
case _: InnerLike =>
(leftOutput, rightOutput)
@@ -149,15 +150,17 @@ object JoinUtils {
// LeftSemi | LeftAnti | ExistenceJoin.
(leftOutput, Nil)
case x =>
- throw new IllegalArgumentException(s"${getClass.getSimpleName} not
take $x as the JoinType")
+ val joinClass =
Option(callerClassName).getOrElse(this.getClass.getSimpleName)
+ throw new IllegalArgumentException(s"$joinClass not take $x as the
JoinType")
}
}
- private def getDirectJoinOutputSeq(
+ def getDirectJoinOutputSeq(
joinType: JoinType,
leftOutput: Seq[Attribute],
- rightOutput: Seq[Attribute]): Seq[Attribute] = {
- val (left, right) = getDirectJoinOutput(joinType, leftOutput, rightOutput)
+ rightOutput: Seq[Attribute],
+ joinClassName: String = null): Seq[Attribute] = {
+ val (left, right) = getDirectJoinOutput(joinType, leftOutput, rightOutput,
joinClassName)
left ++ right
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]