This is an automated email from the ASF dual-hosted git repository.

philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 15f4cde02b [GLUTEN-8115][CORE] Refine the BuildSideRelation transform 
to support all scenarios (#8116)
15f4cde02b is described below

commit 15f4cde02bf965e70bae5079c40aae78ad80e45a
Author: Kaifei Yi <[email protected]>
AuthorDate: Mon Dec 9 14:24:23 2024 +0800

    [GLUTEN-8115][CORE] Refine the BuildSideRelation transform to support all 
scenarios (#8116)
---
 .../org/apache/gluten/utils/PlanNodesUtil.scala    | 49 +++++++++++-----------
 .../joins/ClickHouseBuildSideRelation.scala        | 17 ++++++--
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  |  2 +-
 .../spark/sql/execution/BroadcastUtils.scala       |  6 ++-
 .../sql/execution/ColumnarBuildSideRelation.scala  | 38 ++++++++++-------
 .../execution/ColumnarSubqueryBroadcastExec.scala  | 17 +++++++-
 .../sql/execution/joins/BuildSideRelation.scala    | 11 ++++-
 7 files changed, 90 insertions(+), 50 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala
index 9dcb7ee3c4..d6511f7a4a 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/PlanNodesUtil.scala
@@ -22,13 +22,17 @@ import org.apache.gluten.substrait.expression.ExpressionNode
 import org.apache.gluten.substrait.plan.{PlanBuilder, PlanNode}
 import org.apache.gluten.substrait.rel.RelBuilder
 
-import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, BoundReference, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
Expression}
 
 import com.google.common.collect.Lists
 
+import java.util
+
+import scala.collection.JavaConverters._
+
 object PlanNodesUtil {
 
-  def genProjectionsPlanNode(key: Expression, output: Seq[Attribute]): 
PlanNode = {
+  def genProjectionsPlanNode(key: Seq[Expression], output: Seq[Attribute]): 
PlanNode = {
     val context = new SubstraitContext
 
     var operatorId = 
context.nextOperatorId("ClickHouseBuildSideRelationReadIter")
@@ -36,41 +40,36 @@ object PlanNodesUtil {
     val nameList = ConverterUtils.collectAttributeNamesWithExprId(output)
     val readRel = RelBuilder.makeReadRelForInputIterator(typeList, nameList, 
context, operatorId)
 
-    // replace attribute to BoundRefernce according to the output
-    val newBoundRefKey = key.transformDown {
-      case expression: AttributeReference =>
-        val columnInOutput = output.zipWithIndex.filter {
-          p: (Attribute, Int) => p._1.exprId == expression.exprId || p._1.name 
== expression.name
-        }
-        if (columnInOutput.isEmpty) {
-          throw new IllegalStateException(
-            s"Key $expression not found from build side relation output: 
$output")
-        }
-        if (columnInOutput.size != 1) {
-          throw new IllegalStateException(
-            s"More than one key $expression found from build side relation 
output: $output")
-        }
-        val boundReference = columnInOutput.head
-        BoundReference(boundReference._2, boundReference._1.dataType, 
boundReference._1.nullable)
-      case other => other
-    }
-
     // project
     operatorId = 
context.nextOperatorId("ClickHouseBuildSideRelationProjection")
     val args = context.registeredFunction
 
     val columnarProjExpr = ExpressionConverter
-      .replaceWithExpressionTransformer(newBoundRefKey, attributeSeq = output)
+      .replaceWithExpressionTransformer(key, attributeSeq = output)
 
     val projExprNodeList = new java.util.ArrayList[ExpressionNode]()
-    projExprNodeList.add(columnarProjExpr.doTransform(args))
+    columnarProjExpr.foreach(e => projExprNodeList.add(e.doTransform(args)))
 
     PlanBuilder.makePlan(
       context,
       Lists.newArrayList(
         RelBuilder.makeProjectRel(readRel, projExprNodeList, context, 
operatorId, output.size)),
-      Lists.newArrayList(
-        
ConverterUtils.genColumnNameWithExprId(ConverterUtils.getAttrFromExpr(key)))
+      Lists.newArrayList(genColumnNameWithExprId(key, output))
     )
   }
+
+  private def genColumnNameWithExprId(
+      key: Seq[Expression],
+      output: Seq[Attribute]): util.List[String] = {
+    key
+      .map {
+        k =>
+          val reference = k.collectFirst { case BoundReference(ordinal, _, _) 
=> output(ordinal) }
+          assert(reference.isDefined)
+          reference.get
+      }
+      .map(ConverterUtils.genColumnNameWithExprId)
+      .toList
+      .asJava
+  }
 }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/joins/ClickHouseBuildSideRelation.scala
 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/joins/ClickHouseBuildSideRelation.scala
index 92887f16d7..668525ba0a 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/joins/ClickHouseBuildSideRelation.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/joins/ClickHouseBuildSideRelation.scala
@@ -22,8 +22,8 @@ import org.apache.gluten.vectorized._
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
UnsafeProjection}
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, 
IdentityBroadcastMode}
 import org.apache.spark.sql.execution.utils.CHExecUtil
 import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.storage.CHShuffleReadStreamFactory
@@ -72,7 +72,7 @@ case class ClickHouseBuildSideRelation(
   }
 
   /**
-   * Transform columnar broadcast value to Array[InternalRow] by key and 
distinct.
+   * Transform columnar broadcast value to Array[InternalRow] by key.
    *
    * @return
    */
@@ -80,10 +80,18 @@ case class ClickHouseBuildSideRelation(
     // native block reader
     val blockReader = new 
CHStreamReader(CHShuffleReadStreamFactory.create(batches, true))
     val broadCastIter: Iterator[ColumnarBatch] = 
IteratorUtil.createBatchIterator(blockReader)
+
+    val transformProjections = mode match {
+      case HashedRelationBroadcastMode(k, _) => k
+      case IdentityBroadcastMode => output
+    }
+
     // Expression compute, return block iterator
     val expressionEval = new SimpleExpressionEval(
       new ColumnarNativeIterator(broadCastIter.asJava),
-      PlanNodesUtil.genProjectionsPlanNode(key, output))
+      PlanNodesUtil.genProjectionsPlanNode(transformProjections, output))
+
+    val proj = UnsafeProjection.create(Seq(key))
 
     try {
       // convert columnar to row
@@ -95,6 +103,7 @@ case class ClickHouseBuildSideRelation(
           } else {
             CHExecUtil
               .getRowIterFromSparkRowInfo(block, batch.numColumns(), 
batch.numRows())
+              .map(proj)
               .map(row => row.copy())
           }
       }.toArray
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 2df2e2718e..d837ac4234 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -632,7 +632,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
     }
     numOutputRows += serialized.map(_.getNumRows).sum
     dataSize += rawSize
-    ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized))
+    ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized), 
mode)
   }
 
   override def doCanonicalizeForBroadcastMode(mode: BroadcastMode): 
