This is an automated email from the ASF dual-hosted git repository.

richox pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git


The following commit(s) were added to refs/heads/master by this push:
     new b159ce7f [AURON-1316] Support trim in cast expression (#1317)
b159ce7f is described below

commit b159ce7f187bc6793c02387036c105a46d2b48fb
Author: huang qiwei <[email protected]>
AuthorDate: Sun Sep 21 23:18:46 2025 +0800

    [AURON-1316] Support trim in cast expression (#1317)
    
    * Support trim in cast expression
    
    * re-format
    
    * re-format
    
    * Support new setting
    
    * Revert "re-format"
    
    This reverts commit 42928bd385eb537fae334cce99a1a1fdc7ab48ff.
    
    * Revert "re-format"
    
    This reverts commit 2d2a0295e973ecdc7473b2376ec7f3956efb1c89.
    
    * Revert "Support trim in cast expression"
    
    This reverts commit dddbd8a9ef24d692e1dada5f2a4d92cb868b92ed.
    
    * Add some comments
    
    * re-format
    
    * re-format
    
    * update converter
    
    * reformat
    
    * Update to support boolean trim
---
 .../datafusion-ext-commons/src/arrow/cast.rs       |  2 +
 .../spark/sql/auron/NativeConvertersSuite.scala    | 83 ++++++++++++++++++++++
 .../java/org/apache/spark/sql/auron/AuronConf.java |  3 +
 .../apache/spark/sql/auron/NativeConverters.scala  | 13 +++-
 4 files changed, 100 insertions(+), 1 deletion(-)

diff --git a/native-engine/datafusion-ext-commons/src/arrow/cast.rs 
b/native-engine/datafusion-ext-commons/src/arrow/cast.rs
index 9f92c570..afa759fe 100644
--- a/native-engine/datafusion-ext-commons/src/arrow/cast.rs
+++ b/native-engine/datafusion-ext-commons/src/arrow/cast.rs
@@ -284,6 +284,8 @@ fn try_cast_string_array_to_date(array: &dyn Array) -> 
Result<ArrayRef> {
 }
 
 // this implementation is original copied from spark UTF8String.scala
+// The original implementation included trimming logic, but it was omitted here
+// since Auron’s NativeConverters will handle trimming.
 fn to_integer<T: Bounded + FromPrimitive + Integer + Signed + Copy>(input: 
&str) -> Option<T> {
     let bytes = input.as_bytes();
 
diff --git 
a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/NativeConvertersSuite.scala
 
b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/NativeConvertersSuite.scala
new file mode 100644
index 00000000..d5a1a60c
--- /dev/null
+++ 
b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/NativeConvertersSuite.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.auron
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, 
StringType}
+
+import org.apache.auron.protobuf.ScalarFunction
+
+class NativeConvertersSuite extends QueryTest with BaseAuronSQLSuite with 
AuronSQLTestHelper {
+
+  private def assertTrimmedCast(rawValue: String, targetType: DataType): Unit 
= {
+    val expr = Cast(Literal.create(rawValue, StringType), targetType)
+    val nativeExpr = NativeConverters.convertExpr(expr)
+
+    assert(nativeExpr.hasTryCast)
+    val childExpr = nativeExpr.getTryCast.getExpr
+    assert(childExpr.hasScalarFunction)
+    val scalarFn = childExpr.getScalarFunction
+    assert(scalarFn.getFun == ScalarFunction.Trim)
+    assert(scalarFn.getArgsCount == 1 && scalarFn.getArgs(0).hasLiteral)
+  }
+
+  private def assertNonTrimmedCast(rawValue: String, targetType: DataType): 
Unit = {
+    val expr = Cast(Literal.create(rawValue, StringType), targetType)
+    val nativeExpr = NativeConverters.convertExpr(expr)
+
+    assert(nativeExpr.hasTryCast)
+    val childExpr = nativeExpr.getTryCast.getExpr
+    assert(!childExpr.hasScalarFunction)
+    assert(childExpr.hasLiteral)
+  }
+
+  test("cast from string to numeric adds trim wrapper before native cast when 
enabled") {
+    withSQLConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "true") {
+      assertTrimmedCast(" 42 ", IntegerType)
+    }
+  }
+
+  test("cast from string to boolean adds trim wrapper before native cast when 
enabled") {
+    withSQLConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "true") {
+      assertTrimmedCast(" true ", BooleanType)
+    }
+  }
+
+  test("cast trim disabled via auron conf") {
+    withEnvConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "false") {
+      assertNonTrimmedCast(" 42 ", IntegerType)
+    }
+  }
+
+  test("cast trim disabled via auron conf for boolean cast") {
+    withEnvConf(AuronConf.CAST_STRING_TRIM_ENABLE.key -> "false") {
+      assertNonTrimmedCast(" true ", BooleanType)
+    }
+  }
+
+  test("cast with non-string child remains unchanged") {
+    val expr = Cast(Literal(1.5), IntegerType)
+    val nativeExpr = NativeConverters.convertExpr(expr)
+
+    assert(nativeExpr.hasTryCast)
+    val childExpr = nativeExpr.getTryCast.getExpr
+    assert(!childExpr.hasScalarFunction)
+    assert(childExpr.hasLiteral)
+  }
+}
diff --git 
a/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java 
b/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
index a3d82ae0..d45ccefd 100644
--- a/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
+++ b/spark-extension/src/main/java/org/apache/spark/sql/auron/AuronConf.java
@@ -52,6 +52,9 @@ public enum AuronConf {
     // TypedImperativeAggregate one row mem use size
     
UDAF_FALLBACK_ESTIM_ROW_SIZE("spark.auron.udafFallback.typedImperativeEstimatedRowSize",
 256),
 
+    /// enable trimming string inputs before casting to numeric/boolean types
+    CAST_STRING_TRIM_ENABLE("spark.auron.cast.trimString", true),
+
     /// ignore corrupted input files
     IGNORE_CORRUPTED_FILES("spark.files.ignoreCorruptFiles", false),
 
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
index 78745a74..3ec49807 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
@@ -70,6 +70,7 @@ import org.apache.spark.sql.types.IntegerType
 import org.apache.spark.sql.types.LongType
 import org.apache.spark.sql.types.MapType
 import org.apache.spark.sql.types.NullType
+import org.apache.spark.sql.types.NumericType
 import org.apache.spark.sql.types.ShortType
 import org.apache.spark.sql.types.StringType
 import org.apache.spark.sql.types.StructField
@@ -447,11 +448,21 @@ object NativeConverters extends Logging {
       case cast: Cast
           if !Seq(cast.dataType, cast.child.dataType).exists(t =>
             t.isInstanceOf[TimestampType] || t.isInstanceOf[DateType]) =>
+        val castChild =
+          if (cast.child.dataType == StringType &&
+            (cast.dataType.isInstanceOf[NumericType] || cast.dataType
+              .isInstanceOf[BooleanType]) &&
+            AuronConf.CAST_STRING_TRIM_ENABLE.booleanConf()) {
+            // converting Cast(str as num) to StringTrim(Cast(str as num)) if 
enabled
+            StringTrim(cast.child)
+          } else {
+            cast.child
+          }
         buildExprNode {
           _.setTryCast(
             pb.PhysicalTryCastNode
               .newBuilder()
-              .setExpr(convertExprWithFallback(cast.child, isPruningExpr, 
fallback))
+              .setExpr(convertExprWithFallback(castChild, isPruningExpr, 
fallback))
               .setArrowType(convertDataType(cast.dataType))
               .build())
         }

Reply via email to