Repository: spark
Updated Branches:
  refs/heads/master e2ab7deae -> 42263fd0c


[SPARK-23938][SQL] Add map_zip_with function

## What changes were proposed in this pull request?

This PR adds a new SQL function called ```map_zip_with```. It merges the two 
given maps into a single map by applying function to the pair of values with 
the same key.

## How was this patch tested?

Added new tests into:
- DataFrameFunctionsSuite.scala
- HigherOrderFunctionsSuite.scala

Closes #22017 from mn-mikke/SPARK-23938.

Authored-by: Marek Novotny <mn.mi...@gmail.com>
Signed-off-by: Takuya UESHIN <ues...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/42263fd0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/42263fd0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/42263fd0

Branch: refs/heads/master
Commit: 42263fd0cbdc86c68438515ac439a15033b8bbd2
Parents: e2ab7de
Author: Marek Novotny <mn.mi...@gmail.com>
Authored: Tue Aug 14 21:14:15 2018 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Tue Aug 14 21:14:15 2018 +0900

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../sql/catalyst/analysis/TypeCoercion.scala    |  25 +++
 .../expressions/higherOrderFunctions.scala      | 197 ++++++++++++++++++-
 .../expressions/HigherOrderFunctionsSuite.scala | 129 ++++++++++++
 .../inputs/typeCoercion/native/mapZipWith.sql   |  66 +++++++
 .../typeCoercion/native/mapZipWith.sql.out      | 142 +++++++++++++
 .../spark/sql/DataFrameFunctionsSuite.scala     |  64 ++++++
 7 files changed, 621 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 15543c9..cc2b758 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -446,6 +446,7 @@ object FunctionRegistry {
     expression[ArrayFilter]("filter"),
     expression[ArrayExists]("exists"),
     expression[ArrayAggregate]("aggregate"),
+    expression[MapZipWith]("map_zip_with"),
     CreateStruct.registryEntry,
 
     // misc functions

http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 10d9ee5..288b635 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -54,6 +54,7 @@ object TypeCoercion {
       BooleanEquality ::
       FunctionArgumentConversion ::
       ConcatCoercion(conf) ::
+      MapZipWithCoercion ::
       EltCoercion(conf) ::
       CaseWhenCoercion ::
       IfCoercion ::
@@ -763,6 +764,30 @@ object TypeCoercion {
   }
 
   /**
+   * Coerces key types of two different [[MapType]] arguments of the 
[[MapZipWith]] expression
+   * to a common type.
+   */
+  object MapZipWithCoercion extends TypeCoercionRule {
+    override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan 
resolveExpressions {
+      // Lambda function isn't resolved when the rule is executed.
+      case m @ MapZipWith(left, right, function) if m.arguments.forall(a => 
a.resolved &&
+          MapType.acceptsType(a.dataType)) && 
!m.leftKeyType.sameType(m.rightKeyType) =>
+        findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match {
+          case Some(finalKeyType) if !Cast.forceNullable(m.leftKeyType, 
finalKeyType) &&
+              !Cast.forceNullable(m.rightKeyType, finalKeyType) =>
+            val newLeft = castIfNotSameType(
+              left,
+              MapType(finalKeyType, m.leftValueType, m.leftValueContainsNull))
+            val newRight = castIfNotSameType(
+              right,
+              MapType(finalKeyType, m.rightValueType, 
m.rightValueContainsNull))
+            MapZipWith(newLeft, newRight, function)
+          case _ => m
+        }
+    }
+  }
+
+  /**
    * Coerces the types of [[Elt]] children to expected ones.
    *
    * If `spark.sql.function.eltOutputAsString` is false and all children types 
are binary,

http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index 5d1b8c4..22210f6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -22,11 +22,11 @@ import java.util.concurrent.atomic.AtomicReference
 import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, 
UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, 
UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.array.ByteArrayMethods
 
 /**
  * A named lambda variable.
@@ -496,3 +496,194 @@ case class ArrayAggregate(
 
   override def prettyName: String = "aggregate"
 }
+
+/**
+ * Merges two given maps into a single map by applying function to the pair of 
values with
+ * the same key.
+ */
+@ExpressionDescription(
+  usage =
+    """
+      _FUNC_(map1, map2, function) - Merges two given maps into a single map 
by applying
+      function to the pair of values with the same key. For keys only 
presented in one map,
+      NULL will be passed as the value for the missing key. If an input map 
contains duplicated
+      keys, only the first entry of the duplicated key is passed into the 
lambda function.
+    """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> 
concat(v1, v2));
+       {1:"ax",2:"by"}
+  """,
+  since = "2.4.0")
+case class MapZipWith(left: Expression, right: Expression, function: 
Expression)
+  extends HigherOrderFunction with CodegenFallback {
+
+  def functionForEval: Expression = functionsForEval.head
+
+  @transient lazy val MapType(leftKeyType, leftValueType, 
leftValueContainsNull) = left.dataType
+
+  @transient lazy val MapType(rightKeyType, rightValueType, 
rightValueContainsNull) = right.dataType
+
+  @transient lazy val keyType =
+    TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(leftKeyType, 
rightKeyType).get
+
+  @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType)
+
+  override def arguments: Seq[Expression] = left :: right :: Nil
+
+  override def argumentTypes: Seq[AbstractDataType] = MapType :: MapType :: Nil
+
+  override def functions: Seq[Expression] = function :: Nil
+
+  override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil
+
+  override def nullable: Boolean = left.nullable || right.nullable
+
+  override def dataType: DataType = MapType(keyType, function.dataType, 
function.nullable)
+
+  override def bind(f: (Expression, Seq[(DataType, Boolean)]) => 
LambdaFunction): MapZipWith = {
+    val arguments = Seq((keyType, false), (leftValueType, true), 
(rightValueType, true))
+    copy(function = f(function, arguments))
+  }
+
+  override def checkArgumentDataTypes(): TypeCheckResult = {
+    super.checkArgumentDataTypes() match {
+      case TypeCheckResult.TypeCheckSuccess =>
+        if (leftKeyType.sameType(rightKeyType)) {
+          TypeUtils.checkForOrderingExpr(leftKeyType, s"function $prettyName")
+        } else {
+          TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName 
should have " +
+            s"been two ${MapType.simpleString}s with compatible key types, but 
the key types are " +
+            s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].")
+        }
+      case failure => failure
+    }
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = 
checkArgumentDataTypes()
+
+  override def eval(input: InternalRow): Any = {
+    val value1 = left.eval(input)
+    if (value1 == null) {
+      null
+    } else {
+      val value2 = right.eval(input)
+      if (value2 == null) {
+        null
+      } else {
+        nullSafeEval(input, value1, value2)
+      }
+    }
+  }
+
+  @transient lazy val LambdaFunction(_, Seq(
+    keyVar: NamedLambdaVariable,
+    value1Var: NamedLambdaVariable,
+    value2Var: NamedLambdaVariable),
+    _) = function
+
+  private def keyTypeSupportsEquals = keyType match {
+    case BinaryType => false
+    case _: AtomicType => true
+    case _ => false
+  }
+
+  /**
+   * The function accepts two key arrays and returns a collection of keys with 
indexes
+   * to value arrays. Indexes are represented as an array of two items. This 
is a small
+   * optimization leveraging mutability of arrays.
+   */
+  @transient private lazy val getKeysWithValueIndexes:
+      (ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = {
+    if (keyTypeSupportsEquals) {
+      getKeysWithIndexesFast
+    } else {
+      getKeysWithIndexesBruteForce
+    }
+  }
+
+  private def assertSizeOfArrayBuffer(size: Int): Unit = {
+    if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+      throw new RuntimeException(s"Unsuccessful try to zip maps with $size " +
+        s"unique keys due to exceeding the array size limit " +
+        s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
+    }
+  }
+
+  private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = {
+    val hashMap = new mutable.LinkedHashMap[Any, Array[Option[Int]]]
+    for((z, array) <- Array((0, keys1), (1, keys2))) {
+      var i = 0
+      while (i < array.numElements()) {
+        val key = array.get(i, keyType)
+        hashMap.get(key) match {
+          case Some(indexes) =>
+            if (indexes(z).isEmpty) {
+              indexes(z) = Some(i)
+            }
+          case None =>
+            val indexes = Array[Option[Int]](None, None)
+            indexes(z) = Some(i)
+            hashMap.put(key, indexes)
+        }
+        i += 1
+      }
+    }
+    hashMap
+  }
+
+  private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: ArrayData) 
= {
+    val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])]
+    for((z, array) <- Array((0, keys1), (1, keys2))) {
+      var i = 0
+      while (i < array.numElements()) {
+        val key = array.get(i, keyType)
+        var found = false
+        var j = 0
+        while (!found && j < arrayBuffer.size) {
+          val (bufferKey, indexes) = arrayBuffer(j)
+          if (ordering.equiv(bufferKey, key)) {
+            found = true
+            if(indexes(z).isEmpty) {
+              indexes(z) = Some(i)
+            }
+          }
+          j += 1
+        }
+        if (!found) {
+          assertSizeOfArrayBuffer(arrayBuffer.size)
+          val indexes = Array[Option[Int]](None, None)
+          indexes(z) = Some(i)
+          arrayBuffer += Tuple2(key, indexes)
+        }
+        i += 1
+      }
+    }
+    arrayBuffer
+  }
+
+  private def nullSafeEval(inputRow: InternalRow, value1: Any, value2: Any): 
Any = {
+    val mapData1 = value1.asInstanceOf[MapData]
+    val mapData2 = value2.asInstanceOf[MapData]
+    val keysWithIndexes = getKeysWithValueIndexes(mapData1.keyArray(), 
mapData2.keyArray())
+    val size = keysWithIndexes.size
+    val keys = new GenericArrayData(new Array[Any](size))
+    val values = new GenericArrayData(new Array[Any](size))
+    val valueData1 = mapData1.valueArray()
+    val valueData2 = mapData2.valueArray()
+    var i = 0
+    for ((key, Array(index1, index2)) <- keysWithIndexes) {
+      val v1 = index1.map(valueData1.get(_, leftValueType)).getOrElse(null)
+      val v2 = index2.map(valueData2.get(_, rightValueType)).getOrElse(null)
+      keyVar.value.set(key)
+      value1Var.value.set(v1)
+      value2Var.value.set(v2)
+      keys.update(i, key)
+      values.update(i, functionForEval.eval(inputRow))
+      i += 1
+    }
+    new ArrayBasedMapData(keys, values)
+  }
+
+  override def prettyName: String = "map_zip_with"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
index bc7d04c..3137dc9 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
@@ -44,6 +44,21 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
     LambdaFunction(function, Seq(lv1, lv2))
   }
 
+  private def createLambda(
+      dt1: DataType,
+      nullable1: Boolean,
+      dt2: DataType,
+      nullable2: Boolean,
+      dt3: DataType,
+      nullable3: Boolean,
+      f: (Expression, Expression, Expression) => Expression): Expression = {
+    val lv1 = NamedLambdaVariable("arg1", dt1, nullable1)
+    val lv2 = NamedLambdaVariable("arg2", dt2, nullable2)
+    val lv3 = NamedLambdaVariable("arg3", dt3, nullable3)
+    val function = f(lv1, lv2, lv3)
+    LambdaFunction(function, Seq(lv1, lv2, lv3))
+  }
+
   def transform(expr: Expression, f: Expression => Expression): Expression = {
     val at = expr.dataType.asInstanceOf[ArrayType]
     ArrayTransform(expr, createLambda(at.elementType, at.containsNull, f))
@@ -267,4 +282,118 @@ class HigherOrderFunctionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
         (acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + 
elem), acc)),
       15)
   }
+
+  test("MapZipWith") {
+    def map_zip_with(
+        left: Expression,
+        right: Expression,
+        f: (Expression, Expression, Expression) => Expression): Expression = {
+      val MapType(kt, vt1, vcn1) = left.dataType.asInstanceOf[MapType]
+      val MapType(_, vt2, vcn2) = right.dataType.asInstanceOf[MapType]
+      MapZipWith(left, right, createLambda(kt, false, vt1, vcn1, vt2, vcn2, f))
+    }
+
+    val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30),
+      MapType(IntegerType, IntegerType, valueContainsNull = false))
+    val mii1 = Literal.create(Map(1 -> -1, 2 -> -2, 4 -> -4),
+      MapType(IntegerType, IntegerType, valueContainsNull = false))
+    val mii2 = Literal.create(Map(1 -> null, 2 -> -2, 3 -> null),
+      MapType(IntegerType, IntegerType, valueContainsNull = true))
+    val mii3 = Literal.create(Map(), MapType(IntegerType, IntegerType, 
valueContainsNull = false))
+    val mii4 = MapFromArrays(
+      Literal.create(Seq(2, 2), ArrayType(IntegerType, false)),
+      Literal.create(Seq(20, 200), ArrayType(IntegerType, false)))
+    val miin = Literal.create(null, MapType(IntegerType, IntegerType, 
valueContainsNull = false))
+
+    val multiplyKeyWithValues: (Expression, Expression, Expression) => 
Expression = {
+      (k, v1, v2) => k * v1 * v2
+    }
+
+    checkEvaluation(
+      map_zip_with(mii0, mii1, multiplyKeyWithValues),
+      Map(1 -> -10, 2 -> -80, 3 -> null, 4 -> null))
+    checkEvaluation(
+      map_zip_with(mii0, mii2, multiplyKeyWithValues),
+      Map(1 -> null, 2 -> -80, 3 -> null))
+    checkEvaluation(
+      map_zip_with(mii0, mii3, multiplyKeyWithValues),
+      Map(1 -> null, 2 -> null, 3 -> null))
+    checkEvaluation(
+      map_zip_with(mii0, mii4, multiplyKeyWithValues),
+      Map(1 -> null, 2 -> 800, 3 -> null))
+    checkEvaluation(
+      map_zip_with(mii4, mii0, multiplyKeyWithValues),
+      Map(2 -> 800, 1 -> null, 3 -> null))
+    checkEvaluation(
+      map_zip_with(mii0, miin, multiplyKeyWithValues),
+      null)
+
+    val mss0 = Literal.create(Map("a" -> "x", "b" -> "y", "d" -> "z"),
+      MapType(StringType, StringType, valueContainsNull = false))
+    val mss1 = Literal.create(Map("d" -> "b", "b" -> "d"),
+      MapType(StringType, StringType, valueContainsNull = false))
+    val mss2 = Literal.create(Map("c" -> null, "b" -> "t", "a" -> null),
+      MapType(StringType, StringType, valueContainsNull = true))
+    val mss3 = Literal.create(Map(), MapType(StringType, StringType, 
valueContainsNull = false))
+    val mss4 = MapFromArrays(
+      Literal.create(Seq("a", "a"), ArrayType(StringType, false)),
+      Literal.create(Seq("a", "n"), ArrayType(StringType, false)))
+    val mssn = Literal.create(null, MapType(StringType, StringType, 
valueContainsNull = false))
+
+    val concat: (Expression, Expression, Expression) => Expression = {
+      (k, v1, v2) => Concat(Seq(k, v1, v2))
+    }
+
+    checkEvaluation(
+      map_zip_with(mss0, mss1, concat),
+      Map("a" -> null, "b" -> "byd", "d" -> "dzb"))
+    checkEvaluation(
+      map_zip_with(mss1, mss2, concat),
+      Map("d" -> null, "b" -> "bdt", "c" -> null, "a" -> null))
+    checkEvaluation(
+      map_zip_with(mss0, mss3, concat),
+      Map("a" -> null, "b" -> null, "d" -> null))
+    checkEvaluation(
+      map_zip_with(mss0, mss4, concat),
+      Map("a" -> "axa", "b" -> null, "d" -> null))
+    checkEvaluation(
+      map_zip_with(mss4, mss0, concat),
+      Map("a" -> "aax", "b" -> null, "d" -> null))
+    checkEvaluation(
+      map_zip_with(mss0, mssn, concat),
+      null)
+
+    def b(data: Byte*): Array[Byte] = Array[Byte](data: _*)
+
+    val mbb0 = Literal.create(Map(b(1, 2) -> b(4), b(2, 1) -> b(5), b(1, 3) -> 
b(8)),
+      MapType(BinaryType, BinaryType, valueContainsNull = false))
+    val mbb1 = Literal.create(Map(b(2, 1) -> b(7), b(1, 2) -> b(3), b(1, 1) -> 
b(6)),
+      MapType(BinaryType, BinaryType, valueContainsNull = false))
+    val mbb2 = Literal.create(Map(b(1, 3) -> null, b(1, 2) -> b(2), b(2, 1) -> 
null),
+      MapType(BinaryType, BinaryType, valueContainsNull = true))
+    val mbb3 = Literal.create(Map(), MapType(BinaryType, BinaryType, 
valueContainsNull = false))
+    val mbb4 = MapFromArrays(
+      Literal.create(Seq(b(2, 1), b(2, 1)), ArrayType(BinaryType, false)),
+      Literal.create(Seq(b(1), b(9)), ArrayType(BinaryType, false)))
+    val mbbn = Literal.create(null, MapType(BinaryType, BinaryType, 
valueContainsNull = false))
+
+    checkEvaluation(
+      map_zip_with(mbb0, mbb1, concat),
+      Map(b(1, 2) -> b(1, 2, 4, 3), b(2, 1) -> b(2, 1, 5, 7), b(1, 3) -> null, 
b(1, 1) -> null))
+    checkEvaluation(
+      map_zip_with(mbb1, mbb2, concat),
+      Map(b(2, 1) -> null, b(1, 2) -> b(1, 2, 3, 2), b(1, 1) -> null, b(1, 3) 
-> null))
+    checkEvaluation(
+      map_zip_with(mbb0, mbb3, concat),
+      Map(b(1, 2) -> null, b(2, 1) -> null, b(1, 3) -> null))
+    checkEvaluation(
+      map_zip_with(mbb0, mbb4, concat),
+      Map(b(1, 2) -> null, b(2, 1) -> b(2, 1, 5, 1), b(1, 3) -> null))
+    checkEvaluation(
+      map_zip_with(mbb4, mbb0, concat),
+      Map(b(2, 1) -> b(2, 1, 1, 5), b(1, 2) -> null, b(1, 3) -> null))
+    checkEvaluation(
+      map_zip_with(mbb0, mbbn, concat),
+      null)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql
 
b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql
new file mode 100644
index 0000000..119f868
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql
@@ -0,0 +1,66 @@
+CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES (
+  map(true, false),
+  map(2Y, 1Y),
+  map(2S, 1S),
+  map(2, 1),
+  map(2L, 1L),
+  map(922337203685477897945456575809789456, 
922337203685477897945456575809789456),
+  map(9.22337203685477897945456575809789456, 
9.22337203685477897945456575809789456),
+  map(2.0D, 1.0D),
+  map(float(2.0), float(1.0)),
+  map(date '2016-03-14', date '2016-03-13'),
+  map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 
20:54:00.000'),
+  map('true', 'false', '2', '1'),
+  map('2016-03-14', '2016-03-13'),
+  map('2016-11-15 20:54:00.000', '2016-11-12 20:54:00.000'),
+  map('922337203685477897945456575809789456', 'text'),
+  map(array(1L, 2L), array(1L, 2L)), map(array(1, 2), array(1, 2)),
+  map(struct(1S, 2L), struct(1S, 2L)), map(struct(1, 2), struct(1, 2))
+) AS various_maps(
+  boolean_map,
+  tinyint_map,
+  smallint_map,
+  int_map,
+  bigint_map,
+  decimal_map1, decimal_map2,
+  double_map,
+  float_map,
+  date_map,
+  timestamp_map,
+  string_map1, string_map2, string_map3, string_map4,
+  array_map1, array_map2,
+  struct_map1, struct_map2
+);
+
+SELECT map_zip_with(tinyint_map, smallint_map, (k, v1, v2) -> struct(k, v1, 
v2)) m
+FROM various_maps;
+
+SELECT map_zip_with(smallint_map, int_map, (k, v1, v2) -> struct(k, v1, v2)) m
+FROM various_maps;
+
+SELECT map_zip_with(int_map, bigint_map, (k, v1, v2) -> struct(k, v1, v2)) m
+FROM various_maps;
+
+SELECT map_zip_with(double_map, float_map, (k, v1, v2) -> struct(k, v1, v2)) m
+FROM various_maps;
+
+SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, 
v2)) m
+FROM various_maps;
+
+SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m
+FROM various_maps;
+
+SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m
+FROM various_maps;
+
+SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, 
v2)) m
+FROM various_maps;
+
+SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, 
v2)) m
+FROM various_maps;
+
+SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m
+FROM various_maps;
+
+SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, 
v2)) m
+FROM various_maps;

http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
new file mode 100644
index 0000000..7f7e2f0
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
@@ -0,0 +1,142 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 12
+
+
+-- !query 0
+CREATE TEMPORARY VIEW various_maps AS SELECT * FROM VALUES (
+  map(true, false),
+  map(2Y, 1Y),
+  map(2S, 1S),
+  map(2, 1),
+  map(2L, 1L),
+  map(922337203685477897945456575809789456, 
922337203685477897945456575809789456),
+  map(9.22337203685477897945456575809789456, 
9.22337203685477897945456575809789456),
+  map(2.0D, 1.0D),
+  map(float(2.0), float(1.0)),
+  map(date '2016-03-14', date '2016-03-13'),
+  map(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 
20:54:00.000'),
+  map('true', 'false', '2', '1'),
+  map('2016-03-14', '2016-03-13'),
+  map('2016-11-15 20:54:00.000', '2016-11-12 20:54:00.000'),
+  map('922337203685477897945456575809789456', 'text'),
+  map(array(1L, 2L), array(1L, 2L)), map(array(1, 2), array(1, 2)),
+  map(struct(1S, 2L), struct(1S, 2L)), map(struct(1, 2), struct(1, 2))
+) AS various_maps(
+  boolean_map,
+  tinyint_map,
+  smallint_map,
+  int_map,
+  bigint_map,
+  decimal_map1, decimal_map2,
+  double_map,
+  float_map,
+  date_map,
+  timestamp_map,
+  string_map1, string_map2, string_map3, string_map4,
+  array_map1, array_map2,
+  struct_map1, struct_map2
+)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+SELECT map_zip_with(tinyint_map, smallint_map, (k, v1, v2) -> struct(k, v1, 
v2)) m
+FROM various_maps
+-- !query 1 schema
+struct<m:map<smallint,struct<k:smallint,v1:tinyint,v2:smallint>>>
+-- !query 1 output
+{2:{"k":2,"v1":1,"v2":1}}
+
+
+-- !query 2
+SELECT map_zip_with(smallint_map, int_map, (k, v1, v2) -> struct(k, v1, v2)) m
+FROM various_maps
+-- !query 2 schema
+struct<m:map<int,struct<k:int,v1:smallint,v2:int>>>
+-- !query 2 output
+{2:{"k":2,"v1":1,"v2":1}}
+
+
+-- !query 3
+SELECT map_zip_with(int_map, bigint_map, (k, v1, v2) -> struct(k, v1, v2)) m
+FROM various_maps
+-- !query 3 schema
+struct<m:map<bigint,struct<k:bigint,v1:int,v2:bigint>>>
+-- !query 3 output
+{2:{"k":2,"v1":1,"v2":1}}
+
+
+-- !query 4
+SELECT map_zip_with(double_map, float_map, (k, v1, v2) -> struct(k, v1, v2)) m
+FROM various_maps
+-- !query 4 schema
+struct<m:map<double,struct<k:double,v1:double,v2:float>>>
+-- !query 4 output
+{2.0:{"k":2.0,"v1":1.0,"v2":1.0}}
+
+
+-- !query 5
+SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, 
v2)) m
+FROM various_maps
+-- !query 5 schema
+struct<>
+-- !query 5 output
+org.apache.spark.sql.AnalysisException
+cannot resolve 'map_zip_with(various_maps.`decimal_map1`, 
various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), 
`k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due 
to argument data type mismatch: The input to function map_zip_with should have 
been two maps with compatible key types, but the key types are [decimal(36,0), 
decimal(36,35)].; line 1 pos 7
+
+
+-- !query 6
+SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m
+FROM various_maps
+-- !query 6 schema
+struct<m:map<string,struct<k:string,v1:string,v2:int>>>
+-- !query 6 output
+{"2":{"k":"2","v1":"1","v2":1},"true":{"k":"true","v1":"false","v2":null}}
+
+
+-- !query 7
+SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m
+FROM various_maps
+-- !query 7 schema
+struct<m:map<string,struct<k:string,v1:string,v2:date>>>
+-- !query 7 output
+{"2016-03-14":{"k":"2016-03-14","v1":"2016-03-13","v2":2016-03-13}}
+
+
+-- !query 8
+SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, 
v2)) m
+FROM various_maps
+-- !query 8 schema
+struct<m:map<string,struct<k:string,v1:timestamp,v2:string>>>
+-- !query 8 output
+{"2016-11-15 20:54:00":{"k":"2016-11-15 20:54:00","v1":2016-11-12 
20:54:00.0,"v2":null},"2016-11-15 20:54:00.000":{"k":"2016-11-15 
20:54:00.000","v1":null,"v2":"2016-11-12 20:54:00.000"}}
+
+
+-- !query 9
+SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, 
v2)) m
+FROM various_maps
+-- !query 9 schema
+struct<m:map<string,struct<k:string,v1:decimal(36,0),v2:string>>>
+-- !query 9 output
+{"922337203685477897945456575809789456":{"k":"922337203685477897945456575809789456","v1":922337203685477897945456575809789456,"v2":"text"}}
+
+
+-- !query 10
+SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m
+FROM various_maps
+-- !query 10 schema
+struct<m:map<array<bigint>,struct<k:array<bigint>,v1:array<bigint>,v2:array<int>>>>
+-- !query 10 output
+{[1,2]:{"k":[1,2],"v1":[1,2],"v2":[1,2]}}
+
+
+-- !query 11
+SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, 
v2)) m
+FROM various_maps
+-- !query 11 schema
+struct<m:map<struct<col1:int,col2:bigint>,struct<k:struct<col1:int,col2:bigint>,v1:struct<col1:smallint,col2:bigint>,v2:struct<col1:int,col2:int>>>>
+-- !query 11 output
+{{"col1":1,"col2":2}:{"k":{"col1":1,"col2":2},"v1":{"col1":1,"col2":2},"v2":{"col1":1,"col2":2}}}

http://git-wip-us.apache.org/repos/asf/spark/blob/42263fd0/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 6401e3f..8d7695b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -2238,6 +2238,70 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     assert(ex5.getMessage.contains("cannot resolve '`a`'"))
   }
 
+  test("map_zip_with function - map of primitive types") {
+    val df = Seq(
+      (Map(8 -> 6L, 3 -> 5L, 6 -> 2L), Map[Integer, Integer]((6, 4), (8, 2), 
(3, 2))),
+      (Map(10 -> 6L, 8 -> 3L), Map[Integer, Integer]((8, 4), (4, null))),
+      (Map.empty[Int, Long], Map[Integer, Integer]((5, 1))),
+      (Map(5 -> 1L), null)
+    ).toDF("m1", "m2")
+
+    checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> k == v1 + 
v2)"),
+      Seq(
+        Row(Map(8 -> true, 3 -> false, 6 -> true)),
+        Row(Map(10 -> null, 8 -> false, 4 -> null)),
+        Row(Map(5 -> null)),
+        Row(null)))
+  }
+
+  test("map_zip_with function - map of non-primitive types") {
+    val df = Seq(
+      (Map("z" -> "a", "y" -> "b", "x" -> "c"), Map("x" -> "a", "z" -> "c")),
+      (Map("b" -> "a", "c" -> "d"), Map("c" -> "a", "b" -> null, "d" -> "k")),
+      (Map("a" -> "d"), Map.empty[String, String]),
+      (Map("a" -> "d"), null)
+    ).toDF("m1", "m2")
+
+    checkAnswer(df.selectExpr("map_zip_with(m1, m2, (k, v1, v2) -> (v1, v2))"),
+      Seq(
+        Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", 
"a"))),
+        Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, 
"k"))),
+        Row(Map("a" -> Row("d", null))),
+        Row(null)))
+  }
+
+  test("map_zip_with function - invalid") {
+    val df = Seq(
+      (Map(1 -> 2), Map(1 -> "a"), Map("a" -> "b"), Map(Map(1 -> 2) -> 2), 1)
+    ).toDF("mii", "mis", "mss", "mmi", "i")
+
+    val ex1 = intercept[AnalysisException] {
+      df.selectExpr("map_zip_with(mii, mis, (x, y) -> x + y)")
+    }
+    assert(ex1.getMessage.contains("The number of lambda function arguments 
'2' does not match"))
+
+    val ex2 = intercept[AnalysisException] {
+      df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))")
+    }
+    assert(ex2.getMessage.contains("The input to function map_zip_with should 
have " +
+      "been two maps with compatible key types"))
+
+    val ex3 = intercept[AnalysisException] {
+      df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))")
+    }
+    assert(ex3.getMessage.contains("type mismatch: argument 1 requires map 
type"))
+
+    val ex4 = intercept[AnalysisException] {
+      df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))")
+    }
+    assert(ex4.getMessage.contains("type mismatch: argument 2 requires map 
type"))
+
+    val ex5 = intercept[AnalysisException] {
+      df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)")
+    }
+    assert(ex5.getMessage.contains("function map_zip_with does not support 
ordering on type map"))
+  }
+
   private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
     import DataFrameFunctionsSuite.CodegenFallbackExpr
     for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), 
(false, true))) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to