BroadcastMode = {
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
index 11a8cc9809..c5323d4f8d 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala
@@ -106,7 +106,8 @@ object BroadcastUtils {
           }
           ColumnarBuildSideRelation(
             SparkShimLoader.getSparkShims.attributesFromStruct(schema),
-            serialized)
+            serialized,
+            mode)
         }
         // Rebroadcast Velox relation.
         context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
@@ -124,7 +125,8 @@ object BroadcastUtils {
           }
           ColumnarBuildSideRelation(
             SparkShimLoader.getSparkShims.attributesFromStruct(schema),
-            serialized)
+            serialized,
+            mode)
         }
         // Rebroadcast Velox relation.
         context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
index fa3d348967..977357990c 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala
@@ -26,8 +26,11 @@ import org.apache.gluten.utils.ArrowAbiUtil
 import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, 
NativeColumnarToRowJniWrapper}
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, Expression, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
+import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode
 import org.apache.spark.sql.execution.joins.BuildSideRelation
+import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.utils.SparkArrowUtil
 import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -37,9 +40,19 @@ import org.apache.arrow.c.ArrowSchema
 
 import scala.collection.JavaConverters.asScalaIteratorConverter
 
-case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: 
Array[Array[Byte]])
+case class ColumnarBuildSideRelation(
+    output: Seq[Attribute],
+    batches: Array[Array[Byte]],
+    mode: BroadcastMode)
   extends BuildSideRelation {
 
+  private def transformProjection: UnsafeProjection = {
+    mode match {
+      case HashedRelationBroadcastMode(k, _) => UnsafeProjection.create(k)
+      case IdentityBroadcastMode => UnsafeProjection.create(output, output)
+    }
+  }
+
   override def deserialized: Iterator[ColumnarBatch] = {
     val runtime =
       Runtimes.contextInstance(BackendsApiManager.getBackendName, 
"BuildSideRelation#deserialized")
@@ -84,8 +97,11 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], 
batches: Array[Arra
   override def asReadOnlyCopy(): ColumnarBuildSideRelation = this
 
   /**
-   * Transform columnar broadcast value to Array[InternalRow] by key and 
distinct. NOTE: This method
-   * was called in Spark Driver, should manage resources carefully.
+   * Transform columnar broadcast value to Array[InternalRow] by key.
+   *
+   * NOTE:
+   *   - This method was called in Spark Driver, should manage resources 
carefully.
+   *   - The "key" must be already been bound reference.
    */
   override def transform(key: Expression): Array[InternalRow] = 
TaskResources.runUnsafe {
     val runtime =
@@ -106,17 +122,7 @@ case class ColumnarBuildSideRelation(output: 
Seq[Attribute], batches: Array[Arra
 
     var closed = false
 
-    val exprIds = output.map(_.exprId)
-    val projExpr = key.transformDown {
-      case attr: AttributeReference if !exprIds.contains(attr.exprId) =>
-        val i = output.count(_.name == attr.name)
-        if (i != 1) {
-          throw new IllegalArgumentException(s"Only one attr with the same 
name is supported: $key")
-        } else {
-          output.find(_.name == attr.name).get
-        }
-    }
-    val proj = UnsafeProjection.create(Seq(projExpr), output)
+    val proj = UnsafeProjection.create(Seq(key))
 
     // Convert columnar to Row.
     val jniWrapper = NativeColumnarToRowJniWrapper.create(runtime)
@@ -178,7 +184,7 @@ case class ColumnarBuildSideRelation(output: 
Seq[Attribute], batches: Array[Arra
                 rowId += 1
                 row
               }
-            }.map(proj).map(_.copy())
+            }.map(transformProjection).map(proj).map(_.copy())
           }
         }
       }
diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala
index 6275fbb3aa..12280cc42a 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarSubqueryBroadcastExec.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.execution.joins.{BuildSideRelation, 
HashedRelation, HashJoin, LongHashedRelation}
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.types.IntegralType
 import org.apache.spark.util.ThreadUtils
 
 import scala.concurrent.Future
@@ -64,6 +65,14 @@ case class ColumnarSubqueryBroadcastExec(
     copy(name = "native-dpp", buildKeys = keys, child = child.canonicalized)
   }
 
+  // Copy from 
org.apache.spark.sql.execution.joins.HashJoin#canRewriteAsLongType
+  // we should keep consistent with it to identify the LongHashRelation.
+  private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = {
+    // TODO: support BooleanType, DateType and TimestampType
+    keys.forall(_.dataType.isInstanceOf[IntegralType]) &&
+    keys.map(_.dataType.defaultSize).sum <= 8
+  }
+
   @transient
   private lazy val relationFuture: Future[Array[InternalRow]] = {
     // relationFuture is used in "doExecute". Therefore we can get the 
execution id correctly here.
@@ -78,7 +87,13 @@ case class ColumnarSubqueryBroadcastExec(
             relation match {
               case b: BuildSideRelation =>
                 // Transform columnar broadcast value to Array[InternalRow] by 
key.
-                b.transform(buildKeys(index)).distinct
+                if (canRewriteAsLongType(buildKeys)) {
+                  b.transform(HashJoin.extractKeyExprAt(buildKeys, 
index)).distinct
+                } else {
+                  b.transform(
+                    BoundReference(index, buildKeys(index).dataType, 
buildKeys(index).nullable))
+                    .distinct
+                }
               case h: HashedRelation =>
                 val (iter, expr) = if (h.isInstanceOf[LongHashedRelation]) {
                   (h.keys(), HashJoin.extractKeyExprAt(buildKeys, index))
diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala
index 60f3e2ffd9..e9dbeb560c 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/joins/BuildSideRelation.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
 trait BuildSideRelation extends Serializable {
@@ -26,11 +27,19 @@ trait BuildSideRelation extends Serializable {
   def deserialized: Iterator[ColumnarBatch]
 
   /**
-   * Transform columnar broadcasted value to Array[InternalRow] by key and 
distinct.
+   * Transform columnar broadcasted value to Array[InternalRow] by key.
    * @return
    */
   def transform(key: Expression): Array[InternalRow]
 
   /** Returns a read-only copy of this, to be safely used in current thread. */
   def asReadOnlyCopy(): BuildSideRelation
+
+  /**
+   * The broadcast mode that is associated with this relation in Gluten allows 
for direct
+   * broadcasting of the original relation, so transforming a relation has a 
post-processing nature.
+   *
+   * Post-processed relation transforms can use this mode to obtain the 
desired format.
+   */
+  val mode: BroadcastMode
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to