This is an automated email from the ASF dual-hosted git repository.
csy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git
The following commit(s) were added to refs/heads/master by this push:
new eeac1892 [AURON #1680] initCap semantics are aligned with Spark (#1681)
eeac1892 is described below
commit eeac18921648a2d863cac93e8b772705387b1438
Author: Thomas <[email protected]>
AuthorDate: Tue Dec 16 19:01:25 2025 +0800
[AURON #1680] initCap semantics are aligned with Spark (#1681)
# Which issue does this PR close?
Closes #1680
# Rationale for this change
The current initcap implementation uses DataFusion's initcap, which does
not match Spark's semantics. Spark uses space-only word boundaries and
title-cases the first letter while lowercasing the rest.
# What changes are included in this PR?
+ Implement a new initcap native function aligned with Spark, similar to
Spark's implementation logic: `
string.asInstanceOf[UTF8String].toLowerCase.toTitleCase`.
+ Refactor and expand initcap unit tests, adding corner cases.
# Are there any user-facing changes?
Yes. initcap results will now match Spark's semantics.
# How was this patch tested?
Added unit tests covering ASCII/non-ASCII, punctuation, space-only
boundaries, and edge cases.
---
native-engine/auron-serde/proto/auron.proto | 2 +-
native-engine/auron-serde/src/from_proto.rs | 2 +-
native-engine/datafusion-ext-functions/src/lib.rs | 2 +
.../datafusion-ext-functions/src/spark_initcap.rs | 119 +++++++++++++++++++++
.../scala/org.apache.auron/AuronQuerySuite.scala | 62 +++++++----
.../apache/spark/sql/auron/NativeConverters.scala | 4 +-
6 files changed, 164 insertions(+), 27 deletions(-)
diff --git a/native-engine/auron-serde/proto/auron.proto
b/native-engine/auron-serde/proto/auron.proto
index 29432c1e..29e9f113 100644
--- a/native-engine/auron-serde/proto/auron.proto
+++ b/native-engine/auron-serde/proto/auron.proto
@@ -229,7 +229,7 @@ enum ScalarFunction {
ConcatWithSeparator=27;
DatePart=28;
DateTrunc=29;
- InitCap=30;
+ // InitCap=30;
Left=31;
Lpad=32;
Lower=33;
diff --git a/native-engine/auron-serde/src/from_proto.rs
b/native-engine/auron-serde/src/from_proto.rs
index 31082e8b..0caaad6c 100644
--- a/native-engine/auron-serde/src/from_proto.rs
+++ b/native-engine/auron-serde/src/from_proto.rs
@@ -803,7 +803,7 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
ScalarFunction::CharacterLength => f::unicode::character_length(),
ScalarFunction::Chr => f::string::chr(),
ScalarFunction::ConcatWithSeparator => f::string::concat_ws(),
- ScalarFunction::InitCap => f::unicode::initcap(),
+ // ScalarFunction::InitCap => f::unicode::initcap(),
ScalarFunction::Left => f::unicode::left(),
ScalarFunction::Lpad => f::unicode::lpad(),
ScalarFunction::Random => f::math::random(),
diff --git a/native-engine/datafusion-ext-functions/src/lib.rs
b/native-engine/datafusion-ext-functions/src/lib.rs
index cad5198d..e1989de4 100644
--- a/native-engine/datafusion-ext-functions/src/lib.rs
+++ b/native-engine/datafusion-ext-functions/src/lib.rs
@@ -24,6 +24,7 @@ mod spark_crypto;
mod spark_dates;
pub mod spark_get_json_object;
mod spark_hash;
+mod spark_initcap;
mod spark_isnan;
mod spark_make_array;
mod spark_make_decimal;
@@ -64,6 +65,7 @@ pub fn create_auron_ext_function(name: &str) ->
Result<ScalarFunctionImplementat
"Spark_StringConcatWs" => Arc::new(spark_strings::string_concat_ws),
"Spark_StringLower" => Arc::new(spark_strings::string_lower),
"Spark_StringUpper" => Arc::new(spark_strings::string_upper),
+ "Spark_InitCap" => Arc::new(spark_initcap::string_initcap),
"Spark_Year" => Arc::new(spark_dates::spark_year),
"Spark_Month" => Arc::new(spark_dates::spark_month),
"Spark_Day" => Arc::new(spark_dates::spark_day),
diff --git a/native-engine/datafusion-ext-functions/src/spark_initcap.rs
b/native-engine/datafusion-ext-functions/src/spark_initcap.rs
new file mode 100644
index 00000000..7c5218fc
--- /dev/null
+++ b/native-engine/datafusion-ext-functions/src/spark_initcap.rs
@@ -0,0 +1,119 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::sync::Arc;
+
+use arrow::array::{ArrayRef, StringArray};
+use datafusion::{
+ common::{Result, ScalarValue, cast::as_string_array},
+ logical_expr::ColumnarValue,
+};
+use datafusion_ext_commons::df_execution_err;
+
+pub fn string_initcap(args: &[ColumnarValue]) -> Result<ColumnarValue> {
+ match &args[0] {
+ ColumnarValue::Array(array) => {
+ let input_array = as_string_array(array)?;
+ let output_array =
+ StringArray::from_iter(input_array.into_iter().map(|s|
s.map(|s| initcap(s))));
+ Ok(ColumnarValue::Array(Arc::new(output_array) as ArrayRef))
+ }
+ ColumnarValue::Scalar(ScalarValue::Utf8(Some(str))) => {
+ Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(initcap(str)))))
+ }
+ _ => df_execution_err!("Unsupported args {args:?} for
`string_initcap`"),
+ }
+}
+
+fn initcap(input: &str) -> String {
+ let mut out = String::with_capacity(input.len());
+ let mut prev_is_space = true; // i == 0 or chars[i-1] == ' '
+
+ if input.is_ascii() {
+ // ASCII
+ for ch in input.chars() {
+ if prev_is_space && ch.is_ascii_alphanumeric() {
+ out.push(ch.to_ascii_uppercase());
+ } else {
+ out.push(ch.to_ascii_lowercase());
+ };
+ prev_is_space = ch == ' ';
+ }
+ } else {
+ // Non-ASCII
+ for ch in input.chars() {
+ if prev_is_space && ch.is_alphabetic() {
+ out.extend(ch.to_uppercase());
+ } else {
+ out.extend(ch.to_lowercase());
+ }
+ prev_is_space = ch == ' ';
+ }
+ }
+ out
+}
+
+#[cfg(test)]
+mod test {
+ use std::sync::Arc;
+
+ use arrow::array::{ArrayRef, StringArray};
+ use datafusion::{
+ common::{Result, ScalarValue},
+ physical_plan::ColumnarValue,
+ };
+
+ use crate::spark_initcap::string_initcap;
+
+ #[test]
+ fn test_initcap_array() -> Result<()> {
+ let input_data = vec![
+ None,
+ Some(""),
+ Some("hI THOmAS"),
+ Some("James-Smith"),
+ Some("michael rose"),
+ Some("a1b2 c3D4"),
+ Some(" ---abc--- ABC --ABC-- a-b A B eB Ac c d"),
+ Some(" 世 界 世界 "),
+ ];
+ let input_columnar_value =
ColumnarValue::Array(Arc::new(StringArray::from(input_data)));
+
+ let result =
string_initcap(&vec![input_columnar_value])?.into_array(6)?;
+
+ let expected_data = vec![
+ None,
+ Some(""),
+ Some("Hi Thomas"),
+ Some("James-smith"),
+ Some("Michael Rose"),
+ Some("A1b2 C3d4"),
+ Some(" ---abc--- Abc --abc-- A-b A B Eb Ac C D"),
+ Some(" 世 界 世界 "),
+ ];
+ let expected: ArrayRef = Arc::new(StringArray::from(expected_data));
+ assert_eq!(&result, &expected);
+ Ok(())
+ }
+
+ #[test]
+ fn test_initcap_scalar() -> Result<()> {
+ let input_columnar_value =
ColumnarValue::Scalar(ScalarValue::from("abC c3D4"));
+ let result =
string_initcap(&vec![input_columnar_value])?.into_array(1)?;
+ let expected: ArrayRef = Arc::new(StringArray::from(vec![Some("Abc
C3d4")]));
+ assert_eq!(&result, &expected);
+ Ok(())
+ }
+}
diff --git
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronQuerySuite.scala
b/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronQuerySuite.scala
index aeff61b2..5d0c420e 100644
---
a/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronQuerySuite.scala
+++
b/spark-extension-shims-spark/src/test/scala/org.apache.auron/AuronQuerySuite.scala
@@ -316,37 +316,53 @@ class AuronQuerySuite extends AuronQueryTest with
BaseAuronSQLSuite with AuronSQ
}
test("initcap basic") {
- Seq(
- ("select initcap('spark sql')", Row("Spark Sql")),
- ("select initcap('SPARK')", Row("Spark")),
- ("select initcap('sPaRk')", Row("Spark")),
- ("select initcap('')", Row("")),
- ("select initcap(null)", Row(null))).foreach { case (q, expected) =>
- checkAnswer(sql(q), Seq(expected))
+ withTable("initcap_basic_tbl") {
+ sql(s"CREATE TABLE initcap_basic_tbl(id INT, txt STRING) USING parquet")
+ sql(s"""
+ |INSERT INTO initcap_basic_tbl VALUES
+ | (1, 'spark sql'),
+ | (2, 'SPARK'),
+ | (3, 'sPaRk'),
+ | (4, ''),
+ | (5, NULL)
+ """.stripMargin)
+ checkSparkAnswerAndOperator("select id, initcap(txt) from
initcap_basic_tbl")
}
}
test("initcap: word boundaries and punctuation") {
- Seq(
- ("select initcap('hello world')", Row("Hello World")),
- ("select initcap('hello_world')", Row("Hello_world")),
- ("select initcap('über-alles')", Row("Über-alles")),
- ("select initcap('foo.bar/baz')", Row("Foo.bar/baz")),
- ("select initcap('v2Ray is COOL')", Row("V2ray Is Cool")),
- ("select initcap('rock''n''roll')", Row("Rocknroll")),
- ("select initcap('hi\\tthere')", Row("Hi\tthere")),
- ("select initcap('hi\\nthere')", Row("Hi\nthere"))).foreach { case (q,
expected) =>
- checkAnswer(sql(q), Seq(expected))
+ withTable("initcap_bound_tbl") {
+ sql(s"CREATE TABLE initcap_bound_tbl(id INT, txt STRING) USING parquet")
+ sql(s"""
+ |INSERT INTO initcap_bound_tbl VALUES
+ | (1, 'hello world'),
+ | (2, 'hello_world'),
+ | (3, 'über-alles'),
+ | (4, 'foo.bar/baz'),
+ | (5, 'v2Ray is COOL'),
+ | (6, 'rock''n''roll'),
+ | (7, 'hi\tthere'),
+ | (8, 'hi\nthere')
+ """.stripMargin)
+ checkSparkAnswerAndOperator("select id, initcap(txt) from
initcap_bound_tbl")
}
}
test("initcap: mixed cases and edge cases") {
- Seq(
- ("select initcap('a1b2 c3D4')", Row("A1b2 C3d4")),
- ("select initcap('---abc---')", Row("---abc---")),
- ("select initcap(' multiple spaces ')", Row(" Multiple Spaces
"))).foreach {
- case (q, expected) =>
- checkAnswer(sql(q), Seq(expected))
+ withTable("initcap_mixed_tbl") {
+ sql(s"CREATE TABLE initcap_mixed_tbl(id INT, txt STRING) USING parquet")
+ sql(s"""
+ |INSERT INTO initcap_mixed_tbl VALUES
+ | (1, 'a1b2 c3D4'),
+ | (2, '---abc--- ABC --ABC-- 世界 世 界 '),
+ | (3, ' multiple spaces '),
+ | (4, 'AbCdE aBcDe'),
+ | (5, ' A B A b '),
+ | (6, 'aBćDe ab世De AbĆdE aB世De ÄBĆΔE'),
+ | (7, 'i\u0307onic FIDELİO'),
+ | (8, 'a🙃B🙃c 😄 😆')
+ """.stripMargin)
+ checkSparkAnswerAndOperator("select id, initcap(txt) from
initcap_mixed_tbl")
}
}
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
index 4838804a..2d711012 100644
---
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
+++
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
@@ -885,8 +885,8 @@ object NativeConverters extends Logging {
buildExtScalarFunction("Spark_MD5", Seq(unpackBinaryTypeCast(_1)),
StringType)
case Reverse(_1) =>
buildScalarFunction(pb.ScalarFunction.Reverse,
Seq(unpackBinaryTypeCast(_1)), StringType)
- case InitCap(_1) =>
- buildScalarFunction(pb.ScalarFunction.InitCap,
Seq(unpackBinaryTypeCast(_1)), StringType)
+ case e: InitCap =>
+ buildExtScalarFunction("Spark_InitCap", e.children, e.dataType)
case Sha2(_1, Literal(224, _)) =>
buildExtScalarFunction("Spark_Sha224", Seq(unpackBinaryTypeCast(_1)),
StringType)
case Sha2(_1, Literal(0, _)) =>