This is an automated email from the ASF dual-hosted git repository.
richox pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git
The following commit(s) were added to refs/heads/master by this push:
new 02fe1111 [AURON #1873] Add unified join BuildSide abstraction for
cross-spark-version compatibility (#1874)
02fe1111 is described below
commit 02fe11113155471a03923433d5549973ff4f8873
Author: yew1eb <[email protected]>
AuthorDate: Sat Jan 17 14:30:34 2026 +0800
[AURON #1873] Add unified join BuildSide abstraction for
cross-spark-version compatibility (#1874)
<!--
- Start the PR title with the related issue ID, e.g. '[AURON #XXXX]
Short summary...'.
-->
# Which issue does this PR close?
Closes #1873
# Rationale for this change
# What changes are included in this PR?
# Are there any user-facing changes?
# How was this patch tested?
---
.../org/apache/spark/sql/auron/ShimsImpl.scala | 42 +++++++++-
.../joins/auron/plan/NativeBroadcastJoinExec.scala | 18 ++--
.../plan/NativeShuffledHashJoinExecProvider.scala | 28 +++----
.../spark/sql/auron/AuronConvertStrategy.scala | 4 +-
.../apache/spark/sql/auron/AuronConverters.scala | 97 +++++++++-------------
.../scala/org/apache/spark/sql/auron/Shims.scala | 7 +-
.../spark/sql/auron/join/JoinBuildSides.scala | 23 +++++
.../auron/plan/NativeBroadcastJoinBase.scala | 27 +++---
.../auron/plan/NativeShuffledHashJoinBase.scala | 11 +--
9 files changed, 144 insertions(+), 113 deletions(-)
diff --git
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
index d0667164..cb9492c9 100644
---
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
+++
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SQLContext
import
org.apache.spark.sql.auron.AuronConverters.ForceNativeExecutionWrapperBase
import org.apache.spark.sql.auron.NativeConverters.NativeExprWrapperBase
+import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft,
JoinBuildRight, JoinBuildSide}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -96,6 +97,7 @@ import
org.apache.spark.sql.execution.auron.plan.NativeWindowExec
import
org.apache.spark.sql.execution.auron.shuffle.{AuronBlockStoreShuffleReaderBase,
AuronRssShuffleManagerBase, RssPartitionWriterBase}
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike,
ReusedExchangeExec}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec,
BroadcastNestedLoopJoinExec, ShuffledHashJoinExec}
import org.apache.spark.sql.execution.joins.auron.plan.NativeBroadcastJoinExec
import
org.apache.spark.sql.execution.joins.auron.plan.NativeShuffledHashJoinExecProvider
import
org.apache.spark.sql.execution.joins.auron.plan.NativeSortMergeJoinExecProvider
@@ -227,7 +229,7 @@ class ShimsImpl extends Shims with Logging {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- broadcastSide: BroadcastSide,
+ broadcastSide: JoinBuildSide,
isNullAwareAntiJoin: Boolean): NativeBroadcastJoinBase =
NativeBroadcastJoinExec(
left,
@@ -260,7 +262,7 @@ class ShimsImpl extends Shims with Logging {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- buildSide: BuildSide,
+ buildSide: JoinBuildSide,
isSkewJoin: Boolean): SparkPlan =
NativeShuffledHashJoinExecProvider.provide(
left,
@@ -1036,6 +1038,42 @@ class ShimsImpl extends Shims with Logging {
override def getAdaptiveInputPlan(exec: AdaptiveSparkPlanExec): SparkPlan = {
exec.inputPlan
}
+
+ private def convertJoinBuildSide(
+ exec: SparkPlan,
+ isBuildLeft: Any => Boolean): JoinBuildSide = {
+ exec match {
+ case shj: ShuffledHashJoinExec =>
+ if (isBuildLeft(shj.buildSide)) JoinBuildLeft else JoinBuildRight
+ case bhj: BroadcastHashJoinExec =>
+ if (isBuildLeft(bhj.buildSide)) JoinBuildLeft else JoinBuildRight
+ case bnlj: BroadcastNestedLoopJoinExec =>
+ if (isBuildLeft(bnlj.buildSide)) JoinBuildLeft else JoinBuildRight
+ case other => throw new IllegalArgumentException(s"Unsupported SparkPlan
type: $other")
+ }
+ }
+
+ @sparkver("3.0")
+ override def getJoinBuildSide(exec: SparkPlan): JoinBuildSide = {
+ import org.apache.spark.sql.execution.joins.BuildLeft
+ convertJoinBuildSide(
+ exec,
+ isBuildLeft = {
+ case BuildLeft => true
+ case _ => false
+ })
+ }
+
+ @sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
+ override def getJoinBuildSide(exec: SparkPlan): JoinBuildSide = {
+ import org.apache.spark.sql.catalyst.optimizer.BuildLeft
+ convertJoinBuildSide(
+ exec,
+ isBuildLeft = {
+ case BuildLeft => true
+ case _ => false
+ })
+ }
}
case class ForceNativeExecutionWrapper(override val child: SparkPlan)
diff --git
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
index 9ac6e893..fd51bec3 100644
---
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
+++
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeBroadcastJoinExec.scala
@@ -16,13 +16,11 @@
*/
package org.apache.spark.sql.execution.joins.auron.plan
+import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft,
JoinBuildRight, JoinBuildSide}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.auron.plan.BroadcastLeft
-import org.apache.spark.sql.execution.auron.plan.BroadcastRight
-import org.apache.spark.sql.execution.auron.plan.BroadcastSide
import org.apache.spark.sql.execution.auron.plan.NativeBroadcastJoinBase
import org.apache.spark.sql.execution.joins.HashJoin
@@ -35,7 +33,7 @@ case class NativeBroadcastJoinExec(
override val leftKeys: Seq[Expression],
override val rightKeys: Seq[Expression],
override val joinType: JoinType,
- broadcastSide: BroadcastSide,
+ broadcastSide: JoinBuildSide,
isNullAwareAntiJoin: Boolean)
extends NativeBroadcastJoinBase(
left,
@@ -53,14 +51,14 @@ case class NativeBroadcastJoinExec(
@sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
override def buildSide: org.apache.spark.sql.catalyst.optimizer.BuildSide =
broadcastSide match {
- case BroadcastLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft
- case BroadcastRight => org.apache.spark.sql.catalyst.optimizer.BuildRight
+ case JoinBuildLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft
+ case JoinBuildRight => org.apache.spark.sql.catalyst.optimizer.BuildRight
}
@sparkver("3.0")
override val buildSide: org.apache.spark.sql.execution.joins.BuildSide =
broadcastSide match {
- case BroadcastLeft => org.apache.spark.sql.execution.joins.BuildLeft
- case BroadcastRight => org.apache.spark.sql.execution.joins.BuildRight
+ case JoinBuildLeft => org.apache.spark.sql.execution.joins.BuildLeft
+ case JoinBuildRight => org.apache.spark.sql.execution.joins.BuildRight
}
@sparkver("3.1 / 3.2 / 3.3 / 3.4 / 3.5")
@@ -71,9 +69,9 @@ case class NativeBroadcastJoinExec(
def mode = HashedRelationBroadcastMode(buildBoundKeys, isNullAware = false)
broadcastSide match {
- case BroadcastLeft =>
+ case JoinBuildLeft =>
BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
- case BroadcastRight =>
+ case JoinBuildRight =>
UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
}
}
diff --git
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala
index f742cb71..0236dd26 100644
---
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala
+++
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/execution/joins/auron/plan/NativeShuffledHashJoinExecProvider.scala
@@ -16,10 +16,10 @@
*/
package org.apache.spark.sql.execution.joins.auron.plan
+import org.apache.spark.sql.auron.join.JoinBuildSides.JoinBuildSide
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.auron.plan.BuildSide
import org.apache.spark.sql.execution.auron.plan.NativeShuffledHashJoinBase
import org.apache.spark.sql.execution.joins.HashJoin
@@ -34,7 +34,7 @@ case object NativeShuffledHashJoinExecProvider {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- buildSide: BuildSide,
+ buildSide: JoinBuildSide,
isSkewJoin: Boolean): NativeShuffledHashJoinBase = {
import org.apache.spark.rdd.RDD
@@ -47,7 +47,7 @@ case object NativeShuffledHashJoinExecProvider {
override val leftKeys: Seq[Expression],
override val rightKeys: Seq[Expression],
override val joinType: JoinType,
- buildSide: BuildSide,
+ buildSide: JoinBuildSide,
skewJoin: Boolean)
extends NativeShuffledHashJoinBase(left, right, leftKeys, rightKeys,
joinType, buildSide)
with org.apache.spark.sql.execution.joins.ShuffledJoin {
@@ -87,12 +87,11 @@ case object NativeShuffledHashJoinExecProvider {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- buildSide: BuildSide,
+ buildSide: JoinBuildSide,
isSkewJoin: Boolean): NativeShuffledHashJoinBase = {
+ import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft,
JoinBuildRight}
import org.apache.spark.sql.catalyst.expressions.SortOrder
- import org.apache.spark.sql.execution.auron.plan.BuildLeft
- import org.apache.spark.sql.execution.auron.plan.BuildRight
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
case class NativeShuffledHashJoinExec(
@@ -101,7 +100,7 @@ case object NativeShuffledHashJoinExecProvider {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- buildSide: BuildSide)
+ buildSide: JoinBuildSide)
extends NativeShuffledHashJoinBase(left, right, leftKeys, rightKeys,
joinType, buildSide)
with org.apache.spark.sql.execution.joins.ShuffledJoin {
@@ -112,8 +111,8 @@ case object NativeShuffledHashJoinExecProvider {
override def outputOrdering: Seq[SortOrder] = {
val sparkBuildSide = buildSide match {
- case BuildLeft => org.apache.spark.sql.catalyst.optimizer.BuildLeft
- case BuildRight => org.apache.spark.sql.catalyst.optimizer.BuildRight
+ case JoinBuildLeft =>
org.apache.spark.sql.catalyst.optimizer.BuildLeft
+ case JoinBuildRight =>
org.apache.spark.sql.catalyst.optimizer.BuildRight
}
val shj =
ShuffledHashJoinExec(leftKeys, rightKeys, joinType, sparkBuildSide,
None, left, right)
@@ -135,12 +134,11 @@ case object NativeShuffledHashJoinExecProvider {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- buildSide: BuildSide,
+ buildSide: JoinBuildSide,
isSkewJoin: Boolean): NativeShuffledHashJoinBase = {
+ import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft,
JoinBuildRight}
import org.apache.spark.sql.catalyst.expressions.Attribute
- import org.apache.spark.sql.execution.auron.plan.BuildLeft
- import org.apache.spark.sql.execution.auron.plan.BuildRight
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
case class NativeShuffledHashJoinExec(
@@ -149,7 +147,7 @@ case object NativeShuffledHashJoinExecProvider {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- buildSide: BuildSide)
+ buildSide: JoinBuildSide)
extends NativeShuffledHashJoinBase(
left,
right,
@@ -160,8 +158,8 @@ case object NativeShuffledHashJoinExecProvider {
private def shj: ShuffledHashJoinExec = {
val sparkBuildSide = buildSide match {
- case BuildLeft => org.apache.spark.sql.execution.joins.BuildLeft
- case BuildRight => org.apache.spark.sql.execution.joins.BuildRight
+ case JoinBuildLeft => org.apache.spark.sql.execution.joins.BuildLeft
+ case JoinBuildRight =>
org.apache.spark.sql.execution.joins.BuildRight
}
ShuffledHashJoinExec(leftKeys, rightKeys, joinType, sparkBuildSide,
None, left, right)
}
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConvertStrategy.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConvertStrategy.scala
index fb7471d4..77e4c663 100644
---
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConvertStrategy.scala
+++
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConvertStrategy.scala
@@ -18,13 +18,13 @@ package org.apache.spark.sql.auron
import org.apache.commons.lang3.reflect.MethodUtils
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.auron.join.JoinBuildSides.JoinBuildSide
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
import org.apache.spark.sql.execution.aggregate.SortAggregateExec
-import org.apache.spark.sql.execution.auron.plan.BuildSide
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -44,7 +44,7 @@ object AuronConvertStrategy extends Logging {
val neverConvertReasonTag: TreeNodeTag[String] =
TreeNodeTag("auron.never.convert.reason")
val childOrderingRequiredTag: TreeNodeTag[Boolean] = TreeNodeTag(
"auron.child.ordering.required")
- val joinSmallerSideTag: TreeNodeTag[BuildSide] =
TreeNodeTag("auron.join.smallerSide")
+ val joinSmallerSideTag: TreeNodeTag[JoinBuildSide] =
TreeNodeTag("auron.join.smallerSide")
def apply(exec: SparkPlan): Unit = {
exec.foreach(_.setTagValue(convertibleTag, true))
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
index 491f85da..4f124bd8 100644
---
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
+++
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala
@@ -29,6 +29,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.{config, Logging}
import
org.apache.spark.sql.auron.AuronConvertStrategy.{childOrderingRequiredTag,
convertibleTag, convertStrategyTag, convertToNonNativeTag, isNeverConvert,
joinSmallerSideTag, neverConvertReasonTag}
import org.apache.spark.sql.auron.NativeConverters.{existTimestampType,
isTypeSupported, roundRobinTypeSupported, StubExpr}
+import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft,
JoinBuildRight, JoinBuildSide}
import org.apache.spark.sql.auron.util.AuronLogUtils.logDebugPlanConversion
import org.apache.spark.sql.catalyst.expressions.AggregateWindowFunction
import org.apache.spark.sql.catalyst.expressions.Alias
@@ -53,8 +54,6 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
import org.apache.spark.sql.execution.aggregate.SortAggregateExec
-import org.apache.spark.sql.execution.auron.plan.BroadcastLeft
-import org.apache.spark.sql.execution.auron.plan.BroadcastRight
import org.apache.spark.sql.execution.auron.plan.ConvertToNativeBase
import org.apache.spark.sql.execution.auron.plan.NativeAggBase
import org.apache.spark.sql.execution.auron.plan.NativeBroadcastExchangeBase
@@ -156,14 +155,6 @@ object AuronConverters extends Logging {
supportedShuffleManagers.exists(name.contains)
}
- // format: off
- // scalafix:off
- // necessary imports for cross spark versions build
- import org.apache.spark.sql.catalyst.plans._
- import org.apache.spark.sql.catalyst.optimizer._
- // scalafix:on
- // format: on
-
def convertSparkPlanRecursively(exec: SparkPlan): SparkPlan = {
// convert
var danglingConverted: Seq[SparkPlan] = Nil
@@ -549,15 +540,14 @@ object AuronConverters extends Logging {
"condition" -> condition))
assert(condition.isEmpty, "join condition is not supported")
- val buildSide = exec.getTagValue(joinSmallerSideTag) match {
- case Some(org.apache.spark.sql.execution.auron.plan.BuildLeft) =>
- org.apache.spark.sql.execution.auron.plan.BuildLeft
- case Some(org.apache.spark.sql.execution.auron.plan.BuildRight) =>
- org.apache.spark.sql.execution.auron.plan.BuildRight
- case None =>
+ val buildSide = exec
+ .getTagValue(joinSmallerSideTag)
+ .map(_.asInstanceOf[JoinBuildSide])
+ .getOrElse {
logWarning("JoinSmallerSideTag is missing, defaults to BuildRight")
- org.apache.spark.sql.execution.auron.plan.BuildRight
- }
+ JoinBuildRight
+ }
+
return Shims.get.createNativeShuffledHashJoinExec(
addRenameColumnsExec(convertToNative(left.children(0))),
addRenameColumnsExec(convertToNative(right.children(0))),
@@ -596,14 +586,9 @@ object AuronConverters extends Logging {
}
def convertShuffledHashJoinExec(exec: ShuffledHashJoinExec): SparkPlan = {
- val (leftKeys, rightKeys, joinType, condition, left, right, buildSide) = (
- exec.leftKeys,
- exec.rightKeys,
- exec.joinType,
- exec.condition,
- exec.left,
- exec.right,
- exec.buildSide)
+ val buildSide = Shims.get.getJoinBuildSide(exec)
+ val (leftKeys, rightKeys, joinType, condition, left, right) =
+ (exec.leftKeys, exec.rightKeys, exec.joinType, exec.condition,
exec.left, exec.right)
logDebugPlanConversion(
exec,
Seq(
@@ -620,10 +605,7 @@ object AuronConverters extends Logging {
leftKeys,
rightKeys,
joinType,
- buildSide match {
- case BuildLeft => org.apache.spark.sql.execution.auron.plan.BuildLeft
- case BuildRight =>
org.apache.spark.sql.execution.auron.plan.BuildRight
- },
+ buildSide,
getIsSkewJoinFromSHJ(exec))
} catch {
@@ -671,16 +653,17 @@ object AuronConverters extends Logging {
def isNullAwareAntiJoin(exec: BroadcastHashJoinExec): Boolean = false
def convertBroadcastHashJoinExec(exec: BroadcastHashJoinExec): SparkPlan = {
+ val buildSide = Shims.get.getJoinBuildSide(exec)
try {
- val (leftKeys, rightKeys, joinType, buildSide, condition, left, right,
naaj) = (
- exec.leftKeys,
- exec.rightKeys,
- exec.joinType,
- exec.buildSide,
- exec.condition,
- exec.left,
- exec.right,
- isNullAwareAntiJoin(exec))
+ val (leftKeys, rightKeys, joinType, condition, left, right, naaj) =
+ (
+ exec.leftKeys,
+ exec.rightKeys,
+ exec.joinType,
+ exec.condition,
+ exec.left,
+ exec.right,
+ isNullAwareAntiJoin(exec))
logDebugPlanConversion(
exec,
Seq(
@@ -693,9 +676,9 @@ object AuronConverters extends Logging {
// verify build side is native
buildSide match {
- case BuildRight =>
+ case JoinBuildRight =>
assert(NativeHelper.isNative(right), "broadcast join build side is
not native")
- case BuildLeft =>
+ case JoinBuildLeft =>
assert(NativeHelper.isNative(left), "broadcast join build side is
not native")
}
@@ -706,17 +689,14 @@ object AuronConverters extends Logging {
leftKeys,
rightKeys,
joinType,
- buildSide match {
- case BuildLeft => BroadcastLeft
- case BuildRight => BroadcastRight
- },
+ buildSide,
naaj)
} catch {
case e @ (_: NotImplementedError | _: Exception) =>
- val underlyingBroadcast = exec.buildSide match {
- case BuildLeft => Shims.get.getUnderlyingBroadcast(exec.left)
- case BuildRight => Shims.get.getUnderlyingBroadcast(exec.right)
+ val underlyingBroadcast = buildSide match {
+ case JoinBuildLeft => Shims.get.getUnderlyingBroadcast(exec.left)
+ case JoinBuildRight => Shims.get.getUnderlyingBroadcast(exec.right)
}
underlyingBroadcast.setTagValue(NativeBroadcastExchangeBase.nativeExecutionTag,
false)
throw e
@@ -724,9 +704,10 @@ object AuronConverters extends Logging {
}
def convertBroadcastNestedLoopJoinExec(exec: BroadcastNestedLoopJoinExec):
SparkPlan = {
+ val buildSide = Shims.get.getJoinBuildSide(exec)
try {
- val (joinType, buildSide, condition, left, right) =
- (exec.joinType, exec.buildSide, exec.condition, exec.left, exec.right)
+ val (joinType, condition, left, right) =
+ (exec.joinType, exec.condition, exec.left, exec.right)
logDebugPlanConversion(
exec,
Seq("joinType" -> joinType, "condition" -> condition, "buildSide" ->
buildSide))
@@ -735,9 +716,9 @@ object AuronConverters extends Logging {
// verify build side is native
buildSide match {
- case BuildRight =>
+ case JoinBuildRight =>
assert(NativeHelper.isNative(right), "broadcast join build side is
not native")
- case BuildLeft =>
+ case JoinBuildLeft =>
assert(NativeHelper.isNative(left), "broadcast join build side is
not native")
}
@@ -749,17 +730,13 @@ object AuronConverters extends Logging {
Nil,
Nil,
joinType,
- buildSide match {
- case BuildLeft => BroadcastLeft
- case BuildRight => BroadcastRight
- },
+ buildSide,
isNullAwareAntiJoin = false)
-
} catch {
case e @ (_: NotImplementedError | _: Exception) =>
- val underlyingBroadcast = exec.buildSide match {
- case BuildLeft => Shims.get.getUnderlyingBroadcast(exec.left)
- case BuildRight => Shims.get.getUnderlyingBroadcast(exec.right)
+ val underlyingBroadcast = buildSide match {
+ case JoinBuildLeft => Shims.get.getUnderlyingBroadcast(exec.left)
+ case JoinBuildRight => Shims.get.getUnderlyingBroadcast(exec.right)
}
underlyingBroadcast.setTagValue(NativeBroadcastExchangeBase.nativeExecutionTag,
false)
throw e
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala
index a0dd37ae..d2489726 100644
--- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala
+++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/Shims.scala
@@ -28,6 +28,7 @@ import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.auron.join.JoinBuildSides.JoinBuildSide
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -86,7 +87,7 @@ abstract class Shims {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- broadcastSide: BroadcastSide,
+ broadcastSide: JoinBuildSide,
isNullAwareAntiJoin: Boolean): NativeBroadcastJoinBase
def createNativeSortMergeJoinExec(
@@ -103,7 +104,7 @@ abstract class Shims {
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- buildSide: BuildSide,
+ buildSide: JoinBuildSide,
isSkewJoin: Boolean): SparkPlan
def createNativeExpandExec(
@@ -261,6 +262,8 @@ abstract class Shims {
def postTransform(plan: SparkPlan, sc: SparkContext): Unit = {}
def getAdaptiveInputPlan(exec: AdaptiveSparkPlanExec): SparkPlan
+
+ def getJoinBuildSide(exec: SparkPlan): JoinBuildSide
}
object Shims {
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/join/JoinBuildSides.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/join/JoinBuildSides.scala
new file mode 100644
index 00000000..d5077509
--- /dev/null
+++
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/join/JoinBuildSides.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron.join
+
+object JoinBuildSides {
+ sealed trait JoinBuildSide
+ case object JoinBuildLeft extends JoinBuildSide
+ case object JoinBuildRight extends JoinBuildSide
+}
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala
index 3281947c..c9b20f19 100644
---
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala
+++
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.auron.NativeHelper
import org.apache.spark.sql.auron.NativeRDD
import org.apache.spark.sql.auron.NativeSupports
import org.apache.spark.sql.auron.Shims
+import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft,
JoinBuildRight, JoinBuildSide}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.FullOuter
import org.apache.spark.sql.catalyst.plans.JoinType
@@ -52,7 +53,7 @@ abstract class NativeBroadcastJoinBase(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- broadcastSide: BroadcastSide,
+ broadcastSide: JoinBuildSide,
isNullAwareAntiJoin: Boolean)
extends BinaryExecNode
with NativeSupports {
@@ -76,8 +77,8 @@ abstract class NativeBroadcastJoinBase(
{
val baseBroadcast = broadcastSide match {
- case BroadcastLeft => Shims.get.getUnderlyingBroadcast(left)
- case BroadcastRight => Shims.get.getUnderlyingBroadcast(right)
+ case JoinBuildLeft => Shims.get.getUnderlyingBroadcast(left)
+ case JoinBuildRight => Shims.get.getUnderlyingBroadcast(right)
}
val mode = baseBroadcast match {
case b: BroadcastExchangeExec => b.mode
@@ -112,8 +113,8 @@ abstract class NativeBroadcastJoinBase(
private def nativeJoinType = NativeConverters.convertJoinType(joinType)
private def nativeBroadcastSide = broadcastSide match {
- case BroadcastLeft => pb.JoinSide.LEFT_SIDE
- case BroadcastRight => pb.JoinSide.RIGHT_SIDE
+ case JoinBuildLeft => pb.JoinSide.LEFT_SIDE
+ case JoinBuildRight => pb.JoinSide.RIGHT_SIDE
}
protected def rewriteKeyExprToLong(exprs: Seq[Expression]): Seq[Expression]
@@ -133,14 +134,14 @@ abstract class NativeBroadcastJoinBase(
val nativeJoinOn = this.nativeJoinOn
val (probedRDD, builtRDD) = broadcastSide match {
- case BroadcastLeft => (rightRDD, leftRDD)
- case BroadcastRight => (leftRDD, rightRDD)
+ case JoinBuildLeft => (rightRDD, leftRDD)
+ case JoinBuildRight => (leftRDD, rightRDD)
}
val probedShuffleReadFull = probedRDD.isShuffleReadFull && (broadcastSide
match {
- case BroadcastLeft =>
+ case JoinBuildLeft =>
Seq(FullOuter, RightOuter).contains(joinType)
- case BroadcastRight =>
+ case JoinBuildRight =>
Seq(FullOuter, LeftOuter, LeftSemi, LeftAnti).contains(joinType)
})
@@ -156,11 +157,11 @@ abstract class NativeBroadcastJoinBase(
override def index: Int = 0
}
val (leftChild, rightChild) = broadcastSide match {
- case BroadcastLeft =>
+ case JoinBuildLeft =>
(
leftRDD.nativePlan(partition0, context),
rightRDD.nativePlan(rightRDD.partitions(partition.index),
context))
- case BroadcastRight =>
+ case JoinBuildRight =>
(
leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
rightRDD.nativePlan(partition0, context))
@@ -183,7 +184,3 @@ abstract class NativeBroadcastJoinBase(
friendlyName = "NativeRDD.BroadcastJoin")
}
}
-
-class BroadcastSide {}
-case object BroadcastLeft extends BroadcastSide {}
-case object BroadcastRight extends BroadcastSide {}
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffledHashJoinBase.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffledHashJoinBase.scala
index b9fc8de8..1f8a06c8 100644
---
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffledHashJoinBase.scala
+++
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeShuffledHashJoinBase.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.auron.NativeConverters
import org.apache.spark.sql.auron.NativeHelper
import org.apache.spark.sql.auron.NativeRDD
import org.apache.spark.sql.auron.NativeSupports
+import org.apache.spark.sql.auron.join.JoinBuildSides.{JoinBuildLeft,
JoinBuildRight, JoinBuildSide}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.RightOuter
@@ -40,7 +41,7 @@ abstract class NativeShuffledHashJoinBase(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
- buildSide: BuildSide)
+ buildSide: JoinBuildSide)
extends BinaryExecNode
with NativeSupports {
@@ -79,8 +80,8 @@ abstract class NativeShuffledHashJoinBase(
private def nativeJoinType = NativeConverters.convertJoinType(joinType)
private def nativeBuildSide = buildSide match {
- case BuildLeft => pb.JoinSide.LEFT_SIDE
- case BuildRight => pb.JoinSide.RIGHT_SIDE
+ case JoinBuildLeft => pb.JoinSide.LEFT_SIDE
+ case JoinBuildRight => pb.JoinSide.RIGHT_SIDE
}
protected def rewriteKeyExprToLong(exprs: Seq[Expression]): Seq[Expression]
@@ -133,7 +134,3 @@ abstract class NativeShuffledHashJoinBase(
friendlyName = "NativeRDD.ShuffledHashJoin")
}
}
-
-class BuildSide {}
-case object BuildLeft extends BuildSide {}
-case object BuildRight extends BuildSide {}