Repository: spark Updated Branches: refs/heads/master e2ab7deae -> 42263fd0c
[SPARK-23938][SQL] Add map_zip_with function ## What changes were proposed in this pull request? This PR adds a new SQL function called ```map_zip_with```. It merges the two given maps into a single map by applying function to the pair of values with the same key. ## How was this patch tested? Added new tests into: - DataFrameFunctionsSuite.scala - HigherOrderFunctionsSuite.scala Closes #22017 from mn-mikke/SPARK-23938. Authored-by: Marek Novotny <mn.mi...@gmail.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/42263fd0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/42263fd0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/42263fd0 Branch: refs/heads/master Commit: 42263fd0cbdc86c68438515ac439a15033b8bbd2 Parents: e2ab7de Author: Marek Novotny <mn.mi...@gmail.com> Authored: Tue Aug 14 21:14:15 2018 +0900 Committer: Takuya UESHIN <ues...@databricks.com> Committed: Tue Aug 14 21:14:15 2018 +0900 ---------------------------------------------------------------------- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/analysis/TypeCoercion.scala | 25 +++ .../expressions/higherOrderFunctions.scala | 197 ++++++++++++++++++- .../expressions/HigherOrderFunctionsSuite.scala | 129 ++++++++++++ .../inputs/typeCoercion/native/mapZipWith.sql | 66 +++++++ .../typeCoercion/native/mapZipWith.sql.out | 142 +++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 64 ++++++ 7 files changed, 621 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 15543c9..cc2b758 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -446,6 +446,7 @@ object FunctionRegistry { expression[ArrayFilter]("filter"), expression[ArrayExists]("exists"), expression[ArrayAggregate]("aggregate"), + expression[MapZipWith]("map_zip_with"), CreateStruct.registryEntry, // misc functions http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 10d9ee5..288b635 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,6 +54,7 @@ object TypeCoercion { BooleanEquality :: FunctionArgumentConversion :: ConcatCoercion(conf) :: + MapZipWithCoercion :: EltCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: @@ -763,6 +764,30 @@ object TypeCoercion { } /** + * Coerces key types of two different [[MapType]] arguments of the [[MapZipWith]] expression + * to a common type. + */ + object MapZipWithCoercion extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + // Lambda function isn't resolved when the rule is executed. + case m @ MapZipWith(left, right, function) if m.arguments.forall(a => a.resolved && + MapType.acceptsType(a.dataType)) && !m.leftKeyType.sameType(m.rightKeyType) => + findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match { + case Some(finalKeyType) if !Cast.forceNullable(m.leftKeyType, finalKeyType) && + !Cast.forceNullable(m.rightKeyType, finalKeyType) => + val newLeft = castIfNotSameType( + left, + MapType(finalKeyType, m.leftValueType, m.leftValueContainsNull)) + val newRight = castIfNotSameType( + right, + MapType(finalKeyType, m.rightValueType, m.rightValueContainsNull)) + MapZipWith(newLeft, newRight, function) + case _ => m + } + } + } + + /** * Coerces the types of [[Elt]] children to expected ones. * * If `spark.sql.function.eltOutputAsString` is false and all children types are binary, http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 5d1b8c4..22210f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -22,11 +22,11 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.array.ByteArrayMethods /** * A named lambda variable. @@ -496,3 +496,194 @@ case class ArrayAggregate( override def prettyName: String = "aggregate" } + +/** + * Merges two given maps into a single map by applying function to the pair of values with + * the same key. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(map1, map2, function) - Merges two given maps into a single map by applying + function to the pair of values with the same key. For keys only presented in one map, + NULL will be passed as the value for the missing key. If an input map contains duplicated + keys, only the first entry of the duplicated key is passed into the lambda function. + """, + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)); + {1:"ax",2:"by"} + """, + since = "2.4.0") +case class MapZipWith(left: Expression, right: Expression, function: Expression) + extends HigherOrderFunction with CodegenFallback { + + def functionForEval: Expression = functionsForEval.head + + @transient lazy val MapType(leftKeyType, leftValueType, leftValueContainsNull) = left.dataType + + @transient lazy val MapType(rightKeyType, rightValueType, rightValueContainsNull) = right.dataType + + @transient lazy val keyType = + TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(leftKeyType, rightKeyType).get + + @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType) + + override def arguments: Seq[Expression] = left :: right :: Nil + + override def argumentTypes: Seq[AbstractDataType] = MapType :: MapType :: Nil + + override def functions: Seq[Expression] = function :: Nil + + override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil + + override def nullable: Boolean = left.nullable || right.nullable + + override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = { + val arguments = Seq((keyType, false), (leftValueType, true), (rightValueType, true)) + copy(function = f(function, arguments)) + } + + override def checkArgumentDataTypes(): TypeCheckResult = { + super.checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (leftKeyType.sameType(rightKeyType)) { + TypeUtils.checkForOrderingExpr(leftKeyType, s"function $prettyName") + } else { + TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + + s"been two ${MapType.simpleString}s with compatible key types, but the key types are " + + s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].") + } + case failure => failure + } + } + + override def checkInputDataTypes(): TypeCheckResult = checkArgumentDataTypes() + + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + if (value2 == null) { + null + } else { + nullSafeEval(input, value1, value2) + } + } + } + + @transient lazy val LambdaFunction(_, Seq( + keyVar: NamedLambdaVariable, + value1Var: NamedLambdaVariable, + value2Var: NamedLambdaVariable), + _) = function + + private def keyTypeSupportsEquals = keyType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + /** + * The function accepts two key arrays and returns a collection of keys with indexes + * to value arrays. Indexes are represented as an array of two items. This is a small + * optimization leveraging mutability of arrays. + */ + @transient private lazy val getKeysWithValueIndexes: + (ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = { + if (keyTypeSupportsEquals) { + getKeysWithIndexesFast + } else { + getKeysWithIndexesBruteForce + } + } + + private def assertSizeOfArrayBuffer(size: Int): Unit = { + if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to zip maps with $size " + + s"unique keys due to exceeding the array size limit " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + } + } + + private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = { + val hashMap = new mutable.LinkedHashMap[Any, Array[Option[Int]]] + for((z, array) <- Array((0, keys1), (1, keys2))) { + var i = 0 + while (i < array.numElements()) { + val key = array.get(i, keyType) + hashMap.get(key) match { + case Some(indexes) => + if (indexes(z).isEmpty) { + indexes(z) = Some(i) + } + case None => + val indexes = Array[Option[Int]](None, None) + indexes(z) = Some(i) + hashMap.put(key, indexes) + } + i += 1 + } + } + hashMap + } + + private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: ArrayData) = { + val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])] + for((z, array) <- Array((0, keys1), (1, keys2))) { + var i = 0 + while (i < array.numElements()) { + val key = array.get(i, keyType) + var found = false + var j = 0 + while (!found && j < arrayBuffer.size) { + val (bufferKey, indexes) = arrayBuffer(j) + if (ordering.equiv(bufferKey, key)) { + found = true + if(indexes(z).isEmpty) { + indexes(z) = Some(i) + } + } + j += 1 + } + if (!found) { + assertSizeOfArrayBuffer(arrayBuffer.size) + val indexes = Array[Option[Int]](None, None) + indexes(z) = Some(i) + arrayBuffer += Tuple2(key, indexes) + } + i += 1 + } + } + arrayBuffer + } + + private def nullSafeEval(inputRow: InternalRow, value1: Any, value2: Any): Any = { + val mapData1 = value1.asInstanceOf[MapData] + val mapData2 = value2.asInstanceOf[MapData] + val keysWithIndexes = getKeysWithValueIndexes(mapData1.keyArray(), mapData2.keyArray()) + val size = keysWithIndexes.size + val keys = new GenericArrayData(new Array[Any](size)) + val values = new GenericArrayData(new Array[Any](size)) + val valueData1 = mapData1.valueArray() + val valueData2 = mapData2.valueArray() + var i = 0 + for ((key, Array(index1, index2)) <- keysWithIndexes) { + val v1 = index1.map(valueData1.get(_, leftValueType)).getOrElse(null) + val v2 = index2.map(valueData2.get(_, rightValueType)).getOrElse(null) + keyVar.value.set(key) + value1Var.value.set(v1) + value2Var.value.set(v2) + keys.update(i, key) + values.update(i, functionForEval.eval(inputRow)) + i += 1 + } + new ArrayBasedMapData(keys, values) + } + + override def prettyName: String = "map_zip_with" +} http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index bc7d04c..3137dc9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -44,6 +44,21 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper LambdaFunction(function, Seq(lv1, lv2)) } + private def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + dt3: DataType, + nullable3: Boolean, + f: (Expression, Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val lv3 = NamedLambdaVariable("arg3", dt3, nullable3) + val function = f(lv1, lv2, lv3) + LambdaFunction(function, Seq(lv1, lv2, lv3)) + } + def transform(expr: Expression, f: Expression => Expression): Expression = { val at = expr.dataType.asInstanceOf[ArrayType] ArrayTransform(expr, createLambda(at.elementType, at.containsNull, f)) @@ -267,4 +282,118 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper (acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)), 15) } + + test("MapZipWith") { + def map_zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression, Expression) => Expression): Expression = { + val MapType(kt, vt1, vcn1) = left.dataType.asInstanceOf[MapType] + val MapType(_, vt2, vcn2) = right.dataType.asInstanceOf[MapType] + MapZipWith(left, right, createLambda(kt, false, vt1, vcn1, vt2, vcn2, f)) + } + + val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii1 = Literal.create(Map(1 -> -1, 2 -> -2, 4 -> -4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii2 = Literal.create(Map(1 -> null, 2 -> -2, 3 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val mii3 = Literal.create(Map(), MapType(IntegerType, IntegerType, valueContainsNull = false)) + val mii4 = MapFromArrays( + Literal.create(Seq(2, 2), ArrayType(IntegerType, false)), + Literal.create(Seq(20, 200), ArrayType(IntegerType, false))) + val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) + + val multiplyKeyWithValues: (Expression, Expression, Expression) => Expression = { + (k, v1, v2) => k * v1 * v2 + } + + checkEvaluation( + map_zip_with(mii0, mii1, multiplyKeyWithValues), + Map(1 -> -10, 2 -> -80, 3 -> null, 4 -> null)) + checkEvaluation( + map_zip_with(mii0, mii2, multiplyKeyWithValues), + Map(1 -> null, 2 -> -80, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, mii3, multiplyKeyWithValues), + Map(1 -> null, 2 -> null, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, mii4, multiplyKeyWithValues), + Map(1 -> null, 2 -> 800, 3 -> null)) + checkEvaluation( + map_zip_with(mii4, mii0, multiplyKeyWithValues), + Map(2 -> 800, 1 -> null, 3 -> null)) + checkEvaluation( + map_zip_with(mii0, miin, multiplyKeyWithValues), + null) + + val mss0 = Literal.create(Map("a" -> "x", "b" -> "y", "d" -> "z"), + MapType(StringType, StringType, valueContainsNull = false)) + val mss1 = Literal.create(Map("d" -> "b", "b" -> "d"), + MapType(StringType, StringType, valueContainsNull = false)) + val mss2 = Literal.create(Map("c" -> null, "b" -> "t", "a" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val mss3 = Literal.create(Map(), MapType(StringType, StringType, valueContainsNull = false)) + val mss4 = MapFromArrays( + Literal.create(Seq("a", "a"), ArrayType(StringType, false)), + Literal.create(Seq("a", "n"), ArrayType(StringType, false))) + val mssn = Literal.create(null, MapType(StringType, StringType, valueContainsNull = false)) + + val concat: (Expression, Expression, Expression) => Expression = { + (k, v1, v2) => Concat(Seq(k, v1, v2)) + } + + checkEvaluation( + map_zip_with(mss0, mss1, concat), + Map("a" -> null, "b" -> "byd", "d" -> "dzb")) + checkEvaluation( + map_zip_with(mss1, mss2, concat), + Map("d" -> null, "b" -> "bdt", "c" -> null, "a" -> null)) + checkEvaluation( + map_zip_with(mss0, mss3, concat), + Map("a" -> null, "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss0, mss4, concat), + Map("a" -> "axa", "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss4, mss0, concat), + Map("a" -> "aax", "b" -> null, "d" -> null)) + checkEvaluation( + map_zip_with(mss0, mssn, concat), + null) + + def b(data: Byte*): Array[Byte] = Array[Byte](data: _*) + + val mbb0 = Literal.create(Map(b(1, 2) -> b(4), b(2, 1) -> b(5), b(1, 3) -> b(8)), + MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb1 = Literal.create(Map(b(2, 1) -> b(7), b(1, 2) -> b(3), b(1, 1) -> b(6)), + MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb2 = Literal.create(Map(b(1, 3) -> null, b(1, 2) -> b(2), b(2, 1) -> null), + MapType(BinaryType, BinaryType, valueContainsNull = true)) + val mbb3 = Literal.create(Map(), MapType(BinaryType, BinaryType, valueContainsNull = false)) + val mbb4 = MapFromArrays( + Literal.create(Seq(b(2, 1), b(2, 1)), ArrayType(BinaryType, false)), + Literal.create(Seq(b(1), b(9)), ArrayType(BinaryType, false))) + val mbbn = Literal.create(null, MapType(BinaryType, BinaryType, valueContainsNull = false)) + + checkEvaluation( + map_zip_with(mbb0, mbb1, concat), + Map(b(1, 2) -> b(1, 2, 4, 3), b(2, 1) -> b(2, 1, 5, 7), b(1, 3) -> null, b(1, 1) -> null)) + checkEvaluation( + map_zip_with(mbb1, mbb2, concat), + Map(b(2, 1) -> null, b(1, 2) -> b(1, 2, 3, 2), b(1, 1) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbb3, concat), + Map(b(1, 2) -> null, b(2, 1) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbb4, concat), + Map(b(1, 2) -> null, b(2, 1) -> b(2, 1, 5, 1), b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb4, mbb0, concat), + Map(b(2, 1) -> b(2, 1, 1, 5), b(1, 2) -> null, b(1, 3) -> null)) + checkEvaluation( + map_zip_with(mbb0, mbbn, concat), + null) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql new file mode 100644 index 0000000..119f868 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql @@ -0,0 +1,66 @@ +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), + map(2Y, 1Y), + map(2S, 1S), + map(2, 1), + map(2L, 1L), + map(922337203685477897945456575809789456, 922337203685477897945456575809789456), + map(9.22337203685477897945456575809789456, 9.22337203685477897945456575809789456), + map(2.0D, 1.0D), + map(float(2.0), float(1.0)), + map(date '2016-03-14', date '2016-03-13'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map('true', 'false', '2', '1'), + map('2016-03-14', '2016-03-13'), + map('2016-11-15 20:54:00.000', '2016-11-12 20:54:00.000'), + map('922337203685477897945456575809789456', 'text'), + map(array(1L, 2L), array(1L, 2L)), map(array(1, 2), array(1, 2)), + map(struct(1S, 2L), struct(1S, 2L)), map(struct(1, 2), struct(1, 2)) +) AS various_maps( + boolean_map, + tinyint_map, + smallint_map, + int_map, + bigint_map, + decimal_map1, decimal_map2, + double_map, + float_map, + date_map, + timestamp_map, + string_map1, string_map2, string_map3, string_map4, + array_map1, array_map2, + struct_map1, struct_map2 +); + +SELECT map_zip_with(tinyint_map, smallint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(smallint_map, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(int_map, bigint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(double_map, float_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out new file mode 100644 index 0000000..7f7e2f0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -0,0 +1,142 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 12 + + +-- !query 0 +CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES ( + map(true, false), + map(2Y, 1Y), + map(2S, 1S), + map(2, 1), + map(2L, 1L), + map(922337203685477897945456575809789456, 922337203685477897945456575809789456), + map(9.22337203685477897945456575809789456, 9.22337203685477897945456575809789456), + map(2.0D, 1.0D), + map(float(2.0), float(1.0)), + map(date '2016-03-14', date '2016-03-13'), + map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), + map('true', 'false', '2', '1'), + map('2016-03-14', '2016-03-13'), + map('2016-11-15 20:54:00.000', '2016-11-12 20:54:00.000'), + map('922337203685477897945456575809789456', 'text'), + map(array(1L, 2L), array(1L, 2L)), map(array(1, 2), array(1, 2)), + map(struct(1S, 2L), struct(1S, 2L)), map(struct(1, 2), struct(1, 2)) +) AS various_maps( + boolean_map, + tinyint_map, + smallint_map, + int_map, + bigint_map, + decimal_map1, decimal_map2, + double_map, + float_map, + date_map, + timestamp_map, + string_map1, string_map2, string_map3, string_map4, + array_map1, array_map2, + struct_map1, struct_map2 +) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT map_zip_with(tinyint_map, smallint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 1 schema +struct<m:map<smallint,struct<k:smallint,v1:tinyint,v2:smallint>>> +-- !query 1 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 2 +SELECT map_zip_with(smallint_map, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 2 schema +struct<m:map<int,struct<k:int,v1:smallint,v2:int>>> +-- !query 2 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 3 +SELECT map_zip_with(int_map, bigint_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 3 schema +struct<m:map<bigint,struct<k:bigint,v1:int,v2:bigint>>> +-- !query 3 output +{2:{"k":2,"v1":1,"v2":1}} + + +-- !query 4 +SELECT map_zip_with(double_map, float_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 4 schema +struct<m:map<double,struct<k:double,v1:double,v2:float>>> +-- !query 4 output +{2.0:{"k":2.0,"v1":1.0,"v2":1.0}} + + +-- !query 5 +SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 + + +-- !query 6 +SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 6 schema +struct<m:map<string,struct<k:string,v1:string,v2:int>>> +-- !query 6 output +{"2":{"k":"2","v1":"1","v2":1},"true":{"k":"true","v1":"false","v2":null}} + + +-- !query 7 +SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 7 schema +struct<m:map<string,struct<k:string,v1:string,v2:date>>> +-- !query 7 output +{"2016-03-14":{"k":"2016-03-14","v1":"2016-03-13","v2":2016-03-13}} + + +-- !query 8 +SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 8 schema +struct<m:map<string,struct<k:string,v1:timestamp,v2:string>>> +-- !query 8 output +{"2016-11-15 20:54:00":{"k":"2016-11-15 20:54:00","v1":2016-11-12 20:54:00.0,"v2":null},"2016-11-15 20:54:00.000":{"k":"2016-11-15 20:54:00.000","v1":null,"v2":"2016-11-12 20:54:00.000"}} + + +-- !query 9 +SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 9 schema +struct<m:map<string,struct<k:string,v1:decimal(36,0),v2:string>>> +-- !query 9 output +{"922337203685477897945456575809789456":{"k":"922337203685477897945456575809789456","v1":922337203685477897945456575809789456,"v2":"text"}} + + +-- !query 10 +SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 10 schema +struct<m:map<array<bigint>,struct<k:array<bigint>,v1:array<bigint>,v2:array<int>>>> +-- !query 10 output +{[1,2]:{"k":[1,2],"v1":[1,2],"v2":[1,2]}} + + +-- !query 11 +SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 11 schema +struct<m:map<struct<col1:int,col2:bigint>,struct<k:struct<col1:int,col2:bigint>,v1:struct<col1:smallint,col2:bigint>,v2:struct<col1:int,col2:int>>>> +-- !query 11 output +{{"col1":1,"col2":2}:{"k":{"col1":1,"col2":2},"v1":{"col1":1,"col2":2},"v2":{"col1":1,"col2":2}}} http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6401e3f..8d7695b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2238,6 +2238,70 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex5.getMessage.contains("cannot resolve '`a`'")) } + test("map_zip_with function - map of primitive types") { + val df = Seq( + (Map(8 -> 6L, 3 -> 5L, 6 -> 2L), Map[Integer, Integer]((6, 4), (8, 2), (3, 2))), + (Map(10 -> 6L, 8 -> 3L), Map[Integer, Integer]((8, 4), (4, null))), + (Map.empty[Int, Long], Map[Integer, Integer]((5, 1))), + (Map(5 -> 1L), null) + ).toDF("m1", "m2") + + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> k == v1 + v2)"), + Seq( + Row(Map(8 -> true, 3 -> false, 6 -> true)), + Row(Map(10 -> null, 8 -> false, 4 -> null)), + Row(Map(5 -> null)), + Row(null))) + } + + test("map_zip_with function - map of non-primitive types") { + val df = Seq( + (Map("z" -> "a", "y" -> "b", "x" -> "c"), Map("x" -> "a", "z" -> "c")), + (Map("b" -> "a", "c" -> "d"), Map("c" -> "a", "b" -> null, "d" -> "k")), + (Map("a" -> "d"), Map.empty[String, String]), + (Map("a" -> "d"), null) + ).toDF("m1", "m2") + + checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> (v1, v2))"), + Seq( + Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), + Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), + Row(Map("a" -> Row("d", null))), + Row(null))) + } + + test("map_zip_with function - invalid") { + val df = Seq( + (Map(1 -> 2), Map(1 -> "a"), Map("a" -> "b"), Map(Map(1 -> 2) -> 2), 1) + ).toDF("mii", "mis", "mss", "mmi", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mii, mis, (x, y) -> x + y)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))") + } + assert(ex2.getMessage.contains("The input to function map_zip_with should have " + + "been two maps with compatible key types")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))") + } + assert(ex3.getMessage.contains("type mismatch: argument 1 requires map type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") + } + assert(ex4.getMessage.contains("type mismatch: argument 2 requires map type")) + + val ex5 = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") + } + assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org