This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new 7d9320085999 [SPARK-52767][SQL] Optimize maxRows and
maxRowsPerPartition for join and union
7d9320085999 is described below
commit 7d932008599927797f7e902ed10abc466675c331
Author: zml1206 <[email protected]>
AuthorDate: Tue Nov 18 22:35:43 2025 +0800
[SPARK-52767][SQL] Optimize maxRows and maxRowsPerPartition for join and
union
### What changes were proposed in this pull request?
Make the `maxRows` and `maxRowsPerPartition` only calculated at most once.
### Why are the changes needed?
Improve performance, especially when there are dozens of joins and unions.
Before pr, the number of maxRows executions of join/union increases
exponentially with the number of joins/unions.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Local test, 28 tables join before pr 36s, after pr 4s, 29 tables join
before pr 67s, after pr 5s
```
Seq(1).toDF("a").write.mode("overwrite").parquet("tmp/t1")
spark.read.parquet("tmp/t1").createOrReplaceTempView("t")
val t1 = System.currentTimeMillis()
spark.sql(
"""
|select a,count(1) from (
|select t1.a from (select distinct a from t) t1
|join t t2 on t1.a=t2.a
|join t t3 on t1.a=t3.a
|join t t4 on t1.a=t4.a
|join t t5 on t1.a=t5.a
|join t t6 on t1.a=t6.a
|join t t7 on t1.a=t7.a
|join t t8 on t1.a=t8.a
|join t t9 on t1.a=t9.a
|join t t10 on t1.a=t10.a
|join t t11 on t1.a=t11.a
|join t t12 on t1.a=t12.a
|join t t13 on t1.a=t13.a
|join t t14 on t1.a=t14.a
|join t t15 on t1.a=t15.a
|join t t16 on t1.a=t16.a
|join t t17 on t1.a=t17.a
|join t t18 on t1.a=t18.a
|join t t19 on t1.a=t19.a
|join t t20 on t1.a=t20.a
|join t t21 on t1.a=t21.a
|join t t22 on t1.a=t22.a
|join t t23 on t1.a=t23.a
|join t t24 on t1.a=t24.a
|join t t25 on t1.a=t25.a
|join t t26 on t1.a=t26.a
|join t t27 on t1.a=t27.a
|join t t28 on t1.a=t28.a
|) group by a
|""".stripMargin).show
```
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #51451 from zml1206/SPARK-52767.
Authored-by: zml1206 <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit aa387f32158a98260f7b9b16dc87feb64b504ab4)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../plans/logical/basicLogicalOperators.scala | 44 +++++++++++-----------
1 file changed, 21 insertions(+), 23 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 142420ee258a..b87d018f2ab1 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -580,19 +580,18 @@ case class Union(
allowMissingCol: Boolean = false) extends UnionBase {
assert(!allowMissingCol || byName, "`allowMissingCol` can be true only if
`byName` is true.")
- override def maxRows: Option[Long] = {
- var sum = BigInt(0)
- children.foreach { child =>
- if (child.maxRows.isDefined) {
- sum += child.maxRows.get
- if (!sum.isValidLong) {
- return None
+ override lazy val maxRows: Option[Long] = {
+ val sum = children.foldLeft(Option(BigInt(0))) {
+ case (Some(acc), child) =>
+ child.maxRows match {
+ case Some(n) =>
+ val newSum = acc + n
+ if (newSum.isValidLong) Some(newSum) else None
+ case None => None
}
- } else {
- return None
- }
+ case (None, _) => None
}
- Some(sum.toLong)
+ sum.map(_.toLong)
}
final override val nodePatterns: Seq[TreePattern] = Seq(UNION)
@@ -600,19 +599,18 @@ case class Union(
/**
* Note the definition has assumption about how union is implemented
physically.
*/
- override def maxRowsPerPartition: Option[Long] = {
- var sum = BigInt(0)
- children.foreach { child =>
- if (child.maxRowsPerPartition.isDefined) {
- sum += child.maxRowsPerPartition.get
- if (!sum.isValidLong) {
- return None
+ override lazy val maxRowsPerPartition: Option[Long] = {
+ val sum = children.foldLeft(Option(BigInt(0))) {
+ case (Some(acc), child) =>
+ child.maxRowsPerPartition match {
+ case Some(n) =>
+ val newSum = acc + n
+ if (newSum.isValidLong) Some(newSum) else None
+ case None => None
}
- } else {
- return None
- }
+ case (None, _) => None
}
- Some(sum.toLong)
+ sum.map(_.toLong)
}
private def duplicatesResolvedPerBranch: Boolean =
@@ -666,7 +664,7 @@ case class Join(
hint: JoinHint)
extends BinaryNode with PredicateHelper {
- override def maxRows: Option[Long] = {
+ override lazy val maxRows: Option[Long] = {
joinType match {
case Inner | Cross | FullOuter | LeftOuter | RightOuter | LeftSingle
if left.maxRows.isDefined && right.maxRows.isDefined =>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]