This is an automated email from the ASF dual-hosted git repository.

yamamuro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new aafce7e  [SPARK-28412][SQL] ANSI SQL: OVERLAY function support byte 
array
aafce7e is described below

commit aafce7ebffe1acd8f6022f208beaa9ec6c9f7592
Author: gengjiaan <gengji...@360.cn>
AuthorDate: Tue Sep 10 08:16:18 2019 +0900

    [SPARK-28412][SQL] ANSI SQL: OVERLAY function support byte array
    
    ## What changes were proposed in this pull request?
    
    This is a ANSI SQL and feature id is `T312`
    
    ```
    <binary overlay function> ::=
    OVERLAY <left paren> <binary value expression> PLACING <binary value 
expression>
    FROM <start position> [ FOR <string length> ] <right paren>
    ```
    
    This PR related to https://github.com/apache/spark/pull/24918 and support 
treat byte array.
    
    ref: https://www.postgresql.org/docs/11/functions-binarystring.html
    
    ## How was this patch tested?
    
    new UT.
    There are some show of the PR on my production environment.
    ```
    spark-sql> select overlay(encode('Spark SQL', 'utf-8') PLACING encode('_', 
'utf-8') FROM 6);
    Spark_SQL
    Time taken: 0.285 s
    spark-sql> select overlay(encode('Spark SQL', 'utf-8') PLACING 
encode('CORE', 'utf-8') FROM 7);
    Spark CORE
    Time taken: 0.202 s
    spark-sql> select overlay(encode('Spark SQL', 'utf-8') PLACING encode('ANSI 
', 'utf-8') FROM 7 FOR 0);
    Spark ANSI SQL
    Time taken: 0.165 s
    spark-sql> select overlay(encode('Spark SQL', 'utf-8') PLACING 
encode('tructured', 'utf-8') FROM 2 FOR 4);
    Structured SQL
    Time taken: 0.141 s
    ```
    
    Closes #25172 from beliefer/ansi-overlay-byte-array.
    
    Lead-authored-by: gengjiaan <gengji...@360.cn>
    Co-authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Takeshi Yamamuro <yamam...@apache.org>
---
 .../catalyst/expressions/stringExpressions.scala   | 60 +++++++++++++++---
 .../expressions/StringExpressionsSuite.scala       | 72 +++++++++++++++++++++-
 .../scala/org/apache/spark/sql/functions.scala     | 16 ++---
 .../apache/spark/sql/StringFunctionsSuite.scala    | 33 +++++++---
 4 files changed, 157 insertions(+), 24 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index d7a5fb2..e4847e9 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -472,6 +472,19 @@ object Overlay {
     builder.append(input.substringSQL(pos + length, Int.MaxValue))
     builder.build()
   }
+
+  def calculate(input: Array[Byte], replace: Array[Byte], pos: Int, len: Int): 
Array[Byte] = {
+    // If you specify length, it must be a positive whole number or zero.
+    // Otherwise it will be ignored.
+    // The default value for length is the length of replace.
+    val length = if (len >= 0) {
+      len
+    } else {
+      replace.length
+    }
+    ByteArray.concat(ByteArray.subStringSQL(input, 1, pos - 1),
+      replace, ByteArray.subStringSQL(input, pos + length, Int.MaxValue))
+  }
 }
 
 // scalastyle:off line.size.limit
