This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 22e5938aefc7 [SPARK-46946][SQL] Supporting broadcast of multiple 
filtering keys in DynamicPruning
22e5938aefc7 is described below

commit 22e5938aefc784f50218a86e013e4c2247271072
Author: Thang Long VU <long...@databricks.com>
AuthorDate: Fri Feb 2 19:55:34 2024 +0800

    [SPARK-46946][SQL] Supporting broadcast of multiple filtering keys in 
DynamicPruning
    
    ### What changes were proposed in this pull request?
    
    This PR extends `DynamicPruningSubquery` to support broadcasting of 
multiple filtering keys (instead of one as before). The majority of the PR is 
to simply generalise singularity to plurality.
    
    **Note:** We actually do not use the multiple filtering keys 
`DynamicPruningSubquery` in this PR, we are doing this to make supporting DPP 
Null Safe Equality or multiple Equality predicates easier in the future.
    
    In Null Safe Equality JOIN, the JOIN condition `a <=> b` is transformed to 
`Coalesce(key1, Literal(key1.dataType)) = Coalesce(key2, 
Literal(key2.dataType)) AND IsNull(key1) = IsNull(key2)`. In order to have the 
highest pruning efficiency, we broadcast the 2 keys `Coalesce(key, 
Literal(key.dataType))` and `IsNull(key)` and use them to prune the other side 
at the same time.
    
    Before, the `DynamicPruningSubquery` only has one broadcasting key and we 
only supports DPP for one `EqualTo` JOIN predicate, now we are extending the 
subquery to multiple broadcasting keys. Please note that DPP has not been 
supported for multiple JOIN predicates.
    
    Put it in another way, at the moment, we don't insert a DPP Filter for 
multiple JOIN predicates at the same time, only potentially insert a DPP Filter 
for a given Equality JOIN predicate.
    
    ### Why are the changes needed?
    
    To make supporting DPP Null Safe Equality or DPP multiple Equality 
predicates easier in the future.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44988 from longvu-db/multiple-broadcast-filtering-keys.
    
    Authored-by: Thang Long VU <long...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/expressions/DynamicPruning.scala  | 12 +--
 .../expressions/DynamicPruningSubquerySuite.scala  | 89 ++++++++++++++++++++++
 .../execution/SubqueryAdaptiveBroadcastExec.scala  |  2 +-
 .../sql/execution/SubqueryBroadcastExec.scala      | 37 ++++-----
 .../PlanAdaptiveDynamicPruningFilters.scala        |  8 +-
 .../adaptive/PlanAdaptiveSubqueries.scala          |  4 +-
 .../dynamicpruning/PartitionPruning.scala          | 15 ++--
 .../dynamicpruning/PlanDynamicPruningFilters.scala |  9 ++-
 .../spark/sql/DynamicPartitionPruningSuite.scala   |  2 +-
 9 files changed, 138 insertions(+), 40 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
index ec6925eaa984..cc24a982d5d8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala
@@ -37,13 +37,13 @@ trait DynamicPruning extends Predicate
  *  beneficial and so it should be executed even if it cannot reuse the 
results of the
  *  broadcast through ReuseExchange; otherwise, it will use the filter only if 
it
  *  can reuse the results of the broadcast through ReuseExchange
- * @param broadcastKeyIndex the index of the filtering key collected from the 
broadcast
+ * @param broadcastKeyIndices the indices of the filtering keys collected from 
the broadcast
  */
 case class DynamicPruningSubquery(
     pruningKey: Expression,
     buildQuery: LogicalPlan,
     buildKeys: Seq[Expression],
-    broadcastKeyIndex: Int,
+    broadcastKeyIndices: Seq[Int],
     onlyInBroadcast: Boolean,
     exprId: ExprId = NamedExpression.newExprId,
     hint: Option[HintInfo] = None)
