beliefer commented on a change in pull request #24918: [SPARK-28077][SQL] Support ANSI SQL OVERLAY function. URL: https://github.com/apache/spark/pull/24918#discussion_r296566625
########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala ########## @@ -454,6 +454,70 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp override def prettyName: String = "replace" } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(input, replace, pos[, len]) - Replace `input` with `replace` that starts at `pos` and is of length `len`.", + examples = """ + Examples: + > SELECT _FUNC_('Spark SQL' PLACING '_' FROM 6); + Spark_SQL + > SELECT _FUNC_('Spark SQL' PLACING 'CORE' FROM 7); + Spark CORE + > SELECT _FUNC_('Spark SQL' PLACING 'ANSI ' FROM 7 FOR 0); + Spark ANSI SQL + > SELECT _FUNC_('Spark SQL' PLACING 'tructured' FROM 2 FOR 4); + Structured SQL + """) +// scalastyle:on line.size.limit +case class Overlay(input: Expression, replace: Expression, pos: Expression, len: Expression) + extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant { + + def this(str: Expression, replace: Expression, pos: Expression) = { + this(str, replace, pos, Literal(Integer.MAX_VALUE)) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = + Seq(StringType, StringType, IntegerType, IntegerType) + + override def children: Seq[Expression] = input :: replace :: pos :: len :: Nil + + override def nullSafeEval(inputEval: Any, replaceEval: Any, posEval: Any, lenEval: Any): Any = { + val inputStr = inputEval.asInstanceOf[UTF8String] + val replaceStr = replaceEval.asInstanceOf[UTF8String].toString + val position = posEval.asInstanceOf[Int] + var length = lenEval.asInstanceOf[Int] + if (length.equals(Int.MaxValue)) { + length = replaceStr.size + } + val headStr = inputStr.substringSQL(1, position - 1) + val tailStr = inputStr.substringSQL(position + length, Int.MaxValue) + UTF8String.fromString(headStr.toString + replaceStr + tailStr.toString) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val isMax = ctx.freshName("isMax") + val length = ctx.freshName("length") + val integer = classOf[java.lang.Integer].getName + val head = ctx.freshName("head") + val tail = ctx.freshName("tail") + defineCodeGen(ctx, ev, (input, replace, pos, len) => { + s""" + boolean $isMax = $len.equals($integer.MAX_VALUE); + int $length = $len; + if ($isMax) { + $length = $replace.toString.size; + } + UTF8String $head = $input.substringSQL(1, $pos - 1); + UTF8String $tail = $input.substringSQL($pos + $length, $integer.MAX_VALUE); + $head + $replace + $tail; + } Review comment: @ueshin Thanks for your review. I added some tests in `StringExpressionsSuite`. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org