This is an automated email from the ASF dual-hosted git repository. mbutrovich pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new 7b7ba1942 chore: Refactor string expression serde, part 2 (#2097) 7b7ba1942 is described below commit 7b7ba19422d4dd5e07cdf3d6da21946322586fff Author: Andy Grove <agr...@apache.org> AuthorDate: Fri Aug 8 07:36:50 2025 -0600 chore: Refactor string expression serde, part 2 (#2097) --- .../org/apache/comet/serde/QueryPlanSerde.scala | 69 +++-- .../scala/org/apache/comet/serde/strings.scala | 320 +-------------------- 2 files changed, 45 insertions(+), 344 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 623126276..3391a10c9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -85,39 +85,39 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ArraysOverlap] -> CometArraysOverlap, classOf[ArrayUnion] -> CometArrayUnion, classOf[CreateArray] -> CometCreateArray, - classOf[Ascii] -> CometAscii, - classOf[ConcatWs] -> CometConcatWs, - classOf[Chr] -> CometChr, + classOf[Ascii] -> CometScalarFunction("ascii"), + classOf[ConcatWs] -> CometScalarFunction("concat_ws"), + classOf[Chr] -> CometScalarFunction("char"), classOf[InitCap] -> CometInitCap, classOf[BitwiseCount] -> CometBitwiseCount, classOf[BitwiseGet] -> CometBitwiseGet, classOf[BitwiseNot] -> CometBitwiseNot, classOf[BitwiseOr] -> CometBitwiseOr, classOf[BitwiseXor] -> CometBitwiseXor, - classOf[BitLength] -> CometBitLength, + classOf[BitLength] -> CometScalarFunction("bit_length"), classOf[FromUnixTime] -> CometFromUnixTime, - classOf[Length] -> CometLength, - classOf[Acos] -> UnaryScalarFuncSerde("acos"), - classOf[Cos] -> UnaryScalarFuncSerde("cos"), - classOf[Asin] -> UnaryScalarFuncSerde("asin"), - classOf[Sin] -> UnaryScalarFuncSerde("sin"), - classOf[Atan] -> UnaryScalarFuncSerde("atan"), - classOf[Tan] -> UnaryScalarFuncSerde("tan"), - classOf[Exp] -> UnaryScalarFuncSerde("exp"), - classOf[Expm1] -> UnaryScalarFuncSerde("expm1"), - classOf[Sqrt] -> UnaryScalarFuncSerde("sqrt"), - classOf[Signum] -> UnaryScalarFuncSerde("signum"), - classOf[Md5] -> UnaryScalarFuncSerde("md5"), + classOf[Length] -> CometScalarFunction("length"), + classOf[Acos] -> CometScalarFunction("acos"), + classOf[Cos] -> CometScalarFunction("cos"), + classOf[Asin] -> CometScalarFunction("asin"), + classOf[Sin] -> CometScalarFunction("sin"), + classOf[Atan] -> CometScalarFunction("atan"), + classOf[Tan] -> CometScalarFunction("tan"), + classOf[Exp] -> CometScalarFunction("exp"), + classOf[Expm1] -> CometScalarFunction("expm1"), + classOf[Sqrt] -> CometScalarFunction("sqrt"), + classOf[Signum] -> CometScalarFunction("signum"), + classOf[Md5] -> CometScalarFunction("md5"), classOf[ShiftLeft] -> CometShiftLeft, classOf[ShiftRight] -> CometShiftRight, - classOf[StringInstr] -> CometStringInstr, + classOf[StringInstr] -> CometScalarFunction("instr"), classOf[StringRepeat] -> CometStringRepeat, - classOf[StringReplace] -> CometStringReplace, - classOf[StringTranslate] -> CometStringTranslate, - classOf[StringTrim] -> CometStringTrim, - classOf[StringTrimLeft] -> CometStringTrimLeft, - classOf[StringTrimRight] -> CometStringTrimRight, - classOf[StringTrimBoth] -> CometStringTrimBoth, + classOf[StringReplace] -> CometScalarFunction("replace"), + classOf[StringTranslate] -> CometScalarFunction("translate"), + classOf[StringTrim] -> CometScalarFunction("trim"), + classOf[StringTrimLeft] -> CometScalarFunction("ltrim"), + classOf[StringTrimRight] -> CometScalarFunction("rtrim"), + classOf[StringTrimBoth] -> CometScalarFunction("btrim"), classOf[Upper] -> CometUpper, classOf[Lower] -> CometLower, classOf[Murmur3Hash] -> CometMurmur3Hash, @@ -141,15 +141,15 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Randn] -> CometRandn, classOf[SparkPartitionID] -> CometSparkPartitionId, classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId, - classOf[StringSpace] -> CometStringSpace, - classOf[StartsWith] -> CometStartsWith, - classOf[EndsWith] -> CometEndsWith, - classOf[Contains] -> CometContains, + classOf[StringSpace] -> CometScalarFunction("string_space"), + classOf[StartsWith] -> CometScalarFunction("starts_with"), + classOf[EndsWith] -> CometScalarFunction("ends_with"), + classOf[Contains] -> CometScalarFunction("contains"), classOf[Substring] -> CometSubstring, classOf[Like] -> CometLike, classOf[RLike] -> CometRLike, - classOf[OctetLength] -> CometOctetLength, - classOf[Reverse] -> CometReverse, + classOf[OctetLength] -> CometScalarFunction("octet_length"), + classOf[Reverse] -> CometScalarFunction("reverse"), classOf[StringRPad] -> CometStringRPad) /** @@ -2576,15 +2576,14 @@ trait CometAggregateExpressionSerde { /** Marker trait for an expression that is not guaranteed to be 100% compatible with Spark */ trait IncompatExpr {} -/** Serde for single-argument scalar function. */ -case class UnaryScalarFuncSerde(name: String) extends CometExpressionSerde { +/** Serde for scalar function. */ +case class CometScalarFunction(name: String) extends CometExpressionSerde { override def convert( expr: Expression, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { - val child = expr.children.head - val childExpr = exprToProtoInternal(child, inputs, binding) - val optExpr = scalarFunctionExprToProto(name, childExpr) - optExprWithInfo(optExpr, expr, child) + val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) + val optExpr = scalarFunctionExprToProto(name, childExpr: _*) + optExprWithInfo(optExpr, expr, expr.children: _*) } } diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index de896e0df..75e7e8bd4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -19,9 +19,7 @@ package org.apache.comet.serde -import scala.util.Try - -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Contains, EndsWith, Expression, Like, Literal, OctetLength, Reverse, RLike, StartsWith, StringRPad, Substring} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, Like, Literal, RLike, StringRPad, Substring} import org.apache.spark.sql.types.{DataTypes, LongType, StringType} import org.apache.comet.CometConf @@ -29,50 +27,6 @@ import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} -object CometAscii extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val child = expr.children.head - val castExpr = Cast(child, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs, binding) - val optExpr = scalarFunctionExprToProto("ascii", childExpr) - optExprWithInfo(optExpr, expr, castExpr) - } -} - -object CometBitLength extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val child = expr.children.head - val castExpr = Cast(child, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs, binding) - val optExpr = scalarFunctionExprToProto("bit_length", childExpr) - optExprWithInfo(optExpr, expr, castExpr) - } -} - -object CometStringInstr extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val children = expr.children - val leftCast = Cast(children(0), StringType) - val rightCast = Cast(children(1), StringType) - val leftExpr = exprToProtoInternal(leftCast, inputs, binding) - val rightExpr = exprToProtoInternal(rightCast, inputs, binding) - val optExpr = scalarFunctionExprToProto("strpos", leftExpr, rightExpr) - optExprWithInfo(optExpr, expr, leftCast, rightCast) - } -} - object CometStringRepeat extends CometExpressionSerde { override def convert( @@ -89,267 +43,43 @@ object CometStringRepeat extends CometExpressionSerde { } } -object CometStringReplace extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val children = expr.children - val srcCast = Cast(children(0), StringType) - val searchCast = Cast(children(1), StringType) - val replaceCast = Cast(children(2), StringType) - val srcExpr = exprToProtoInternal(srcCast, inputs, binding) - val searchExpr = exprToProtoInternal(searchCast, inputs, binding) - val replaceExpr = exprToProtoInternal(replaceCast, inputs, binding) - val optExpr = scalarFunctionExprToProto("replace", srcExpr, searchExpr, replaceExpr) - optExprWithInfo(optExpr, expr, srcCast, searchCast, replaceCast) - } -} - -object CometStringTranslate extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val children = expr.children - val srcCast = Cast(children(0), StringType) - val matchingCast = Cast(children(1), StringType) - val replaceCast = Cast(children(2), StringType) - val srcExpr = exprToProtoInternal(srcCast, inputs, binding) - val matchingExpr = exprToProtoInternal(matchingCast, inputs, binding) - val replaceExpr = exprToProtoInternal(replaceCast, inputs, binding) - val optExpr = scalarFunctionExprToProto("translate", srcExpr, matchingExpr, replaceExpr) - optExprWithInfo(optExpr, expr, srcCast, matchingCast, replaceCast) - } -} - -object CometStringTrim extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val children = expr.children - val srcStr = children(0) - val trimStr = Try(children(1)).toOption - CometTrimCommon.trim(expr, srcStr, trimStr, inputs, binding, "trim") - } -} - -object CometStringTrimLeft extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - val children = expr.children - val srcStr = children(0) - val trimStr = Try(children(1)).toOption - CometTrimCommon.trim(expr, srcStr, trimStr, inputs, binding, "ltrim") - } -} - -object CometStringTrimRight extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - val children = expr.children - val srcStr = children(0) - val trimStr = Try(children(1)).toOption - CometTrimCommon.trim(expr, srcStr, trimStr, inputs, binding, "rtrim") - } -} - -object CometStringTrimBoth extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - val children = expr.children - val srcStr = children(0) - val trimStr = Try(children(1)).toOption - CometTrimCommon.trim(expr, srcStr, trimStr, inputs, binding, "btrim") - } -} - -private object CometTrimCommon { - def trim( - expr: Expression, // parent expression - srcStr: Expression, - trimStr: Option[Expression], - inputs: Seq[Attribute], - binding: Boolean, - trimType: String): Option[Expr] = { - val srcCast = Cast(srcStr, StringType) - val srcExpr = exprToProtoInternal(srcCast, inputs, binding) - if (trimStr.isDefined) { - val trimCast = Cast(trimStr.get, StringType) - val trimExpr = exprToProtoInternal(trimCast, inputs, binding) - val optExpr = scalarFunctionExprToProto(trimType, srcExpr, trimExpr) - optExprWithInfo(optExpr, expr, srcCast, trimCast) - } else { - val optExpr = scalarFunctionExprToProto(trimType, srcExpr) - optExprWithInfo(optExpr, expr, srcCast) - } - } -} - -object CometUpper extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) { - val castExpr = Cast(expr.children.head, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs, binding) - val optExpr = scalarFunctionExprToProto("upper", childExpr) - optExprWithInfo(optExpr, expr, castExpr) - } else { - withInfo( - expr, - "Comet is not compatible with Spark for case conversion in " + - s"locale-specific cases. Set ${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " + - "to enable it anyway.") - None - } - } -} - -object CometLower extends CometExpressionSerde { +class CometCaseConversionBase(function: String) extends CometScalarFunction(function) { override def convert( expr: Expression, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { - if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) { - val castExpr = Cast(expr.children.head, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs, binding) - val optExpr = scalarFunctionExprToProto("lower", childExpr) - optExprWithInfo(optExpr, expr, castExpr) - } else { + if (!CometConf.COMET_CASE_CONVERSION_ENABLED.get()) { withInfo( expr, "Comet is not compatible with Spark for case conversion in " + s"locale-specific cases. Set ${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " + "to enable it anyway.") - None + return None } + super.convert(expr, inputs, binding) } } -object CometLength extends CometExpressionSerde { +object CometUpper extends CometCaseConversionBase("upper") - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - val castExpr = Cast(expr.children.head, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs, binding) - val optExpr = scalarFunctionExprToProto("length", childExpr) - optExprWithInfo(optExpr, expr, castExpr) - } -} +object CometLower extends CometCaseConversionBase("lower") -object CometInitCap extends CometExpressionSerde { +object CometInitCap extends CometScalarFunction("initcap") { override def convert( expr: Expression, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { - if (CometConf.COMET_EXEC_INITCAP_ENABLED.get()) { - val castExpr = Cast(expr.children.head, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs, binding) - val optExpr = scalarFunctionExprToProto("initcap", childExpr) - optExprWithInfo(optExpr, expr, castExpr) - } else { + if (!CometConf.COMET_EXEC_INITCAP_ENABLED.get()) { withInfo( expr, "Comet initCap is not compatible with Spark yet. " + "See https://github.com/apache/datafusion-comet/issues/1052 ." + s"Set ${CometConf.COMET_EXEC_INITCAP_ENABLED.key}=true to enable it anyway.") - None + return None } - - } -} - -object CometChr extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - val child = expr.children.head - val childExpr = exprToProtoInternal(child, inputs, binding) - val optExpr = scalarFunctionExprToProto("char", childExpr) - optExprWithInfo(optExpr, expr, child) - } -} - -object CometConcatWs extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - var childExprs: Seq[Expression] = Seq() - val exprs = expr.children.map(e => { - val castExpr = Cast(e, StringType) - childExprs = childExprs :+ castExpr - exprToProtoInternal(castExpr, inputs, binding) - }) - val optExpr = scalarFunctionExprToProto("concat_ws", exprs: _*) - optExprWithInfo(optExpr, expr, childExprs: _*) - } -} - -object CometStringSpace extends UnaryScalarFuncSerde("string_space") - -object CometStartsWith extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - val startsWith = expr.asInstanceOf[StartsWith] - scalarFunctionExprToProto( - "starts_with", - exprToProtoInternal(startsWith.left, inputs, binding), - exprToProtoInternal(startsWith.right, inputs, binding)) - } -} - -object CometEndsWith extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - val endsWith = expr.asInstanceOf[EndsWith] - scalarFunctionExprToProto( - "ends_with", - exprToProtoInternal(endsWith.left, inputs, binding), - exprToProtoInternal(endsWith.right, inputs, binding)) - } -} - -object CometContains extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - val contains = expr.asInstanceOf[Contains] - scalarFunctionExprToProto( - "contains", - exprToProtoInternal(contains.left, inputs, binding), - exprToProtoInternal(contains.right, inputs, binding)) + super.convert(expr, inputs, binding) } } @@ -431,34 +161,6 @@ object CometRLike extends CometExpressionSerde { } } -object CometOctetLength extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - val castExpr = Cast(expr.asInstanceOf[OctetLength].child, StringType) - optExprWithInfo( - scalarFunctionExprToProto("octet_length", exprToProtoInternal(castExpr, inputs, binding)), - expr, - castExpr) - } -} - -object CometReverse extends CometExpressionSerde { - - override def convert( - expr: Expression, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - val castExpr = Cast(expr.asInstanceOf[Reverse].child, StringType) - optExprWithInfo( - scalarFunctionExprToProto("reverse", exprToProtoInternal(castExpr, inputs, binding)), - expr, - castExpr) - } -} - object CometStringRPad extends CometExpressionSerde { override def convert( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org