Repository: spark
Updated Branches:
  refs/heads/master d5bec48b9 -> 46bb2b512


[SPARK-23924][SQL] Add element_at function

## What changes were proposed in this pull request?

The PR adds the SQL function `element_at`. The behavior of the function is 
based on Presto's one.

This function returns element of array at given index in value if column is 
array, or returns value for the given key in value if column is map.

## How was this patch tested?

Added UTs

Author: Kazuaki Ishizaki <[email protected]>

Closes #21053 from kiszk/SPARK-23924.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/46bb2b51
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/46bb2b51
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/46bb2b51

Branch: refs/heads/master
Commit: 46bb2b5129833cc5829089bf1174a76cb7b81741
Parents: d5bec48
Author: Kazuaki Ishizaki <[email protected]>
Authored: Thu Apr 19 21:00:10 2018 +0900
Committer: Takuya UESHIN <[email protected]>
Committed: Thu Apr 19 21:00:10 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  24 +++++
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../expressions/collectionOperations.scala      | 104 +++++++++++++++++++
 .../expressions/complexTypeExtractors.scala     |  64 +++++++-----
 .../CollectionExpressionsSuite.scala            |  48 +++++++++
 .../scala/org/apache/spark/sql/functions.scala  |  11 ++
 .../spark/sql/DataFrameFunctionsSuite.scala     |  48 +++++++++
 7 files changed, 276 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/46bb2b51/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 36dcabc..1be68f2 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1862,6 +1862,30 @@ def array_position(col, value):
     return Column(sc._jvm.functions.array_position(_to_java_column(col), 
value))
 
 
+@ignore_unicode_prefix
+@since(2.4)
+def element_at(col, extraction):
+    """
+    Collection function: Returns element of array at given index in extraction 
if col is array.
+    Returns value for the given key in extraction if col is map.
+
+    :param col: name of column containing array or map
+    :param extraction: index to check for in array or key to check for in map
+
+    .. note:: The position is not zero based, but 1 based index.
+
+    >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data'])
+    >>> df.select(element_at(df.data, 1)).collect()
+    [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)]
+
+    >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data'])
+    >>> df.select(element_at(df.data, "a")).collect()
+    [Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.element_at(_to_java_column(col), 
extraction))
+
+
 @since(1.4)
 def explode(col):
     """Returns a new row for each element in the given array or map.

http://git-wip-us.apache.org/repos/asf/spark/blob/46bb2b51/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 74095fe..a44f2d5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -405,6 +405,7 @@ object FunctionRegistry {
     expression[ArrayPosition]("array_position"),
     expression[CreateMap]("map"),
     expression[CreateNamedStruct]("named_struct"),
+    expression[ElementAt]("element_at"),
     expression[MapKeys]("map_keys"),
     expression[MapValues]("map_values"),
     expression[Size]("size"),

http://git-wip-us.apache.org/repos/asf/spark/blob/46bb2b51/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index e6a05f5..dba426e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -561,3 +561,107 @@ case class ArrayPosition(left: Expression, right: 
Expression)
     })
   }
 }
