This is an automated email from the ASF dual-hosted git repository.
philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 3099799e02 [GLUTEN-7749][VL] Trim ISOControl characters when casting
string to integral type (#7806)
3099799e02 is described below
commit 3099799e023bdd85d53fcf1e95e9a0e661adb24f
Author: Zhen Wang <[email protected]>
AuthorDate: Wed Nov 6 11:44:15 2024 +0800
[GLUTEN-7749][VL] Trim ISOControl characters when casting string to
integral type (#7806)
---
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 43 ++++++------
.../apache/spark/sql/GlutenDataFrameSuite.scala | 80 +++++++++++++---------
2 files changed, 67 insertions(+), 56 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 81564a4401..7901374f6b 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -723,30 +723,29 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
val trimParaSepStr = "\u2029"
// Needs to be trimmed for casting to float/double/decimal
val trimSpaceStr = ('\u0000' to '\u0020').toList.mkString
+ // ISOControl characters, refer java.lang.Character.isISOControl(int)
+ val isoControlStr = (('\u0000' to '\u001F') ++ ('\u007F' to
'\u009F')).toList.mkString
// scalastyle:on nonascii
- c.dataType match {
- case BinaryType | _: ArrayType | _: MapType | _: StructType | _:
UserDefinedType[_] =>
- c
- case FloatType | DoubleType | _: DecimalType =>
- c.child.dataType match {
- case StringType if GlutenConfig.getConf.castFromVarcharAddTrimNode =>
- val trimNode = StringTrim(c.child, Some(Literal(trimSpaceStr)))
- c.withNewChildren(Seq(trimNode)).asInstanceOf[Cast]
- case _ =>
- c
- }
- case _ =>
- c.child.dataType match {
- case StringType if GlutenConfig.getConf.castFromVarcharAddTrimNode =>
- val trimNode = StringTrim(
- c.child,
- Some(
- Literal(trimWhitespaceStr +
- trimSpaceSepStr + trimLineSepStr + trimParaSepStr)))
- c.withNewChildren(Seq(trimNode)).asInstanceOf[Cast]
- case _ =>
- c
+ if (GlutenConfig.getConf.castFromVarcharAddTrimNode && c.child.dataType ==
StringType) {
+ val trimStr = c.dataType match {
+ case BinaryType | _: ArrayType | _: MapType | _: StructType | _:
UserDefinedType[_] =>
+ None
+ case FloatType | DoubleType | _: DecimalType =>
+ Some(trimSpaceStr)
+ case _ =>
+ Some(
+ (trimWhitespaceStr + trimSpaceSepStr + trimLineSepStr
+ + trimParaSepStr + isoControlStr).toSet.mkString
+ )
+ }
+ trimStr
+ .map {
+ trim =>
+ c.withNewChildren(Seq(StringTrim(c.child,
Some(Literal(trim))))).asInstanceOf[Cast]
}
+ .getOrElse(c)
+ } else {
+ c
}
}
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSuite.scala
index 4008f862e1..3b2db7117f 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSuite.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSuite.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.sql
+import org.apache.gluten.GlutenConfig
import org.apache.gluten.execution.{ProjectExecTransformer,
WholeStageTransformer}
import org.apache.spark.SparkException
@@ -323,41 +324,52 @@ class GlutenDataFrameSuite extends DataFrameSuite with
GlutenSQLTestsTrait {
}
testGluten("Allow leading/trailing whitespace in string before casting") {
- def checkResult(df: DataFrame, expectedResult: Seq[Row]): Unit = {
- checkAnswer(df, expectedResult)
-
assert(find(df.queryExecution.executedPlan)(_.isInstanceOf[ProjectExecTransformer]).isDefined)
- }
+ withSQLConf(GlutenConfig.CAST_FROM_VARCHAR_ADD_TRIM_NODE.key -> "true") {
+ def checkResult(df: DataFrame, expectedResult: Seq[Row]): Unit = {
+ checkAnswer(df, expectedResult)
+ assert(
+
find(df.queryExecution.executedPlan)(_.isInstanceOf[ProjectExecTransformer]).isDefined)
+ }
- // scalastyle:off nonascii
- Seq(" 123", "123 ", " 123 ", "\u2000123\n\n\n", "123\r\r\r", "123\f\f\f",
"123\u000C")
- .toDF("col1")
- .createOrReplaceTempView("t1")
- // scalastyle:on nonascii
- val expectedIntResult = Row(123) :: Row(123) ::
- Row(123) :: Row(123) :: Row(123) :: Row(123) :: Row(123) :: Nil
- var df = spark.sql("select cast(col1 as int) from t1")
- checkResult(df, expectedIntResult)
- df = spark.sql("select cast(col1 as long) from t1")
- checkResult(df, expectedIntResult)
-
- Seq(" 123.5", "123.5 ", " 123.5 ", "123.5\n\n\n", "123.5\r\r\r",
"123.5\f\f\f", "123.5\u000C")
- .toDF("col1")
- .createOrReplaceTempView("t1")
- val expectedFloatResult = Row(123.5) :: Row(123.5) ::
- Row(123.5) :: Row(123.5) :: Row(123.5) :: Row(123.5) :: Row(123.5) :: Nil
- df = spark.sql("select cast(col1 as float) from t1")
- checkResult(df, expectedFloatResult)
- df = spark.sql("select cast(col1 as double) from t1")
- checkResult(df, expectedFloatResult)
-
- // scalastyle:off nonascii
- val rawData =
- Seq(" abc", "abc ", " abc ", "\u2000abc\n\n\n", "abc\r\r\r",
"abc\f\f\f", "abc\u000C")
- // scalastyle:on nonascii
- rawData.toDF("col1").createOrReplaceTempView("t1")
- val expectedBinaryResult = rawData.map(d =>
Row(d.getBytes(StandardCharsets.UTF_8))).seq
- df = spark.sql("select cast(col1 as binary) from t1")
- checkResult(df, expectedBinaryResult)
+ // scalastyle:off nonascii
+ Seq(
+ " 123",
+ "123 ",
+ " 123 ",
+ "\u2000123\n\n\n",
+ "123\r\r\r",
+ "123\f\f\f",
+ "123\u000C",
+ "123\u0000")
+ .toDF("col1")
+ .createOrReplaceTempView("t1")
+ // scalastyle:on nonascii
+ val expectedIntResult = Row(123) :: Row(123) ::
+ Row(123) :: Row(123) :: Row(123) :: Row(123) :: Row(123) :: Row(123)
:: Nil
+ var df = spark.sql("select cast(col1 as int) from t1")
+ checkResult(df, expectedIntResult)
+ df = spark.sql("select cast(col1 as long) from t1")
+ checkResult(df, expectedIntResult)
+
+ Seq(" 123.5", "123.5 ", " 123.5 ", "123.5\n\n\n", "123.5\r\r\r",
"123.5\f\f\f", "123.5\u000C")
+ .toDF("col1")
+ .createOrReplaceTempView("t1")
+ val expectedFloatResult = Row(123.5) :: Row(123.5) ::
+ Row(123.5) :: Row(123.5) :: Row(123.5) :: Row(123.5) :: Row(123.5) ::
Nil
+ df = spark.sql("select cast(col1 as float) from t1")
+ checkResult(df, expectedFloatResult)
+ df = spark.sql("select cast(col1 as double) from t1")
+ checkResult(df, expectedFloatResult)
+
+ // scalastyle:off nonascii
+ val rawData =
+ Seq(" abc", "abc ", " abc ", "\u2000abc\n\n\n", "abc\r\r\r",
"abc\f\f\f", "abc\u000C")
+ // scalastyle:on nonascii
+ rawData.toDF("col1").createOrReplaceTempView("t1")
+ val expectedBinaryResult = rawData.map(d =>
Row(d.getBytes(StandardCharsets.UTF_8))).seq
+ df = spark.sql("select cast(col1 as binary) from t1")
+ checkResult(df, expectedBinaryResult)
+ }
}
private def withExpr(newExpr: Expression): Column = new Column(newExpr)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]