Repository: spark
Updated Branches:
  refs/heads/branch-2.3 55afac4e7 -> bf853018c


[SPARK-22937][SQL] SQL elt output binary for binary inputs

## What changes were proposed in this pull request?
This pr modified `elt` to output binary for binary inputs.
`elt` in the current master always output data as a string. But, in some 
databases (e.g., MySQL), if all inputs are binary, `elt` also outputs binary 
(Also, this might be a small surprise).
This pr is related to #19977.

## How was this patch tested?
Added tests in `SQLQueryTestSuite` and `TypeCoercionSuite`.

Author: Takeshi Yamamuro <yamam...@apache.org>

Closes #20135 from maropu/SPARK-22937.

(cherry picked from commit e8af7e8aeca15a6107248f358d9514521ffdc6d3)
Signed-off-by: gatorsmile <gatorsm...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bf853018
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bf853018
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bf853018

Branch: refs/heads/branch-2.3
Commit: bf853018cabcd3b3abf84bfe534d2981020b4a71
Parents: 55afac4
Author: Takeshi Yamamuro <yamam...@apache.org>
Authored: Sat Jan 6 09:26:03 2018 +0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Sat Jan 6 09:26:21 2018 +0800

----------------------------------------------------------------------
 docs/sql-programming-guide.md                   |   2 +
 .../sql/catalyst/analysis/TypeCoercion.scala    |  29 +++++
 .../expressions/stringExpressions.scala         |  46 +++++---
 .../org/apache/spark/sql/internal/SQLConf.scala |   8 ++
 .../catalyst/analysis/TypeCoercionSuite.scala   |  54 +++++++++
 .../inputs/typeCoercion/native/elt.sql          |  44 +++++++
 .../results/typeCoercion/native/elt.sql.out     | 115 +++++++++++++++++++
 7 files changed, 281 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bf853018/docs/sql-programming-guide.md
----------------------------------------------------------------------
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index dc3e384..b50f936 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1783,6 +1783,8 @@ options.
 
  - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns 
an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it 
always returns as a string despite of input types. To keep the old behavior, 
set `spark.sql.function.concatBinaryAsString` to `true`.
 
+ - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output 
as binary. Otherwise, it returns as a string. Until Spark 2.3, it always 
returns as a string despite of input types. To keep the old behavior, set 
`spark.sql.function.eltOutputAsString` to `true`.
+
 ## Upgrading From Spark SQL 2.1 to 2.2
 
   - Spark 2.1.1 introduced a new configuration key: 
`spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of 
`NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 
changes this setting's default value to `INFER_AND_SAVE` to restore 
compatibility with reading Hive metastore tables whose underlying file schema 
have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on 
first access Spark will perform schema inference on any Hive metastore table 
for which it has not already saved an inferred schema. Note that schema 
inference can be a very time consuming operation for tables with thousands of 
partitions. If compatibility with mixed-case column names is not a concern, you 
can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to 
avoid the initial overhead of schema inference. Note that with the new default 
`INFER_AND_SAVE` setting, the results of the schema inference are saved as a 
metastore key for future use
 . Therefore, the initial schema inference occurs only at a table's first 
access.

http://git-wip-us.apache.org/repos/asf/spark/blob/bf853018/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index e943636..e8669c4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -54,6 +54,7 @@ object TypeCoercion {
       BooleanEquality ::
       FunctionArgumentConversion ::
       ConcatCoercion(conf) ::
+      EltCoercion(conf) ::
       CaseWhenCoercion ::
       IfCoercion ::
       StackCoercion ::
@@ -685,6 +686,34 @@ object TypeCoercion {
   }
 
   /**
+   * Coerces the types of [[Elt]] children to expected ones.
+   *
+   * If `spark.sql.function.eltOutputAsString` is false and all children types 
are binary,
+   * the expected types are binary. Otherwise, the expected ones are strings.
+   */
+  case class EltCoercion(conf: SQLConf) extends TypeCoercionRule {
+
+    override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan 
transform { case p =>
+      p transformExpressionsUp {
+        // Skip nodes if unresolved or not enough children
+        case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c
+        case c @ Elt(children) =>
+          val index = children.head
+          val newIndex = ImplicitTypeCasts.implicitCast(index, 
IntegerType).getOrElse(index)
+          val newInputs = if (conf.eltOutputAsString ||
+              !children.tail.map(_.dataType).forall(_ == BinaryType)) {
+            children.tail.map { e =>
+              ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
+            }
+          } else {
+            children.tail
+          }
+          c.copy(children = newIndex +: newInputs)
+      }
+    }
+  }
+
+  /**
    * Turns Add/Subtract of DateType/TimestampType/StringType and 
CalendarIntervalType
    * to TimeAdd/TimeSub
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/bf853018/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 41dc762..e004bfc 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -271,33 +271,45 @@ case class ConcatWs(children: Seq[Expression])
   }
 }
 
+/**
+ * An expression that returns the `n`-th input in given inputs.
+ * If all inputs are binary, `elt` returns an output as binary. Otherwise, it 
returns as string.
+ * If any input is null, `elt` returns null.
+ */
 // scalastyle:off line.size.limit
 @ExpressionDescription(
-  usage = "_FUNC_(n, str1, str2, ...) - Returns the `n`-th string, e.g., 
returns `str2` when `n` is 2.",
+  usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., 
returns `input2` when `n` is 2.",
   examples = """
     Examples:
       > SELECT _FUNC_(1, 'scala', 'java');
        scala
   """)
 // scalastyle:on line.size.limit