+
+/**
+ * Returns the value of index `right` in Array `left` or the value for key 
`right` in Map `left`.
+ */
+@ExpressionDescription(
+  usage = """
+    _FUNC_(array, index) - Returns element of array at given (1-based) index. 
If index < 0,
+      accesses elements from the last to the first. Returns NULL if the index 
exceeds the length
+      of the array.
+
+    _FUNC_(map, key) - Returns value for given key, or NULL if the key is not 
contained in the map
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3), 2);
+       2
+      > SELECT _FUNC_(map(1, 'a', 2, 'b'), 2);
+       "b"
+  """,
+  since = "2.4.0")
+case class ElementAt(left: Expression, right: Expression) extends 
GetMapValueUtil {
+
+  override def dataType: DataType = left.dataType match {
+    case ArrayType(elementType, _) => elementType
+    case MapType(_, valueType, _) => valueType
+  }
+
+  override def inputTypes: Seq[AbstractDataType] = {
+    Seq(TypeCollection(ArrayType, MapType),
+      left.dataType match {
+        case _: ArrayType => IntegerType
+        case _: MapType => left.dataType.asInstanceOf[MapType].keyType
+      }
+    )
+  }
+
+  override def nullable: Boolean = true
+
+  override def nullSafeEval(value: Any, ordinal: Any): Any = {
+    left.dataType match {
+      case _: ArrayType =>
+        val array = value.asInstanceOf[ArrayData]
+        val index = ordinal.asInstanceOf[Int]
+        if (array.numElements() < math.abs(index)) {
+          null
+        } else {
+          val idx = if (index == 0) {
+            throw new ArrayIndexOutOfBoundsException("SQL array indices start 
at 1")
+          } else if (index > 0) {
+            index - 1
+          } else {
+            array.numElements() + index
+          }
+          if (left.dataType.asInstanceOf[ArrayType].containsNull && 
array.isNullAt(idx)) {
+            null
+          } else {
+            array.get(idx, dataType)
+          }
+        }
+      case _: MapType =>
+        getValueEval(value, ordinal, 
left.dataType.asInstanceOf[MapType].keyType)
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    left.dataType match {
+      case _: ArrayType =>
+        nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
+          val index = ctx.freshName("elementAtIndex")
+          val nullCheck = if 
(left.dataType.asInstanceOf[ArrayType].containsNull) {
+            s"""
+               |if ($eval1.isNullAt($index)) {
+               |  ${ev.isNull} = true;
+               |} else
+             """.stripMargin
+          } else {
+            ""
+          }
+          s"""
+             |int $index = (int) $eval2;
+             |if ($eval1.numElements() < Math.abs($index)) {
+             |  ${ev.isNull} = true;
+             |} else {
+             |  if ($index == 0) {
+             |    throw new ArrayIndexOutOfBoundsException("SQL array indices 
start at 1");
+             |  } else if ($index > 0) {
+             |    $index--;
+             |  } else {
+             |    $index += $eval1.numElements();
+             |  }
+             |  $nullCheck
+             |  {
+             |    ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, 
index)};
+             |  }
+             |}
+           """.stripMargin
+        })
+      case _: MapType =>
+        doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType])
+    }
+  }
+
+  override def prettyName: String = "element_at"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/46bb2b51/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 6cdad19..3fba52d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -268,31 +268,12 @@ case class GetArrayItem(child: Expression, ordinal: 
Expression)
 }
 
 /**
- * Returns the value of key `key` in Map `child`.
- *
- * We need to do type checking here as `key` expression maybe unresolved.
+ * Common base class for [[GetMapValue]] and [[ElementAt]].
  */
-case class GetMapValue(child: Expression, key: Expression)
-  extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with 
NullIntolerant {
-
-  private def keyType = child.dataType.asInstanceOf[MapType].keyType
-
-  // We have done type checking for child in `ExtractValue`, so only need to 
check the `key`.
-  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
-
-  override def toString: String = s"$child[$key]"
-  override def sql: String = s"${child.sql}[${key.sql}]"
-
-  override def left: Expression = child
-  override def right: Expression = key
-
-  /** `Null` is returned for invalid ordinals. */
-  override def nullable: Boolean = true
-
-  override def dataType: DataType = 
child.dataType.asInstanceOf[MapType].valueType
 
+abstract class GetMapValueUtil extends BinaryExpression with 
ImplicitCastInputTypes {
   // todo: current search is O(n), improve it.
-  protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
+  def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = {
     val map = value.asInstanceOf[MapData]
     val length = map.numElements()
     val keys = map.keyArray()
@@ -315,14 +296,15 @@ case class GetMapValue(child: Expression, key: Expression)
     }
   }
 
-  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+  def doGetValueGenCode(ctx: CodegenContext, ev: ExprCode, mapType: MapType): 
ExprCode = {
     val index = ctx.freshName("index")
     val length = ctx.freshName("length")
     val keys = ctx.freshName("keys")
     val found = ctx.freshName("found")
     val key = ctx.freshName("key")
     val values = ctx.freshName("values")
-    val nullCheck = if 
(child.dataType.asInstanceOf[MapType].valueContainsNull) {
+    val keyType = mapType.keyType
+    val nullCheck = if (mapType.valueContainsNull) {
       s" || $values.isNullAt($index)"
     } else {
       ""
@@ -354,3 +336,37 @@ case class GetMapValue(child: Expression, key: Expression)
     })
   }
 }
