This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 00922cfc4 fix: Fallback to Spark for lpad/rpad for unsupported 
arguments & fix negative length handling (#2630)
00922cfc4 is described below

commit 00922cfc464c934a583548727196144c695e7698
Author: Andy Grove <[email protected]>
AuthorDate: Fri Oct 24 04:18:06 2025 -0600

    fix: Fallback to Spark for lpad/rpad for unsupported arguments & fix 
negative length handling (#2630)
---
 .../char_varchar_utils/read_side_padding.rs        |  21 ++--
 .../scala/org/apache/comet/serde/strings.scala     |  35 +++---
 .../apache/comet/testing/FuzzDataGenerator.scala   |  15 ++-
 .../org/apache/comet/CometExpressionSuite.scala    |  62 -----------
 .../apache/comet/CometStringExpressionSuite.scala  | 121 +++++++++++++++++++++
 .../scala/org/apache/spark/sql/CometTestBase.scala |   7 ++
 6 files changed, 174 insertions(+), 87 deletions(-)

diff --git 
a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs 
b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
index d969b6279..89485ddec 100644
--- 
a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
+++ 
b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs
@@ -204,14 +204,21 @@ fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
             );
 
             for (string, length) in string_array.iter().zip(int_pad_array) {
+                let length = length.unwrap();
                 match string {
-                    Some(string) => builder.append_value(add_padding_string(
-                        string.parse().unwrap(),
-                        length.unwrap() as usize,
-                        truncate,
-                        pad_string,
-                        is_left_pad,
-                    )?),
+                    Some(string) => {
+                        if length >= 0 {
+                            builder.append_value(add_padding_string(
+                                string.parse().unwrap(),
+                                length as usize,
+                                truncate,
+                                pad_string,
+                                is_left_pad,
+                            )?)
+                        } else {
+                            builder.append_value("");
+                        }
+                    }
                     _ => builder.append_null(),
                 }
             }
diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala 
b/spark/src/main/scala/org/apache/comet/serde/strings.scala
index c6f5a8508..3d4bacfa2 100644
--- a/spark/src/main/scala/org/apache/comet/serde/strings.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala
@@ -162,6 +162,16 @@ object CometRLike extends CometExpressionSerde[RLike] {
 
 object CometStringRPad extends CometExpressionSerde[StringRPad] {
 
+  override def getSupportLevel(expr: StringRPad): SupportLevel = {
+    if (expr.str.isInstanceOf[Literal]) {
+      return Unsupported(Some("Scalar values are not supported for the str 
argument"))
+    }
+    if (!expr.pad.isInstanceOf[Literal]) {
+      return Unsupported(Some("Only scalar values are supported for the pad 
argument"))
+    }
+    Compatible()
+  }
+
   override def convert(
       expr: StringRPad,
       inputs: Seq[Attribute],
@@ -177,21 +187,16 @@ object CometStringRPad extends 
CometExpressionSerde[StringRPad] {
 
 object CometStringLPad extends CometExpressionSerde[StringLPad] {
 
-  /**
-   * Convert a Spark expression into a protocol buffer representation that can 
be passed into
-   * native code.
-   *
-   * @param expr
-   *   The Spark expression.
-   * @param inputs
-   *   The input attributes.
-   * @param binding
-   *   Whether the attributes are bound (this is only relevant in aggregate 
expressions).
-   * @return
-   *   Protocol buffer representation, or None if the expression could not be 
converted. In this
-   *   case it is expected that the input expression will have been tagged 
with reasons why it
-   *   could not be converted.
-   */
+  override def getSupportLevel(expr: StringLPad): SupportLevel = {
+    if (expr.str.isInstanceOf[Literal]) {
+      return Unsupported(Some("Scalar values are not supported for the str 
argument"))
+    }
+    if (!expr.pad.isInstanceOf[Literal]) {
+      return Unsupported(Some("Only scalar values are supported for the pad 
argument"))
+    }
+    Compatible()
+  }
+
   override def convert(
       expr: StringLPad,
       inputs: Seq[Attribute],
diff --git 
a/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala 
b/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala
index 188da1d79..5363fda15 100644
--- a/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala
+++ b/spark/src/main/scala/org/apache/comet/testing/FuzzDataGenerator.scala
@@ -194,11 +194,13 @@ object FuzzDataGenerator {
             case 1 => r.nextInt().toByte.toString
             case 2 => r.nextLong().toString
             case 3 => r.nextDouble().toString
-            case 4 => RandomStringUtils.randomAlphabetic(8)
+            case 4 => 
RandomStringUtils.randomAlphabetic(options.maxStringLength)
             case 5 =>
               // use a constant value to trigger dictionary encoding
               "dict_encode_me!"
-            case _ => r.nextString(8)
+            case 6 if options.customStrings.nonEmpty =>
+              randomChoice(options.customStrings, r)
+            case _ => r.nextString(options.maxStringLength)
           }
         })
       case DataTypes.BinaryType =>
@@ -221,6 +223,11 @@ object FuzzDataGenerator {
       case _ => throw new IllegalStateException(s"Cannot generate data for 
$dataType yet")
     }
   }
+
+  private def randomChoice[T](list: Seq[T], r: Random): T = {
+    list(r.nextInt(list.length))
+  }
+
 }
 
 object SchemaGenOptions {
@@ -250,4 +257,6 @@ case class SchemaGenOptions(
 case class DataGenOptions(
     allowNull: Boolean = true,
     generateNegativeZero: Boolean = true,
-    baseDate: Long = FuzzDataGenerator.defaultBaseDate)
+    baseDate: Long = FuzzDataGenerator.defaultBaseDate,
+    customStrings: Seq[String] = Seq.empty,
+    maxStringLength: Int = 8)
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 1eca17dcc..ddbe7d14e 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -414,41 +414,6 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       }
     }
   }
-  test("Verify rpad expr support for second arg instead of just literal") {
-    val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2))
-    withParquetTable(data, "t1") {
-      val res = sql("select rpad(_1,_2) , rpad(_1,2) from t1 order by _1")
-      checkSparkAnswerAndOperator(res)
-    }
-  }
-
-  test("RPAD with character support other than default space") {
-    val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2))
-    withParquetTable(data, "t1") {
-      val res = sql(
-        """ select rpad(_1,_2,'?'), rpad(_1,_2,'??') , rpad(_1,2, '??'), 
hex(rpad(unhex('aabb'), 5)),
-          rpad(_1, 5, '??') from t1 order by _1 """.stripMargin)
-      checkSparkAnswerAndOperator(res)
-    }
-  }
-
-  test("test lpad expression support") {
-    val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("తెలుగు", 2))
-    withParquetTable(data, "t1") {
-      val res = sql("select lpad(_1,_2) , lpad(_1,2) from t1 order by _1")
-      checkSparkAnswerAndOperator(res)
-    }
-  }
-
-  test("LPAD with character support other than default space") {
-    val data = Seq(("IfIWasARoadIWouldBeBent", 10), ("hi", 2))
-    withParquetTable(data, "t1") {
-      val res = sql(
-        """ select lpad(_1,_2,'?'), lpad(_1,_2,'??') , lpad(_1,2, '??'), 
hex(lpad(unhex('aabb'), 5)),
-          rpad(_1, 5, '??') from t1 order by _1 """.stripMargin)
-      checkSparkAnswerAndOperator(res)
-    }
-  }
 
   test("dictionary arithmetic") {
     // TODO: test ANSI mode
@@ -2292,33 +2257,6 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
-  test("rpad") {
-    val table = "rpad"
-    val gen = new DataGenerator(new Random(42))
-    withTable(table) {
-      // generate some data
-      val dataChars = "abc123"
-      sql(s"create table $table(id int, name1 char(8), name2 varchar(8)) using 
parquet")
-      val testData = gen.generateStrings(100, dataChars, 6) ++ Seq(
-        "é", // unicode 'e\\u{301}'
-        "é" // unicode '\\u{e9}'
-      )
-      testData.zipWithIndex.foreach { x =>
-        sql(s"insert into $table values(${x._2}, '${x._1}', '${x._1}')")
-      }
-      // test 2-arg version
-      checkSparkAnswerAndOperator(
-        s"SELECT id, rpad(name1, 10), rpad(name2, 10) FROM $table ORDER BY id")
-      // test 3-arg version
-      for (length <- Seq(2, 10)) {
-        checkSparkAnswerAndOperator(
-          s"SELECT id, name1, rpad(name1, $length, ' ') FROM $table ORDER BY 
id")
-        checkSparkAnswerAndOperator(
-          s"SELECT id, name2, rpad(name2, $length, ' ') FROM $table ORDER BY 
id")
-      }
-    }
-  }
-
   test("isnan") {
     Seq("true", "false").foreach { dictionary =>
       withSQLConf("parquet.enable.dictionary" -> dictionary) {
diff --git 
a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala
index 44d40cf1c..a63aba8da 100644
--- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala
@@ -19,12 +19,133 @@
 
 package org.apache.comet
 
+import scala.util.Random
+
 import org.apache.parquet.hadoop.ParquetOutputFormat
 import org.apache.spark.sql.{CometTestBase, DataFrame}
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
+
+import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}
 
 class CometStringExpressionSuite extends CometTestBase {
 
+  test("lpad string") {
+    testStringPadding("lpad")
+  }
+
+  test("rpad string") {
+    testStringPadding("rpad")
+  }
+
+  test("lpad binary") {
+    testBinaryPadding("lpad")
+  }
+
+  test("rpad binary") {
+    testBinaryPadding("rpad")
+  }
+
+  private def testStringPadding(expr: String): Unit = {
+    val r = new Random(42)
+    val schema = StructType(
+      Seq(
+        StructField("str", DataTypes.StringType, nullable = true),
+        StructField("len", DataTypes.IntegerType, nullable = true),
+        StructField("pad", DataTypes.StringType, nullable = true)))
+    // scalastyle:off
+    val edgeCases = Seq(
+      "é", // unicode 'e\\u{301}'
+      "é", // unicode '\\u{e9}'
+      "తెలుగు")
+    // scalastyle:on
+    val df = FuzzDataGenerator.generateDataFrame(
+      r,
+      spark,
+      schema,
+      1000,
+      DataGenOptions(maxStringLength = 6, customStrings = edgeCases))
+    df.createOrReplaceTempView("t1")
+
+    // test all combinations of scalar and array arguments
+    for (str <- Seq("'hello'", "str")) {
+      for (len <- Seq("6", "-6", "0", "len % 10")) {
+        for (pad <- Seq(Some("'x'"), Some("'zzz'"), Some("pad"), None)) {
+          val sql = pad match {
+            case Some(p) =>
+              // 3 args
+              s"SELECT $str, $len, $expr($str, $len, $p) FROM t1 ORDER BY str, 
len, pad"
+            case _ =>
+              // 2 args (default pad of ' ')
+              s"SELECT $str, $len, $expr($str, $len) FROM t1 ORDER BY str, 
len, pad"
+          }
+          val isLiteralStr = str == "'hello'"
+          val isLiteralLen = !len.contains("len")
+          val isLiteralPad = !pad.contains("pad")
+          if (isLiteralStr && isLiteralLen && isLiteralPad) {
+            // all arguments are literal, so Spark constant folding will kick 
in
+            // and pad function will not be evaluated by Comet
+            checkSparkAnswer(sql)
+          } else if (isLiteralStr) {
+            checkSparkAnswerAndFallbackReason(
+              sql,
+              "Scalar values are not supported for the str argument")
+          } else if (!isLiteralPad) {
+            checkSparkAnswerAndFallbackReason(
+              sql,
+              "Only scalar values are supported for the pad argument")
+          } else {
+            checkSparkAnswerAndOperator(sql)
+          }
+        }
+      }
+    }
+  }
+
+  private def testBinaryPadding(expr: String): Unit = {
+    val r = new Random(42)
+    val schema = StructType(
+      Seq(
+        StructField("str", DataTypes.BinaryType, nullable = true),
+        StructField("len", DataTypes.IntegerType, nullable = true),
+        StructField("pad", DataTypes.BinaryType, nullable = true)))
+    val df = FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, 
DataGenOptions())
+    df.createOrReplaceTempView("t1")
+
+    // test all combinations of scalar and array arguments
+    for (str <- Seq("unhex('DDEEFF')", "str")) {
+      // Spark does not support negative length for lpad/rpad with binary 
input and Comet does
+      // not support abs yet, so use `10 + len % 10` to avoid negative length
+      for (len <- Seq("6", "0", "10 + len % 10")) {
+        for (pad <- Seq(Some("unhex('CAFE')"), Some("pad"), None)) {
+
+          val sql = pad match {
+            case Some(p) =>
+              // 3 args
+              s"SELECT $str, $len, $expr($str, $len, $p) FROM t1 ORDER BY str, 
len, pad"
+            case _ =>
+              // 2 args (default pad of ' ')
+              s"SELECT $str, $len, $expr($str, $len) FROM t1 ORDER BY str, 
len, pad"
+          }
+
+          val isLiteralStr = str != "str"
+          val isLiteralLen = !len.contains("len")
+          val isLiteralPad = !pad.contains("pad")
+
+          if (isLiteralStr && isLiteralLen && isLiteralPad) {
+            // all arguments are literal, so Spark constant folding will kick 
in
+            // and pad function will not be evaluated by Comet
+            checkSparkAnswer(sql)
+          } else {
+            // Comet will fall back to Spark because the plan contains a 
staticinvoke instruction
+            // which is not supported
+            checkSparkAnswerAndFallbackReason(sql, "staticinvoke is not 
supported")
+          }
+        }
+      }
+    }
+  }
+
   test("Various String scalar functions") {
     val table = "names"
     withTable(table) {
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 3a4e52b4a..2308858f6 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -166,6 +166,13 @@ abstract class CometTestBase
     (sparkPlan, dfComet.queryExecution.executedPlan)
   }
 
+  /** Check for the correct results as well as the expected fallback reason */
+  def checkSparkAnswerAndFallbackReason(sql: String, fallbackReason: String): 
Unit = {
+    val (_, cometPlan) = checkSparkAnswer(sql)
+    val explain = new 
ExtendedExplainInfo().generateVerboseExtendedInfo(cometPlan)
+    assert(explain.contains(fallbackReason))
+  }
+
   protected def checkSparkAnswerAndOperator(query: String, excludedClasses: 
Class[_]*): Unit = {
     checkSparkAnswerAndOperator(sql(query), excludedClasses: _*)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to