This is an automated email from the ASF dual-hosted git repository.
parthc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 74b41aa57 fix: handle inf/-inf/nan in ShimSparkErrorConverter cast
overflow (#3768)
74b41aa57 is described below
commit 74b41aa57924e5dbb3d021221efdff20be734c8f
Author: Manu Zhang <[email protected]>
AuthorDate: Sat Mar 28 00:43:10 2026 +0800
fix: handle inf/-inf/nan in ShimSparkErrorConverter cast overflow (#3768)
Normalize inf/nan literals for float/double cast overflow conversion across
Spark 3.4/3.5/4.0 and add unit tests in SparkErrorConverterSuite for
float/double inf/-inf/nan.
Co-authored-by: Codex <[email protected]>
---
.github/workflows/pr_build_linux.yml | 1 +
.github/workflows/pr_build_macos.yml | 1 +
.../sql/comet/shims/ShimSparkErrorConverter.scala | 23 ++++-
.../sql/comet/shims/ShimSparkErrorConverter.scala | 23 ++++-
.../sql/comet/shims/ShimSparkErrorConverter.scala | 23 ++++-
.../apache/comet/SparkErrorConverterSuite.scala | 104 +++++++++++++++++++++
6 files changed, 169 insertions(+), 6 deletions(-)
diff --git a/.github/workflows/pr_build_linux.yml
b/.github/workflows/pr_build_linux.yml
index 899fa6139..6811d6c2b 100644
--- a/.github/workflows/pr_build_linux.yml
+++ b/.github/workflows/pr_build_linux.yml
@@ -338,6 +338,7 @@ jobs:
org.apache.comet.CometCsvExpressionSuite
org.apache.comet.CometJsonExpressionSuite
org.apache.comet.CometDateTimeUtilsSuite
+ org.apache.comet.SparkErrorConverterSuite
org.apache.comet.expressions.conditional.CometIfSuite
org.apache.comet.expressions.conditional.CometCoalesceSuite
org.apache.comet.expressions.conditional.CometCaseWhenSuite
diff --git a/.github/workflows/pr_build_macos.yml
b/.github/workflows/pr_build_macos.yml
index 53001b04e..8362a6cfb 100644
--- a/.github/workflows/pr_build_macos.yml
+++ b/.github/workflows/pr_build_macos.yml
@@ -213,6 +213,7 @@ jobs:
org.apache.comet.CometJsonExpressionSuite
org.apache.comet.CometCsvExpressionSuite
org.apache.comet.CometDateTimeUtilsSuite
+ org.apache.comet.SparkErrorConverterSuite
org.apache.comet.expressions.conditional.CometIfSuite
org.apache.comet.expressions.conditional.CometCoalesceSuite
org.apache.comet.expressions.conditional.CometCaseWhenSuite
diff --git
a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
index 6eee3f5bc..ba37f8c94 100644
---
a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
+++
b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -44,6 +44,25 @@ trait ShimSparkErrorConverter {
private def sqlCtx(context: Array[QueryContext]): SQLQueryContext =
context.headOption.map(_.asInstanceOf[SQLQueryContext]).getOrElse(null)
+ private def parseFloatLiteral(value: String): Float = {
+ value.toLowerCase match {
+ case "inf" | "+inf" | "infinity" | "+infinity" => Float.PositiveInfinity
+ case "-inf" | "-infinity" => Float.NegativeInfinity
+ case "nan" | "+nan" | "-nan" => Float.NaN
+ case _ => value.toFloat
+ }
+ }
+
+ private def parseDoubleLiteral(value: String): Double = {
+ val normalized = value.toLowerCase.stripSuffix("d")
+ normalized match {
+ case "inf" | "+inf" | "infinity" | "+infinity" => Double.PositiveInfinity
+ case "-inf" | "-infinity" => Double.NegativeInfinity
+ case "nan" | "+nan" | "-nan" => Double.NaN
+ case _ => normalized.toDouble
+ }
+ }
+
def convertErrorType(
errorType: String,
errorClass: String,
@@ -207,8 +226,8 @@ trait ShimSparkErrorConverter {
case LongType =>
val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1)
else valueStr
cleanStr.toLong
- case FloatType => valueStr.toFloat
- case DoubleType => valueStr.toDouble
+ case FloatType => parseFloatLiteral(valueStr)
+ case DoubleType => parseDoubleLiteral(valueStr)
case StringType => UTF8String.fromString(valueStr)
case _ => valueStr
}
diff --git
a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
index 75316c51e..1d140e190 100644
---
a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
+++
b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -44,6 +44,25 @@ trait ShimSparkErrorConverter {
private def sqlCtx(context: Array[QueryContext]): SQLQueryContext =
context.headOption.map(_.asInstanceOf[SQLQueryContext]).getOrElse(null)
+ private def parseFloatLiteral(value: String): Float = {
+ value.toLowerCase match {
+ case "inf" | "+inf" | "infinity" | "+infinity" => Float.PositiveInfinity
+ case "-inf" | "-infinity" => Float.NegativeInfinity
+ case "nan" | "+nan" | "-nan" => Float.NaN
+ case _ => value.toFloat
+ }
+ }
+
+ private def parseDoubleLiteral(value: String): Double = {
+ val normalized = value.toLowerCase.stripSuffix("d")
+ normalized match {
+ case "inf" | "+inf" | "infinity" | "+infinity" => Double.PositiveInfinity
+ case "-inf" | "-infinity" => Double.NegativeInfinity
+ case "nan" | "+nan" | "-nan" => Double.NaN
+ case _ => normalized.toDouble
+ }
+ }
+
def convertErrorType(
errorType: String,
errorClass: String,
@@ -205,8 +224,8 @@ trait ShimSparkErrorConverter {
case LongType =>
val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1)
else valueStr
cleanStr.toLong
- case FloatType => valueStr.toFloat
- case DoubleType => valueStr.toDouble
+ case FloatType => parseFloatLiteral(valueStr)
+ case DoubleType => parseDoubleLiteral(valueStr)
case StringType => UTF8String.fromString(valueStr)
case _ => valueStr
}
diff --git
a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
index fc13a58a4..a787fb801 100644
---
a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
+++
b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -37,6 +37,25 @@ object ShimSparkErrorConverter {
*/
trait ShimSparkErrorConverter {
+ private def parseFloatLiteral(value: String): Float = {
+ value.toLowerCase match {
+ case "inf" | "+inf" | "infinity" | "+infinity" => Float.PositiveInfinity
+ case "-inf" | "-infinity" => Float.NegativeInfinity
+ case "nan" | "+nan" | "-nan" => Float.NaN
+ case _ => value.toFloat
+ }
+ }
+
+ private def parseDoubleLiteral(value: String): Double = {
+ val normalized = value.toLowerCase.stripSuffix("d")
+ normalized match {
+ case "inf" | "+inf" | "infinity" | "+infinity" => Double.PositiveInfinity
+ case "-inf" | "-infinity" => Double.NegativeInfinity
+ case "nan" | "+nan" | "-nan" => Double.NaN
+ case _ => normalized.toDouble
+ }
+ }
+
/**
* Convert error type string and parameters to appropriate Spark exception.
Version-specific
* implementations call the correct QueryExecutionErrors.* methods.
@@ -213,8 +232,8 @@ trait ShimSparkErrorConverter {
// Strip "L" suffix for BIGINT literals
val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1)
else valueStr
cleanStr.toLong
- case FloatType => valueStr.toFloat
- case DoubleType => valueStr.toDouble
+ case FloatType => parseFloatLiteral(valueStr)
+ case DoubleType => parseDoubleLiteral(valueStr)
case StringType => UTF8String.fromString(valueStr)
case _ => valueStr // Fallback to string
}
diff --git
a/spark/src/test/scala/org/apache/comet/SparkErrorConverterSuite.scala
b/spark/src/test/scala/org/apache/comet/SparkErrorConverterSuite.scala
new file mode 100644
index 000000000..d3e2c2c64
--- /dev/null
+++ b/spark/src/test/scala/org/apache/comet/SparkErrorConverterSuite.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class SparkErrorConverterSuite extends AnyFunSuite {
+ private def castOverflowError(fromType: String, value: String): Throwable = {
+ SparkErrorConverter
+ .convertErrorType(
+ "CastOverFlow",
+ "CAST_OVERFLOW",
+ Map("fromType" -> fromType, "toType" -> "INT", "value" -> value),
+ Array.empty,
+ null)
+ .getOrElse(fail("Expected CastOverFlow to be converted to a Spark
exception"))
+ }
+
+ private def assertCastOverflowContains(
+ fromType: String,
+ value: String,
+ expectedMessagePart: String): Unit = {
+ val err = castOverflowError(fromType, value)
+ assert(
+ !err.isInstanceOf[NumberFormatException],
+ s"Unexpected parse failure for $fromType $value")
+ assert(
+ err.getMessage.contains(expectedMessagePart),
+ s"Expected '${err.getMessage}' to contain '$expectedMessagePart' for
$fromType $value")
+ }
+
+ private def assertCastOverflowContainsNaN(fromType: String, value: String):
Unit = {
+ val err = castOverflowError(fromType, value)
+ assert(
+ !err.isInstanceOf[NumberFormatException],
+ s"Unexpected parse failure for $fromType $value")
+ assert(
+ err.getMessage.toLowerCase.contains("nan"),
+ s"Expected '${err.getMessage}' to contain NaN for $fromType $value")
+ }
+
+ test("CastOverFlow conversion handles all float positive infinity literals")
{
+ Seq("inf", "+inf", "infinity", "+infinity").foreach { value =>
+ assertCastOverflowContains("FLOAT", value, "Infinity")
+ }
+ }
+
+ test("CastOverFlow conversion handles all float negative infinity literals")
{
+ Seq("-inf", "-infinity").foreach { value =>
+ assertCastOverflowContains("FLOAT", value, "-Infinity")
+ }
+ }
+
+ test("CastOverFlow conversion handles all float NaN literals") {
+ Seq("nan", "+nan", "-nan").foreach { value =>
+ assertCastOverflowContainsNaN("FLOAT", value)
+ }
+ }
+
+ test("CastOverFlow conversion handles float standard numeric literal
fallback") {
+ assertCastOverflowContains("FLOAT", "1.5", "1.5")
+ }
+
+ test("CastOverFlow conversion handles all double positive infinity
literals") {
+ Seq("inf", "infd", "+inf", "+infd", "infinity", "infinityd", "+infinity",
"+infinityd")
+ .foreach { value =>
+ assertCastOverflowContains("DOUBLE", value, "Infinity")
+ }
+ }
+
+ test("CastOverFlow conversion handles all double negative infinity
literals") {
+ Seq("-inf", "-infd", "-infinity", "-infinityd").foreach { value =>
+ assertCastOverflowContains("DOUBLE", value, "-Infinity")
+ }
+ }
+
+ test("CastOverFlow conversion handles all double NaN literals") {
+ Seq("nan", "nand", "+nan", "+nand", "-nan", "-nand").foreach { value =>
+ assertCastOverflowContainsNaN("DOUBLE", value)
+ }
+ }
+
+ test("CastOverFlow conversion handles double standard numeric literal
fallback") {
+ assertCastOverflowContains("DOUBLE", "1.5", "1.5")
+ assertCastOverflowContains("DOUBLE", "1.5d", "1.5")
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]