+
+/**
+ * Returns the value of key `key` in Map `child`.
+ *
+ * We need to do type checking here as `key` expression maybe unresolved.
+ */
+case class GetMapValue(child: Expression, key: Expression)
+  extends GetMapValueUtil with ExtractValue with NullIntolerant {
+
+  private def keyType = child.dataType.asInstanceOf[MapType].keyType
+
+  // We have done type checking for child in `ExtractValue`, so only need to 
check the `key`.
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
+
+  override def toString: String = s"$child[$key]"
+  override def sql: String = s"${child.sql}[${key.sql}]"
+
+  override def left: Expression = child
+  override def right: Expression = key
+
+  /** `Null` is returned for invalid ordinals. */
+  override def nullable: Boolean = true
+
+  override def dataType: DataType = 
child.dataType.asInstanceOf[MapType].valueType
+
+  // todo: current search is O(n), improve it.
+  override def nullSafeEval(value: Any, ordinal: Any): Any = {
+    getValueEval(value, ordinal, keyType)
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    doGetValueGenCode(ctx, ev, child.dataType.asInstanceOf[MapType])
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/46bb2b51/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 916cd3b..7d8fe21 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -191,4 +191,52 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     checkEvaluation(ArrayPosition(a3, Literal("")), null)
     checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
   }
+
+  test("elementAt") {
+    val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
+    val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
+    val a2 = Literal.create(Seq(null), ArrayType(LongType))
+    val a3 = Literal.create(null, ArrayType(StringType))
+
+    intercept[Exception] {
+      checkEvaluation(ElementAt(a0, Literal(0)), null)
+    }.getMessage.contains("SQL array indices start at 1")
+    intercept[Exception] { checkEvaluation(ElementAt(a0, Literal(1.1)), null) }
+    checkEvaluation(ElementAt(a0, Literal(4)), null)
+    checkEvaluation(ElementAt(a0, Literal(-4)), null)
+
+    checkEvaluation(ElementAt(a0, Literal(1)), 1)
+    checkEvaluation(ElementAt(a0, Literal(2)), 2)
+    checkEvaluation(ElementAt(a0, Literal(3)), 3)
+    checkEvaluation(ElementAt(a0, Literal(-3)), 1)
+    checkEvaluation(ElementAt(a0, Literal(-2)), 2)
+    checkEvaluation(ElementAt(a0, Literal(-1)), 3)
+
+    checkEvaluation(ElementAt(a1, Literal(1)), null)
+    checkEvaluation(ElementAt(a1, Literal(2)), "")
+    checkEvaluation(ElementAt(a1, Literal(-2)), null)
+    checkEvaluation(ElementAt(a1, Literal(-1)), "")
+
+    checkEvaluation(ElementAt(a2, Literal(1)), null)
+
+    checkEvaluation(ElementAt(a3, Literal(1)), null)
+
+
+    val m0 =
+      Literal.create(Map("a" -> "1", "b" -> "2", "c" -> null), 
MapType(StringType, StringType))
+    val m1 = Literal.create(Map[String, String](), MapType(StringType, 
StringType))
+    val m2 = Literal.create(null, MapType(StringType, StringType))
+
+    checkEvaluation(ElementAt(m0, Literal(1.0)), null)
+
+    checkEvaluation(ElementAt(m0, Literal("d")), null)
+
+    checkEvaluation(ElementAt(m1, Literal("a")), null)
+
+    checkEvaluation(ElementAt(m0, Literal("a")), "1")
+    checkEvaluation(ElementAt(m0, Literal("b")), "2")
+    checkEvaluation(ElementAt(m0, Literal("c")), null)
+
+    checkEvaluation(ElementAt(m2, Literal("a")), null)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/46bb2b51/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3a09ec4..9c85803 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3053,6 +3053,17 @@ object functions {
   }
 
   /**
+   * Returns element of array at given index in value if column is array. 
Returns value for
+   * the given key in value if column is map.
+   *
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def element_at(column: Column, value: Any): Column = withExpr {
+    ElementAt(column.expr, Literal(value))
+  }
+
+  /**
    * Creates a new row for each element in the given array or map column.
    *
    * @group collection_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/46bb2b51/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 13161e7..7c976c1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -569,6 +569,54 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     )
   }
 
+  test("element_at function") {
+    val df = Seq(
+      (Seq[String]("1", "2", "3")),
+      (Seq[String](null, "")),
+      (Seq[String]())
+    ).toDF("a")
+
+    intercept[Exception] {
+      checkAnswer(
+        df.select(element_at(df("a"), 0)),
+        Seq(Row(null), Row(null), Row(null))
+      )
+    }.getMessage.contains("SQL array indices start at 1")
+    intercept[Exception] {
+      checkAnswer(
+        df.select(element_at(df("a"), 1.1)),
+        Seq(Row(null), Row(null), Row(null))
+      )
+    }
+    checkAnswer(
+      df.select(element_at(df("a"), 4)),
+      Seq(Row(null), Row(null), Row(null))
+    )
+
+    checkAnswer(
+      df.select(element_at(df("a"), 1)),
+      Seq(Row("1"), Row(null), Row(null))
+    )
+    checkAnswer(
+      df.select(element_at(df("a"), -1)),
+      Seq(Row("3"), Row(""), Row(null))
+    )
+
+    checkAnswer(
+      df.selectExpr("element_at(a, 4)"),
+      Seq(Row(null), Row(null), Row(null))
+    )
+
+    checkAnswer(
+      df.selectExpr("element_at(a, 1)"),
+      Seq(Row("1"), Row(null), Row(null))
+    )
+    checkAnswer(
+      df.selectExpr("element_at(a, -1)"),
+      Seq(Row("3"), Row(""), Row(null))
+    )
+  }
+
   private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
     import DataFrameFunctionsSuite.CodegenFallbackExpr
     for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), 
(false, true))) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to