This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.1 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push: new 687494a [SPARK-38286][SQL] Union's maxRows and maxRowsPerPartition may overflow 687494a is described below commit 687494a65b769c55ea3e33f09bdafccf06802fa7 Author: Ruifeng Zheng <ruife...@foxmail.com> AuthorDate: Thu Feb 24 10:49:52 2022 +0800 [SPARK-38286][SQL] Union's maxRows and maxRowsPerPartition may overflow check Union's maxRows and maxRowsPerPartition Union's maxRows and maxRowsPerPartition may overflow: case 1: ``` scala> val df1 = spark.range(0, Long.MaxValue, 1, 1) df1: org.apache.spark.sql.Dataset[Long] = [id: bigint] scala> val df2 = spark.range(0, 100, 1, 10) df2: org.apache.spark.sql.Dataset[Long] = [id: bigint] scala> val union = df1.union(df2) union: org.apache.spark.sql.Dataset[Long] = [id: bigint] scala> union.queryExecution.logical.maxRowsPerPartition res19: Option[Long] = Some(-9223372036854775799) scala> union.queryExecution.logical.maxRows res20: Option[Long] = Some(-9223372036854775709) ``` case 2: ``` scala> val n = 2000000 n: Int = 2000000 scala> val df1 = spark.range(0, n, 1, 1).selectExpr("id % 5 as key1", "id as value1") df1: org.apache.spark.sql.DataFrame = [key1: bigint, value1: bigint] scala> val df2 = spark.range(0, n, 1, 2).selectExpr("id % 3 as key2", "id as value2") df2: org.apache.spark.sql.DataFrame = [key2: bigint, value2: bigint] scala> val df3 = spark.range(0, n, 1, 3).selectExpr("id % 4 as key3", "id as value3") df3: org.apache.spark.sql.DataFrame = [key3: bigint, value3: bigint] scala> val joined = df1.join(df2, col("key1") === col("key2")).join(df3, col("key1") === col("key3")) joined: org.apache.spark.sql.DataFrame = [key1: bigint, value1: bigint ... 4 more fields] scala> val unioned = joined.select(col("key1"), col("value3")).union(joined.select(col("key1"), col("value2"))) unioned: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [key1: bigint, value3: bigint] scala> unioned.queryExecution.optimizedPlan.maxRows res32: Option[Long] = Some(-2446744073709551616) scala> unioned.queryExecution.optimizedPlan.maxRows res33: Option[Long] = Some(-2446744073709551616) ``` No added testsuite Closes #35609 from zhengruifeng/union_maxRows_validate. Authored-by: Ruifeng Zheng <ruife...@foxmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 683bc46ff9a791ab6b9cd3cb95be6bbc368121e0) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../plans/logical/basicLogicalOperators.scala | 30 ++++++++++++++++------ .../sql/catalyst/plans/LogicalPlanSuite.scala | 9 +++++++ 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 5149ec9..1d8d4e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -235,22 +235,36 @@ case class Union( assert(!allowMissingCol || byName, "`allowMissingCol` can be true only if `byName` is true.") override def maxRows: Option[Long] = { - if (children.exists(_.maxRows.isEmpty)) { - None - } else { - Some(children.flatMap(_.maxRows).sum) + var sum = BigInt(0) + children.foreach { child => + if (child.maxRows.isDefined) { + sum += child.maxRows.get + if (!sum.isValidLong) { + return None + } + } else { + return None + } } + Some(sum.toLong) } /** * Note the definition has assumption about how union is implemented physically. */ override def maxRowsPerPartition: Option[Long] = { - if (children.exists(_.maxRowsPerPartition.isEmpty)) { - None - } else { - Some(children.flatMap(_.maxRowsPerPartition).sum) + var sum = BigInt(0) + children.foreach { child => + if (child.maxRowsPerPartition.isDefined) { + sum += child.maxRowsPerPartition.get + if (!sum.isValidLong) { + return None + } + } else { + return None + } } + Some(sum.toLong) } def duplicateResolved: Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 8445239..630abca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType @@ -93,4 +94,12 @@ class LogicalPlanSuite extends SparkFunSuite { OneRowRelation()) assert(result.sameResult(expected)) } + + test("SPARK-38286: Union's maxRows and maxRowsPerPartition may overflow") { + val query1 = Range(0, Long.MaxValue, 1, 1) + val query2 = Range(0, 100, 1, 10) + val query = query1.union(query2) + assert(query.maxRows.isEmpty) + assert(query.maxRowsPerPartition.isEmpty) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org