@@ -487,6 +500,14 @@ object Overlay {
        Spark ANSI SQL
       > SELECT _FUNC_('Spark SQL' PLACING 'tructured' FROM 2 FOR 4);
        Structured SQL
+      > SELECT _FUNC_(encode('Spark SQL', 'utf-8') PLACING encode('_', 
'utf-8') FROM 6);
+       Spark_SQL
+      > SELECT _FUNC_(encode('Spark SQL', 'utf-8') PLACING encode('CORE', 
'utf-8') FROM 7);
+       Spark CORE
+      > SELECT _FUNC_(encode('Spark SQL', 'utf-8') PLACING encode('ANSI ', 
'utf-8') FROM 7 FOR 0);
+       Spark ANSI SQL
+      > SELECT _FUNC_(encode('Spark SQL', 'utf-8') PLACING encode('tructured', 
'utf-8') FROM 2 FOR 4);
+       Structured SQL
   """)
 // scalastyle:on line.size.limit
 case class Overlay(input: Expression, replace: Expression, pos: Expression, 
len: Expression)
@@ -496,19 +517,42 @@ case class Overlay(input: Expression, replace: 
Expression, pos: Expression, len:
     this(str, replace, pos, Literal.create(-1, IntegerType))
   }
 
-  override def dataType: DataType = StringType
+  override def dataType: DataType = input.dataType
 
-  override def inputTypes: Seq[AbstractDataType] =
-    Seq(StringType, StringType, IntegerType, IntegerType)
+  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(StringType, BinaryType),
+    TypeCollection(StringType, BinaryType), IntegerType, IntegerType)
 
   override def children: Seq[Expression] = input :: replace :: pos :: len :: 
Nil
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val inputTypeCheck = super.checkInputDataTypes()
+    if (inputTypeCheck.isSuccess) {
+      TypeUtils.checkForSameTypeInputExpr(
+        input.dataType :: replace.dataType :: Nil, s"function $prettyName")
+    } else {
+      inputTypeCheck
+    }
+  }
+
+  private lazy val replaceFunc = input.dataType match {
+    case StringType =>
+      (inputEval: Any, replaceEval: Any, posEval: Int, lenEval: Int) => {
+        Overlay.calculate(
+          inputEval.asInstanceOf[UTF8String],
+          replaceEval.asInstanceOf[UTF8String],
+          posEval, lenEval)
+      }
+    case BinaryType =>
+      (inputEval: Any, replaceEval: Any, posEval: Int, lenEval: Int) => {
+        Overlay.calculate(
+          inputEval.asInstanceOf[Array[Byte]],
+          replaceEval.asInstanceOf[Array[Byte]],
+          posEval, lenEval)
+      }
+  }
+
   override def nullSafeEval(inputEval: Any, replaceEval: Any, posEval: Any, 
lenEval: Any): Any = {
-    val inputStr = inputEval.asInstanceOf[UTF8String]
-    val replaceStr = replaceEval.asInstanceOf[UTF8String]
-    val position = posEval.asInstanceOf[Int]
-    val length = lenEval.asInstanceOf[Int]
-    Overlay.calculate(inputStr, replaceStr, position, length)
+    replaceFunc(inputEval, replaceEval, posEval.asInstanceOf[Int], 
lenEval.asInstanceOf[Int])
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 1b5acf4..4308f98 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.types._
 
@@ -428,7 +429,7 @@ class StringExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     // scalastyle:on
   }
 
-  test("overlay") {
+  test("overlay for string") {
     checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("_"),
       Literal.create(6, IntegerType)), "Spark_SQL")
     checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("CORE"),
@@ -450,6 +451,75 @@ class StringExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(new Overlay(Literal("Spark的SQL"), Literal("_"),
       Literal.create(6, IntegerType)), "Spark_SQL")
     // scalastyle:on
+    // position greater than the length of input string
+    checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("_"),
+      Literal.create(10, IntegerType)), "Spark SQL_")
+    checkEvaluation(Overlay(Literal("Spark SQL"), Literal("_"),
+      Literal.create(10, IntegerType), Literal.create(4, IntegerType)), "Spark 
SQL_")
+    // position is zero
+    checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("__"),
+      Literal.create(0, IntegerType)), "__park SQL")
+    checkEvaluation(Overlay(Literal("Spark SQL"), Literal("__"),
+      Literal.create(0, IntegerType), Literal.create(4, IntegerType)), "__rk 
SQL")
+    // position is negative
+    checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("__"),
+      Literal.create(-10, IntegerType)), "__park SQL")
+    checkEvaluation(Overlay(Literal("Spark SQL"), Literal("__"),
+      Literal.create(-10, IntegerType), Literal.create(4, IntegerType)), "__rk 
SQL")
+  }
+
+  test("overlay for byte array") {
+    val input = Literal(Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9))
+    checkEvaluation(new Overlay(input, Literal(Array[Byte](-1)),
+      Literal.create(6, IntegerType)), Array[Byte](1, 2, 3, 4, 5, -1, 7, 8, 9))
+    checkEvaluation(new Overlay(input, Literal(Array[Byte](-1, -1, -1, -1)),
+      Literal.create(7, IntegerType)), Array[Byte](1, 2, 3, 4, 5, 6, -1, -1, 
-1, -1))
+    checkEvaluation(Overlay(input, Literal(Array[Byte](-1, -1)), 
Literal.create(7, IntegerType),
+      Literal.create(0, IntegerType)), Array[Byte](1, 2, 3, 4, 5, 6, -1, -1, 
7, 8, 9))
+    checkEvaluation(Overlay(input, Literal(Array[Byte](-1, -1, -1, -1, -1)),
+      Literal.create(2, IntegerType), Literal.create(4, IntegerType)),
+      Array[Byte](1, -1, -1, -1, -1, -1, 6, 7, 8, 9))
+
+    val nullInput = Literal.create(null, BinaryType)
+    checkEvaluation(new Overlay(nullInput, Literal(Array[Byte](-1)),
+      Literal.create(6, IntegerType)), null)
+    checkEvaluation(new Overlay(nullInput, Literal(Array[Byte](-1, -1, -1, 
-1)),
+      Literal.create(7, IntegerType)), null)
+    checkEvaluation(Overlay(nullInput, Literal(Array[Byte](-1, -1)),
+      Literal.create(7, IntegerType), Literal.create(0, IntegerType)), null)
+    checkEvaluation(Overlay(nullInput, Literal(Array[Byte](-1, -1, -1, -1, 
-1)),
+      Literal.create(2, IntegerType), Literal.create(4, IntegerType)), null)
+    // position greater than the length of input byte array
+    checkEvaluation(new Overlay(input, Literal(Array[Byte](-1)),
+      Literal.create(10, IntegerType)), Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, 
-1))
+    checkEvaluation(Overlay(input, Literal(Array[Byte](-1)), 
Literal.create(10, IntegerType),
+      Literal.create(4, IntegerType)), Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9, 
-1))
+    // position is zero
+    checkEvaluation(new Overlay(input, Literal(Array[Byte](-1, -1)),
+      Literal.create(0, IntegerType)), Array[Byte](-1, -1, 2, 3, 4, 5, 6, 7, 
8, 9))
+    checkEvaluation(Overlay(input, Literal(Array[Byte](-1, -1)), 
Literal.create(0, IntegerType),
+      Literal.create(4, IntegerType)), Array[Byte](-1, -1, 4, 5, 6, 7, 8, 9))
+    // position is negative
+    checkEvaluation(new Overlay(input, Literal(Array[Byte](-1, -1)),
+      Literal.create(-10, IntegerType)), Array[Byte](-1, -1, 2, 3, 4, 5, 6, 7, 
8, 9))
+    checkEvaluation(Overlay(input, Literal(Array[Byte](-1, -1)), 
Literal.create(-10, IntegerType),
+      Literal.create(4, IntegerType)), Array[Byte](-1, -1, 4, 5, 6, 7, 8, 9))
+  }
+
+  test("Check Overlay.checkInputDataTypes results") {
+    assert(new Overlay(Literal("Spark SQL"), Literal("_"),
+      Literal.create(6, IntegerType)).checkInputDataTypes().isSuccess)
+    assert(Overlay(Literal("Spark SQL"), Literal("ANSI "), Literal.create(7, 
IntegerType),
+      Literal.create(0, IntegerType)).checkInputDataTypes().isSuccess)
+    assert(new Overlay(Literal.create("Spark SQL".getBytes), 
Literal.create("_".getBytes),
+      Literal.create(6, IntegerType)).checkInputDataTypes().isSuccess)
+    assert(Overlay(Literal.create("Spark SQL".getBytes), Literal.create("ANSI 
".getBytes),
+      Literal.create(7, IntegerType), Literal.create(0, IntegerType))
+      .checkInputDataTypes().isSuccess)
+    assert(new Overlay(Literal.create(1), Literal.create(2), Literal.create(0, 
IntegerType))
+      .checkInputDataTypes().isFailure)
+    assert(Overlay(Literal("Spark SQL"), Literal.create(2), Literal.create(7, 
IntegerType),
+      Literal.create(0, IntegerType)).checkInputDataTypes().isFailure)
   }
 
   test("translate") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 6b8127b..395f1b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2521,25 +2521,25 @@ object functions {
   }
 
   /**
-   * Overlay the specified portion of `src` with `replaceString`,
-   *  starting from byte position `pos` of `inputString` and proceeding for 
`len` bytes.
+   * Overlay the specified portion of `src` with `replace`,
+   *  starting from byte position `pos` of `src` and proceeding for `len` 
bytes.
    *
    * @group string_funcs
    * @since 3.0.0
    */
-  def overlay(src: Column, replaceString: String, pos: Int, len: Int): Column 
= withExpr {
-    Overlay(src.expr, lit(replaceString).expr, lit(pos).expr, lit(len).expr)
+  def overlay(src: Column, replace: Column, pos: Column, len: Column): Column 
= withExpr {
+    Overlay(src.expr, replace.expr, pos.expr, len.expr)
   }
 
   /**
-   * Overlay the specified portion of `src` with `replaceString`,
-   *  starting from byte position `pos` of `inputString`.
+   * Overlay the specified portion of `src` with `replace`,
+   *  starting from byte position `pos` of `src`.
    *
    * @group string_funcs
    * @since 3.0.0
    */
-  def overlay(src: Column, replaceString: String, pos: Int): Column = withExpr 
{
-    new Overlay(src.expr, lit(replaceString).expr, lit(pos).expr)
+  def overlay(src: Column, replace: Column, pos: Column): Column = withExpr {
+    new Overlay(src.expr, replace.expr, pos.expr)
   }
 
   /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 88b3e5e..5049df3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -129,18 +129,37 @@ class StringFunctionsSuite extends QueryTest with 
SharedSparkSession {
       Row("AQIDBA==", bytes))
   }
 
-  test("overlay function") {
+  test("string overlay function") {
     // scalastyle:off
     // non ascii characters are not allowed in the code, so we disable the 
scalastyle here.
-    val df = Seq(("Spark SQL", "Spark的SQL")).toDF("a", "b")
-    checkAnswer(df.select(overlay($"a", "_", 6)), Row("Spark_SQL"))
-    checkAnswer(df.select(overlay($"a", "CORE", 7)), Row("Spark CORE"))
-    checkAnswer(df.select(overlay($"a", "ANSI ", 7, 0)), Row("Spark ANSI SQL"))
-    checkAnswer(df.select(overlay($"a", "tructured", 2, 4)), Row("Structured 
SQL"))
-    checkAnswer(df.select(overlay($"b", "_", 6)), Row("Spark_SQL"))
+    val df = Seq(("Spark SQL", "Spark的SQL", "_", "CORE", "ANSI ", "tructured", 
6, 7, 0, 2, 4)).
+      toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k")
+    checkAnswer(df.select(overlay($"a", $"c", $"g")), Row("Spark_SQL"))
+    checkAnswer(df.select(overlay($"a", $"d", $"h")), Row("Spark CORE"))
+    checkAnswer(df.select(overlay($"a", $"e", $"h", $"i")), Row("Spark ANSI 
SQL"))
+    checkAnswer(df.select(overlay($"a", $"f", $"j", $"k")), Row("Structured 
SQL"))
+    checkAnswer(df.select(overlay($"b", $"c", $"g")), Row("Spark_SQL"))
     // scalastyle:on
   }
 
+  test("binary overlay function") {
+    // non ascii characters are not allowed in the code, so we disable the 
scalastyle here.
+    val df = Seq((
+      Array[Byte](1, 2, 3, 4, 5, 6, 7, 8, 9),
+      Array[Byte](-1),
+      Array[Byte](-1, -1, -1, -1),
+      Array[Byte](-1, -1),
+      Array[Byte](-1, -1, -1, -1, -1),
+      6, 7, 0, 2, 4)).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i", "j")
+    checkAnswer(df.select(overlay($"a", $"b", $"f")), Row(Array[Byte](1, 2, 3, 
4, 5, -1, 7, 8, 9)))
+    checkAnswer(df.select(overlay($"a", $"c", $"g")),
+      Row(Array[Byte](1, 2, 3, 4, 5, 6, -1, -1, -1, -1)))
+    checkAnswer(df.select(overlay($"a", $"d", $"g", $"h")),
+      Row(Array[Byte](1, 2, 3, 4, 5, 6, -1, -1, 7, 8, 9)))
+    checkAnswer(df.select(overlay($"a", $"e", $"i", $"j")),
+      Row(Array[Byte](1, -1, -1, -1, -1, -1, 6, 7, 8, 9)))
+  }
+
   test("string / binary substring function") {
     // scalastyle:off
     // non ascii characters are not allowed in the code, so we disable the 
scalastyle here.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to