This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9fd575ae46f [SPARK-46274][SQL] Fix Range operator computeStats() to 
check long validity before converting
9fd575ae46f is described below

commit 9fd575ae46f8a4dbd7da18887a44c693d8788332
Author: Nick Young <nick.yo...@databricks.com>
AuthorDate: Wed Dec 6 15:20:19 2023 -0800

    [SPARK-46274][SQL] Fix Range operator computeStats() to check long validity 
before converting
    
    ### What changes were proposed in this pull request?
    
    Range operator's `computeStats()` function unsafely casts from `BigInt` to 
`Long` and causes issues downstream with statistics estimation. Adds bounds 
checking to avoid crashing.
    
    ### Why are the changes needed?
    
    Downstream statistics estimation will crash and fail loudly; to avoid this 
and help maintain clean code we should fix this.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44191 from n-young-db/range-compute-stats.
    
    Authored-by: Nick Young <nick.yo...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../catalyst/plans/logical/basicLogicalOperators.scala   | 12 +++++++-----
 .../statsEstimation/BasicStatsEstimationSuite.scala      | 16 ++++++++++++++++
 2 files changed, 23 insertions(+), 5 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index c66ead30ab3..497f485b67f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1083,10 +1083,12 @@ case class Range(
     if (numElements == 0) {
       Statistics(sizeInBytes = 0, rowCount = Some(0))
     } else {
-      val (minVal, maxVal) = if (step > 0) {
-        (start, start + (numElements - 1) * step)
+      val (minVal, maxVal) = if (!numElements.isValidLong) {
+        (None, None)
+      } else if (step > 0) {
+        (Some(start), Some(start + (numElements.toLong - 1) * step))
       } else {
-        (start + (numElements - 1) * step, start)
+        (Some(start + (numElements.toLong - 1) * step), Some(start))
       }
 
       val histogram = if (conf.histogramEnabled) {
@@ -1097,8 +1099,8 @@ case class Range(
 
       val colStat = ColumnStat(
         distinctCount = Some(numElements),
-        max = Some(maxVal),
-        min = Some(minVal),
+        max = maxVal,
+        min = minVal,
         nullCount = Some(0),
         avgLen = Some(LongType.defaultSize),
         maxLen = Some(LongType.defaultSize),
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
index e421d5f3929..63410447948 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
@@ -177,6 +177,22 @@ class BasicStatsEstimationSuite extends PlanTest with 
StatsEstimationTestBase {
         expectedStatsCboOff = rangeStats, extraConfig)
   }
 
+test("range with invalid long value") {
+  val numElements = BigInt(Long.MaxValue) - BigInt(Long.MinValue)
+  val range = Range(Long.MinValue, Long.MaxValue, 1, None)
+  val rangeAttrs = AttributeMap(range.output.map(attr =>
+    (attr, ColumnStat(
+      distinctCount = Some(numElements),
+      nullCount = Some(0),
+      maxLen = Some(LongType.defaultSize),
+      avgLen = Some(LongType.defaultSize)))))
+  val rangeStats = Statistics(
+    sizeInBytes = numElements * 8,
+    rowCount = Some(numElements),
+    attributeStats = rangeAttrs)
+  checkStats(range, rangeStats, rangeStats)
+}
+
   test("windows") {
     val windows = plan.window(Seq(min(attribute).as("sum_attr")), 
Seq(attribute), Nil)
     val windowsStats = Statistics(sizeInBytes = plan.size.get * (4 + 4 + 8) / 
(4 + 8))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to