This is an automated email from the ASF dual-hosted git repository.
richox pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git
The following commit(s) were added to refs/heads/master by this push:
new b159ce7f [AURON-1316] Support trim in cast expression (#1317)
b159ce7f is described below
commit b159ce7f187bc6793c02387036c105a46d2b48fb
Author: huang qiwei <[email protected]>
AuthorDate: Sun Sep 21 23:18:46 2025 +0800
[AURON-1316] Support trim in cast expression (#1317)
* Support trim in cast expression
* re-format
* re-format
* Support new setting
* Revert "re-format"
This reverts commit 42928bd385eb537fae334cce99a1a1fdc7ab48ff.
* Revert "re-format"
This reverts commit 2d2a0295e973ecdc7473b2376ec7f3956efb1c89.
* Revert "Support trim in cast expression"
This reverts commit dddbd8a9ef24d692e1dada5f2a4d92cb868b92ed.
* Add some comments
* re-format
* re-format
* update converter
* reformat
* Update to support boolean trim
---
.../datafusion-ext-commons/src/arrow/cast.rs | 2 +
.../spark/sql/auron/NativeConvertersSuite.scala | 83 ++++++++++++++++++++++
.../java/org/apache/spark/sql/auron/AuronConf.java | 3 +
.../apache/spark/sql/auron/NativeConverters.scala | 13 +++-
4 files changed, 100 insertions(+), 1 deletion(-)
diff --git a/native-engine/datafusion-ext-commons/src/arrow/cast.rs
b/native-engine/datafusion-ext-commons/src/arrow/cast.rs
index 9f92c570..afa759fe 100644
--- a/native-engine/datafusion-ext-commons/src/arrow/cast.rs
+++ b/native-engine/datafusion-ext-commons/src/arrow/cast.rs
@@ -284,6 +284,8 @@ fn try_cast_string_array_to_date(array: &dyn Array) ->
Result<ArrayRef> {
}
// this implementation is original copied from spark UTF8String.scala
+// The original implementation included trimming logic, but it was omitted here
+// since Auron’s NativeConverters will handle trimming.
fn to_integer<T: Bounded + FromPrimitive + Integer + Signed + Copy>(input:
&str) -> Option<T> {
let bytes = input.as_bytes();
diff --git
a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/NativeConvertersSuite.scala
b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/NativeConvertersSuite.scala
new file mode 100644
index 00000000..d5a1a60c
--- /dev/null
+++
b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/NativeConvertersSuite.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType,
StringType}
+
+import org.apache.auron.protobuf.ScalarFunction
+
+class NativeConvertersSuite extends QueryTest with BaseAuronSQLSuite with
AuronSQLTestHelper {
+
+ private def assertTrimmedCast(rawValue: String, targetType: DataType): Unit
= {
+ val expr = Cast(Literal.create(rawValue, StringType), targetType)
+ val nativeExpr = NativeConverters.convertExpr(expr)
+
+ assert(nativeExpr.hasTryCast)
+ val childExpr = nativeExpr.getTryCast.getExpr
+ assert(childExpr.hasScalarFunction)
+ val scalarFn = childExpr.getScalarFunction
+ assert(scalarFn.getFun == ScalarFunction.Trim)
+ assert(scalarFn.getArgsCount == 1 && scalarFn.getArgs(0).hasLiteral)
+ }
+
+ private def assertNonTrimmedCast(rawValue: String, targetType: DataType):
Unit = {
+ val expr = Cast(Literal.create(rawValue, StringType), targetType)
+ val nativeExpr = NativeConverters.convertExpr(expr)
+
+ assert(nativeExpr.hasTryCast)
+ val childExpr = nativeExpr.getTryCast.getExpr
+ assert(!childExpr.hasScalarFunction)
+ assert(childExpr.hasLiteral)
+ }
+
+ test("cast from string to numeric adds trim wrapper before native cast when
enabled") {
+ withSQLConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "true") {
+ assertTrimmedCast(" 42 ", IntegerType)
+ }
+ }
+
+ test("cast from string to boolean adds trim wrapper before native cast when
enabled") {
+ withSQLConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "true") {
+ assertTrimmedCast(" true ", BooleanType)
+ }
+ }
+
+ test("cast trim disabled via auron conf") {
+ withEnvConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "false") {
+ assertNonTrimmedCast(" 42 ", IntegerType)
+ }
+ }
+
+ test("cast trim disabled via auron conf for boolean cast") {
+ withEnvConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "false") {
+ assertNonTrimmedCast(" true ", BooleanType)
+ }
+ }
+
+ test("cast with non-string child remains unchanged") {
+ val expr = Cast(Literal(1.5), IntegerType)
+ val nativeExpr = NativeConverters.convertExpr(expr)
+
+ assert(nativeExpr.hasTryCast)
+ val childExpr = nativeExpr.getTryCast.getExpr
+ assert(!childExpr.hasScalarFunction)
+ assert(childExpr.hasLiteral)
+ }
+}
diff --git
a/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
b/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
index a3d82ae0..d45ccefd 100644
--- a/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
+++ b/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
@@ -52,6 +52,9 @@ public enum AuronConf {
// TypedImperativeAggregate one row mem use size
UDAF_FALLBACK_ESTIM_ROW_SIZE("spark.auron.udafFallback.typedImperativeEstimatedRowSize",
256),
+ /// enable trimming string inputs before casting to numeric/boolean types
+ CAST_STRING_TRIM_ENABLE("spark.auron.cast.trimString", true),
+
/// ignore corrupted input files
IGNORE_CORRUPTED_FILES("spark.files.ignoreCorruptFiles", false),
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
index 78745a74..3ec49807 100644
---
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
+++
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
@@ -70,6 +70,7 @@ import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.NullType
+import org.apache.spark.sql.types.NumericType
import org.apache.spark.sql.types.ShortType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
@@ -447,11 +448,21 @@ object NativeConverters extends Logging {
case cast: Cast
if !Seq(cast.dataType, cast.child.dataType).exists(t =>
t.isInstanceOf[TimestampType] || t.isInstanceOf[DateType]) =>
+ val castChild =
+ if (cast.child.dataType == StringType &&
+ (cast.dataType.isInstanceOf[NumericType] || cast.dataType
+ .isInstanceOf[BooleanType]) &&
+ AuronConf.CAST_STRING_TRIM_ENABLE.booleanConf()) {
+ // converting Cast(str as num) to StringTrim(Cast(str as num)) if
enabled
+ StringTrim(cast.child)
+ } else {
+ cast.child
+ }
buildExprNode {
_.setTryCast(
pb.PhysicalTryCastNode
.newBuilder()
- .setExpr(convertExprWithFallback(cast.child, isPruningExpr,
fallback))
+ .setExpr(convertExprWithFallback(castChild, isPruningExpr,
fallback))
.setArrowType(convertDataType(cast.dataType))
.build())
}