This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9fd575ae46f [SPARK-46274][SQL] Fix Range operator computeStats() to
check long validity before converting
9fd575ae46f is described below
commit 9fd575ae46f8a4dbd7da18887a44c693d8788332
Author: Nick Young <[email protected]>
AuthorDate: Wed Dec 6 15:20:19 2023 -0800
[SPARK-46274][SQL] Fix Range operator computeStats() to check long validity
before converting
### What changes were proposed in this pull request?
Range operator's `computeStats()` function unsafely casts from `BigInt` to
`Long` and causes issues downstream with statistics estimation. Adds bounds
checking to avoid crashing.
### Why are the changes needed?
Downstream statistics estimation will crash and fail loudly; to avoid this
and help maintain clean code we should fix this.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
UT
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44191 from n-young-db/range-compute-stats.
Authored-by: Nick Young <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../catalyst/plans/logical/basicLogicalOperators.scala | 12 +++++++-----
.../statsEstimation/BasicStatsEstimationSuite.scala | 16 ++++++++++++++++
2 files changed, 23 insertions(+), 5 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index c66ead30ab3..497f485b67f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1083,10 +1083,12 @@ case class Range(
if (numElements == 0) {
Statistics(sizeInBytes = 0, rowCount = Some(0))
} else {
- val (minVal, maxVal) = if (step > 0) {
- (start, start + (numElements - 1) * step)
+ val (minVal, maxVal) = if (!numElements.isValidLong) {
+ (None, None)
+ } else if (step > 0) {
+ (Some(start), Some(start + (numElements.toLong - 1) * step))
} else {
- (start + (numElements - 1) * step, start)
+ (Some(start + (numElements.toLong - 1) * step), Some(start))
}
val histogram = if (conf.histogramEnabled) {
@@ -1097,8 +1099,8 @@ case class Range(
val colStat = ColumnStat(
distinctCount = Some(numElements),
- max = Some(maxVal),
- min = Some(minVal),
+ max = maxVal,
+ min = minVal,
nullCount = Some(0),
avgLen = Some(LongType.defaultSize),
maxLen = Some(LongType.defaultSize),
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
index e421d5f3929..63410447948 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
@@ -177,6 +177,22 @@ class BasicStatsEstimationSuite extends PlanTest with
StatsEstimationTestBase {
expectedStatsCboOff = rangeStats, extraConfig)
}
+test("range with invalid long value") {
+ val numElements = BigInt(Long.MaxValue) - BigInt(Long.MinValue)
+ val range = Range(Long.MinValue, Long.MaxValue, 1, None)
+ val rangeAttrs = AttributeMap(range.output.map(attr =>
+ (attr, ColumnStat(
+ distinctCount = Some(numElements),
+ nullCount = Some(0),
+ maxLen = Some(LongType.defaultSize),
+ avgLen = Some(LongType.defaultSize)))))
+ val rangeStats = Statistics(
+ sizeInBytes = numElements * 8,
+ rowCount = Some(numElements),
+ attributeStats = rangeAttrs)
+ checkStats(range, rangeStats, rangeStats)
+}
+
test("windows") {
val windows = plan.window(Seq(min(attribute).as("sum_attr")),
Seq(attribute), Nil)
val windowsStats = Statistics(sizeInBytes = plan.size.get * (4 + 4 + 8) /
(4 + 8))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]