This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 5062735a7 feat: implement_comet_native_lpad_expr (#2102)
5062735a7 is described below
commit 5062735a71bc3ff5abe90586f868fae493761af2
Author: B Vadlamani <[email protected]>
AuthorDate: Thu Oct 2 13:52:05 2025 -0700
feat: implement_comet_native_lpad_expr (#2102)
* implement lpad expression
---
native/spark-expr/src/comet_scalar_funcs.rs | 8 +++--
.../src/static_invoke/char_varchar_utils/mod.rs | 2 +-
.../char_varchar_utils/read_side_padding.rs | 42 ++++++++++++++++++----
native/spark-expr/src/static_invoke/mod.rs | 2 +-
.../org/apache/comet/serde/QueryPlanSerde.scala | 1 +
.../scala/org/apache/comet/serde/strings.scala | 31 +++++++++++++++-
.../org/apache/comet/CometExpressionSuite.scala | 18 ++++++++++
7 files changed, 93 insertions(+), 11 deletions(-)
diff --git a/native/spark-expr/src/comet_scalar_funcs.rs
b/native/spark-expr/src/comet_scalar_funcs.rs
index 4b863927e..393f57662 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -20,8 +20,8 @@ use crate::math_funcs::checked_arithmetic::{checked_add,
checked_div, checked_mu
use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
spark_array_repeat, spark_ceil, spark_decimal_div,
spark_decimal_integral_div, spark_floor,
- spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding,
spark_round, spark_rpad,
- spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount,
SparkBitwiseNot,
+ spark_hex, spark_isnan, spark_lpad, spark_make_decimal,
spark_read_side_padding, spark_round,
+ spark_rpad, spark_unhex, spark_unscaled_value, EvalMode,
SparkBitwiseCount, SparkBitwiseNot,
SparkDateTrunc, SparkStringSpace,
};
use arrow::datatypes::DataType;
@@ -114,6 +114,10 @@ pub fn create_comet_physical_fun_with_eval_mode(
let func = Arc::new(spark_rpad);
make_comet_scalar_udf!("rpad", func, without data_type)
}
+ "lpad" => {
+ let func = Arc::new(spark_lpad);
+ make_comet_scalar_udf!("lpad", func, without data_type)
+ }
"round" => {
make_comet_scalar_udf!("round", spark_round, data_type)
}
diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs
b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs
index 0a8d8f3c5..5bb94a7ad 100644
--- a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs
+++ b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs
@@ -17,4 +17,4 @@
mod read_side_padding;
-pub use read_side_padding::{spark_read_side_padding, spark_rpad};
+pub use read_side_padding::{spark_lpad, spark_read_side_padding, spark_rpad};
diff --git
a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
index 166bb6ddf..d969b6279 100644
---
a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
+++
b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
@@ -28,17 +28,23 @@ use std::sync::Arc;
const SPACE: &str = " ";
/// Similar to DataFusion `rpad`, but not to truncate when the string is
already longer than length
pub fn spark_read_side_padding(args: &[ColumnarValue]) ->
Result<ColumnarValue, DataFusionError> {
- spark_read_side_padding2(args, false)
+ spark_read_side_padding2(args, false, false)
}
/// Custom `rpad` because DataFusion's `rpad` has differences in unicode
handling
pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
- spark_read_side_padding2(args, true)
+ spark_read_side_padding2(args, true, false)
+}
+
+/// Custom `lpad` because DataFusion's `lpad` has differences in unicode
handling
+pub fn spark_lpad(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
+ spark_read_side_padding2(args, true, true)
}
fn spark_read_side_padding2(
args: &[ColumnarValue],
truncate: bool,
+ is_left_pad: bool,
) -> Result<ColumnarValue, DataFusionError> {
match args {
[ColumnarValue::Array(array),
ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
@@ -48,12 +54,14 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
+ is_left_pad,
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
+ is_left_pad,
),
// Dictionary support required for SPARK-48498
DataType::Dictionary(_, value_type) => {
@@ -64,6 +72,7 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
+ is_left_pad,
)?
} else {
spark_read_side_padding_internal::<i64>(
@@ -71,6 +80,7 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
+ is_left_pad,
)?
};
// col consists of an array, so arg of to_array() is not
used. Can be anything
@@ -91,12 +101,14 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
string,
+ is_left_pad,
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
string,
+ is_left_pad,
),
// Dictionary support required for SPARK-48498
DataType::Dictionary(_, value_type) => {
@@ -107,6 +119,7 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
+ is_left_pad,
)?
} else {
spark_read_side_padding_internal::<i64>(
@@ -114,6 +127,7 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Scalar(ScalarValue::Int32(Some(*length))),
SPACE,
+ is_left_pad,
)?
};
// col consists of an array, so arg of to_array() is not
used. Can be anything
@@ -122,7 +136,7 @@ fn spark_read_side_padding2(
Ok(ColumnarValue::Array(make_array(result.into())))
}
other => Err(DataFusionError::Internal(format!(
- "Unsupported data type {other:?} for function
rpad/read_side_padding",
+ "Unsupported data type {other:?} for function
rpad/lpad/read_side_padding",
))),
}
}
@@ -132,15 +146,17 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
SPACE,
+ is_left_pad,
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
SPACE,
+ is_left_pad,
),
other => Err(DataFusionError::Internal(format!(
- "Unsupported data type {other:?} for function
rpad/read_side_padding",
+ "Unsupported data type {other:?} for function
rpad/lpad/read_side_padding",
))),
},
[ColumnarValue::Array(array), ColumnarValue::Array(array_int),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))] => {
@@ -150,12 +166,14 @@ fn spark_read_side_padding2(
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
string,
+ is_left_pad,
),
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(
array,
truncate,
ColumnarValue::Array(Arc::<dyn Array>::clone(array_int)),
string,
+ is_left_pad,
),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function
rpad/read_side_padding",
@@ -163,7 +181,7 @@ fn spark_read_side_padding2(
}
}
other => Err(DataFusionError::Internal(format!(
- "Unsupported arguments {other:?} for function
rpad/read_side_padding",
+ "Unsupported arguments {other:?} for function
rpad/lpad/read_side_padding",
))),
}
}
@@ -173,6 +191,7 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
truncate: bool,
pad_type: ColumnarValue,
pad_string: &str,
+ is_left_pad: bool,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = as_generic_string_array::<T>(array)?;
match pad_type {
@@ -191,6 +210,7 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
length.unwrap() as usize,
truncate,
pad_string,
+ is_left_pad,
)?),
_ => builder.append_null(),
}
@@ -212,6 +232,7 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
length,
truncate,
pad_string,
+ is_left_pad,
)?),
_ => builder.append_null(),
}
@@ -226,6 +247,7 @@ fn add_padding_string(
length: usize,
truncate: bool,
pad_string: &str,
+ is_left_pad: bool,
) -> Result<String, DataFusionError> {
// It looks Spark's UTF8String is closer to chars rather than graphemes
// https://stackoverflow.com/a/46290728
@@ -250,6 +272,14 @@ fn add_padding_string(
} else {
let pad_needed = length - char_len;
let pad: String =
pad_string.chars().cycle().take(pad_needed).collect();
- Ok(string + &pad)
+ let mut result = String::with_capacity(string.len() + pad.len());
+ if is_left_pad {
+ result.push_str(&pad);
+ result.push_str(&string);
+ } else {
+ result.push_str(&string);
+ result.push_str(&pad);
+ }
+ Ok(result)
}
}
diff --git a/native/spark-expr/src/static_invoke/mod.rs
b/native/spark-expr/src/static_invoke/mod.rs
index 39735f156..6a2176b5f 100644
--- a/native/spark-expr/src/static_invoke/mod.rs
+++ b/native/spark-expr/src/static_invoke/mod.rs
@@ -17,4 +17,4 @@
mod char_varchar_utils;
-pub use char_varchar_utils::{spark_read_side_padding, spark_rpad};
+pub use char_varchar_utils::{spark_lpad, spark_read_side_padding, spark_rpad};
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 892d8bca6..bb05015c2 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -175,6 +175,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[StringRepeat] -> CometStringRepeat,
classOf[StringReplace] -> CometScalarFunction("replace"),
classOf[StringRPad] -> CometStringRPad,
+ classOf[StringLPad] -> CometStringLPad,
classOf[StringSpace] -> CometScalarFunction("string_space"),
classOf[StringTranslate] -> CometScalarFunction("translate"),
classOf[StringTrim] -> CometScalarFunction("trim"),
diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala
b/spark/src/main/scala/org/apache/comet/serde/strings.scala
index 36df9ed1c..9c85d8d6c 100644
--- a/spark/src/main/scala/org/apache/comet/serde/strings.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala
@@ -21,7 +21,7 @@ package org.apache.comet.serde
import java.util.Locale
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression,
InitCap, Like, Literal, Lower, RLike, StringRepeat, StringRPad, Substring,
Upper}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression,
InitCap, Like, Literal, Lower, RLike, StringLPad, StringRepeat, StringRPad,
Substring, Upper}
import org.apache.spark.sql.types.{DataTypes, LongType, StringType}
import org.apache.comet.CometConf
@@ -168,6 +168,35 @@ object CometStringRPad extends
CometExpressionSerde[StringRPad] {
}
}
+object CometStringLPad extends CometExpressionSerde[StringLPad] {
+
+ /**
+ * Convert a Spark expression into a protocol buffer representation that can
be passed into
+ * native code.
+ *
+ * @param expr
+ * The Spark expression.
+ * @param inputs
+ * The input attributes.
+ * @param binding
+ * Whether the attributes are bound (this is only relevant in aggregate
expressions).
+ * @return
+ * Protocol buffer representation, or None if the expression could not be
converted. In this
+ * case it is expected that the input expression will have been tagged
with reasons why it
+ * could not be converted.
+ */
+ override def convert(
+ expr: StringLPad,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[Expr] = {
+ scalarFunctionExprToProto(
+ "lpad",
+ exprToProtoInternal(expr.str, inputs, binding),
+ exprToProtoInternal(expr.len, inputs, binding),
+ exprToProtoInternal(expr.pad, inputs, binding))
+ }
+}
+
trait CommonStringExprs {
def stringDecode(
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 07663ea91..f391d52f7 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -431,6 +431,24 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
+ test("test lpad expression support") {
+ val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2))
+ withParquetTable(data, "t1") {
+ val res = sql("select lpad(_1,_2) , lpad(_1,2) from t1 order by _1")
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+
+ test("LPAD with character support other than default space") {
+ val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2))
+ withParquetTable(data, "t1") {
+ val res = sql(
+ """ select lpad(_1,_2,'?'), lpad(_1,_2,'??') , lpad(_1,2, '??'),
hex(lpad(unhex('aabb'), 5)),
+ rpad(_1, 5, '??') from t1 order by _1 """.stripMargin)
+ checkSparkAnswerAndOperator(res)
+ }
+ }
+
test("dictionary arithmetic") {
// TODO: test ANSI mode
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false",
"parquet.enable.dictionary" -> "true") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]