cloud-fan commented on a change in pull request #30981:
URL: https://github.com/apache/spark/pull/30981#discussion_r551130707
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
##########
@@ -751,267 +751,348 @@ case class FindInSet(left: Expression, right:
Expression) extends BinaryExpressi
override def prettyName: String = "find_in_set"
}
-trait String2TrimExpression extends Expression with ImplicitCastInputTypes {
+trait TrimExpression extends Expression with ImplicitCastInputTypes {
- protected def srcStr: Expression
- protected def trimStr: Option[Expression]
+ protected def srcExpr: Expression
+ protected def trimExprOpt: Option[Expression]
protected def direction: String
- override def children: Seq[Expression] = srcStr +: trimStr.toSeq
- override def dataType: DataType = StringType
- override def inputTypes: Seq[AbstractDataType] =
Seq.fill(children.size)(StringType)
+ override def children: Seq[Expression] = srcExpr +: trimExprOpt.toSeq
+ override def dataType: DataType = srcExpr.dataType
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq.fill(children.size)(TypeCollection(StringType, BinaryType))
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val inputTypeCheck = super.checkInputDataTypes()
+ if (inputTypeCheck.isSuccess) {
+ TypeUtils.checkForSameTypeInputExpr(
+ children.map(_.dataType), s"function $prettyName")
+ } else {
+ inputTypeCheck
+ }
+ }
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
protected def doEval(srcString: UTF8String): UTF8String
+ protected def doEval(srcBytes: Array[Byte]): Array[Byte]
protected def doEval(srcString: UTF8String, trimString: UTF8String):
UTF8String
+ protected def doEval(srcBytes: Array[Byte], trimBytes: Array[Byte]):
Array[Byte]
+
+ private lazy val evalFunc = srcExpr.dataType match {
+ case StringType =>
+ (input: InternalRow) => {
+ val srcString = srcExpr.eval(input).asInstanceOf[UTF8String]
+ if (srcString == null) {
+ null
+ } else if (trimExprOpt.isDefined) {
+ doEval(srcString,
trimExprOpt.get.eval(input).asInstanceOf[UTF8String])
+ } else {
+ doEval(srcString)
+ }
+ }
+ case BinaryType =>
+ (input: InternalRow) => {
+ val srcBytes = srcExpr.eval (input).asInstanceOf[Array[Byte]]
+ if (srcBytes == null) {
+ null
+ } else if (trimExprOpt.isDefined) {
+ doEval(srcBytes,
trimExprOpt.get.eval(input).asInstanceOf[Array[Byte]])
+ } else {
+ doEval(srcBytes)
+ }
+ }
+ }
override def eval(input: InternalRow): Any = {
- val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
- if (srcString == null) {
- null
- } else if (trimStr.isDefined) {
- doEval(srcString, trimStr.get.eval(input).asInstanceOf[UTF8String])
- } else {
- doEval(srcString)
- }
+ evalFunc(input)
}
protected val trimMethod: String
+ private lazy val resultType = srcExpr.dataType match {
Review comment:
We can define it as a local variable inside `doGenCode`
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]