This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 9d54a95  [SPARK-38286][SQL] Union's maxRows and maxRowsPerPartition 
may overflow
9d54a95 is described below

commit 9d54a95d1cd64085ffb39b68931e88990e86aa07
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Feb 24 10:49:52 2022 +0800

    [SPARK-38286][SQL] Union's maxRows and maxRowsPerPartition may overflow
    
    check Union's maxRows and maxRowsPerPartition
    
    Union's maxRows and maxRowsPerPartition may overflow:
    
    case 1:
    ```
    scala> val df1 = spark.range(0, Long.MaxValue, 1, 1)
    df1: org.apache.spark.sql.Dataset[Long] = [id: bigint]
    
    scala> val df2 = spark.range(0, 100, 1, 10)
    df2: org.apache.spark.sql.Dataset[Long] = [id: bigint]
    
    scala> val union = df1.union(df2)
    union: org.apache.spark.sql.Dataset[Long] = [id: bigint]
    
    scala> union.queryExecution.logical.maxRowsPerPartition
    res19: Option[Long] = Some(-9223372036854775799)
    
    scala> union.queryExecution.logical.maxRows
    res20: Option[Long] = Some(-9223372036854775709)
    ```
    
    case 2:
    ```
    scala> val n = 2000000
    n: Int = 2000000
    
    scala> val df1 = spark.range(0, n, 1, 1).selectExpr("id % 5 as key1", "id 
as value1")
    df1: org.apache.spark.sql.DataFrame = [key1: bigint, value1: bigint]
    
    scala> val df2 = spark.range(0, n, 1, 2).selectExpr("id % 3 as key2", "id 
as value2")
    df2: org.apache.spark.sql.DataFrame = [key2: bigint, value2: bigint]
    
    scala> val df3 = spark.range(0, n, 1, 3).selectExpr("id % 4 as key3", "id 
as value3")
    df3: org.apache.spark.sql.DataFrame = [key3: bigint, value3: bigint]
    
    scala> val joined = df1.join(df2, col("key1") === col("key2")).join(df3, 
col("key1") === col("key3"))
    joined: org.apache.spark.sql.DataFrame = [key1: bigint, value1: bigint ... 
4 more fields]
    
    scala> val unioned = joined.select(col("key1"), 
col("value3")).union(joined.select(col("key1"), col("value2")))
    unioned: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [key1: 
bigint, value3: bigint]
    
    scala> unioned.queryExecution.optimizedPlan.maxRows
    res32: Option[Long] = Some(-2446744073709551616)
    
    scala> unioned.queryExecution.optimizedPlan.maxRows
    res33: Option[Long] = Some(-2446744073709551616)
    ```
    
    No
    
    added testsuite
    
    Closes #35609 from zhengruifeng/union_maxRows_validate.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
    (cherry picked from commit 683bc46ff9a791ab6b9cd3cb95be6bbc368121e0)
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../plans/logical/basicLogicalOperators.scala      | 30 ++++++++++++++++------
 .../sql/catalyst/plans/LogicalPlanSuite.scala      |  9 +++++++
 2 files changed, 31 insertions(+), 8 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 4153c84..1aa1e79 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -222,22 +222,36 @@ object Union {
  */
 case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
   override def maxRows: Option[Long] = {
-    if (children.exists(_.maxRows.isEmpty)) {
-      None
-    } else {
-      Some(children.flatMap(_.maxRows).sum)
+    var sum = BigInt(0)
+    children.foreach { child =>
+      if (child.maxRows.isDefined) {
+        sum += child.maxRows.get
+        if (!sum.isValidLong) {
+          return None
+        }
+      } else {
+        return None
+      }
     }
+    Some(sum.toLong)
   }
 
   /**
    * Note the definition has assumption about how union is implemented 
physically.
    */
   override def maxRowsPerPartition: Option[Long] = {
-    if (children.exists(_.maxRowsPerPartition.isEmpty)) {
-      None
-    } else {
-      Some(children.flatMap(_.maxRowsPerPartition).sum)
+    var sum = BigInt(0)
+    children.foreach { child =>
+      if (child.maxRowsPerPartition.isDefined) {
+        sum += child.maxRowsPerPartition.get
+        if (!sum.isValidLong) {
+          return None
+        }
+      } else {
+        return None
+      }
     }
+    Some(sum.toLong)
   }
 
   def duplicateResolved: Boolean = {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
index 8445239..630abca 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.plans
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, Literal, NamedExpression}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types.IntegerType
@@ -93,4 +94,12 @@ class LogicalPlanSuite extends SparkFunSuite {
       OneRowRelation())
     assert(result.sameResult(expected))
   }
+
+  test("SPARK-38286: Union's maxRows and maxRowsPerPartition may overflow") {
+    val query1 = Range(0, Long.MaxValue, 1, 1)
+    val query2 = Range(0, 100, 1, 10)
+    val query = query1.union(query2)
+    assert(query.maxRows.isEmpty)
+    assert(query.maxRowsPerPartition.isEmpty)
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to