This is an automated email from the ASF dual-hosted git repository.
philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 15f4cde02b [GLUTEN-8115][CORE] Refine the BuildSideRelation transform
to support all scenarios (#8116)
15f4cde02b is described below
commit 15f4cde02bf965e70bae5079c40aae78ad80e45a
Author: Kaifei Yi <[email protected]>
AuthorDate: Mon Dec 9 14:24:23 2024 +0800
[GLUTEN-8115][CORE] Refine the BuildSideRelation transform to support all
scenarios (#8116)
---
.../org/apache/gluten/utils/PlanNodesUtil.scala | 49 +++++++++++-----------
.../joins/ClickHouseBuildSideRelation.scala | 17 ++++++--
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 2 +-
.../spark/sql/execution/BroadcastUtils.scala | 6 ++-
.../sql/execution/ColumnarBuildSideRelation.scala | 38 ++++++++++-------
.../execution/ColumnarSubqueryBroadcastExec.scala | 17 +++++++-
.../sql/execution/joins/BuildSideRelation.scala | 11 ++++-
7 files changed, 90 insertions(+), 50 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala
index 9dcb7ee3c4..d6511f7a4a 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala
@@ -22,13 +22,17 @@ import org.apache.gluten.substrait.expression.ExpressionNode
import org.apache.gluten.substrait.plan.{PlanBuilder, PlanNode}
import org.apache.gluten.substrait.rel.RelBuilder
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, BoundReference, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference,
Expression}
import com.google.common.collect.Lists
+import java.util
+
+import scala.collection.JavaConverters._
+
object PlanNodesUtil {
- def genProjectionsPlanNode(key: Expression, output: Seq[Attribute]):
PlanNode = {
+ def genProjectionsPlanNode(key: Seq[Expression], output: Seq[Attribute]):
PlanNode = {
val context = new SubstraitContext
var operatorId =
context.nextOperatorId("ClickHouseBuildSideRelationReadIter")
@@ -36,41 +40,36 @@ object PlanNodesUtil {
val nameList = ConverterUtils.collectAttributeNamesWithExprId(output)
val readRel = RelBuilder.makeReadRelForInputIterator(typeList, nameList,
context, operatorId)
- // replace attribute to BoundRefernce according to the output
- val newBoundRefKey = key.transformDown {
- case expression: AttributeReference =>
- val columnInOutput = output.zipWithIndex.filter {
- p: (Attribute, Int) => p._1.exprId == expression.exprId || p._1.name
== expression.name
- }
- if (columnInOutput.isEmpty) {
- throw new IllegalStateException(
- s"Key $expression not found from build side relation output:
$output")
- }
- if (columnInOutput.size != 1) {
- throw new IllegalStateException(
- s"More than one key $expression found from build side relation
output: $output")
- }
- val boundReference = columnInOutput.head
- BoundReference(boundReference._2, boundReference._1.dataType,
boundReference._1.nullable)
- case other => other
- }
-
// project
operatorId =
context.nextOperatorId("ClickHouseBuildSideRelationProjection")
val args = context.registeredFunction
val columnarProjExpr = ExpressionConverter
- .replaceWithExpressionTransformer(newBoundRefKey, attributeSeq = output)
+ .replaceWithExpressionTransformer(key, attributeSeq = output)
val projExprNodeList = new java.util.ArrayList[ExpressionNode]()
- projExprNodeList.add(columnarProjExpr.doTransform(args))
+ columnarProjExpr.foreach(e => projExprNodeList.add(e.doTransform(args)))
PlanBuilder.makePlan(
context,
Lists.newArrayList(
RelBuilder.makeProjectRel(readRel, projExprNodeList, context,
operatorId, output.size)),
- Lists.newArrayList(
-
ConverterUtils.genColumnNameWithExprId(ConverterUtils.getAttrFromExpr(key)))
+ Lists.newArrayList(genColumnNameWithExprId(key, output))
)
}
+
+ private def genColumnNameWithExprId(
+ key: Seq[Expression],
+ output: Seq[Attribute]): util.List[String] = {
+ key
+ .map {
+ k =>
+ val reference = k.collectFirst { case BoundReference(ordinal, _, _)
=> output(ordinal) }
+ assert(reference.isDefined)
+ reference.get
+ }
+ .map(ConverterUtils.genColumnNameWithExprId)
+ .toList
+ .asJava
+ }
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/joins/ClickHouseBuildSideRelation.scala
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/joins/ClickHouseBuildSideRelation.scala
index 92887f16d7..668525ba0a 100644
---
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/joins/ClickHouseBuildSideRelation.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/joins/ClickHouseBuildSideRelation.scala
@@ -22,8 +22,8 @@ import org.apache.gluten.vectorized._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
UnsafeProjection}
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode,
IdentityBroadcastMode}
import org.apache.spark.sql.execution.utils.CHExecUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.CHShuffleReadStreamFactory
@@ -72,7 +72,7 @@ case class ClickHouseBuildSideRelation(
}
/**
- * Transform columnar broadcast value to Array[InternalRow] by key and
distinct.
+ * Transform columnar broadcast value to Array[InternalRow] by key.
*
* @return
*/
@@ -80,10 +80,18 @@ case class ClickHouseBuildSideRelation(
// native block reader
val blockReader = new
CHStreamReader(CHShuffleReadStreamFactory.create(batches, true))
val broadCastIter: Iterator[ColumnarBatch] =
IteratorUtil.createBatchIterator(blockReader)
+
+ val transformProjections = mode match {
+ case HashedRelationBroadcastMode(k, _) => k
+ case IdentityBroadcastMode => output
+ }
+
// Expression compute, return block iterator
val expressionEval = new SimpleExpressionEval(
new ColumnarNativeIterator(broadCastIter.asJava),
- PlanNodesUtil.genProjectionsPlanNode(key, output))
+ PlanNodesUtil.genProjectionsPlanNode(transformProjections, output))
+
+ val proj = UnsafeProjection.create(Seq(key))
try {
// convert columnar to row
@@ -95,6 +103,7 @@ case class ClickHouseBuildSideRelation(
} else {
CHExecUtil
.getRowIterFromSparkRowInfo(block, batch.numColumns(),
batch.numRows())
+ .map(proj)
.map(row => row.copy())
}
}.toArray
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 2df2e2718e..d837ac4234 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -632,7 +632,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
}
numOutputRows += serialized.map(_.getNumRows).sum
dataSize += rawSize
- ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized))
+ ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized),
mode)
}
override def doCanonicalizeForBroadcastMode(mode: BroadcastMode):
BroadcastMode = {
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
index 11a8cc9809..c5323d4f8d 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
@@ -106,7 +106,8 @@ object BroadcastUtils {
}
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
- serialized)
+ serialized,
+ mode)
}
// Rebroadcast Velox relation.
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
@@ -124,7 +125,8 @@ object BroadcastUtils {
}
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
- serialized)
+ serialized,
+ mode)
}
// Rebroadcast Velox relation.
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
index fa3d348967..977357990c 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
@@ -26,8 +26,11 @@ import org.apache.gluten.utils.ArrowAbiUtil
import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper,
NativeColumnarToRowJniWrapper}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, Expression, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
+import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode
import org.apache.spark.sql.execution.joins.BuildSideRelation
+import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.utils.SparkArrowUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -37,9 +40,19 @@ import org.apache.arrow.c.ArrowSchema
import scala.collection.JavaConverters.asScalaIteratorConverter
-case class ColumnarBuildSideRelation(output: Seq[Attribute], batches:
Array[Array[Byte]])
+case class ColumnarBuildSideRelation(
+ output: Seq[Attribute],
+ batches: Array[Array[Byte]],
+ mode: BroadcastMode)
extends BuildSideRelation {
+ private def transformProjection: UnsafeProjection = {
+ mode match {
+ case HashedRelationBroadcastMode(k, _) => UnsafeProjection.create(k)
+ case IdentityBroadcastMode => UnsafeProjection.create(output, output)
+ }
+ }
+
override def deserialized: Iterator[ColumnarBatch] = {
val runtime =
Runtimes.contextInstance(BackendsApiManager.getBackendName,
"BuildSideRelation#deserialized")
@@ -84,8 +97,11 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute],
batches: Array[Arra
override def asReadOnlyCopy(): ColumnarBuildSideRelation = this
/**
- * Transform columnar broadcast value to Array[InternalRow] by key and
distinct. NOTE: This method
- * was called in Spark Driver, should manage resources carefully.
+ * Transform columnar broadcast value to Array[InternalRow] by key.
+ *
+ * NOTE:
+ * - This method was called in Spark Driver, should manage resources
carefully.
+ * - The "key" must be already been bound reference.
*/
override def transform(key: Expression): Array[InternalRow] =
TaskResources.runUnsafe {
val runtime =
@@ -106,17 +122,7 @@ case class ColumnarBuildSideRelation(output:
Seq[Attribute], batches: Array[Arra
var closed = false
- val exprIds = output.map(_.exprId)
- val projExpr = key.transformDown {
- case attr: AttributeReference if !exprIds.contains(attr.exprId) =>
- val i = output.count(_.name == attr.name)
- if (i != 1) {
- throw new IllegalArgumentException(s"Only one attr with the same
name is supported: $key")
- } else {
- output.find(_.name == attr.name).get
- }
- }
- val proj = UnsafeProjection.create(Seq(projExpr), output)
+ val proj = UnsafeProjection.create(Seq(key))
// Convert columnar to Row.
val jniWrapper = NativeColumnarToRowJniWrapper.create(runtime)
@@ -178,7 +184,7 @@ case class ColumnarBuildSideRelation(output:
Seq[Attribute], batches: Array[Arra
rowId += 1
row
}
- }.map(proj).map(_.copy())
+ }.map(transformProjection).map(proj).map(_.copy())
}
}
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala
index 6275fbb3aa..12280cc42a 100644
---
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala
+++
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution.joins.{BuildSideRelation,
HashedRelation, HashJoin, LongHashedRelation}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.types.IntegralType
import org.apache.spark.util.ThreadUtils
import scala.concurrent.Future
@@ -64,6 +65,14 @@ case class ColumnarSubqueryBroadcastExec(
copy(name = "native-dpp", buildKeys = keys, child = child.canonicalized)
}
+ // Copy from
org.apache.spark.sql.execution.joins.HashJoin#canRewriteAsLongType
+ // we should keep consistent with it to identify the LongHashRelation.
+ private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = {
+ // TODO: support BooleanType, DateType and TimestampType
+ keys.forall(_.dataType.isInstanceOf[IntegralType]) &&
+ keys.map(_.dataType.defaultSize).sum <= 8
+ }
+
@transient
private lazy val relationFuture: Future[Array[InternalRow]] = {
// relationFuture is used in "doExecute". Therefore we can get the
execution id correctly here.
@@ -78,7 +87,13 @@ case class ColumnarSubqueryBroadcastExec(
relation match {
case b: BuildSideRelation =>
// Transform columnar broadcast value to Array[InternalRow] by
key.
- b.transform(buildKeys(index)).distinct
+ if (canRewriteAsLongType(buildKeys)) {
+ b.transform(HashJoin.extractKeyExprAt(buildKeys,
index)).distinct
+ } else {
+ b.transform(
+ BoundReference(index, buildKeys(index).dataType,
buildKeys(index).nullable))
+ .distinct
+ }
case h: HashedRelation =>
val (iter, expr) = if (h.isInstanceOf[LongHashedRelation]) {
(h.keys(), HashJoin.extractKeyExprAt(buildKeys, index))
diff --git
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala
index 60f3e2ffd9..e9dbeb560c 100644
---
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala
+++
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.vectorized.ColumnarBatch
trait BuildSideRelation extends Serializable {
@@ -26,11 +27,19 @@ trait BuildSideRelation extends Serializable {
def deserialized: Iterator[ColumnarBatch]
/**
- * Transform columnar broadcasted value to Array[InternalRow] by key and
distinct.
+ * Transform columnar broadcasted value to Array[InternalRow] by key.
* @return
*/
def transform(key: Expression): Array[InternalRow]
/** Returns a read-only copy of this, to be safely used in current thread. */
def asReadOnlyCopy(): BuildSideRelation
+
+ /**
+ * The broadcast mode that is associated with this relation in Gluten allows
for direct
+ * broadcasting of the original relation, so transforming a relation has a
post-processing nature.
+ *
+ * Post-processed relation transforms can use this mode to obtain the
desired format.
+ */
+ val mode: BroadcastMode
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]