-case class Elt(children: Seq[Expression])
-  extends Expression with ImplicitCastInputTypes {
+case class Elt(children: Seq[Expression]) extends Expression {
 
   private lazy val indexExpr = children.head
-  private lazy val stringExprs = children.tail.toArray
+  private lazy val inputExprs = children.tail.toArray
 
   /** This expression is always nullable because it returns null if index is 
out of range. */
   override def nullable: Boolean = true
 
-  override def dataType: DataType = StringType
-
-  override def inputTypes: Seq[DataType] = IntegerType +: 
Seq.fill(children.size - 1)(StringType)
+  override def dataType: DataType = 
inputExprs.map(_.dataType).headOption.getOrElse(StringType)
 
   override def checkInputDataTypes(): TypeCheckResult = {
     if (children.size < 2) {
       TypeCheckResult.TypeCheckFailure("elt function requires at least two 
arguments")
     } else {
-      super[ImplicitCastInputTypes].checkInputDataTypes()
+      val (indexType, inputTypes) = (indexExpr.dataType, 
inputExprs.map(_.dataType))
+      if (indexType != IntegerType) {
+        return TypeCheckResult.TypeCheckFailure(s"first input to function 
$prettyName should " +
+          s"have IntegerType, but it's $indexType")
+      }
+      if (inputTypes.exists(tpe => !Seq(StringType, 
BinaryType).contains(tpe))) {
+        return TypeCheckResult.TypeCheckFailure(
+          s"input to function $prettyName should have StringType or 
BinaryType, but it's " +
+            inputTypes.map(_.simpleString).mkString("[", ", ", "]"))
+      }
+      TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName")
     }
   }
 
@@ -307,27 +319,27 @@ case class Elt(children: Seq[Expression])
       null
     } else {
       val index = indexObj.asInstanceOf[Int]
-      if (index <= 0 || index > stringExprs.length) {
+      if (index <= 0 || index > inputExprs.length) {
         null
       } else {
-        stringExprs(index - 1).eval(input)
+        inputExprs(index - 1).eval(input)
       }
     }
   }
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
     val index = indexExpr.genCode(ctx)
-    val strings = stringExprs.map(_.genCode(ctx))
+    val inputs = inputExprs.map(_.genCode(ctx))
     val indexVal = ctx.freshName("index")
     val indexMatched = ctx.freshName("eltIndexMatched")
 
-    val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal")
+    val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal")
 
-    val assignStringValue = strings.zipWithIndex.map { case (eval, index) =>
+    val assignInputValue = inputs.zipWithIndex.map { case (eval, index) =>
       s"""
          |if ($indexVal == ${index + 1}) {
          |  ${eval.code}
-         |  $stringVal = ${eval.isNull} ? null : ${eval.value};
+         |  $inputVal = ${eval.isNull} ? null : ${eval.value};
          |  $indexMatched = true;
          |  continue;
          |}
@@ -335,7 +347,7 @@ case class Elt(children: Seq[Expression])
     }
 
     val codes = ctx.splitExpressionsWithCurrentInputs(
-      expressions = assignStringValue,
+      expressions = assignInputValue,
       funcName = "eltFunc",
       extraArguments = ("int", indexVal) :: Nil,
       returnType = ctx.JAVA_BOOLEAN,
@@ -361,11 +373,11 @@ case class Elt(children: Seq[Expression])
          |${index.code}
          |final int $indexVal = ${index.value};
          |${ctx.JAVA_BOOLEAN} $indexMatched = false;
-         |$stringVal = null;
+         |$inputVal = null;
          |do {
          |  $codes
          |} while (false);
-         |final UTF8String ${ev.value} = $stringVal;
+         |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal;
          |final boolean ${ev.isNull} = ${ev.value} == null;
        """.stripMargin)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/bf853018/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 5d6edf6..80b8965 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1052,6 +1052,12 @@ object SQLConf {
     .booleanConf
     .createWithDefault(false)
 
+  val ELT_OUTPUT_AS_STRING = buildConf("spark.sql.function.eltOutputAsString")
+    .doc("When this option is set to false and all inputs are binary, `elt` 
returns " +
+      "an output as binary. Otherwise, it returns as a string. ")
+    .booleanConf
+    .createWithDefault(false)
+
   val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
     buildConf("spark.sql.streaming.continuous.executorQueueSize")
     .internal()
@@ -1412,6 +1418,8 @@ class SQLConf extends Serializable with Logging {
 
   def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING)
 
+  def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING)
+
   def partitionOverwriteMode: PartitionOverwriteMode.Value =
     PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bf853018/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 3661530..52a7ebd 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -923,6 +923,60 @@ class TypeCoercionSuite extends AnalysisTest {
     }
   }
 
+  test("type coercion for Elt") {
+    val rule = TypeCoercion.EltCoercion(conf)
+
+    ruleTest(rule,
+      Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))),
+      Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))))
+    ruleTest(rule,
+      Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))),
+      Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), 
Literal("cde"))))
+    ruleTest(rule,
+      Elt(Seq(Literal(2), Literal(null), Literal("abc"))),
+      Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc"))))
+    ruleTest(rule,
+      Elt(Seq(Literal(2), Literal(1), Literal("234"))),
+      Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234"))))
+    ruleTest(rule,
+      Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))),
+      Elt(Seq(Literal(3), Cast(Literal(1L), StringType), 
Cast(Literal(2.toByte), StringType),
+        Cast(Literal(0.1), StringType))))
+    ruleTest(rule,
+      Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))),
+      Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), 
StringType),
+        Cast(Literal(3.toShort), StringType))))
+    ruleTest(rule,
+      Elt(Seq(Literal(1), Literal(1L), Literal(0.1))),
+      Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), 
StringType))))
+    ruleTest(rule,
+      Elt(Seq(Literal(1), Literal(Decimal(10)))),
+      Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType))))
+    ruleTest(rule,
+      Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))),
+      Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType))))
+    ruleTest(rule,
+      Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))),
+      Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), 
StringType))))
+    ruleTest(rule,
+      Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new 
Timestamp(0)))),
+      Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType),
+        Cast(Literal(new Timestamp(0)), StringType))))
+
+    withSQLConf("spark.sql.function.eltOutputAsString" -> "true") {
+      ruleTest(rule,
+        Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))),
+        Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType),
+          Cast(Literal("456".getBytes), StringType))))
+    }
+
+    withSQLConf("spark.sql.function.eltOutputAsString" -> "false") {
+      ruleTest(rule,
+        Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))),
+        Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))))
+    }
+  }
+
   test("BooleanEquality type cast") {
     val be = TypeCoercion.BooleanEquality
     // Use something more than a literal to avoid triggering the 
simplification rules.

http://git-wip-us.apache.org/repos/asf/spark/blob/bf853018/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql 
b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql
new file mode 100644
index 0000000..717616f
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql
@@ -0,0 +1,44 @@
+-- Mixed inputs (output type is string)
+SELECT elt(2, col1, col2, col3, col4, col5) col
+FROM (
+  SELECT
+    'prefix_' col1,
+    id col2,
+    string(id + 1) col3,
+    encode(string(id + 2), 'utf-8') col4,
+    CAST(id AS DOUBLE) col5
+  FROM range(10)
+);
+
+SELECT elt(3, col1, col2, col3, col4) col
+FROM (
+  SELECT
+    string(id) col1,
+    string(id + 1) col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+);
+
+-- turn on eltOutputAsString
+set spark.sql.function.eltOutputAsString=true;
+
+SELECT elt(1, col1, col2) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2
+  FROM range(10)
+);
+
+-- turn off eltOutputAsString
+set spark.sql.function.eltOutputAsString=false;
+
+-- Elt binary inputs (output type is binary)
+SELECT elt(2, col1, col2) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2
+  FROM range(10)
+);

