This is an automated email from the ASF dual-hosted git repository.

philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 3099799e02 [GLUTEN-7749][VL] Trim ISOControl characters when casting 
string to integral type (#7806)
3099799e02 is described below

commit 3099799e023bdd85d53fcf1e95e9a0e661adb24f
Author: Zhen Wang <[email protected]>
AuthorDate: Wed Nov 6 11:44:15 2024 +0800

    [GLUTEN-7749][VL] Trim ISOControl characters when casting string to 
integral type (#7806)
---
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  | 43 ++++++------
 .../apache/spark/sql/GlutenDataFrameSuite.scala    | 80 +++++++++++++---------
 2 files changed, 67 insertions(+), 56 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 81564a4401..7901374f6b 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -723,30 +723,29 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
     val trimParaSepStr = "\u2029"
     // Needs to be trimmed for casting to float/double/decimal
     val trimSpaceStr = ('\u0000' to '\u0020').toList.mkString
+    // ISOControl characters, refer java.lang.Character.isISOControl(int)
+    val isoControlStr = (('\u0000' to '\u001F') ++ ('\u007F' to 
'\u009F')).toList.mkString
     // scalastyle:on nonascii
-    c.dataType match {
-      case BinaryType | _: ArrayType | _: MapType | _: StructType | _: 
UserDefinedType[_] =>
-        c
-      case FloatType | DoubleType | _: DecimalType =>
-        c.child.dataType match {
-          case StringType if GlutenConfig.getConf.castFromVarcharAddTrimNode =>
-            val trimNode = StringTrim(c.child, Some(Literal(trimSpaceStr)))
-            c.withNewChildren(Seq(trimNode)).asInstanceOf[Cast]
-          case _ =>
-            c
-        }
-      case _ =>
-        c.child.dataType match {
-          case StringType if GlutenConfig.getConf.castFromVarcharAddTrimNode =>
-            val trimNode = StringTrim(
-              c.child,
-              Some(
-                Literal(trimWhitespaceStr +
-                  trimSpaceSepStr + trimLineSepStr + trimParaSepStr)))
-            c.withNewChildren(Seq(trimNode)).asInstanceOf[Cast]
-          case _ =>
-            c
+    if (GlutenConfig.getConf.castFromVarcharAddTrimNode && c.child.dataType == 
StringType) {
+      val trimStr = c.dataType match {
+        case BinaryType | _: ArrayType | _: MapType | _: StructType | _: 
UserDefinedType[_] =>
+          None
+        case FloatType | DoubleType | _: DecimalType =>
+          Some(trimSpaceStr)
+        case _ =>
+          Some(
+            (trimWhitespaceStr + trimSpaceSepStr + trimLineSepStr
+              + trimParaSepStr + isoControlStr).toSet.mkString
+          )
+      }
+      trimStr
+        .map {
+          trim =>
+            c.withNewChildren(Seq(StringTrim(c.child, 
Some(Literal(trim))))).asInstanceOf[Cast]
         }
+        .getOrElse(c)
+    } else {
+      c
     }
   }
 
diff --git 
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSuite.scala
 
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSuite.scala
index 4008f862e1..3b2db7117f 100644
--- 
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSuite.scala
+++ 
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSuite.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.spark.sql
 
+import org.apache.gluten.GlutenConfig
 import org.apache.gluten.execution.{ProjectExecTransformer, 
WholeStageTransformer}
 
 import org.apache.spark.SparkException
@@ -323,41 +324,52 @@ class GlutenDataFrameSuite extends DataFrameSuite with 
GlutenSQLTestsTrait {
   }
 
   testGluten("Allow leading/trailing whitespace in string before casting") {
-    def checkResult(df: DataFrame, expectedResult: Seq[Row]): Unit = {
-      checkAnswer(df, expectedResult)
-      
assert(find(df.queryExecution.executedPlan)(_.isInstanceOf[ProjectExecTransformer]).isDefined)
-    }
+    withSQLConf(GlutenConfig.CAST_FROM_VARCHAR_ADD_TRIM_NODE.key -> "true") {
+      def checkResult(df: DataFrame, expectedResult: Seq[Row]): Unit = {
+        checkAnswer(df, expectedResult)
+        assert(
+          
find(df.queryExecution.executedPlan)(_.isInstanceOf[ProjectExecTransformer]).isDefined)
+      }
 
-    // scalastyle:off nonascii
-    Seq(" 123", "123 ", " 123 ", "\u2000123\n\n\n", "123\r\r\r", "123\f\f\f", 
"123\u000C")
-      .toDF("col1")
-      .createOrReplaceTempView("t1")
-    // scalastyle:on nonascii
-    val expectedIntResult = Row(123) :: Row(123) ::
-      Row(123) :: Row(123) :: Row(123) :: Row(123) :: Row(123) :: Nil
-    var df = spark.sql("select cast(col1 as int) from t1")
-    checkResult(df, expectedIntResult)
-    df = spark.sql("select cast(col1 as long) from t1")
-    checkResult(df, expectedIntResult)
-
-    Seq(" 123.5", "123.5 ", " 123.5 ", "123.5\n\n\n", "123.5\r\r\r", 
"123.5\f\f\f", "123.5\u000C")
-      .toDF("col1")
-      .createOrReplaceTempView("t1")
-    val expectedFloatResult = Row(123.5) :: Row(123.5) ::
-      Row(123.5) :: Row(123.5) :: Row(123.5) :: Row(123.5) :: Row(123.5) :: Nil
-    df = spark.sql("select cast(col1 as float) from t1")
-    checkResult(df, expectedFloatResult)
-    df = spark.sql("select cast(col1 as double) from t1")
-    checkResult(df, expectedFloatResult)
-
-    // scalastyle:off nonascii
-    val rawData =
-      Seq(" abc", "abc ", " abc ", "\u2000abc\n\n\n", "abc\r\r\r", 
"abc\f\f\f", "abc\u000C")
-    // scalastyle:on nonascii
-    rawData.toDF("col1").createOrReplaceTempView("t1")
-    val expectedBinaryResult = rawData.map(d => 
Row(d.getBytes(StandardCharsets.UTF_8))).seq
-    df = spark.sql("select cast(col1 as binary) from t1")
-    checkResult(df, expectedBinaryResult)
+      // scalastyle:off nonascii
+      Seq(
+        " 123",
+        "123 ",
+        " 123 ",
+        "\u2000123\n\n\n",
+        "123\r\r\r",
+        "123\f\f\f",
+        "123\u000C",
+        "123\u0000")
+        .toDF("col1")
+        .createOrReplaceTempView("t1")
+      // scalastyle:on nonascii
+      val expectedIntResult = Row(123) :: Row(123) ::
+        Row(123) :: Row(123) :: Row(123) :: Row(123) :: Row(123) :: Row(123) 
:: Nil
+      var df = spark.sql("select cast(col1 as int) from t1")
+      checkResult(df, expectedIntResult)
+      df = spark.sql("select cast(col1 as long) from t1")
+      checkResult(df, expectedIntResult)
+
+      Seq(" 123.5", "123.5 ", " 123.5 ", "123.5\n\n\n", "123.5\r\r\r", 
"123.5\f\f\f", "123.5\u000C")
+        .toDF("col1")
+        .createOrReplaceTempView("t1")
+      val expectedFloatResult = Row(123.5) :: Row(123.5) ::
+        Row(123.5) :: Row(123.5) :: Row(123.5) :: Row(123.5) :: Row(123.5) :: 
Nil
+      df = spark.sql("select cast(col1 as float) from t1")
+      checkResult(df, expectedFloatResult)
+      df = spark.sql("select cast(col1 as double) from t1")
+      checkResult(df, expectedFloatResult)
+
+      // scalastyle:off nonascii
+      val rawData =
+        Seq(" abc", "abc ", " abc ", "\u2000abc\n\n\n", "abc\r\r\r", 
"abc\f\f\f", "abc\u000C")
+      // scalastyle:on nonascii
+      rawData.toDF("col1").createOrReplaceTempView("t1")
+      val expectedBinaryResult = rawData.map(d => 
Row(d.getBytes(StandardCharsets.UTF_8))).seq
+      df = spark.sql("select cast(col1 as binary) from t1")
+      checkResult(df, expectedBinaryResult)
+    }
   }
 
   private def withExpr(newExpr: Expression): Column = new Column(newExpr)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to