This is an automated email from the ASF dual-hosted git repository.

mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 7b7ba1942 chore: Refactor string expression serde, part 2 (#2097)
7b7ba1942 is described below

commit 7b7ba19422d4dd5e07cdf3d6da21946322586fff
Author: Andy Grove <agr...@apache.org>
AuthorDate: Fri Aug 8 07:36:50 2025 -0600

    chore: Refactor string expression serde, part 2 (#2097)
---
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  69 +++--
 .../scala/org/apache/comet/serde/strings.scala     | 320 +--------------------
 2 files changed, 45 insertions(+), 344 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 623126276..3391a10c9 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -85,39 +85,39 @@ object QueryPlanSerde extends Logging with CometExprShim {
     classOf[ArraysOverlap] -> CometArraysOverlap,
     classOf[ArrayUnion] -> CometArrayUnion,
     classOf[CreateArray] -> CometCreateArray,
-    classOf[Ascii] -> CometAscii,
-    classOf[ConcatWs] -> CometConcatWs,
-    classOf[Chr] -> CometChr,
+    classOf[Ascii] -> CometScalarFunction("ascii"),
+    classOf[ConcatWs] -> CometScalarFunction("concat_ws"),
+    classOf[Chr] -> CometScalarFunction("char"),
     classOf[InitCap] -> CometInitCap,
     classOf[BitwiseCount] -> CometBitwiseCount,
     classOf[BitwiseGet] -> CometBitwiseGet,
     classOf[BitwiseNot] -> CometBitwiseNot,
     classOf[BitwiseOr] -> CometBitwiseOr,
     classOf[BitwiseXor] -> CometBitwiseXor,
-    classOf[BitLength] -> CometBitLength,
+    classOf[BitLength] -> CometScalarFunction("bit_length"),
     classOf[FromUnixTime] -> CometFromUnixTime,
-    classOf[Length] -> CometLength,
-    classOf[Acos] -> UnaryScalarFuncSerde("acos"),
-    classOf[Cos] -> UnaryScalarFuncSerde("cos"),
-    classOf[Asin] -> UnaryScalarFuncSerde("asin"),
-    classOf[Sin] -> UnaryScalarFuncSerde("sin"),
-    classOf[Atan] -> UnaryScalarFuncSerde("atan"),
-    classOf[Tan] -> UnaryScalarFuncSerde("tan"),
-    classOf[Exp] -> UnaryScalarFuncSerde("exp"),
-    classOf[Expm1] -> UnaryScalarFuncSerde("expm1"),
-    classOf[Sqrt] -> UnaryScalarFuncSerde("sqrt"),
-    classOf[Signum] -> UnaryScalarFuncSerde("signum"),
-    classOf[Md5] -> UnaryScalarFuncSerde("md5"),
+    classOf[Length] -> CometScalarFunction("length"),
+    classOf[Acos] -> CometScalarFunction("acos"),
+    classOf[Cos] -> CometScalarFunction("cos"),
+    classOf[Asin] -> CometScalarFunction("asin"),
+    classOf[Sin] -> CometScalarFunction("sin"),
+    classOf[Atan] -> CometScalarFunction("atan"),
+    classOf[Tan] -> CometScalarFunction("tan"),
+    classOf[Exp] -> CometScalarFunction("exp"),
+    classOf[Expm1] -> CometScalarFunction("expm1"),
+    classOf[Sqrt] -> CometScalarFunction("sqrt"),
+    classOf[Signum] -> CometScalarFunction("signum"),
+    classOf[Md5] -> CometScalarFunction("md5"),
     classOf[ShiftLeft] -> CometShiftLeft,
     classOf[ShiftRight] -> CometShiftRight,
-    classOf[StringInstr] -> CometStringInstr,
+    classOf[StringInstr] -> CometScalarFunction("instr"),
     classOf[StringRepeat] -> CometStringRepeat,
-    classOf[StringReplace] -> CometStringReplace,
-    classOf[StringTranslate] -> CometStringTranslate,
-    classOf[StringTrim] -> CometStringTrim,
-    classOf[StringTrimLeft] -> CometStringTrimLeft,
-    classOf[StringTrimRight] -> CometStringTrimRight,
-    classOf[StringTrimBoth] -> CometStringTrimBoth,
+    classOf[StringReplace] -> CometScalarFunction("replace"),
+    classOf[StringTranslate] -> CometScalarFunction("translate"),
+    classOf[StringTrim] -> CometScalarFunction("trim"),
+    classOf[StringTrimLeft] -> CometScalarFunction("ltrim"),
+    classOf[StringTrimRight] -> CometScalarFunction("rtrim"),
+    classOf[StringTrimBoth] -> CometScalarFunction("btrim"),
     classOf[Upper] -> CometUpper,
     classOf[Lower] -> CometLower,
     classOf[Murmur3Hash] -> CometMurmur3Hash,
@@ -141,15 +141,15 @@ object QueryPlanSerde extends Logging with CometExprShim {
     classOf[Randn] -> CometRandn,
     classOf[SparkPartitionID] -> CometSparkPartitionId,
     classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId,
-    classOf[StringSpace] -> CometStringSpace,
-    classOf[StartsWith] -> CometStartsWith,
-    classOf[EndsWith] -> CometEndsWith,
-    classOf[Contains] -> CometContains,
+    classOf[StringSpace] -> CometScalarFunction("string_space"),
+    classOf[StartsWith] -> CometScalarFunction("starts_with"),
+    classOf[EndsWith] -> CometScalarFunction("ends_with"),
+    classOf[Contains] -> CometScalarFunction("contains"),
     classOf[Substring] -> CometSubstring,
     classOf[Like] -> CometLike,
     classOf[RLike] -> CometRLike,
-    classOf[OctetLength] -> CometOctetLength,
-    classOf[Reverse] -> CometReverse,
+    classOf[OctetLength] -> CometScalarFunction("octet_length"),
+    classOf[Reverse] -> CometScalarFunction("reverse"),
     classOf[StringRPad] -> CometStringRPad)
 
   /**
@@ -2576,15 +2576,14 @@ trait CometAggregateExpressionSerde {
 /** Marker trait for an expression that is not guaranteed to be 100% 
compatible with Spark */
 trait IncompatExpr {}
 
-/** Serde for single-argument scalar function. */
-case class UnaryScalarFuncSerde(name: String) extends CometExpressionSerde {
+/** Serde for scalar function. */
+case class CometScalarFunction(name: String) extends CometExpressionSerde {
   override def convert(
       expr: Expression,
       inputs: Seq[Attribute],
       binding: Boolean): Option[Expr] = {
-    val child = expr.children.head
-    val childExpr = exprToProtoInternal(child, inputs, binding)
-    val optExpr = scalarFunctionExprToProto(name, childExpr)
-    optExprWithInfo(optExpr, expr, child)
+    val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding))
+    val optExpr = scalarFunctionExprToProto(name, childExpr: _*)
+    optExprWithInfo(optExpr, expr, expr.children: _*)
   }
 }
diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala 
b/spark/src/main/scala/org/apache/comet/serde/strings.scala
index de896e0df..75e7e8bd4 100644
--- a/spark/src/main/scala/org/apache/comet/serde/strings.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala
@@ -19,9 +19,7 @@
 
 package org.apache.comet.serde
 
-import scala.util.Try
-
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Contains, 
EndsWith, Expression, Like, Literal, OctetLength, Reverse, RLike, StartsWith, 
StringRPad, Substring}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, 
Like, Literal, RLike, StringRPad, Substring}
 import org.apache.spark.sql.types.{DataTypes, LongType, StringType}
 
 import org.apache.comet.CometConf
