Github user kevinyu98 commented on a diff in the pull request: https://github.com/apache/spark/pull/12646#discussion_r137366326 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala --- @@ -503,69 +504,319 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def prettyName: String = "find_in_set" } +trait String2TrimExpression extends Expression with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def sql: String = { + if (children.size == 1) { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" + } else { + val trimSQL = children(0).map(_.sql).mkString(", ") + val tarSQL = children(1).map(_.sql).mkString(", ") + s"$prettyName($trimSQL, $tarSQL)" + } + } +} + +object StringTrim { + def apply(str: Expression, trimStr: Expression) : StringTrim = StringTrim(str, Some(trimStr)) + def apply(str: Expression) : StringTrim = StringTrim(str, None) +} + /** - * A function that trim the spaces from both ends for the specified string. - */ + * A function that takes a character string, removes the leading and trailing characters matching with the characters + * in the trim string, returns the new string. + * If BOTH and trimStr keywords are not specified, it defaults to remove space character from both ends. The trim + * function will have one argument, which contains the source string. + * If BOTH and trimStr keywords are specified, it trims the characters from both ends, and the trim function will have + * two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: A character string to be trimmed from the source string, if it has multiple characters, the function + * searches for each character in the source string, removes the characters from the source string until it + * encounters the first non-match character. + * BOTH: removes any characters from both ends of the source string that matches characters in the trim string. + */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the leading and trailing space characters from `str`. + _FUNC_(BOTH trimStr FROM str) - Remove the leading and trailing trimString from `str` + """, + arguments = """ + Arguments: + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, examples = """ Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL + > SELECT _FUNC_(BOTH 'SL' FROM 'SSparkSQLS'); + parkSQ """) -case class StringTrim(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrim( + srcStr: Expression, + trimStr: Option[Expression] = None) + extends String2TrimExpression { - def convert(v: UTF8String): UTF8String = v.trim() + def this (trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) + + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "trim" + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil + } + override def eval(input: InternalRow): Any = { + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString != null) { + if (trimStr.isDefined) { + return srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + return srcString.trim() + } + } + null + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trim()") + val evals = children.map(_.genCode(ctx)) + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString("\n") + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trim(); + } + """.stripMargin) + } else { + val trimString = evals(1) + val getTrimFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trim(${trimString.value}); + }""".stripMargin + ev.copy(evals.map(_.code).mkString("\n") + + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimFunction + } + """.stripMargin) + } } } +object StringTrimLeft { + def apply(str: Expression, trimStr: Expression) : StringTrimLeft = StringTrimLeft(str, Some(trimStr)) + def apply(str: Expression) : StringTrimLeft = StringTrimLeft(str, None) +} + /** - * A function that trim the spaces from left end for given string. + * A function that trims the characters from left end for a given string. + * If LEADING and trimStr keywords are not specified, it defaults to remove space character from the left end. The ltrim + * function will have one argument, which contains the source string. + * If LEADING and trimStr keywords are not specified, it trims the characters from left end. The ltrim function will + * have two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: the function removes any characters from the left end of the source string which matches with the characters + * from trimStr, it stops at the first non-match character. + * LEADING: removes any characters from the left end of the source string that matches characters in the trim string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the leading space characters from `str`. + _FUNC_(trimStr, str) - Removes the leading string contains the characters from the trim string + """, + arguments = """ + Arguments: + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, examples = """ Examples: - > SELECT _FUNC_(' SparkSQL'); + > SELECT _FUNC_(' SparkSQL '); SparkSQL + > SELECT _FUNC_('Sp', 'SSparkSQLS'); + arkSQLS """) -case class StringTrimLeft(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrimLeft( + srcStr: Expression, + trimStr: Option[Expression] = None) + extends String2TrimExpression { + + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) - def convert(v: UTF8String): UTF8String = v.trimLeft() + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "ltrim" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trimLeft()") + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil + } + + override def eval(input: InternalRow): Any = { + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString != null) { + if (trimStr.isDefined) { + return srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + return srcString.trimLeft() + } + } + null } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString("\n") + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimLeft(); + }""".stripMargin) + } else { + val trimString = evals(1) + val getTrimLeftFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); + }""".stripMargin + ev.copy(evals.map(_.code).mkString("\n") + + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimLeftFunction + } + """.stripMargin ) + } + } +} + +object StringTrimRight { + def apply(str: Expression, trimStr: Expression) : StringTrimRight = StringTrimRight(str, Some(trimStr)) + def apply(str: Expression) : StringTrimRight = StringTrimRight(str, None) } /** - * A function that trim the spaces from right end for given string. + * A function that trims the characters from right end for a given string. + * If TRAILING and trimStr keywords are not specified, it defaults to remove space character from the right end. The + * rtrim function will have one argument, which contains the source string. + * If TRAILING and trimStr keywords are specified, it trims the characters from right end. The rtrim function will + * have two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: the function removes any characters from the right end of source string which matches with the characters + * from trimStr, it stops at the first non-match character. + * TRAILING: removes any characters from the right end of the source string that matches characters in the trim string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the trailing space characters from `str`. + _FUNC_(trimStr, str) - Removes the trailing string which contains the character from the trim string from the `str` + """, + arguments = """ + Arguments: + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, examples = """ Examples: > SELECT _FUNC_(' SparkSQL '); - SparkSQL + SparkSQL + > SELECT _FUNC_('LQSa', 'SSparkSQLS'); + SSpark """) -case class StringTrimRight(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrimRight( + srcStr: Expression, + trimStr: Option[Expression] = None) + extends String2TrimExpression { - def convert(v: UTF8String): UTF8String = v.trimRight() + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) + + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "rtrim" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trimRight()") + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil + } + + override def eval(input: InternalRow): Any = { + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString != null) { + if (trimStr.isDefined) { + return srcString.trimRight(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + return srcString.trimRight() + } + } --- End diff -- changed.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org