http://git-wip-us.apache.org/repos/asf/spark/blob/bf853018/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out
new file mode 100644
index 0000000..b62e1b6
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out
@@ -0,0 +1,115 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 6
+
+
+-- !query 0
+SELECT elt(2, col1, col2, col3, col4, col5) col
+FROM (
+  SELECT
+    'prefix_' col1,
+    id col2,
+    string(id + 1) col3,
+    encode(string(id + 2), 'utf-8') col4,
+    CAST(id AS DOUBLE) col5
+  FROM range(10)
+)
+-- !query 0 schema
+struct<col:string>
+-- !query 0 output
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+
+
+-- !query 1
+SELECT elt(3, col1, col2, col3, col4) col
+FROM (
+  SELECT
+    string(id) col1,
+    string(id + 1) col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+)
+-- !query 1 schema
+struct<col:string>
+-- !query 1 output
+10
+11
+2
+3
+4
+5
+6
+7
+8
+9
+
+
+-- !query 2
+set spark.sql.function.eltOutputAsString=true
+-- !query 2 schema
+struct<key:string,value:string>
+-- !query 2 output
+spark.sql.function.eltOutputAsString   true
+
+
+-- !query 3
+SELECT elt(1, col1, col2) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2
+  FROM range(10)
+)
+-- !query 3 schema
+struct<col:string>
+-- !query 3 output
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+
+
+-- !query 4
+set spark.sql.function.eltOutputAsString=false
+-- !query 4 schema
+struct<key:string,value:string>
+-- !query 4 output
+spark.sql.function.eltOutputAsString   false
+
+
+-- !query 5
+SELECT elt(2, col1, col2) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2
+  FROM range(10)
+)
+-- !query 5 schema
+struct<col:binary>
+-- !query 5 output
+1
+10
+2
+3
+4
+5
+6
+7
+8
+9


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to