Repository: spark Updated Branches: refs/heads/master 7822c3f8d -> 5b4a38d82
[SPARK-23939][SQL] Add transform_keys function ## What changes were proposed in this pull request? This pr adds transform_keys function which applies the function to each entry of the map and transforms the keys. ```javascript > SELECT transform_keys(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> k + 1); map(2->1, 3->2, 4->3) > SELECT transform_keys(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> k + v); map(2->1, 4->2, 6->3) ``` ## How was this patch tested? Added tests. Closes #22013 from codeatri/SPARK-23939. Authored-by: codeatri <nehapat...@gmail.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5b4a38d8 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5b4a38d8 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5b4a38d8 Branch: refs/heads/master Commit: 5b4a38d826807ea6733e4382c8f9b82a355a6eb4 Parents: 7822c3f Author: codeatri <nehapat...@gmail.com> Authored: Thu Aug 16 17:07:33 2018 +0900 Committer: Takuya UESHIN <ues...@databricks.com> Committed: Thu Aug 16 17:07:33 2018 +0900 ---------------------------------------------------------------------- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 53 ++++++++++++ .../expressions/HigherOrderFunctionsSuite.scala | 75 +++++++++++++++++ .../sql-tests/inputs/higher-order-functions.sql | 14 ++++ .../results/higher-order-functions.sql.out | 39 ++++++++- .../spark/sql/DataFrameFunctionsSuite.scala | 87 ++++++++++++++++++++ 6 files changed, 268 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/5b4a38d8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index cc2b758..b993e1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -446,6 +446,7 @@ object FunctionRegistry { expression[ArrayFilter]("filter"), expression[ArrayExists]("exists"), expression[ArrayAggregate]("aggregate"), + expression[TransformKeys]("transform_keys"), expression[MapZipWith]("map_zip_with"), CreateStruct.registryEntry, http://git-wip-us.apache.org/repos/asf/spark/blob/5b4a38d8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 22210f6..a305a05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -498,6 +498,59 @@ case class ArrayAggregate( } /** + * Transform Keys for every entry of the map by applying the transform_keys function. + * Returns map with transformed key entries + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.", + examples = """ + Examples: + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1); + map(array(2, 3, 4), array(1, 2, 3)) + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + map(array(2, 4, 6), array(1, 2, 3)) + """, + since = "2.4.0") +case class TransformKeys( + argument: Expression, + function: Expression) + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { + + override def nullable: Boolean = argument.nullable + + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType + + override def dataType: DataType = MapType(function.dataType, valueType, valueContainsNull) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + + @transient lazy val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + + + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val map = argumentValue.asInstanceOf[MapData] + val resultKeys = new GenericArrayData(new Array[Any](map.numElements)) + var i = 0 + while (i < map.numElements) { + keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) + valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) + val result = functionForEval.eval(inputRow) + if (result == null) { + throw new RuntimeException("Cannot use null as map key!") + } + resultKeys.update(i, result) + i += 1 + } + new ArrayBasedMapData(resultKeys, map.valueArray()) + } + + override def prettyName: String = "transform_keys" +} + +/** * Merges two given maps into a single map by applying function to the pair of values with * the same key. */ http://git-wip-us.apache.org/repos/asf/spark/blob/5b4a38d8/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 3137dc9..12ef018 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types._ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -74,6 +75,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f)) } + def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val map = expr.dataType.asInstanceOf[MapType] + TransformKeys(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f)) + } + def aggregate( expr: Expression, zero: Expression, @@ -283,6 +289,75 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper 15) } + test("TransformKeys") { + val ai0 = Literal.create( + Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val ai1 = Literal.create( + Map.empty[Int, Int], + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai2 = Literal.create( + Map(1 -> 1, 2 -> null, 3 -> 3), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1 + val plusValue: (Expression, Expression) => Expression = (k, v) => k + v + val modKey: (Expression, Expression) => Expression = (k, v) => k % 3 + + checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4)) + checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4)) + checkEvaluation( + transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) + checkEvaluation(transformKeys(ai0, modKey), + ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4))) + checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) + checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) + checkEvaluation( + transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int]) + checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3)) + checkEvaluation( + transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3)) + checkEvaluation(transformKeys(ai3, plusOne), null) + + val as0 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = false)) + val as1 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val as2 = Literal.create(null, + MapType(StringType, StringType, valueContainsNull = false)) + val as3 = Literal.create(Map.empty[StringType, StringType], + MapType(StringType, StringType, valueContainsNull = true)) + + val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) + val convertKeyToKeyLength: (Expression, Expression) => Expression = + (k, v) => Length(k) + 1 + + checkEvaluation( + transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx")) + checkEvaluation( + transformKeys(transformKeys(as0, concatValue), concatValue), + Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx")) + checkEvaluation(transformKeys(as3, concatValue), Map.empty[String, String]) + checkEvaluation( + transformKeys(transformKeys(as3, concatValue), convertKeyToKeyLength), + Map.empty[Int, String]) + checkEvaluation(transformKeys(as0, convertKeyToKeyLength), + Map(2 -> "xy", 3 -> "yz", 4 -> "zx")) + checkEvaluation(transformKeys(as1, convertKeyToKeyLength), + Map(2 -> "xy", 3 -> "yz", 4 -> null)) + checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null) + checkEvaluation(transformKeys(as3, convertKeyToKeyLength), Map.empty[Int, String]) + + val ax0 = Literal.create( + Map(1 -> "x", 2 -> "y", 3 -> "z"), + MapType(IntegerType, StringType, valueContainsNull = false)) + + checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z")) + } + test("MapZipWith") { def map_zip_with( left: Expression, http://git-wip-us.apache.org/repos/asf/spark/blob/5b4a38d8/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index ce1d0da..9a84544 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -51,3 +51,17 @@ select exists(ys, y -> y > 30) as v from nested; -- Check for element existence in a null array select exists(cast(null as array<int>), y -> y > 30) as v; + +create or replace temporary view nested as values + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) + as t(x, ys); + +-- Identity Transform Keys in a map +select transform_keys(ys, (k, v) -> k) as v from nested; + +-- Transform Keys in a map by adding constant +select transform_keys(ys, (k, v) -> k + 1) as v from nested; + +-- Transform Keys in a map using values +select transform_keys(ys, (k, v) -> k + v) as v from nested; http://git-wip-us.apache.org/repos/asf/spark/blob/5b4a38d8/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index e18abce..b77bda7 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 20 -- !query 0 @@ -163,3 +163,40 @@ select exists(cast(null as array<int>), y -> y > 30) as v struct<v:boolean> -- !query 16 output NULL + + +-- !query 17 +create or replace temporary view nested as values + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) + as t(x, ys) +-- !query 17 schema +struct<> +-- !query 17 output + + +-- !query 18 +select transform_keys(ys, (k, v) -> k) as v from nested +-- !query 18 schema +struct<v:map<int,int>> +-- !query 18 output +{1:1,2:2,3:3} +{4:4,5:5,6:6} + + +-- !query 19 +select transform_keys(ys, (k, v) -> k + 1) as v from nested +-- !query 19 schema +struct<v:map<int,int>> +-- !query 19 output +{2:1,3:2,4:3} +{5:4,6:5,7:6} + + +-- !query 20 +select transform_keys(ys, (k, v) -> k + v) as v from nested +-- !query 20 schema +struct<v:map<int,int>> +-- !query 20 output +{10:5,12:6,8:4} +{2:1,4:2,6:3} http://git-wip-us.apache.org/repos/asf/spark/blob/5b4a38d8/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8d7695b..22f1912 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2302,6 +2302,93 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map")) } + test("transform keys function - primitive data types") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70) + ).toDF("j") + + val dfExample3 = Seq( + Map[Int, Boolean](25 -> true, 26 -> false) + ).toDF("x") + + val dfExample4 = Seq( + Map[Array[Int], Boolean](Array(1, 2) -> false) + ).toDF("y") + + + def testMapOfPrimitiveTypesCombination(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), + Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, " + + "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), + Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), + Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), + Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), + Seq(Row(Map(true -> true, true -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), + Seq(Row(Map(50 -> true, 78 -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), + Seq(Row(Map(50 -> true, 78 -> false)))) + + checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), + Seq(Row(Map(false -> false)))) + } + // Test with local relation, the Project will be evaluated without codegen + testMapOfPrimitiveTypesCombination() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + dfExample4.cache() + // Test with cached relation, the Project will be evaluated with codegen + testMapOfPrimitiveTypesCombination() + } + + test("transform keys function - Invalid lambda functions and exceptions") { + + val dfExample1 = Seq( + Map[String, String]("a" -> null) + ).toDF("i") + + val dfExample2 = Seq( + Seq(1, 2, 3, 4) + ).toDF("j") + + val ex1 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_keys(i, k -> k)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_keys(i, (k, v, x) -> k + 1)") + } + assert(ex2.getMessage.contains( + "The number of lambda function arguments '3' does not match")) + + val ex3 = intercept[RuntimeException] { + dfExample1.selectExpr("transform_keys(i, (k, v) -> v)").show() + } + assert(ex3.getMessage.contains("Cannot use null as map key!")) + + val ex4 = intercept[AnalysisException] { + dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") + } + assert(ex4.getMessage.contains( + "data type mismatch: argument 1 requires map type")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org