@@ -29,50 +27,6 @@ import org.apache.comet.CometSparkSessionExtensions.withInfo
 import org.apache.comet.serde.ExprOuterClass.Expr
 import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, 
exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto}
 
-object CometAscii extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[ExprOuterClass.Expr] = {
-    val child = expr.children.head
-    val castExpr = Cast(child, StringType)
-    val childExpr = exprToProtoInternal(castExpr, inputs, binding)
-    val optExpr = scalarFunctionExprToProto("ascii", childExpr)
-    optExprWithInfo(optExpr, expr, castExpr)
-  }
-}
-
-object CometBitLength extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[ExprOuterClass.Expr] = {
-    val child = expr.children.head
-    val castExpr = Cast(child, StringType)
-    val childExpr = exprToProtoInternal(castExpr, inputs, binding)
-    val optExpr = scalarFunctionExprToProto("bit_length", childExpr)
-    optExprWithInfo(optExpr, expr, castExpr)
-  }
-}
-
-object CometStringInstr extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[ExprOuterClass.Expr] = {
-    val children = expr.children
-    val leftCast = Cast(children(0), StringType)
-    val rightCast = Cast(children(1), StringType)
-    val leftExpr = exprToProtoInternal(leftCast, inputs, binding)
-    val rightExpr = exprToProtoInternal(rightCast, inputs, binding)
-    val optExpr = scalarFunctionExprToProto("strpos", leftExpr, rightExpr)
-    optExprWithInfo(optExpr, expr, leftCast, rightCast)
-  }
-}
-
 object CometStringRepeat extends CometExpressionSerde {
 
   override def convert(
@@ -89,267 +43,43 @@ object CometStringRepeat extends CometExpressionSerde {
   }
 }
 
-object CometStringReplace extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[ExprOuterClass.Expr] = {
-    val children = expr.children
-    val srcCast = Cast(children(0), StringType)
-    val searchCast = Cast(children(1), StringType)
-    val replaceCast = Cast(children(2), StringType)
-    val srcExpr = exprToProtoInternal(srcCast, inputs, binding)
-    val searchExpr = exprToProtoInternal(searchCast, inputs, binding)
-    val replaceExpr = exprToProtoInternal(replaceCast, inputs, binding)
-    val optExpr = scalarFunctionExprToProto("replace", srcExpr, searchExpr, 
replaceExpr)
-    optExprWithInfo(optExpr, expr, srcCast, searchCast, replaceCast)
-  }
-}
-
-object CometStringTranslate extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[ExprOuterClass.Expr] = {
-    val children = expr.children
-    val srcCast = Cast(children(0), StringType)
-    val matchingCast = Cast(children(1), StringType)
-    val replaceCast = Cast(children(2), StringType)
-    val srcExpr = exprToProtoInternal(srcCast, inputs, binding)
-    val matchingExpr = exprToProtoInternal(matchingCast, inputs, binding)
-    val replaceExpr = exprToProtoInternal(replaceCast, inputs, binding)
-    val optExpr = scalarFunctionExprToProto("translate", srcExpr, 
matchingExpr, replaceExpr)
-    optExprWithInfo(optExpr, expr, srcCast, matchingCast, replaceCast)
-  }
-}
-
-object CometStringTrim extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[ExprOuterClass.Expr] = {
-    val children = expr.children
-    val srcStr = children(0)
-    val trimStr = Try(children(1)).toOption
-    CometTrimCommon.trim(expr, srcStr, trimStr, inputs, binding, "trim")
-  }
-}
-
-object CometStringTrimLeft extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[Expr] = {
-    val children = expr.children
-    val srcStr = children(0)
-    val trimStr = Try(children(1)).toOption
-    CometTrimCommon.trim(expr, srcStr, trimStr, inputs, binding, "ltrim")
-  }
-}
-
-object CometStringTrimRight extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[Expr] = {
-    val children = expr.children
-    val srcStr = children(0)
-    val trimStr = Try(children(1)).toOption
-    CometTrimCommon.trim(expr, srcStr, trimStr, inputs, binding, "rtrim")
-  }
-}
-
-object CometStringTrimBoth extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[Expr] = {
-    val children = expr.children
-    val srcStr = children(0)
-    val trimStr = Try(children(1)).toOption
-    CometTrimCommon.trim(expr, srcStr, trimStr, inputs, binding, "btrim")
-  }
-}
-
-private object CometTrimCommon {
-  def trim(
-      expr: Expression, // parent expression
-      srcStr: Expression,
-      trimStr: Option[Expression],
-      inputs: Seq[Attribute],
-      binding: Boolean,
-      trimType: String): Option[Expr] = {
-    val srcCast = Cast(srcStr, StringType)
-    val srcExpr = exprToProtoInternal(srcCast, inputs, binding)
-    if (trimStr.isDefined) {
-      val trimCast = Cast(trimStr.get, StringType)
-      val trimExpr = exprToProtoInternal(trimCast, inputs, binding)
-      val optExpr = scalarFunctionExprToProto(trimType, srcExpr, trimExpr)
-      optExprWithInfo(optExpr, expr, srcCast, trimCast)
-    } else {
-      val optExpr = scalarFunctionExprToProto(trimType, srcExpr)
-      optExprWithInfo(optExpr, expr, srcCast)
-    }
-  }
-}
-
-object CometUpper extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[Expr] = {
-    if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) {
-      val castExpr = Cast(expr.children.head, StringType)
-      val childExpr = exprToProtoInternal(castExpr, inputs, binding)
-      val optExpr = scalarFunctionExprToProto("upper", childExpr)
-      optExprWithInfo(optExpr, expr, castExpr)
-    } else {
-      withInfo(
-        expr,
-        "Comet is not compatible with Spark for case conversion in " +
-          s"locale-specific cases. Set 
${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " +
-          "to enable it anyway.")
-      None
-    }
-  }
-}
-
-object CometLower extends CometExpressionSerde {
+class CometCaseConversionBase(function: String) extends 
CometScalarFunction(function) {
 
   override def convert(
       expr: Expression,
       inputs: Seq[Attribute],
       binding: Boolean): Option[Expr] = {
-    if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) {
-      val castExpr = Cast(expr.children.head, StringType)
-      val childExpr = exprToProtoInternal(castExpr, inputs, binding)
-      val optExpr = scalarFunctionExprToProto("lower", childExpr)
-      optExprWithInfo(optExpr, expr, castExpr)
-    } else {
+    if (!CometConf.COMET_CASE_CONVERSION_ENABLED.get()) {
       withInfo(
         expr,
         "Comet is not compatible with Spark for case conversion in " +
           s"locale-specific cases. Set 
${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " +
           "to enable it anyway.")
-      None
+      return None
     }
+    super.convert(expr, inputs, binding)
   }
 }
 
-object CometLength extends CometExpressionSerde {
+object CometUpper extends CometCaseConversionBase("upper")
 
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[Expr] = {
-    val castExpr = Cast(expr.children.head, StringType)
-    val childExpr = exprToProtoInternal(castExpr, inputs, binding)
-    val optExpr = scalarFunctionExprToProto("length", childExpr)
-    optExprWithInfo(optExpr, expr, castExpr)
-  }
-}
+object CometLower extends CometCaseConversionBase("lower")
 
-object CometInitCap extends CometExpressionSerde {
+object CometInitCap extends CometScalarFunction("initcap") {
 
   override def convert(
       expr: Expression,
       inputs: Seq[Attribute],
       binding: Boolean): Option[Expr] = {
-    if (CometConf.COMET_EXEC_INITCAP_ENABLED.get()) {
-      val castExpr = Cast(expr.children.head, StringType)
-      val childExpr = exprToProtoInternal(castExpr, inputs, binding)
-      val optExpr = scalarFunctionExprToProto("initcap", childExpr)
-      optExprWithInfo(optExpr, expr, castExpr)
-    } else {
+    if (!CometConf.COMET_EXEC_INITCAP_ENABLED.get()) {
       withInfo(
         expr,
         "Comet initCap is not compatible with Spark yet. " +
           "See https://github.com/apache/datafusion-comet/issues/1052 ." +
           s"Set ${CometConf.COMET_EXEC_INITCAP_ENABLED.key}=true to enable it 
anyway.")
-      None
+      return None
     }
-
-  }
-}
-
-object CometChr extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[Expr] = {
-    val child = expr.children.head
-    val childExpr = exprToProtoInternal(child, inputs, binding)
-    val optExpr = scalarFunctionExprToProto("char", childExpr)
-    optExprWithInfo(optExpr, expr, child)
-  }
-}
-
-object CometConcatWs extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[Expr] = {
-    var childExprs: Seq[Expression] = Seq()
-    val exprs = expr.children.map(e => {
-      val castExpr = Cast(e, StringType)
-      childExprs = childExprs :+ castExpr
-      exprToProtoInternal(castExpr, inputs, binding)
-    })
-    val optExpr = scalarFunctionExprToProto("concat_ws", exprs: _*)
-    optExprWithInfo(optExpr, expr, childExprs: _*)
-  }
-}
-
-object CometStringSpace extends UnaryScalarFuncSerde("string_space")
-
-object CometStartsWith extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[Expr] = {
-    val startsWith = expr.asInstanceOf[StartsWith]
-    scalarFunctionExprToProto(
-      "starts_with",
-      exprToProtoInternal(startsWith.left, inputs, binding),
-      exprToProtoInternal(startsWith.right, inputs, binding))
-  }
-}
-
-object CometEndsWith extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[Expr] = {
-    val endsWith = expr.asInstanceOf[EndsWith]
-    scalarFunctionExprToProto(
-      "ends_with",
-      exprToProtoInternal(endsWith.left, inputs, binding),
-      exprToProtoInternal(endsWith.right, inputs, binding))
-  }
-}
-
-object CometContains extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[Expr] = {
-    val contains = expr.asInstanceOf[Contains]
-    scalarFunctionExprToProto(
-      "contains",
-      exprToProtoInternal(contains.left, inputs, binding),
-      exprToProtoInternal(contains.right, inputs, binding))
+    super.convert(expr, inputs, binding)
   }
 }
 
@@ -431,34 +161,6 @@ object CometRLike extends CometExpressionSerde {
   }
 }
 
-object CometOctetLength extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[Expr] = {
-    val castExpr = Cast(expr.asInstanceOf[OctetLength].child, StringType)
-    optExprWithInfo(
-      scalarFunctionExprToProto("octet_length", exprToProtoInternal(castExpr, 
inputs, binding)),
-      expr,
-      castExpr)
-  }
-}
-
-object CometReverse extends CometExpressionSerde {
-
-  override def convert(
-      expr: Expression,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[Expr] = {
-    val castExpr = Cast(expr.asInstanceOf[Reverse].child, StringType)
-    optExprWithInfo(
-      scalarFunctionExprToProto("reverse", exprToProtoInternal(castExpr, 
inputs, binding)),
-      expr,
-      castExpr)
-  }
-}
-
 object CometStringRPad extends CometExpressionSerde {
 
   override def convert(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to