Repository: spark
Updated Branches:
  refs/heads/master 7822c3f8d -> 5b4a38d82


[SPARK-23939][SQL] Add transform_keys function

## What changes were proposed in this pull request?
This pr adds transform_keys function which applies the function to each entry 
of the map and transforms the keys.
```javascript
> SELECT transform_keys(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> k + 1);
       map(2->1, 3->2, 4->3)

> SELECT transform_keys(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> k + v);
       map(2->1, 4->2, 6->3)
```

## How was this patch tested?
Added tests.

Closes #22013 from codeatri/SPARK-23939.

Authored-by: codeatri <nehapat...@gmail.com>
Signed-off-by: Takuya UESHIN <ues...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5b4a38d8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5b4a38d8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5b4a38d8

Branch: refs/heads/master
Commit: 5b4a38d826807ea6733e4382c8f9b82a355a6eb4
Parents: 7822c3f
Author: codeatri <nehapat...@gmail.com>
Authored: Thu Aug 16 17:07:33 2018 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Thu Aug 16 17:07:33 2018 +0900

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../expressions/higherOrderFunctions.scala      | 53 ++++++++++++
 .../expressions/HigherOrderFunctionsSuite.scala | 75 +++++++++++++++++
 .../sql-tests/inputs/higher-order-functions.sql | 14 ++++
 .../results/higher-order-functions.sql.out      | 39 ++++++++-
 .../spark/sql/DataFrameFunctionsSuite.scala     | 87 ++++++++++++++++++++
 6 files changed, 268 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5b4a38d8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index cc2b758..b993e1a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -446,6 +446,7 @@ object FunctionRegistry {
     expression[ArrayFilter]("filter"),
     expression[ArrayExists]("exists"),
     expression[ArrayAggregate]("aggregate"),
+    expression[TransformKeys]("transform_keys"),
     expression[MapZipWith]("map_zip_with"),
     CreateStruct.registryEntry,
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5b4a38d8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index 22210f6..a305a05 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -498,6 +498,59 @@ case class ArrayAggregate(
 }
 
 /**
+ * Transform Keys for every entry of the map by applying the transform_keys 
function.
+ * Returns map with transformed key entries
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(expr, func) - Transforms elements in a map using the 
function.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1);
+       map(array(2, 3, 4), array(1, 2, 3))
+      > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
+       map(array(2, 4, 6), array(1, 2, 3))
+  """,
+  since = "2.4.0")
+case class TransformKeys(
+    argument: Expression,
+    function: Expression)
+  extends MapBasedSimpleHigherOrderFunction with CodegenFallback {
+
+  override def nullable: Boolean = argument.nullable
+
+  @transient lazy val MapType(keyType, valueType, valueContainsNull) = 
argument.dataType
+
+  override def dataType: DataType = MapType(function.dataType, valueType, 
valueContainsNull)
+
+  override def bind(f: (Expression, Seq[(DataType, Boolean)]) => 
LambdaFunction): TransformKeys = {
+    copy(function = f(function, (keyType, false) :: (valueType, 
valueContainsNull) :: Nil))
+  }
+
+  @transient lazy val LambdaFunction(
+  _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, 
_) = function
+
+
+  override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
+    val map = argumentValue.asInstanceOf[MapData]
+    val resultKeys = new GenericArrayData(new Array[Any](map.numElements))
+    var i = 0
+    while (i < map.numElements) {
+      keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
+      valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
+      val result = functionForEval.eval(inputRow)
+      if (result == null) {
+        throw new RuntimeException("Cannot use null as map key!")
+      }
+      resultKeys.update(i, result)
+      i += 1
+    }
+    new ArrayBasedMapData(resultKeys, map.valueArray())
+  }
+
+  override def prettyName: String = "transform_keys"
+}
+
+/**
  * Merges two given maps into a single map by applying function to the pair of 
values with
  * the same key.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/5b4a38d8/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
index 3137dc9..12ef018 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
 import org.apache.spark.sql.types._
 
 class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
@@ -74,6 +75,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
     ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f))
   }
 
+  def transformKeys(expr: Expression, f: (Expression, Expression) => 
Expression): Expression = {
+    val map = expr.dataType.asInstanceOf[MapType]
+    TransformKeys(expr, createLambda(map.keyType, false, map.valueType, 
map.valueContainsNull, f))
+  }
+
   def aggregate(
       expr: Expression,
       zero: Expression,
@@ -283,6 +289,75 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
       15)
   }
 
+  test("TransformKeys") {
+    val ai0 = Literal.create(
+      Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4),
+      MapType(IntegerType, IntegerType, valueContainsNull = false))
+    val ai1 = Literal.create(
+      Map.empty[Int, Int],
+      MapType(IntegerType, IntegerType, valueContainsNull = true))
+    val ai2 = Literal.create(
+      Map(1 -> 1, 2 -> null, 3 -> 3),
+      MapType(IntegerType, IntegerType, valueContainsNull = true))
+    val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, 
valueContainsNull = false))
+
+    val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1
+    val plusValue: (Expression, Expression) => Expression = (k, v) => k + v
+    val modKey: (Expression, Expression) => Expression = (k, v) => k % 3
+
+    checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 
-> 4))
+    checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 
8 -> 4))
+    checkEvaluation(
+      transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 
2, 7 -> 3, 9 -> 4))
+    checkEvaluation(transformKeys(ai0, modKey),
+      ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4)))
+    checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
+    checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
+    checkEvaluation(
+      transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, 
Int])
+    checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 
3))
+    checkEvaluation(
+      transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> 
null, 5 -> 3))
+    checkEvaluation(transformKeys(ai3, plusOne), null)
+
+    val as0 = Literal.create(
+      Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"),
+      MapType(StringType, StringType, valueContainsNull = false))
+    val as1 = Literal.create(
+      Map("a" -> "xy", "bb" -> "yz", "ccc" -> null),
+      MapType(StringType, StringType, valueContainsNull = true))
+    val as2 = Literal.create(null,
+      MapType(StringType, StringType, valueContainsNull = false))
+    val as3 = Literal.create(Map.empty[StringType, StringType],
+      MapType(StringType, StringType, valueContainsNull = true))
+
+    val concatValue: (Expression, Expression) => Expression = (k, v) => 
Concat(Seq(k, v))
+    val convertKeyToKeyLength: (Expression, Expression) => Expression =
+      (k, v) => Length(k) + 1
+
+    checkEvaluation(
+      transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", 
"ccczx" -> "zx"))
+    checkEvaluation(
+      transformKeys(transformKeys(as0, concatValue), concatValue),
+      Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx"))
+    checkEvaluation(transformKeys(as3, concatValue), Map.empty[String, String])
+    checkEvaluation(
+      transformKeys(transformKeys(as3, concatValue), convertKeyToKeyLength),
+      Map.empty[Int, String])
+    checkEvaluation(transformKeys(as0, convertKeyToKeyLength),
+      Map(2 -> "xy", 3 -> "yz", 4 -> "zx"))
+    checkEvaluation(transformKeys(as1, convertKeyToKeyLength),
+      Map(2 -> "xy", 3 -> "yz", 4 -> null))
+    checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null)
+    checkEvaluation(transformKeys(as3, convertKeyToKeyLength), Map.empty[Int, 
String])
+
+    val ax0 = Literal.create(
+      Map(1 -> "x", 2 -> "y", 3 -> "z"),
+      MapType(IntegerType, StringType, valueContainsNull = false))
+
+    checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> 
"z"))
+  }
+
   test("MapZipWith") {
     def map_zip_with(
         left: Expression,

http://git-wip-us.apache.org/repos/asf/spark/blob/5b4a38d8/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
index ce1d0da..9a84544 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql
@@ -51,3 +51,17 @@ select exists(ys, y -> y > 30) as v from nested;
 
 -- Check for element existence in a null array
 select exists(cast(null as array<int>), y -> y > 30) as v;
+                                                                         
+create or replace temporary view nested as values
+  (1, map(1, 1, 2, 2, 3, 3)),
+  (2, map(4, 4, 5, 5, 6, 6))
+  as t(x, ys);
+
+-- Identity Transform Keys in a map
+select transform_keys(ys, (k, v) -> k) as v from nested;
+
+-- Transform Keys in a map by adding constant
+select transform_keys(ys, (k, v) -> k + 1) as v from nested;
+
+-- Transform Keys in a map using values
+select transform_keys(ys, (k, v) -> k + v) as v from nested;

http://git-wip-us.apache.org/repos/asf/spark/blob/5b4a38d8/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
index e18abce..b77bda7 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 15
+-- Number of queries: 20
 
 
 -- !query 0
@@ -163,3 +163,40 @@ select exists(cast(null as array<int>), y -> y > 30) as v
 struct<v:boolean>
 -- !query 16 output
 NULL
+
+
+-- !query 17
+create or replace temporary view nested as values
+  (1, map(1, 1, 2, 2, 3, 3)),
+  (2, map(4, 4, 5, 5, 6, 6))
+  as t(x, ys)
+-- !query 17 schema
+struct<>
+-- !query 17 output
+
+
+-- !query 18
+select transform_keys(ys, (k, v) -> k) as v from nested
+-- !query 18 schema
+struct<v:map<int,int>>
+-- !query 18 output
+{1:1,2:2,3:3}
+{4:4,5:5,6:6}
+
+
+-- !query 19
+select transform_keys(ys, (k, v) -> k + 1) as v from nested
+-- !query 19 schema
+struct<v:map<int,int>>
+-- !query 19 output
+{2:1,3:2,4:3}
+{5:4,6:5,7:6}
+
+
+-- !query 20
+select transform_keys(ys, (k, v) -> k + v) as v from nested
+-- !query 20 schema
+struct<v:map<int,int>>
+-- !query 20 output
+{10:5,12:6,8:4}
+{2:1,4:2,6:3}

http://git-wip-us.apache.org/repos/asf/spark/blob/5b4a38d8/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 8d7695b..22f1912 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -2302,6 +2302,93 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     assert(ex5.getMessage.contains("function map_zip_with does not support 
ordering on type map"))
   }
 
+  test("transform keys function - primitive data types") {
+    val dfExample1 = Seq(
+      Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
+    ).toDF("i")
+
+    val dfExample2 = Seq(
+      Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70)
+    ).toDF("j")
+
+    val dfExample3 = Seq(
+      Map[Int, Boolean](25 -> true, 26 -> false)
+    ).toDF("x")
+
+    val dfExample4 = Seq(
+      Map[Array[Int], Boolean](Array(1, 2) -> false)
+    ).toDF("y")
+
+
+    def testMapOfPrimitiveTypesCombination(): Unit = {
+      checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"),
+        Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))))
+
+      checkAnswer(dfExample2.selectExpr("transform_keys(j, " +
+        "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 
'three'))[k])"),
+        Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7))))
+
+      checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 
2 AS BIGINT) + k)"),
+        Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))))
+
+      checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"),
+        Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))
+
+      checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) ->  k % 2 = 
0 OR v)"),
+        Seq(Row(Map(true -> true, true -> false))))
+
+      checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * 
k, 3 * k))"),
+        Seq(Row(Map(50 -> true, 78 -> false))))
+
+      checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * 
k, 3 * k))"),
+        Seq(Row(Map(50 -> true, 78 -> false))))
+
+      checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> 
array_contains(k, 3) AND v)"),
+        Seq(Row(Map(false -> false))))
+    }
+    // Test with local relation, the Project will be evaluated without codegen
+    testMapOfPrimitiveTypesCombination()
+    dfExample1.cache()
+    dfExample2.cache()
+    dfExample3.cache()
+    dfExample4.cache()
+    // Test with cached relation, the Project will be evaluated with codegen
+    testMapOfPrimitiveTypesCombination()
+  }
+
+  test("transform keys function - Invalid lambda functions and exceptions") {
+
+    val dfExample1 = Seq(
+      Map[String, String]("a" -> null)
+    ).toDF("i")
+
+    val dfExample2 = Seq(
+      Seq(1, 2, 3, 4)
+    ).toDF("j")
+
+    val ex1 = intercept[AnalysisException] {
+      dfExample1.selectExpr("transform_keys(i, k -> k)")
+    }
+    assert(ex1.getMessage.contains("The number of lambda function arguments 
'1' does not match"))
+
+    val ex2 = intercept[AnalysisException] {
+      dfExample1.selectExpr("transform_keys(i, (k, v, x) -> k + 1)")
+    }
+    assert(ex2.getMessage.contains(
+      "The number of lambda function arguments '3' does not match"))
+
+    val ex3 = intercept[RuntimeException] {
+      dfExample1.selectExpr("transform_keys(i, (k, v) -> v)").show()
+    }
+    assert(ex3.getMessage.contains("Cannot use null as map key!"))
+
+    val ex4 = intercept[AnalysisException] {
+      dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)")
+    }
+    assert(ex4.getMessage.contains(
+      "data type mismatch: argument 1 requires map type"))
+  }
+
   private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
     import DataFrameFunctionsSuite.CodegenFallbackExpr
     for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), 
(false, true))) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to