This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 00922cfc4 fix: Fallback to Spark for lpad/rpad for unsupported
arguments & fix negative length handling (#2630)
00922cfc4 is described below
commit 00922cfc464c934a583548727196144c695e7698
Author: Andy Grove <[email protected]>
AuthorDate: Fri Oct 24 04:18:06 2025 -0600
fix: Fallback to Spark for lpad/rpad for unsupported arguments & fix
negative length handling (#2630)
---
.../char_varchar_utils/read_side_padding.rs | 21 ++--
.../scala/org/apache/comet/serde/strings.scala | 35 +++---
.../apache/comet/testing/FuzzDataGenerator.scala | 15 ++-
.../org/apache/comet/CometExpressionSuite.scala | 62 -----------
.../apache/comet/CometStringExpressionSuite.scala | 121 +++++++++++++++++++++
.../scala/org/apache/spark/sql/CometTestBase.scala | 7 ++
6 files changed, 174 insertions(+), 87 deletions(-)
diff --git
a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
index d969b6279..89485ddec 100644
---
a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
+++
b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
@@ -204,14 +204,21 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
);
for (string, length) in string_array.iter().zip(int_pad_array) {
+ let length = length.unwrap();
match string {
- Some(string) => builder.append_value(add_padding_string(
- string.parse().unwrap(),
- length.unwrap() as usize,
- truncate,
- pad_string,
- is_left_pad,
- )?),
+ Some(string) => {
+ if length >= 0 {
+ builder.append_value(add_padding_string(
+ string.parse().unwrap(),
+ length as usize,
+ truncate,
+ pad_string,
+ is_left_pad,
+ )?)
+ } else {
+ builder.append_value("");
+ }
+ }
_ => builder.append_null(),
}
}
diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala
b/spark/src/main/scala/org/apache/comet/serde/strings.scala
index c6f5a8508..3d4bacfa2 100644
--- a/spark/src/main/scala/org/apache/comet/serde/strings.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala
@@ -162,6 +162,16 @@ object CometRLike extends CometExpressionSerde[RLike] {
object CometStringRPad extends CometExpressionSerde[StringRPad] {
+ override def getSupportLevel(expr: StringRPad): SupportLevel = {
+ if (expr.str.isInstanceOf[Literal]) {
+ return Unsupported(Some("Scalar values are not supported for the str
argument"))
+ }
+ if (!expr.pad.isInstanceOf[Literal]) {
+ return Unsupported(Some("Only scalar values are supported for the pad
argument"))
+ }
+ Compatible()
+ }
+
override def convert(
expr: StringRPad,
inputs: Seq[Attribute],
@@ -177,21 +187,16 @@ object CometStringRPad extends
CometExpressionSerde[StringRPad] {
object CometStringLPad extends CometExpressionSerde[StringLPad] {
- /**
- * Convert a Spark expression into a protocol buffer representation that can
be passed into
- * native code.
- *
- * @param expr
- * The Spark expression.
- * @param inputs
- * The input attributes.
- * @param binding
- * Whether the attributes are bound (this is only relevant in aggregate
expressions).
- * @return
- * Protocol buffer representation, or None if the expression could not be
converted. In this
- * case it is expected that the input expression will have been tagged
with reasons why it
- * could not be converted.
- */
+ override def getSupportLevel(expr: StringLPad): SupportLevel = {
+ if (expr.str.isInstanceOf[Literal]) {
+ return Unsupported(Some("Scalar values are not supported for the str
argument"))
+ }
+ if (!expr.pad.isInstanceOf[Literal]) {
+ return Unsupported(Some("Only scalar values are supported for the pad
argument"))
+ }
+ Compatible()
+ }
+
override def convert(
expr: StringLPad,
inputs: Seq[Attribute],
diff --git
a/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala
b/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala
index 188da1d79..5363fda15 100644
--- a/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala
+++ b/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala
@@ -194,11 +194,13 @@ object FuzzDataGenerator {
case 1 => r.nextInt().toByte.toString
case 2 => r.nextLong().toString
case 3 => r.nextDouble().toString
- case 4 => RandomStringUtils.randomAlphabetic(8)
+ case 4 =>
RandomStringUtils.randomAlphabetic(options.maxStringLength)
case 5 =>
// use a constant value to trigger dictionary encoding
"dict_encode_me!"
- case _ => r.nextString(8)
+ case 6 if options.customStrings.nonEmpty =>
+ randomChoice(options.customStrings, r)
+ case _ => r.nextString(options.maxStringLength)
}
})
case DataTypes.BinaryType =>
@@ -221,6 +223,11 @@ object FuzzDataGenerator {
case _ => throw new IllegalStateException(s"Cannot generate data for
$dataType yet")
}
}
+
+ private def randomChoice[T](list: Seq[T], r: Random): T = {
+ list(r.nextInt(list.length))
+ }
+
}
object SchemaGenOptions {
@@ -250,4 +257,6 @@ case class SchemaGenOptions(
case class DataGenOptions(
allowNull: Boolean = true,
generateNegativeZero: Boolean = true,
- baseDate: Long = FuzzDataGenerator.defaultBaseDate)
+ baseDate: Long = FuzzDataGenerator.defaultBaseDate,
+ customStrings: Seq[String] = Seq.empty,
+ maxStringLength: Int = 8)
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 1eca17dcc..ddbe7d14e 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -414,41 +414,6 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
}
- test("Verify rpad expr support for second arg instead of just literal") {
- val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2))
- withParquetTable(data, "t1") {
- val res = sql("select rpad(_1,_2) , rpad(_1,2) from t1 order by _1")
- checkSparkAnswerAndOperator(res)
- }
- }
-
- test("RPAD with character support other than default space") {
- val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2))
- withParquetTable(data, "t1") {
- val res = sql(
- """ select rpad(_1,_2,'?'), rpad(_1,_2,'??') , rpad(_1,2, '??'),
hex(rpad(unhex('aabb'), 5)),
- rpad(_1, 5, '??') from t1 order by _1 """.stripMargin)
- checkSparkAnswerAndOperator(res)
- }
- }
-
- test("test lpad expression support") {
- val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2))
- withParquetTable(data, "t1") {
- val res = sql("select lpad(_1,_2) , lpad(_1,2) from t1 order by _1")
- checkSparkAnswerAndOperator(res)
- }
- }
-
- test("LPAD with character support other than default space") {
- val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2))
- withParquetTable(data, "t1") {
- val res = sql(
- """ select lpad(_1,_2,'?'), lpad(_1,_2,'??') , lpad(_1,2, '??'),
hex(lpad(unhex('aabb'), 5)),
- rpad(_1, 5, '??') from t1 order by _1 """.stripMargin)
- checkSparkAnswerAndOperator(res)
- }
- }
test("dictionary arithmetic") {
// TODO: test ANSI mode
@@ -2292,33 +2257,6 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
- test("rpad") {
- val table = "rpad"
- val gen = new DataGenerator(new Random(42))
- withTable(table) {
- // generate some data
- val dataChars = "abc123"
- sql(s"create table $table(id int, name1 char(8), name2 varchar(8)) using
parquet")
- val testData = gen.generateStrings(100, dataChars, 6) ++ Seq(
- "é", // unicode 'e\\u{301}'
- "é" // unicode '\\u{e9}'
- )
- testData.zipWithIndex.foreach { x =>
- sql(s"insert into $table values(${x._2}, '${x._1}', '${x._1}')")
- }
- // test 2-arg version
- checkSparkAnswerAndOperator(
- s"SELECT id, rpad(name1, 10), rpad(name2, 10) FROM $table ORDER BY id")
- // test 3-arg version
- for (length <- Seq(2, 10)) {
- checkSparkAnswerAndOperator(
- s"SELECT id, name1, rpad(name1, $length, ' ') FROM $table ORDER BY
id")
- checkSparkAnswerAndOperator(
- s"SELECT id, name2, rpad(name2, $length, ' ') FROM $table ORDER BY
id")
- }
- }
- }
-
test("isnan") {
Seq("true", "false").foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary) {
diff --git
a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala
index 44d40cf1c..a63aba8da 100644
--- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala
@@ -19,12 +19,133 @@
package org.apache.comet
+import scala.util.Random
+
import org.apache.parquet.hadoop.ParquetOutputFormat
import org.apache.spark.sql.{CometTestBase, DataFrame}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
+
+import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}
class CometStringExpressionSuite extends CometTestBase {
+ test("lpad string") {
+ testStringPadding("lpad")
+ }
+
+ test("rpad string") {
+ testStringPadding("rpad")
+ }
+
+ test("lpad binary") {
+ testBinaryPadding("lpad")
+ }
+
+ test("rpad binary") {
+ testBinaryPadding("rpad")
+ }
+
+ private def testStringPadding(expr: String): Unit = {
+ val r = new Random(42)
+ val schema = StructType(
+ Seq(
+ StructField("str", DataTypes.StringType, nullable = true),
+ StructField("len", DataTypes.IntegerType, nullable = true),
+ StructField("pad", DataTypes.StringType, nullable = true)))
+ // scalastyle:off
+ val edgeCases = Seq(
+ "é", // unicode 'e\\u{301}'
+ "é", // unicode '\\u{e9}'
+ "తెలుగు")
+ // scalastyle:on
+ val df = FuzzDataGenerator.generateDataFrame(
+ r,
+ spark,
+ schema,
+ 1000,
+ DataGenOptions(maxStringLength = 6, customStrings = edgeCases))
+ df.createOrReplaceTempView("t1")
+
+ // test all combinations of scalar and array arguments
+ for (str <- Seq("'hello'", "str")) {
+ for (len <- Seq("6", "-6", "0", "len % 10")) {
+ for (pad <- Seq(Some("'x'"), Some("'zzz'"), Some("pad"), None)) {
+ val sql = pad match {
+ case Some(p) =>
+ // 3 args
+ s"SELECT $str, $len, $expr($str, $len, $p) FROM t1 ORDER BY str,
len, pad"
+ case _ =>
+ // 2 args (default pad of ' ')
+ s"SELECT $str, $len, $expr($str, $len) FROM t1 ORDER BY str,
len, pad"
+ }
+ val isLiteralStr = str == "'hello'"
+ val isLiteralLen = !len.contains("len")
+ val isLiteralPad = !pad.contains("pad")
+ if (isLiteralStr && isLiteralLen && isLiteralPad) {
+ // all arguments are literal, so Spark constant folding will kick
in
+ // and pad function will not be evaluated by Comet
+ checkSparkAnswer(sql)
+ } else if (isLiteralStr) {
+ checkSparkAnswerAndFallbackReason(
+ sql,
+ "Scalar values are not supported for the str argument")
+ } else if (!isLiteralPad) {
+ checkSparkAnswerAndFallbackReason(
+ sql,
+ "Only scalar values are supported for the pad argument")
+ } else {
+ checkSparkAnswerAndOperator(sql)
+ }
+ }
+ }
+ }
+ }
+
+ private def testBinaryPadding(expr: String): Unit = {
+ val r = new Random(42)
+ val schema = StructType(
+ Seq(
+ StructField("str", DataTypes.BinaryType, nullable = true),
+ StructField("len", DataTypes.IntegerType, nullable = true),
+ StructField("pad", DataTypes.BinaryType, nullable = true)))
+ val df = FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000,
DataGenOptions())
+ df.createOrReplaceTempView("t1")
+
+ // test all combinations of scalar and array arguments
+ for (str <- Seq("unhex('DDEEFF')", "str")) {
+ // Spark does not support negative length for lpad/rpad with binary
input and Comet does
+ // not support abs yet, so use `10 + len % 10` to avoid negative length
+ for (len <- Seq("6", "0", "10 + len % 10")) {
+ for (pad <- Seq(Some("unhex('CAFE')"), Some("pad"), None)) {
+
+ val sql = pad match {
+ case Some(p) =>
+ // 3 args
+ s"SELECT $str, $len, $expr($str, $len, $p) FROM t1 ORDER BY str,
len, pad"
+ case _ =>
+ // 2 args (default pad of ' ')
+ s"SELECT $str, $len, $expr($str, $len) FROM t1 ORDER BY str,
len, pad"
+ }
+
+ val isLiteralStr = str != "str"
+ val isLiteralLen = !len.contains("len")
+ val isLiteralPad = !pad.contains("pad")
+
+ if (isLiteralStr && isLiteralLen && isLiteralPad) {
+ // all arguments are literal, so Spark constant folding will kick
in
+ // and pad function will not be evaluated by Comet
+ checkSparkAnswer(sql)
+ } else {
+ // Comet will fall back to Spark because the plan contains a
staticinvoke instruction
+ // which is not supported
+ checkSparkAnswerAndFallbackReason(sql, "staticinvoke is not
supported")
+ }
+ }
+ }
+ }
+ }
+
test("Various String scalar functions") {
val table = "names"
withTable(table) {
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 3a4e52b4a..2308858f6 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -166,6 +166,13 @@ abstract class CometTestBase
(sparkPlan, dfComet.queryExecution.executedPlan)
}
+ /** Check for the correct results as well as the expected fallback reason */
+ def checkSparkAnswerAndFallbackReason(sql: String, fallbackReason: String):
Unit = {
+ val (_, cometPlan) = checkSparkAnswer(sql)
+ val explain = new
ExtendedExplainInfo().generateVerboseExtendedInfo(cometPlan)
+ assert(explain.contains(fallbackReason))
+ }
+
protected def checkSparkAnswerAndOperator(query: String, excludedClasses:
Class[_]*): Unit = {
checkSparkAnswerAndOperator(sql(query), excludedClasses: _*)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]