@@ -67,10 +67,12 @@ case class DynamicPruningSubquery(
       buildQuery.resolved &&
       buildKeys.nonEmpty &&
       buildKeys.forall(_.resolved) &&
-      broadcastKeyIndex >= 0 &&
-      broadcastKeyIndex < buildKeys.size &&
+      broadcastKeyIndices.forall(idx => idx >= 0 && idx < buildKeys.size) &&
       buildKeys.forall(_.references.subsetOf(buildQuery.outputSet)) &&
-      pruningKey.dataType == buildKeys(broadcastKeyIndex).dataType
+      // DynamicPruningSubquery should only have a single broadcasting key 
since
+      // there are no usage for multiple broadcasting keys at the moment.
+      broadcastKeyIndices.size == 1 &&
+      child.dataType == buildKeys(broadcastKeyIndices.head).dataType
   }
 
   final override def nodePatternsInternal(): Seq[TreePattern] = 
Seq(DYNAMIC_PRUNING_SUBQUERY)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruningSubquerySuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruningSubquerySuite.scala
new file mode 100644
index 000000000000..9d7d756019bd
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruningSubquerySuite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.types.IntegerType
+
+class DynamicPruningSubquerySuite extends SparkFunSuite {
+  private val pruningKeyExpression = Literal(1)
+
+  private val validDynamicPruningSubquery = DynamicPruningSubquery(
+    pruningKey = pruningKeyExpression,
+    buildQuery = Project(Seq(AttributeReference("id", IntegerType)()),
+      LocalRelation(AttributeReference("id", IntegerType)())),
+    buildKeys = Seq(pruningKeyExpression),
+    broadcastKeyIndices = Seq(0),
+    onlyInBroadcast = false
+  )
+
+  test("pruningKey data type matches single buildKey") {
+    val dynamicPruningSubquery = validDynamicPruningSubquery
+      .copy(buildKeys = Seq(Literal(2023)))
+    assert(dynamicPruningSubquery.resolved == true)
+  }
+
+  test("pruningKey data type is a Struct and matches with Struct buildKey") {
+    val dynamicPruningSubquery = validDynamicPruningSubquery
+      .copy(pruningKey = CreateStruct(Seq(Literal(1), Literal.FalseLiteral)),
+        buildKeys = Seq(CreateStruct(Seq(Literal(2), Literal.TrueLiteral))))
+    assert(dynamicPruningSubquery.resolved == true)
+  }
+
+  test("multiple buildKeys but only one broadcastKeyIndex") {
+    val dynamicPruningSubquery = validDynamicPruningSubquery
+      .copy(buildKeys = Seq(Literal(0), Literal(2), Literal(0), Literal(9)),
+        broadcastKeyIndices = Seq(1))
+    assert(dynamicPruningSubquery.resolved == true)
+  }
+
+  test("pruningKey data type does not match the single buildKey") {
+    val dynamicPruningSubquery = validDynamicPruningSubquery.copy(
+      pruningKey = Literal.TrueLiteral,
+      buildKeys = Seq(Literal(2013)))
+    assert(dynamicPruningSubquery.resolved == false)
+  }
+
+  test("pruningKey data type is a Struct but mismatch with Struct buildKey") {
+    val dynamicPruningSubquery = validDynamicPruningSubquery
+      .copy(pruningKey = CreateStruct(Seq(Literal(1), Literal.FalseLiteral)),
+        buildKeys = Seq(CreateStruct(Seq(Literal.TrueLiteral, Literal(2)))))
+    assert(dynamicPruningSubquery.resolved == false)
+  }
+
+  test("DynamicPruningSubquery should only have a single broadcasting key") {
+    val dynamicPruningSubquery = validDynamicPruningSubquery
+      .copy(buildKeys = Seq(Literal(2025), Literal(2), Literal(1809)),
+        broadcastKeyIndices = Seq(0, 2))
+    assert(dynamicPruningSubquery.resolved == false)
+  }
+
+  test("duplicates in broadcastKeyIndices, and also should not be allowed") {
+    val dynamicPruningSubquery = validDynamicPruningSubquery
+      .copy(buildKeys = Seq(Literal(2)),
+        broadcastKeyIndices = Seq(0, 0))
+    assert(dynamicPruningSubquery.resolved == false)
+  }
+
+  test("broadcastKeyIndex out of bounds") {
+    val dynamicPruningSubquery = validDynamicPruningSubquery
+      .copy(broadcastKeyIndices = Seq(1))
+    assert(dynamicPruningSubquery.resolved == false)
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala
index e7092ee91d76..555f4f41d3cd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryAdaptiveBroadcastExec.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
  */
 case class SubqueryAdaptiveBroadcastExec(
     name: String,
-    index: Int,
+    indices: Seq[Int],
     onlyInBroadcast: Boolean,
     @transient buildPlan: LogicalPlan,
     buildKeys: Seq[Expression],
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala
index 05657fe62e8e..9e7c1193c8ae 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala
@@ -34,32 +34,34 @@ import org.apache.spark.util.ThreadUtils
 
 /**
  * Physical plan for a custom subquery that collects and transforms the 
broadcast key values.
- * This subquery retrieves the partition key from the broadcast results based 
on the type of
- * [[HashedRelation]] returned. If the key is packed inside a Long, we extract 
it through
+ * This subquery retrieves the partition keys from the broadcast results based 
on the type of
+ * [[HashedRelation]] returned. If a key is packed inside a Long, we extract 
it through
  * bitwise operations, otherwise we return it from the appropriate index of 
the [[UnsafeRow]].
  *
- * @param index the index of the join key in the list of keys from the build 
side
+ * @param indices the indices of the join keys in the list of keys from the 
build side
  * @param buildKeys the join keys from the build side of the join used
  * @param child the BroadcastExchange or the AdaptiveSparkPlan with 
BroadcastQueryStageExec
  *              from the build side of the join
  */
 case class SubqueryBroadcastExec(
     name: String,
-    index: Int,
+    indices: Seq[Int],
     buildKeys: Seq[Expression],
     child: SparkPlan) extends BaseSubqueryExec with UnaryExecNode {
 
   // `SubqueryBroadcastExec` is only used with `InSubqueryExec`. No one would 
reference this output,
   // so the exprId doesn't matter here. But it's important to correctly report 
the output length, so
-  // that `InSubqueryExec` can know it's the single-column execution mode, not 
multi-column.
+  // that `InSubqueryExec` can know whether it's the single-column or 
multi-column execution mode.
   override def output: Seq[Attribute] = {
-    val key = buildKeys(index)
-    val name = key match {
-      case n: NamedExpression => n.name
-      case Cast(n: NamedExpression, _, _, _) => n.name
-      case _ => "key"
+    indices.map { idx =>
+      val key = buildKeys(idx)
+      val name = key match {
+        case n: NamedExpression => n.name
+        case Cast(n: NamedExpression, _, _, _) => n.name
+        case _ => s"key_$idx"
+      }
+      AttributeReference(name, key.dataType, key.nullable)()
     }
-    Seq(AttributeReference(name, key.dataType, key.nullable)())
   }
 
   override lazy val metrics = Map(
@@ -69,7 +71,7 @@ case class SubqueryBroadcastExec(
 
   override def doCanonicalize(): SparkPlan = {
     val keys = buildKeys.map(k => QueryPlan.normalizeExpressions(k, 
child.output))
-    SubqueryBroadcastExec("dpp", index, keys, child.canonicalized)
+    SubqueryBroadcastExec("dpp", indices, keys, child.canonicalized)
   }
 
   @transient
@@ -84,14 +86,15 @@ case class SubqueryBroadcastExec(
         val beforeCollect = System.nanoTime()
 
         val broadcastRelation = child.executeBroadcast[HashedRelation]().value
-        val (iter, expr) = if 
(broadcastRelation.isInstanceOf[LongHashedRelation]) {
-          (broadcastRelation.keys(), HashJoin.extractKeyExprAt(buildKeys, 
index))
+        val exprs = if (broadcastRelation.isInstanceOf[LongHashedRelation]) {
+          indices.map { idx => HashJoin.extractKeyExprAt(buildKeys, idx) }
         } else {
-          (broadcastRelation.keys(),
-            BoundReference(index, buildKeys(index).dataType, 
buildKeys(index).nullable))
+          indices.map { idx =>
+            BoundReference(idx, buildKeys(idx).dataType, 
buildKeys(idx).nullable) }
         }
 
-        val proj = UnsafeProjection.create(expr)
+        val proj = UnsafeProjection.create(exprs)
+        val iter = broadcastRelation.keys()
         val keyIter = iter.map(proj).map(_.copy())
 
         val rows = if (broadcastRelation.keyIsUnique) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala
index 9a780c11eefa..3d35abff3c53 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala
@@ -39,7 +39,7 @@ case class PlanAdaptiveDynamicPruningFilters(
     plan.transformAllExpressionsWithPruning(
       _.containsAllPatterns(DYNAMIC_PRUNING_EXPRESSION, IN_SUBQUERY_EXEC)) {
       case DynamicPruningExpression(InSubqueryExec(
-          value, SubqueryAdaptiveBroadcastExec(name, index, onlyInBroadcast, 
buildPlan, buildKeys,
+          value, SubqueryAdaptiveBroadcastExec(name, indices, onlyInBroadcast, 
buildPlan, buildKeys,
           adaptivePlan: AdaptiveSparkPlanExec), exprId, _, _, _)) =>
         val packedKeys = BindReferences.bindReferences(
           HashJoin.rewriteKeyExpr(buildKeys), adaptivePlan.executedPlan.output)
@@ -61,14 +61,14 @@ case class PlanAdaptiveDynamicPruningFilters(
           val newAdaptivePlan = adaptivePlan.copy(inputPlan = exchange)
 
           val broadcastValues = SubqueryBroadcastExec(
-            name, index, buildKeys, newAdaptivePlan)
+            name, indices, buildKeys, newAdaptivePlan)
           DynamicPruningExpression(InSubqueryExec(value, broadcastValues, 
exprId))
         } else if (onlyInBroadcast) {
           DynamicPruningExpression(Literal.TrueLiteral)
         } else {
           // we need to apply an aggregate on the buildPlan in order to be 
column pruned
-          val alias = Alias(buildKeys(index), buildKeys(index).toString)()
-          val aggregate = Aggregate(Seq(alias), Seq(alias), buildPlan)
+          val aliases = indices.map(idx => Alias(buildKeys(idx), 
buildKeys(idx).toString)())
+          val aggregate = Aggregate(aliases, aliases, buildPlan)
 
           val session = adaptivePlan.context.session
           val sparkPlan = QueryExecution.prepareExecutedPlan(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
index 7816fbd52c0a..df4d89586758 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala
@@ -47,9 +47,9 @@ case class PlanAdaptiveSubqueries(
         val subquery = SubqueryExec(s"subquery#${exprId.id}", 
subqueryMap(exprId.id))
         InSubqueryExec(expr, subquery, exprId, isDynamicPruning = false)
       case expressions.DynamicPruningSubquery(value, buildPlan,
-          buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) =>
+          buildKeys, broadcastKeyIndices, onlyInBroadcast, exprId, _) =>
         val name = s"dynamicpruning#${exprId.id}"
-        val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndex, 
onlyInBroadcast,
+        val subquery = SubqueryAdaptiveBroadcastExec(name, 
broadcastKeyIndices, onlyInBroadcast,
           buildPlan, buildKeys, subqueryMap(exprId.id))
         DynamicPruningExpression(InSubqueryExec(value, subquery, exprId))
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala
index 4e52137b7427..ef22c0ab44e4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala
@@ -103,13 +103,16 @@ object PartitionPruning extends Rule[LogicalPlan] with 
PredicateHelper with Join
   private def insertPredicate(
       pruningKey: Expression,
       pruningPlan: LogicalPlan,
-      filteringKey: Expression,
+      filteringKeys: Seq[Expression],
       filteringPlan: LogicalPlan,
       joinKeys: Seq[Expression],
       partScan: LogicalPlan): LogicalPlan = {
     val reuseEnabled = conf.exchangeReuseEnabled
-    val index = joinKeys.indexOf(filteringKey)
-    lazy val hasBenefit = pruningHasBenefit(pruningKey, partScan, 
filteringKey, filteringPlan)
+    require(filteringKeys.size == 1, "DPP Filters should only have a single 
broadcasting key " +
+      "since there are no usage for multiple broadcasting keys at the moment.")
+    val indices = Seq(joinKeys.indexOf(filteringKeys.head))
+    lazy val hasBenefit = pruningHasBenefit(
+      pruningKey, partScan, filteringKeys.head, filteringPlan)
     if (reuseEnabled || hasBenefit) {
       // insert a DynamicPruning wrapper to identify the subquery during query 
planning
       Filter(
@@ -117,7 +120,7 @@ object PartitionPruning extends Rule[LogicalPlan] with 
PredicateHelper with Join
           pruningKey,
           filteringPlan,
           joinKeys,
-          index,
+          indices,
           conf.dynamicPartitionPruningReuseBroadcastOnly || !hasBenefit),
         pruningPlan)
     } else {
@@ -255,12 +258,12 @@ object PartitionPruning extends Rule[LogicalPlan] with 
PredicateHelper with Join
             var filterableScan = getFilterableTableScan(l, left)
             if (filterableScan.isDefined && canPruneLeft(joinType) &&
                 hasPartitionPruningFilter(right)) {
-              newLeft = insertPredicate(l, newLeft, r, right, rightKeys, 
filterableScan.get)
+              newLeft = insertPredicate(l, newLeft, Seq(r), right, rightKeys, 
filterableScan.get)
             } else {
               filterableScan = getFilterableTableScan(r, right)
               if (filterableScan.isDefined && canPruneRight(joinType) &&
                   hasPartitionPruningFilter(left) ) {
-                newRight = insertPredicate(r, newRight, l, left, leftKeys, 
filterableScan.get)
+                newRight = insertPredicate(r, newRight, Seq(l), left, 
leftKeys, filterableScan.get)
               }
             }
           case _ =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
index fef92edbce64..3a08b13be013 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala
@@ -51,7 +51,7 @@ case class PlanDynamicPruningFilters(sparkSession: 
SparkSession) extends Rule[Sp
 
     
plan.transformAllExpressionsWithPruning(_.containsPattern(DYNAMIC_PRUNING_SUBQUERY))
 {
       case DynamicPruningSubquery(
-          value, buildPlan, buildKeys, broadcastKeyIndex, onlyInBroadcast, 
exprId, _) =>
+          value, buildPlan, buildKeys, broadcastKeyIndices, onlyInBroadcast, 
exprId, _) =>
         val sparkPlan = QueryExecution.createSparkPlan(
           sparkSession, sparkSession.sessionState.planner, buildPlan)
         // Using `sparkPlan` is a little hacky as it is based on the 
assumption that this rule is
@@ -73,15 +73,16 @@ case class PlanDynamicPruningFilters(sparkSession: 
SparkSession) extends Rule[Sp
           val name = s"dynamicpruning#${exprId.id}"
           // place the broadcast adaptor for reusing the broadcast results on 
the probe side
           val broadcastValues =
-            SubqueryBroadcastExec(name, broadcastKeyIndex, buildKeys, exchange)
+            SubqueryBroadcastExec(name, broadcastKeyIndices, buildKeys, 
exchange)
           DynamicPruningExpression(InSubqueryExec(value, broadcastValues, 
exprId))
         } else if (onlyInBroadcast) {
           // it is not worthwhile to execute the query, so we fall-back to a 
true literal
           DynamicPruningExpression(Literal.TrueLiteral)
         } else {
           // we need to apply an aggregate on the buildPlan in order to be 
column pruned
-          val alias = Alias(buildKeys(broadcastKeyIndex), 
buildKeys(broadcastKeyIndex).toString)()
-          val aggregate = Aggregate(Seq(alias), Seq(alias), buildPlan)
+          val aliases = broadcastKeyIndices.map(idx =>
+            Alias(buildKeys(idx), buildKeys(idx).toString)())
+          val aggregate = Aggregate(aliases, aliases, buildPlan)
           DynamicPruningExpression(expressions.InSubquery(
             Seq(value), ListQuery(aggregate, numCols = 
aggregate.output.length)))
         }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
index 50dcb9d71897..2c24cc7d570b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
@@ -246,7 +246,7 @@ abstract class DynamicPartitionPruningSuiteBase
 
     val buf = 
collectDynamicPruningExpressions(df.queryExecution.executedPlan).collect {
       case InSubqueryExec(_, b: SubqueryBroadcastExec, _, _, _, _) =>
-        b.index
+        b.indices.map(idx => b.buildKeys(idx))
     }
     assert(buf.distinct.size == n)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to