This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 78ea0bf0b chore: Refactor string expression serde, part 1 (#2068)
78ea0bf0b is described below

commit 78ea0bf0bca01bd915171feec8c605e6c0f9c47d
Author: Andy Grove <agr...@apache.org>
AuthorDate: Thu Aug 7 11:37:33 2025 -0600

    chore: Refactor string expression serde, part 1 (#2068)
---
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  62 ++-----
 .../scala/org/apache/comet/serde/strings.scala     | 178 ++++++++++++++++++++-
 2 files changed, 185 insertions(+), 55 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 4a62cf8b6..623126276 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -114,7 +114,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
     classOf[StringRepeat] -> CometStringRepeat,
     classOf[StringReplace] -> CometStringReplace,
     classOf[StringTranslate] -> CometStringTranslate,
-    classOf[StringTrim] -> CometTrim,
+    classOf[StringTrim] -> CometStringTrim,
     classOf[StringTrimLeft] -> CometStringTrimLeft,
     classOf[StringTrimRight] -> CometStringTrimRight,
     classOf[StringTrimBoth] -> CometStringTrimBoth,
@@ -141,7 +141,16 @@ object QueryPlanSerde extends Logging with CometExprShim {
     classOf[Randn] -> CometRandn,
     classOf[SparkPartitionID] -> CometSparkPartitionId,
     classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId,
-    classOf[StringSpace] -> UnaryScalarFuncSerde("string_space"))
+    classOf[StringSpace] -> CometStringSpace,
+    classOf[StartsWith] -> CometStartsWith,
+    classOf[EndsWith] -> CometEndsWith,
+    classOf[Contains] -> CometContains,
+    classOf[Substring] -> CometSubstring,
+    classOf[Like] -> CometLike,
+    classOf[RLike] -> CometRLike,
+    classOf[OctetLength] -> CometOctetLength,
+    classOf[Reverse] -> CometReverse,
+    classOf[StringRPad] -> CometStringRPad)
 
   /**
    * Mapping of Spark aggregate expression class to Comet expression handler.
@@ -747,25 +756,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
         withInfo(expr, s"Unsupported datatype $dataType")
         None
 
-      case Substring(str, Literal(pos, _), Literal(len, _)) =>
-        val strExpr = exprToProtoInternal(str, inputs, binding)
-
-        if (strExpr.isDefined) {
-          val builder = ExprOuterClass.Substring.newBuilder()
-          builder.setChild(strExpr.get)
-          builder.setStart(pos.asInstanceOf[Int])
-          builder.setLen(len.asInstanceOf[Int])
-
-          Some(
-            ExprOuterClass.Expr
-              .newBuilder()
-              .setSubstring(builder)
-              .build())
-        } else {
-          withInfo(expr, str)
-          None
-        }
-
       // ToPrettyString is new in Spark 3.5
       case _
           if expr.getClass.getSimpleName == "ToPrettyString" && expr
@@ -1388,18 +1378,6 @@ object QueryPlanSerde extends Logging with CometExprShim 
{
           None
         }
 
-      case OctetLength(child) =>
-        val castExpr = Cast(child, StringType)
-        val childExpr = exprToProtoInternal(castExpr, inputs, binding)
-        val optExpr = scalarFunctionExprToProto("octet_length", childExpr)
-        optExprWithInfo(optExpr, expr, castExpr)
-
-      case Reverse(child) =>
-        val castExpr = Cast(child, StringType)
-        val childExpr = exprToProtoInternal(castExpr, inputs, binding)
-        val optExpr = scalarFunctionExprToProto("reverse", childExpr)
-        optExprWithInfo(optExpr, expr, castExpr)
-
       case BitwiseAnd(left, right) =>
         createBinaryExpr(
           expr,
@@ -1464,24 +1442,6 @@ object QueryPlanSerde extends Logging with CometExprShim 
{
           None
         }
 
-      // read-side padding in Spark 3.5.2+ is represented by rpad function
-      case StringRPad(srcStr, size, chars) =>
-        chars match {
-          case Literal(str, DataTypes.StringType) if str.toString == " " =>
-            val arg0 = exprToProtoInternal(srcStr, inputs, binding)
-            val arg1 = exprToProtoInternal(size, inputs, binding)
-            if (arg0.isDefined && arg1.isDefined) {
-              scalarFunctionExprToProto("rpad", arg0, arg1)
-            } else {
-              withInfo(expr, "rpad unsupported arguments", srcStr, size)
-              None
-            }
-
-          case _ =>
-            withInfo(expr, "rpad only supports padding with spaces")
-            None
-        }
-
       case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) =>
         val dataType = serializeDataType(expr.dataType)
         if (dataType.isEmpty) {
diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala 
b/spark/src/main/scala/org/apache/comet/serde/strings.scala
index 991c55b42..de54a074c 100644
--- a/spark/src/main/scala/org/apache/comet/serde/strings.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala
@@ -21,13 +21,13 @@ package org.apache.comet.serde
 
 import scala.util.Try
 
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression}
-import org.apache.spark.sql.types.{LongType, StringType}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Contains, 
EndsWith, Expression, Like, Literal, OctetLength, Reverse, RLike, StartsWith, 
StringRPad, Substring}
+import org.apache.spark.sql.types.{DataTypes, LongType, StringType}
 
 import org.apache.comet.CometConf
 import org.apache.comet.CometSparkSessionExtensions.withInfo
 import org.apache.comet.serde.ExprOuterClass.Expr
-import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, 
optExprWithInfo, scalarFunctionExprToProto}
+import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, 
exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto}
 
 object CometAscii extends CometExpressionSerde {
 
@@ -125,7 +125,7 @@ object CometStringTranslate extends CometExpressionSerde {
   }
 }
 
-object CometTrim extends CometExpressionSerde {
+object CometStringTrim extends CometExpressionSerde {
 
   override def convert(
       expr: Expression,
@@ -308,3 +308,173 @@ object CometConcatWs extends CometExpressionSerde {
     optExprWithInfo(optExpr, expr, childExprs: _*)
   }
 }
+
+object CometStringSpace extends UnaryScalarFuncSerde("string_space")
+
+object CometStartsWith extends CometExpressionSerde {
+
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[Expr] = {
+    val startsWith = expr.asInstanceOf[StartsWith]
+    scalarFunctionExprToProto(
+      "starts_with",
+      exprToProtoInternal(startsWith.left, inputs, binding),
+      exprToProtoInternal(startsWith.right, inputs, binding))
+  }
+}
+
+object CometEndsWith extends CometExpressionSerde {
+
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[Expr] = {
+    val endsWith = expr.asInstanceOf[EndsWith]
+    scalarFunctionExprToProto(
+      "ends_with",
+      exprToProtoInternal(endsWith.left, inputs, binding),
+      exprToProtoInternal(endsWith.right, inputs, binding))
+  }
+}
+
+object CometContains extends CometExpressionSerde {
+
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[Expr] = {
+    val contains = expr.asInstanceOf[Contains]
+    scalarFunctionExprToProto(
+      "contains",
+      exprToProtoInternal(contains.left, inputs, binding),
+      exprToProtoInternal(contains.right, inputs, binding))
+  }
+}
+
+object CometSubstring extends CometExpressionSerde {
+
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[Expr] = {
+    val substring = expr.asInstanceOf[Substring]
+    (substring.pos, substring.len) match {
+      case (Literal(pos, _), Literal(len, _)) =>
+        exprToProtoInternal(substring.str, inputs, binding) match {
+          case Some(strExpr) =>
+            val builder = ExprOuterClass.Substring.newBuilder()
+            builder.setChild(strExpr)
+            builder.setStart(pos.asInstanceOf[Int])
+            builder.setLen(len.asInstanceOf[Int])
+            
Some(ExprOuterClass.Expr.newBuilder().setSubstring(builder).build())
+          case None =>
+            withInfo(expr, substring.str)
+            None
+        }
+      case _ =>
+        withInfo(expr, "Substring pos and len must be literals")
+        None
+    }
+  }
+}
+
+object CometLike extends CometExpressionSerde {
+
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[Expr] = {
+    val like = expr.asInstanceOf[Like]
+    if (like.escapeChar == '\\') {
+      createBinaryExpr(
+        expr,
+        like.left,
+        like.right,
+        inputs,
+        binding,
+        (builder, binaryExpr) => builder.setLike(binaryExpr))
+    } else {
+      withInfo(expr, s"custom escape character ${like.escapeChar} not 
supported in LIKE")
+      None
+    }
+  }
+}
+
+object CometRLike extends CometExpressionSerde {
+
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[Expr] = {
+    val rlike = expr.asInstanceOf[RLike]
+    rlike.right match {
+      case Literal(pattern, DataTypes.StringType) =>
+        val regex = pattern.toString
+        if (regex.contains("(?i)") || regex.contains("(?-i)")) {
+          withInfo(expr, "Regex flag (?i) and (?-i) are not supported")
+          None
+        } else {
+          createBinaryExpr(
+            expr,
+            rlike.left,
+            rlike.right,
+            inputs,
+            binding,
+            (builder, binaryExpr) => builder.setRlike(binaryExpr))
+        }
+      case _ =>
+        withInfo(expr, "Only scalar regexp patterns are supported")
+        None
+    }
+  }
+}
+
+object CometOctetLength extends CometExpressionSerde {
+
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[Expr] = {
+    val castExpr = Cast(expr.asInstanceOf[OctetLength].child, StringType)
+    optExprWithInfo(
+      scalarFunctionExprToProto("octet_length", exprToProtoInternal(castExpr, 
inputs, binding)),
+      expr,
+      castExpr)
+  }
+}
+
+object CometReverse extends CometExpressionSerde {
+
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[Expr] = {
+    val castExpr = Cast(expr.asInstanceOf[Reverse].child, StringType)
+    optExprWithInfo(
+      scalarFunctionExprToProto("reverse", exprToProtoInternal(castExpr, 
inputs, binding)),
+      expr,
+      castExpr)
+  }
+}
+
+object CometStringRPad extends CometExpressionSerde {
+
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[Expr] = {
+    val stringRPad = expr.asInstanceOf[StringRPad]
+    stringRPad.pad match {
+      case Literal(str, DataTypes.StringType) if str.toString == " " =>
+        scalarFunctionExprToProto(
+          "rpad",
+          exprToProtoInternal(stringRPad.str, inputs, binding),
+          exprToProtoInternal(stringRPad.len, inputs, binding))
+      case _ =>
+        withInfo(expr, "StringRPad with non-space characters is not supported")
+        None
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to