Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/12646#discussion_r139513283 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala --- @@ -503,69 +504,304 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def prettyName: String = "find_in_set" } +trait String2TrimExpression extends Expression with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) +} + +object StringTrim { + def apply(str: Expression, trimStr: Expression) : StringTrim = StringTrim(str, Some(trimStr)) + def apply(str: Expression) : StringTrim = StringTrim(str, None) +} + /** - * A function that trim the spaces from both ends for the specified string. + * A function that takes a character string, removes the leading and trailing characters matching with any character + * in the trim string, returns the new string. + * If BOTH and trimStr keywords are not specified, it defaults to remove space character from both ends. The trim + * function will have one argument, which contains the source string. + * If BOTH and trimStr keywords are specified, it trims the characters from both ends, and the trim function will have + * two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: A character string to be trimmed from the source string, if it has multiple characters, the function + * searches for each character in the source string, removes the characters from the source string until it + * encounters the first non-match character. + * BOTH: removes any character from both ends of the source string that matches characters in the trim string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the leading and trailing space characters from `str`. + _FUNC_(BOTH trimStr FROM str) - Remove the leading and trailing trimString from `str` + """, + arguments = """ + Arguments: + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, examples = """ Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL + > SELECT _FUNC_(BOTH 'SL' FROM 'SSparkSQLS'); + parkSQ """) -case class StringTrim(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrim( + srcStr: Expression, + trimStr: Option[Expression] = None) + extends String2TrimExpression { + + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) - def convert(v: UTF8String): UTF8String = v.trim() + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "trim" + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil + } + override def eval(input: InternalRow): Any = { + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString == null) { + null + } else { + if (trimStr.isDefined) { + srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + srcString.trim() + } + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trim()") + val evals = children.map(_.genCode(ctx)) + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trim(); + }""") + } else { + val trimString = evals(1) + val getTrimFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trim(${trimString.value}); + }""" + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimFunction + }""") + } } } +object StringTrimLeft { + def apply(str: Expression, trimStr: Expression) : StringTrimLeft = StringTrimLeft(str, Some(trimStr)) + def apply(str: Expression) : StringTrimLeft = StringTrimLeft(str, None) +} + /** - * A function that trim the spaces from left end for given string. + * A function that trims the characters from left end for a given string. + * If LEADING and trimStr keywords are not specified, it defaults to remove space character from the left end. The ltrim + * function will have one argument, which contains the source string. + * If LEADING and trimStr keywords are not specified, it trims the characters from left end. The ltrim function will + * have two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: the function removes any character from the left end of the source string which matches with the characters + * from trimStr, it stops at the first non-match character. + * LEADING: removes any character from the left end of the source string that matches characters in the trim string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the leading space characters from `str`. + _FUNC_(trimStr, str) - Removes the leading string contains the characters from the trim string + """, + arguments = """ + Arguments: + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, examples = """ Examples: - > SELECT _FUNC_(' SparkSQL'); + > SELECT _FUNC_(' SparkSQL '); SparkSQL + > SELECT _FUNC_('Sp', 'SSparkSQLS'); + arkSQLS """) -case class StringTrimLeft(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrimLeft( + srcStr: Expression, + trimStr: Option[Expression] = None) + extends String2TrimExpression { - def convert(v: UTF8String): UTF8String = v.trimLeft() + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) + + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "ltrim" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trimLeft()") + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil } + + override def eval(input: InternalRow): Any = { + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString == null) { + null + } else { + if (trimStr.isDefined) { + srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + srcString.trimLeft() + } + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimLeft(); + }""") + } else { + val trimString = evals(1) + val getTrimLeftFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); + }""" + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimLeftFunction + }""") + } + } +} + +object StringTrimRight { + def apply(str: Expression, trimStr: Expression) : StringTrimRight = StringTrimRight(str, Some(trimStr)) + def apply(str: Expression) : StringTrimRight = StringTrimRight(str, None) } /** - * A function that trim the spaces from right end for given string. + * A function that trims the characters from right end for a given string. + * If TRAILING and trimStr keywords are not specified, it defaults to remove space character from the right end. The + * rtrim function will have one argument, which contains the source string. + * If TRAILING and trimStr keywords are specified, it trims the characters from right end. The rtrim function will + * have two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: the function removes any character from the right end of source string which matches with the characters + * from trimStr, it stops at the first non-match character. + * TRAILING: removes any character from the right end of the source string that matches characters in the trim string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the trailing space characters from `str`. + _FUNC_(trimStr, str) - Removes the trailing string which contains the characters from the trim string from the `str` + """, + arguments = """ + Arguments: + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, examples = """ Examples: > SELECT _FUNC_(' SparkSQL '); - SparkSQL + SparkSQL + > SELECT _FUNC_('LQSa', 'SSparkSQLS'); + SSpark --- End diff -- Same here. Add another example here to explain how to use `BOTH` and `FROM`?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org