This is an automated email from the ASF dual-hosted git repository.

richox pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git


The following commit(s) were added to refs/heads/master by this push:
     new 02fe1111 [AURON #1873] Add unified join BuildSide abstraction for 
cross-spark-version compatibility (#1874)
02fe1111 is described below

commit 02fe11113155471a03923433d5549973ff4f8873
Author: yew1eb <[email protected]>
AuthorDate: Sat Jan 17 14:30:34 2026 +0800

    [AURON #1873] Add unified join BuildSide abstraction for 
cross-spark-version compatibility (#1874)
    
    <!--
    - Start the PR title with the related issue ID, e.g. '[AURON #XXXX]
    Short summary...'.
    -->
    # Which issue does this PR close?
    
    Closes #1873
    
    # Rationale for this change
    
    # What changes are included in this PR?
    
    # Are there any user-facing changes?
    
    # How was this patch tested?
---
 .../org/apache/spark/sql/auron/ShimsImpl.scala     | 42 +++++++++-
 .../joins/auron/plan/NativeBroadcastJoinExec.scala | 18 ++--
 .../plan/NativeShuffledHashJoinExecProvider.scala  | 28 +++----
 .../spark/sql/auron/AuronConvertStrategy.scala     |  4 +-
 .../apache/spark/sql/auron/AuronConverters.scala   | 97 +++++++++-------------
 .../scala/org/apache/spark/sql/auron/Shims.scala   |  7 +-
 .../spark/sql/auron/join/JoinBuildSides.scala      | 23 +++++
 .../auron/plan/NativeBroadcastJoinBase.scala       | 27 +++---
 .../auron/plan/NativeShuffledHashJoinBase.scala    | 11 +--
 9 files changed, 144 insertions(+), 113 deletions(-)

diff --git 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
index d0667164..cb9492c9 100644
--- 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
+++ 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.SQLContext
 import 
org.apache.spark.sql.auron.AuronConverters.ForceNativeExecutionWrapperBase
 import org.apache.spark.sql.auron.NativeConverters.NativeExprWrapperBase
+import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft, 
JoinBuildRight, JoinBuildSide}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.catalog.CatalogTable
 import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -96,6 +97,7 @@ import 
org.apache.spark.sql.execution.auron.plan.NativeWindowExec
 import 
org.apache.spark.sql.execution.auron.shuffle.{AuronBlockStoreShuffleReaderBase, 
AuronRssShuffleManagerBase, RssPartitionWriterBase}
 import org.apache.spark.sql.execution.datasources.PartitionedFile
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, 
ReusedExchangeExec}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
BroadcastNestedLoopJoinExec, ShuffledHashJoinExec}
 import org.apache.spark.sql.execution.joins.auron.plan.NativeBroadcastJoinExec
 import 
org.apache.spark.sql.execution.joins.auron.plan.NativeShuffledHashJoinExecProvider
 import 
org.apache.spark.sql.execution.joins.auron.plan.NativeSortMergeJoinExecProvider
@@ -227,7 +229,7 @@ class ShimsImpl extends Shims with Logging {
       leftKeys: Seq[Expression],
       rightKeys: Seq[Expression],
       joinType: JoinType,
-      broadcastSide: BroadcastSide,
+      broadcastSide: JoinBuildSide,
       isNullAwareAntiJoin: Boolean): NativeBroadcastJoinBase =
     NativeBroadcastJoinExec(
       left,
@@ -260,7 +262,7 @@ class ShimsImpl extends Shims with Logging {
       leftKeys: Seq[Expression],
       rightKeys: Seq[Expression],
       joinType: JoinType,
-      buildSide: BuildSide,
+      buildSide: JoinBuildSide,
       isSkewJoin: Boolean): SparkPlan =
     NativeShuffledHashJoinExecProvider.provide(
       left,
@@ -1036,6 +1038,42 @@ class ShimsImpl extends Shims with Logging {
   override def getAdaptiveInputPlan(exec: AdaptiveSparkPlanExec): SparkPlan = {
     exec.inputPlan
   }
+
+  private def convertJoinBuildSide(
+      exec: SparkPlan,
+      isBuildLeft: Any => Boolean): JoinBuildSide = {
+    exec match {
+      case shj: ShuffledHashJoinExec =>
+        if (isBuildLeft(shj.buildSide)) JoinBuildLeft else JoinBuildRight
+      case bhj: BroadcastHashJoinExec =>
+        if (isBuildLeft(bhj.buildSide)) JoinBuildLeft else JoinBuildRight
+      case bnlj: BroadcastNestedLoopJoinExec =>
+        if (isBuildLeft(bnlj.buildSide)) JoinBuildLeft else JoinBuildRight
+      case other => throw new IllegalArgumentException(s"Unsupported SparkPlan 
type: $other")
+    }
+  }
+
+  @sparkver("3.0")
+  override def getJoinBuildSide(exec: SparkPlan): JoinBuildSide = {
+    import org.apache.spark.sql.execution.joins.BuildLeft
+    convertJoinBuildSide(
+      exec,
+      isBuildLeft = {
+        case BuildLeft => true
+        case _ => false
+      })
+  }
+
+  @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
+  override def getJoinBuildSide(exec: SparkPlan): JoinBuildSide = {
+    import org.apache.spark.sql.catalyst.optimizer.BuildLeft
+    convertJoinBuildSide(
+      exec,
+      isBuildLeft = {
+        case BuildLeft => true
+        case _ => false
+      })
+  }
 }
 
 case class ForceNativeExecutionWrapper(override val child: SparkPlan)
diff --git 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
index 9ac6e893..fd51bec3 100644
--- 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
+++ 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
@@ -16,13 +16,11 @@
  */
 package org.apache.spark.sql.execution.joins.auron.plan
 
+import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft, 
JoinBuildRight, JoinBuildSide}
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.auron.plan.BroadcastLeft
-import org.apache.spark.sql.execution.auron.plan.BroadcastRight
-import org.apache.spark.sql.execution.auron.plan.BroadcastSide
 import org.apache.spark.sql.execution.auron.plan.NativeBroadcastJoinBase
 import org.apache.spark.sql.execution.joins.HashJoin
 
@@ -35,7 +33,7 @@ case class NativeBroadcastJoinExec(
     override val leftKeys: Seq[Expression],
     override val rightKeys: Seq[Expression],
     override val joinType: JoinType,
-    broadcastSide: BroadcastSide,
+    broadcastSide: JoinBuildSide,
     isNullAwareAntiJoin: Boolean)
     extends NativeBroadcastJoinBase(
       left,
@@ -53,14 +51,14 @@ case class NativeBroadcastJoinExec(
   @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
   override def buildSide: org.apache.spark.sql.catalyst.optimizer.BuildSide =
     broadcastSide match {
-      case BroadcastLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft
-      case BroadcastRight => org.apache.spark.sql.catalyst.optimizer.BuildRight
+      case JoinBuildLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft
+      case JoinBuildRight => org.apache.spark.sql.catalyst.optimizer.BuildRight
     }
 
   @sparkver("3.0")
   override val buildSide: org.apache.spark.sql.execution.joins.BuildSide = 
broadcastSide match {
-    case BroadcastLeft => org.apache.spark.sql.execution.joins.BuildLeft
-    case BroadcastRight => org.apache.spark.sql.execution.joins.BuildRight
+    case JoinBuildLeft => org.apache.spark.sql.execution.joins.BuildLeft
+    case JoinBuildRight => org.apache.spark.sql.execution.joins.BuildRight
   }
 
   @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
@@ -71,9 +69,9 @@ case class NativeBroadcastJoinExec(
 
     def mode = HashedRelationBroadcastMode(buildBoundKeys, isNullAware = false)
     broadcastSide match {
-      case BroadcastLeft =>
+      case JoinBuildLeft =>
         BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
-      case BroadcastRight =>
+      case JoinBuildRight =>
         UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
     }
   }
diff --git 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala
 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala
index f742cb71..0236dd26 100644
--- 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala
+++ 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala
@@ -16,10 +16,10 @@
  */
 package org.apache.spark.sql.execution.joins.auron.plan
 
+import org.apache.spark.sql.auron.join.JoinBuildSides.JoinBuildSide
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.auron.plan.BuildSide
 import org.apache.spark.sql.execution.auron.plan.NativeShuffledHashJoinBase
 import org.apache.spark.sql.execution.joins.HashJoin
 
@@ -34,7 +34,7 @@ case object NativeShuffledHashJoinExecProvider {
       leftKeys: Seq[Expression],
       rightKeys: Seq[Expression],
       joinType: JoinType,
-      buildSide: BuildSide,
+      buildSide: JoinBuildSide,
       isSkewJoin: Boolean): NativeShuffledHashJoinBase = {
 
     import org.apache.spark.rdd.RDD
@@ -47,7 +47,7 @@ case object NativeShuffledHashJoinExecProvider {
         override val leftKeys: Seq[Expression],
         override val rightKeys: Seq[Expression],
         override val joinType: JoinType,
-        buildSide: BuildSide,
+        buildSide: JoinBuildSide,
         skewJoin: Boolean)
         extends NativeShuffledHashJoinBase(left, right, leftKeys, rightKeys, 
joinType, buildSide)
         with org.apache.spark.sql.execution.joins.ShuffledJoin {
@@ -87,12 +87,11 @@ case object NativeShuffledHashJoinExecProvider {
       leftKeys: Seq[Expression],
       rightKeys: Seq[Expression],
       joinType: JoinType,
-      buildSide: BuildSide,
+      buildSide: JoinBuildSide,
       isSkewJoin: Boolean): NativeShuffledHashJoinBase = {
 
+    import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft, 
JoinBuildRight}
     import org.apache.spark.sql.catalyst.expressions.SortOrder
-    import org.apache.spark.sql.execution.auron.plan.BuildLeft
-    import org.apache.spark.sql.execution.auron.plan.BuildRight
     import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
 
     case class NativeShuffledHashJoinExec(
@@ -101,7 +100,7 @@ case object NativeShuffledHashJoinExecProvider {
         leftKeys: Seq[Expression],
         rightKeys: Seq[Expression],
         joinType: JoinType,
-        buildSide: BuildSide)
+        buildSide: JoinBuildSide)
         extends NativeShuffledHashJoinBase(left, right, leftKeys, rightKeys, 
joinType, buildSide)
         with org.apache.spark.sql.execution.joins.ShuffledJoin {
 
@@ -112,8 +111,8 @@ case object NativeShuffledHashJoinExecProvider {
 
       override def outputOrdering: Seq[SortOrder] = {
         val sparkBuildSide = buildSide match {
-          case BuildLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft
-          case BuildRight => org.apache.spark.sql.catalyst.optimizer.BuildRight
+          case JoinBuildLeft => 
org.apache.spark.sql.catalyst.optimizer.BuildLeft
+          case JoinBuildRight => 
org.apache.spark.sql.catalyst.optimizer.BuildRight
         }
         val shj =
           ShuffledHashJoinExec(leftKeys, rightKeys, joinType, sparkBuildSide, 
None, left, right)
@@ -135,12 +134,11 @@ case object NativeShuffledHashJoinExecProvider {
       leftKeys: Seq[Expression],
       rightKeys: Seq[Expression],
       joinType: JoinType,
-      buildSide: BuildSide,
+      buildSide: JoinBuildSide,
       isSkewJoin: Boolean): NativeShuffledHashJoinBase = {
 
+    import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft, 
JoinBuildRight}
     import org.apache.spark.sql.catalyst.expressions.Attribute
-    import org.apache.spark.sql.execution.auron.plan.BuildLeft
-    import org.apache.spark.sql.execution.auron.plan.BuildRight
     import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
 
     case class NativeShuffledHashJoinExec(
@@ -149,7 +147,7 @@ case object NativeShuffledHashJoinExecProvider {
         leftKeys: Seq[Expression],
         rightKeys: Seq[Expression],
         joinType: JoinType,
-        buildSide: BuildSide)
+        buildSide: JoinBuildSide)
         extends NativeShuffledHashJoinBase(
           left,
           right,
@@ -160,8 +158,8 @@ case object NativeShuffledHashJoinExecProvider {
 
       private def shj: ShuffledHashJoinExec = {
         val sparkBuildSide = buildSide match {
-          case BuildLeft => org.apache.spark.sql.execution.joins.BuildLeft
-          case BuildRight => org.apache.spark.sql.execution.joins.BuildRight
+          case JoinBuildLeft => org.apache.spark.sql.execution.joins.BuildLeft
+          case JoinBuildRight => 
org.apache.spark.sql.execution.joins.BuildRight
         }
         ShuffledHashJoinExec(leftKeys, rightKeys, joinType, sparkBuildSide, 
None, left, right)
       }
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConvertStrategy.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConvertStrategy.scala
index fb7471d4..77e4c663 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConvertStrategy.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConvertStrategy.scala
@@ -18,13 +18,13 @@ package org.apache.spark.sql.auron
 
 import org.apache.commons.lang3.reflect.MethodUtils
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.auron.join.JoinBuildSides.JoinBuildSide
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
 import org.apache.spark.sql.execution.aggregate.SortAggregateExec
-import org.apache.spark.sql.execution.auron.plan.BuildSide
 import org.apache.spark.sql.execution.command.DataWritingCommandExec
 import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -44,7 +44,7 @@ object AuronConvertStrategy extends Logging {
   val neverConvertReasonTag: TreeNodeTag[String] = 
TreeNodeTag("auron.never.convert.reason")
   val childOrderingRequiredTag: TreeNodeTag[Boolean] = TreeNodeTag(
     "auron.child.ordering.required")
-  val joinSmallerSideTag: TreeNodeTag[BuildSide] = 
TreeNodeTag("auron.join.smallerSide")
+  val joinSmallerSideTag: TreeNodeTag[JoinBuildSide] = 
TreeNodeTag("auron.join.smallerSide")
 
   def apply(exec: SparkPlan): Unit = {
     exec.foreach(_.setTagValue(convertibleTag, true))
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
index 491f85da..4f124bd8 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
@@ -29,6 +29,7 @@ import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.{config, Logging}
 import 
org.apache.spark.sql.auron.AuronConvertStrategy.{childOrderingRequiredTag, 
convertibleTag, convertStrategyTag, convertToNonNativeTag, isNeverConvert, 
joinSmallerSideTag, neverConvertReasonTag}
 import org.apache.spark.sql.auron.NativeConverters.{existTimestampType, 
isTypeSupported, roundRobinTypeSupported, StubExpr}
+import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft, 
JoinBuildRight, JoinBuildSide}
 import org.apache.spark.sql.auron.util.AuronLogUtils.logDebugPlanConversion
 import org.apache.spark.sql.catalyst.expressions.AggregateWindowFunction
 import org.apache.spark.sql.catalyst.expressions.Alias
@@ -53,8 +54,6 @@ import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
 import org.apache.spark.sql.execution.aggregate.SortAggregateExec
-import org.apache.spark.sql.execution.auron.plan.BroadcastLeft
-import org.apache.spark.sql.execution.auron.plan.BroadcastRight
 import org.apache.spark.sql.execution.auron.plan.ConvertToNativeBase
 import org.apache.spark.sql.execution.auron.plan.NativeAggBase
 import org.apache.spark.sql.execution.auron.plan.NativeBroadcastExchangeBase
@@ -156,14 +155,6 @@ object AuronConverters extends Logging {
     supportedShuffleManagers.exists(name.contains)
   }
 
-  // format: off
-  // scalafix:off
-  // necessary imports for cross spark versions build
-  import org.apache.spark.sql.catalyst.plans._
-  import org.apache.spark.sql.catalyst.optimizer._
-  // scalafix:on
-  // format: on
-
   def convertSparkPlanRecursively(exec: SparkPlan): SparkPlan = {
     // convert
     var danglingConverted: Seq[SparkPlan] = Nil
@@ -549,15 +540,14 @@ object AuronConverters extends Logging {
           "condition" -> condition))
       assert(condition.isEmpty, "join condition is not supported")
 
-      val buildSide = exec.getTagValue(joinSmallerSideTag) match {
-        case Some(org.apache.spark.sql.execution.auron.plan.BuildLeft) =>
-          org.apache.spark.sql.execution.auron.plan.BuildLeft
-        case Some(org.apache.spark.sql.execution.auron.plan.BuildRight) =>
-          org.apache.spark.sql.execution.auron.plan.BuildRight
-        case None =>
+      val buildSide = exec
+        .getTagValue(joinSmallerSideTag)
+        .map(_.asInstanceOf[JoinBuildSide])
+        .getOrElse {
           logWarning("JoinSmallerSideTag is missing, defaults to BuildRight")
-          org.apache.spark.sql.execution.auron.plan.BuildRight
-      }
+          JoinBuildRight
+        }
+
       return Shims.get.createNativeShuffledHashJoinExec(
         addRenameColumnsExec(convertToNative(left.children(0))),
         addRenameColumnsExec(convertToNative(right.children(0))),
@@ -596,14 +586,9 @@ object AuronConverters extends Logging {
   }
 
   def convertShuffledHashJoinExec(exec: ShuffledHashJoinExec): SparkPlan = {
-    val (leftKeys, rightKeys, joinType, condition, left, right, buildSide) = (
-      exec.leftKeys,
-      exec.rightKeys,
-      exec.joinType,
-      exec.condition,
-      exec.left,
-      exec.right,
-      exec.buildSide)
+    val buildSide = Shims.get.getJoinBuildSide(exec)
+    val (leftKeys, rightKeys, joinType, condition, left, right) =
+      (exec.leftKeys, exec.rightKeys, exec.joinType, exec.condition, 
exec.left, exec.right)
     logDebugPlanConversion(
       exec,
       Seq(
@@ -620,10 +605,7 @@ object AuronConverters extends Logging {
         leftKeys,
         rightKeys,
         joinType,
-        buildSide match {
-          case BuildLeft => org.apache.spark.sql.execution.auron.plan.BuildLeft
-          case BuildRight => 
org.apache.spark.sql.execution.auron.plan.BuildRight
-        },
+        buildSide,
         getIsSkewJoinFromSHJ(exec))
 
     } catch {
@@ -671,16 +653,17 @@ object AuronConverters extends Logging {
   def isNullAwareAntiJoin(exec: BroadcastHashJoinExec): Boolean = false
 
   def convertBroadcastHashJoinExec(exec: BroadcastHashJoinExec): SparkPlan = {
+    val buildSide = Shims.get.getJoinBuildSide(exec)
     try {
-      val (leftKeys, rightKeys, joinType, buildSide, condition, left, right, 
naaj) = (
-        exec.leftKeys,
-        exec.rightKeys,
-        exec.joinType,
-        exec.buildSide,
-        exec.condition,
-        exec.left,
-        exec.right,
-        isNullAwareAntiJoin(exec))
+      val (leftKeys, rightKeys, joinType, condition, left, right, naaj) =
+        (
+          exec.leftKeys,
+          exec.rightKeys,
+          exec.joinType,
+          exec.condition,
+          exec.left,
+          exec.right,
+          isNullAwareAntiJoin(exec))
       logDebugPlanConversion(
         exec,
         Seq(
@@ -693,9 +676,9 @@ object AuronConverters extends Logging {
 
       // verify build side is native
       buildSide match {
-        case BuildRight =>
+        case JoinBuildRight =>
           assert(NativeHelper.isNative(right), "broadcast join build side is 
not native")
-        case BuildLeft =>
+        case JoinBuildLeft =>
           assert(NativeHelper.isNative(left), "broadcast join build side is 
not native")
       }
 
@@ -706,17 +689,14 @@ object AuronConverters extends Logging {
         leftKeys,
         rightKeys,
         joinType,
-        buildSide match {
-          case BuildLeft => BroadcastLeft
-          case BuildRight => BroadcastRight
-        },
+        buildSide,
         naaj)
 
     } catch {
       case e @ (_: NotImplementedError | _: Exception) =>
-        val underlyingBroadcast = exec.buildSide match {
-          case BuildLeft => Shims.get.getUnderlyingBroadcast(exec.left)
-          case BuildRight => Shims.get.getUnderlyingBroadcast(exec.right)
+        val underlyingBroadcast = buildSide match {
+          case JoinBuildLeft => Shims.get.getUnderlyingBroadcast(exec.left)
+          case JoinBuildRight => Shims.get.getUnderlyingBroadcast(exec.right)
         }
         
underlyingBroadcast.setTagValue(NativeBroadcastExchangeBase.nativeExecutionTag, 
false)
         throw e
@@ -724,9 +704,10 @@ object AuronConverters extends Logging {
   }
 
   def convertBroadcastNestedLoopJoinExec(exec: BroadcastNestedLoopJoinExec): 
SparkPlan = {
+    val buildSide = Shims.get.getJoinBuildSide(exec)
     try {
-      val (joinType, buildSide, condition, left, right) =
-        (exec.joinType, exec.buildSide, exec.condition, exec.left, exec.right)
+      val (joinType, condition, left, right) =
+        (exec.joinType, exec.condition, exec.left, exec.right)
       logDebugPlanConversion(
         exec,
         Seq("joinType" -> joinType, "condition" -> condition, "buildSide" -> 
buildSide))
@@ -735,9 +716,9 @@ object AuronConverters extends Logging {
 
       // verify build side is native
       buildSide match {
-        case BuildRight =>
+        case JoinBuildRight =>
           assert(NativeHelper.isNative(right), "broadcast join build side is 
not native")
-        case BuildLeft =>
+        case JoinBuildLeft =>
           assert(NativeHelper.isNative(left), "broadcast join build side is 
not native")
       }
 
@@ -749,17 +730,13 @@ object AuronConverters extends Logging {
         Nil,
         Nil,
         joinType,
-        buildSide match {
-          case BuildLeft => BroadcastLeft
-          case BuildRight => BroadcastRight
-        },
+        buildSide,
         isNullAwareAntiJoin = false)
-
     } catch {
       case e @ (_: NotImplementedError | _: Exception) =>
-        val underlyingBroadcast = exec.buildSide match {
-          case BuildLeft => Shims.get.getUnderlyingBroadcast(exec.left)
-          case BuildRight => Shims.get.getUnderlyingBroadcast(exec.right)
+        val underlyingBroadcast = buildSide match {
+          case JoinBuildLeft => Shims.get.getUnderlyingBroadcast(exec.left)
+          case JoinBuildRight => Shims.get.getUnderlyingBroadcast(exec.right)
         }
         
underlyingBroadcast.setTagValue(NativeBroadcastExchangeBase.nativeExecutionTag, 
false)
         throw e
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala
index a0dd37ae..d2489726 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala
@@ -28,6 +28,7 @@ import org.apache.spark.shuffle.ShuffleHandle
 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.auron.join.JoinBuildSides.JoinBuildSide
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.catalog.CatalogTable
 import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -86,7 +87,7 @@ abstract class Shims {
       leftKeys: Seq[Expression],
       rightKeys: Seq[Expression],
       joinType: JoinType,
-      broadcastSide: BroadcastSide,
+      broadcastSide: JoinBuildSide,
       isNullAwareAntiJoin: Boolean): NativeBroadcastJoinBase
 
   def createNativeSortMergeJoinExec(
@@ -103,7 +104,7 @@ abstract class Shims {
       leftKeys: Seq[Expression],
       rightKeys: Seq[Expression],
       joinType: JoinType,
-      buildSide: BuildSide,
+      buildSide: JoinBuildSide,
       isSkewJoin: Boolean): SparkPlan
 
   def createNativeExpandExec(
@@ -261,6 +262,8 @@ abstract class Shims {
   def postTransform(plan: SparkPlan, sc: SparkContext): Unit = {}
 
   def getAdaptiveInputPlan(exec: AdaptiveSparkPlanExec): SparkPlan
+
+  def getJoinBuildSide(exec: SparkPlan): JoinBuildSide
 }
 
 object Shims {
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/join/JoinBuildSides.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/join/JoinBuildSides.scala
new file mode 100644
index 00000000..d5077509
--- /dev/null
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/join/JoinBuildSides.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron.join
+
+object JoinBuildSides {
+  sealed trait JoinBuildSide
+  case object JoinBuildLeft extends JoinBuildSide
+  case object JoinBuildRight extends JoinBuildSide
+}
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala
index 3281947c..c9b20f19 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.auron.NativeHelper
 import org.apache.spark.sql.auron.NativeRDD
 import org.apache.spark.sql.auron.NativeSupports
 import org.apache.spark.sql.auron.Shims
+import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft, 
JoinBuildRight, JoinBuildSide}
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.plans.FullOuter
 import org.apache.spark.sql.catalyst.plans.JoinType
@@ -52,7 +53,7 @@ abstract class NativeBroadcastJoinBase(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
     joinType: JoinType,
-    broadcastSide: BroadcastSide,
+    broadcastSide: JoinBuildSide,
     isNullAwareAntiJoin: Boolean)
     extends BinaryExecNode
     with NativeSupports {
@@ -76,8 +77,8 @@ abstract class NativeBroadcastJoinBase(
 
   {
     val baseBroadcast = broadcastSide match {
-      case BroadcastLeft => Shims.get.getUnderlyingBroadcast(left)
-      case BroadcastRight => Shims.get.getUnderlyingBroadcast(right)
+      case JoinBuildLeft => Shims.get.getUnderlyingBroadcast(left)
+      case JoinBuildRight => Shims.get.getUnderlyingBroadcast(right)
     }
     val mode = baseBroadcast match {
       case b: BroadcastExchangeExec => b.mode
@@ -112,8 +113,8 @@ abstract class NativeBroadcastJoinBase(
   private def nativeJoinType = NativeConverters.convertJoinType(joinType)
 
   private def nativeBroadcastSide = broadcastSide match {
-    case BroadcastLeft => pb.JoinSide.LEFT_SIDE
-    case BroadcastRight => pb.JoinSide.RIGHT_SIDE
+    case JoinBuildLeft => pb.JoinSide.LEFT_SIDE
+    case JoinBuildRight => pb.JoinSide.RIGHT_SIDE
   }
 
   protected def rewriteKeyExprToLong(exprs: Seq[Expression]): Seq[Expression]
@@ -133,14 +134,14 @@ abstract class NativeBroadcastJoinBase(
     val nativeJoinOn = this.nativeJoinOn
 
     val (probedRDD, builtRDD) = broadcastSide match {
-      case BroadcastLeft => (rightRDD, leftRDD)
-      case BroadcastRight => (leftRDD, rightRDD)
+      case JoinBuildLeft => (rightRDD, leftRDD)
+      case JoinBuildRight => (leftRDD, rightRDD)
     }
 
     val probedShuffleReadFull = probedRDD.isShuffleReadFull && (broadcastSide 
match {
-      case BroadcastLeft =>
+      case JoinBuildLeft =>
         Seq(FullOuter, RightOuter).contains(joinType)
-      case BroadcastRight =>
+      case JoinBuildRight =>
         Seq(FullOuter, LeftOuter, LeftSemi, LeftAnti).contains(joinType)
     })
 
@@ -156,11 +157,11 @@ abstract class NativeBroadcastJoinBase(
           override def index: Int = 0
         }
         val (leftChild, rightChild) = broadcastSide match {
-          case BroadcastLeft =>
+          case JoinBuildLeft =>
             (
               leftRDD.nativePlan(partition0, context),
               rightRDD.nativePlan(rightRDD.partitions(partition.index), 
context))
-          case BroadcastRight =>
+          case JoinBuildRight =>
             (
               leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
               rightRDD.nativePlan(partition0, context))
@@ -183,7 +184,3 @@ abstract class NativeBroadcastJoinBase(
       friendlyName = "NativeRDD.BroadcastJoin")
   }
 }
-
-class BroadcastSide {}
-case object BroadcastLeft extends BroadcastSide {}
-case object BroadcastRight extends BroadcastSide {}
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffledHashJoinBase.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffledHashJoinBase.scala
index b9fc8de8..1f8a06c8 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffledHashJoinBase.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffledHashJoinBase.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.auron.NativeConverters
 import org.apache.spark.sql.auron.NativeHelper
 import org.apache.spark.sql.auron.NativeRDD
 import org.apache.spark.sql.auron.NativeSupports
+import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft, 
JoinBuildRight, JoinBuildSide}
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.plans.RightOuter
@@ -40,7 +41,7 @@ abstract class NativeShuffledHashJoinBase(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
     joinType: JoinType,
-    buildSide: BuildSide)
+    buildSide: JoinBuildSide)
     extends BinaryExecNode
     with NativeSupports {
 
@@ -79,8 +80,8 @@ abstract class NativeShuffledHashJoinBase(
   private def nativeJoinType = NativeConverters.convertJoinType(joinType)
 
   private def nativeBuildSide = buildSide match {
-    case BuildLeft => pb.JoinSide.LEFT_SIDE
-    case BuildRight => pb.JoinSide.RIGHT_SIDE
+    case JoinBuildLeft => pb.JoinSide.LEFT_SIDE
+    case JoinBuildRight => pb.JoinSide.RIGHT_SIDE
   }
 
   protected def rewriteKeyExprToLong(exprs: Seq[Expression]): Seq[Expression]
@@ -133,7 +134,3 @@ abstract class NativeShuffledHashJoinBase(
       friendlyName = "NativeRDD.ShuffledHashJoin")
   }
 }
-
-class BuildSide {}
-case object BuildLeft extends BuildSide {}
-case object BuildRight extends BuildSide {}

Reply via email to