This is an automated email from the ASF dual-hosted git repository.

csy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git


The following commit(s) were added to refs/heads/master by this push:
     new eeac1892 [AURON #1680] initCap semantics are aligned with Spark (#1681)
eeac1892 is described below

commit eeac18921648a2d863cac93e8b772705387b1438
Author: Thomas <[email protected]>
AuthorDate: Tue Dec 16 19:01:25 2025 +0800

    [AURON #1680] initCap semantics are aligned with Spark (#1681)
    
    # Which issue does this PR close?
    
    Closes #1680
    
     # Rationale for this change
    The current initcap implementation uses DataFusion's initcap, which does
    not match Spark's semantics. Spark uses space-only word boundaries and
    title-cases the first letter while lowercasing the rest.
    
    # What changes are included in this PR?
    + Implement a new initcap native function aligned with Spark, similar to
    Spark's implementation logic: `
    string.asInstanceOf[UTF8String].toLowerCase.toTitleCase`.
    + Refactor and expand initcap unit tests, adding corner cases.
    
    # Are there any user-facing changes?
     Yes. initcap results will now match Spark's semantics.
    
    # How was this patch tested?
    Added unit tests covering ASCII/non-ASCII, punctuation, space-only
    boundaries, and edge cases.
---
 native-engine/auron-serde/proto/auron.proto        |   2 +-
 native-engine/auron-serde/src/from_proto.rs        |   2 +-
 native-engine/datafusion-ext-functions/src/lib.rs  |   2 +
 .../datafusion-ext-functions/src/spark_initcap.rs  | 119 +++++++++++++++++++++
 .../scala/org.apache.auron/AuronQuerySuite.scala   |  62 +++++++----
 .../apache/spark/sql/auron/NativeConverters.scala  |   4 +-
 6 files changed, 164 insertions(+), 27 deletions(-)

diff --git a/native-engine/auron-serde/proto/auron.proto 
b/native-engine/auron-serde/proto/auron.proto
index 29432c1e..29e9f113 100644
--- a/native-engine/auron-serde/proto/auron.proto
+++ b/native-engine/auron-serde/proto/auron.proto
@@ -229,7 +229,7 @@ enum ScalarFunction {
   ConcatWithSeparator=27;
   DatePart=28;
   DateTrunc=29;
-  InitCap=30;
+  // InitCap=30;
   Left=31;
   Lpad=32;
   Lower=33;
diff --git a/native-engine/auron-serde/src/from_proto.rs 
b/native-engine/auron-serde/src/from_proto.rs
index 31082e8b..0caaad6c 100644
--- a/native-engine/auron-serde/src/from_proto.rs
+++ b/native-engine/auron-serde/src/from_proto.rs
@@ -803,7 +803,7 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
             ScalarFunction::CharacterLength => f::unicode::character_length(),
             ScalarFunction::Chr => f::string::chr(),
             ScalarFunction::ConcatWithSeparator => f::string::concat_ws(),
-            ScalarFunction::InitCap => f::unicode::initcap(),
+            // ScalarFunction::InitCap => f::unicode::initcap(),
             ScalarFunction::Left => f::unicode::left(),
             ScalarFunction::Lpad => f::unicode::lpad(),
             ScalarFunction::Random => f::math::random(),
diff --git a/native-engine/datafusion-ext-functions/src/lib.rs 
b/native-engine/datafusion-ext-functions/src/lib.rs
index cad5198d..e1989de4 100644
--- a/native-engine/datafusion-ext-functions/src/lib.rs
+++ b/native-engine/datafusion-ext-functions/src/lib.rs
@@ -24,6 +24,7 @@ mod spark_crypto;
 mod spark_dates;
 pub mod spark_get_json_object;
 mod spark_hash;
+mod spark_initcap;
 mod spark_isnan;
 mod spark_make_array;
 mod spark_make_decimal;
@@ -64,6 +65,7 @@ pub fn create_auron_ext_function(name: &str) -> 
Result<ScalarFunctionImplementat
         "Spark_StringConcatWs" => Arc::new(spark_strings::string_concat_ws),
         "Spark_StringLower" => Arc::new(spark_strings::string_lower),
         "Spark_StringUpper" => Arc::new(spark_strings::string_upper),
+        "Spark_InitCap" => Arc::new(spark_initcap::string_initcap),
         "Spark_Year" => Arc::new(spark_dates::spark_year),
         "Spark_Month" => Arc::new(spark_dates::spark_month),
         "Spark_Day" => Arc::new(spark_dates::spark_day),
diff --git a/native-engine/datafusion-ext-functions/src/spark_initcap.rs 
b/native-engine/datafusion-ext-functions/src/spark_initcap.rs
new file mode 100644
index 00000000..7c5218fc
--- /dev/null
+++ b/native-engine/datafusion-ext-functions/src/spark_initcap.rs
@@ -0,0 +1,119 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::sync::Arc;
+
+use arrow::array::{ArrayRef, StringArray};
+use datafusion::{
+    common::{Result, ScalarValue, cast::as_string_array},
+    logical_expr::ColumnarValue,
+};
+use datafusion_ext_commons::df_execution_err;
+
+pub fn string_initcap(args: &[ColumnarValue]) -> Result<ColumnarValue> {
+    match &args[0] {
+        ColumnarValue::Array(array) => {
+            let input_array = as_string_array(array)?;
+            let output_array =
+                StringArray::from_iter(input_array.into_iter().map(|s| 
s.map(|s| initcap(s))));
+            Ok(ColumnarValue::Array(Arc::new(output_array) as ArrayRef))
+        }
+        ColumnarValue::Scalar(ScalarValue::Utf8(Some(str))) => {
+            Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(initcap(str)))))
+        }
+        _ => df_execution_err!("Unsupported args {args:?} for 
`string_initcap`"),
+    }
+}
+
+fn initcap(input: &str) -> String {
+    let mut out = String::with_capacity(input.len());
+    let mut prev_is_space = true; // i == 0 or chars[i-1] == ' '
+
+    if input.is_ascii() {
+        // ASCII
+        for ch in input.chars() {
+            if prev_is_space && ch.is_ascii_alphanumeric() {
+                out.push(ch.to_ascii_uppercase());
+            } else {
+                out.push(ch.to_ascii_lowercase());
+            };
+            prev_is_space = ch == ' ';
+        }
+    } else {
+        // Non-ASCII
+        for ch in input.chars() {
+            if prev_is_space && ch.is_alphabetic() {
+                out.extend(ch.to_uppercase());
+            } else {
+                out.extend(ch.to_lowercase());
+            }
+            prev_is_space = ch == ' ';
+        }
+    }
+    out
+}
+
+#[cfg(test)]
+mod test {
+    use std::sync::Arc;
+
+    use arrow::array::{ArrayRef, StringArray};
+    use datafusion::{
+        common::{Result, ScalarValue},
+        physical_plan::ColumnarValue,
+    };
+
+    use crate::spark_initcap::string_initcap;
+
+    #[test]
+    fn test_initcap_array() -> Result<()> {
+        let input_data = vec![
+            None,
+            Some(""),
+            Some("hI THOmAS"),
+            Some("James-Smith"),
+            Some("michael rose"),
+            Some("a1b2   c3D4"),
+            Some(" ---abc--- ABC --ABC-- a-b A B eB Ac c d"),
+            Some(" 世  界  世界 "),
+        ];
+        let input_columnar_value = 
ColumnarValue::Array(Arc::new(StringArray::from(input_data)));
+
+        let result = 
string_initcap(&vec![input_columnar_value])?.into_array(6)?;
+
+        let expected_data = vec![
+            None,
+            Some(""),
+            Some("Hi Thomas"),
+            Some("James-smith"),
+            Some("Michael Rose"),
+            Some("A1b2   C3d4"),
+            Some(" ---abc--- Abc --abc-- A-b A B Eb Ac C D"),
+            Some(" 世  界  世界 "),
+        ];
+        let expected: ArrayRef = Arc::new(StringArray::from(expected_data));
+        assert_eq!(&result, &expected);
+        Ok(())
+    }
+
+    #[test]
+    fn test_initcap_scalar() -> Result<()> {
+        let input_columnar_value = 
ColumnarValue::Scalar(ScalarValue::from("abC c3D4"));
+        let result = 
string_initcap(&vec![input_columnar_value])?.into_array(1)?;
+        let expected: ArrayRef = Arc::new(StringArray::from(vec![Some("Abc 
C3d4")]));
+        assert_eq!(&result, &expected);
+        Ok(())
+    }
+}
diff --git 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronQuerySuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronQuerySuite.scala
index aeff61b2..5d0c420e 100644
--- 
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronQuerySuite.scala
+++ 
b/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronQuerySuite.scala
@@ -316,37 +316,53 @@ class AuronQuerySuite extends AuronQueryTest with 
BaseAuronSQLSuite with AuronSQ
   }
 
   test("initcap basic") {
-    Seq(
-      ("select initcap('spark sql')", Row("Spark Sql")),
-      ("select initcap('SPARK')", Row("Spark")),
-      ("select initcap('sPaRk')", Row("Spark")),
-      ("select initcap('')", Row("")),
-      ("select initcap(null)", Row(null))).foreach { case (q, expected) =>
-      checkAnswer(sql(q), Seq(expected))
+    withTable("initcap_basic_tbl") {
+      sql(s"CREATE TABLE initcap_basic_tbl(id INT, txt STRING) USING parquet")
+      sql(s"""
+           |INSERT INTO initcap_basic_tbl VALUES
+           | (1, 'spark sql'),
+           | (2, 'SPARK'),
+           | (3, 'sPaRk'),
+           | (4, ''),
+           | (5, NULL)
+        """.stripMargin)
+      checkSparkAnswerAndOperator("select id, initcap(txt) from 
initcap_basic_tbl")
     }
   }
 
   test("initcap: word boundaries and punctuation") {
-    Seq(
-      ("select initcap('hello world')", Row("Hello World")),
-      ("select initcap('hello_world')", Row("Hello_world")),
-      ("select initcap('über-alles')", Row("Über-alles")),
-      ("select initcap('foo.bar/baz')", Row("Foo.bar/baz")),
-      ("select initcap('v2Ray is COOL')", Row("V2ray Is Cool")),
-      ("select initcap('rock''n''roll')", Row("Rocknroll")),
-      ("select initcap('hi\\tthere')", Row("Hi\tthere")),
-      ("select initcap('hi\\nthere')", Row("Hi\nthere"))).foreach { case (q, 
expected) =>
-      checkAnswer(sql(q), Seq(expected))
+    withTable("initcap_bound_tbl") {
+      sql(s"CREATE TABLE initcap_bound_tbl(id INT, txt STRING) USING parquet")
+      sql(s"""
+           |INSERT INTO initcap_bound_tbl VALUES
+           | (1, 'hello world'),
+           | (2, 'hello_world'),
+           | (3, 'über-alles'),
+           | (4, 'foo.bar/baz'),
+           | (5, 'v2Ray is COOL'),
+           | (6, 'rock''n''roll'),
+           | (7, 'hi\tthere'),
+           | (8, 'hi\nthere')
+        """.stripMargin)
+      checkSparkAnswerAndOperator("select id, initcap(txt) from 
initcap_bound_tbl")
     }
   }
 
   test("initcap: mixed cases and edge cases") {
-    Seq(
-      ("select initcap('a1b2 c3D4')", Row("A1b2 C3d4")),
-      ("select initcap('---abc---')", Row("---abc---")),
-      ("select initcap('  multiple   spaces ')", Row("  Multiple   Spaces 
"))).foreach {
-      case (q, expected) =>
-        checkAnswer(sql(q), Seq(expected))
+    withTable("initcap_mixed_tbl") {
+      sql(s"CREATE TABLE initcap_mixed_tbl(id INT, txt STRING) USING parquet")
+      sql(s"""
+           |INSERT INTO initcap_mixed_tbl VALUES
+           | (1, 'a1b2 c3D4'),
+           | (2, '---abc--- ABC --ABC-- 世界 世 界 '),
+           | (3, ' multiple   spaces '),
+           | (4, 'AbCdE aBcDe'),
+           | (5, ' A B A b '),
+           | (6, 'aBćDe  ab世De AbĆdE aB世De ÄBĆΔE'),
+           | (7, 'i\u0307onic  FIDELİO'),
+           | (8, 'a🙃B🙃c  😄 😆')
+        """.stripMargin)
+      checkSparkAnswerAndOperator("select id, initcap(txt) from 
initcap_mixed_tbl")
     }
   }
 
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
index 4838804a..2d711012 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
@@ -885,8 +885,8 @@ object NativeConverters extends Logging {
         buildExtScalarFunction("Spark_MD5", Seq(unpackBinaryTypeCast(_1)), 
StringType)
       case Reverse(_1) =>
         buildScalarFunction(pb.ScalarFunction.Reverse, 
Seq(unpackBinaryTypeCast(_1)), StringType)
-      case InitCap(_1) =>
-        buildScalarFunction(pb.ScalarFunction.InitCap, 
Seq(unpackBinaryTypeCast(_1)), StringType)
+      case e: InitCap =>
+        buildExtScalarFunction("Spark_InitCap", e.children, e.dataType)
       case Sha2(_1, Literal(224, _)) =>
         buildExtScalarFunction("Spark_Sha224", Seq(unpackBinaryTypeCast(_1)), 
StringType)
       case Sha2(_1, Literal(0, _)) =>

Reply via email to