Repository: spark
Updated Branches:
  refs/heads/master 51bee7aca -> 4446a0b0d


[SPARK-23914][SQL][FOLLOW-UP] refactor ArrayUnion

## What changes were proposed in this pull request?

This PR refactors `ArrayUnion` based on [this 
suggestion](https://github.com/apache/spark/pull/21103#discussion_r205668821).
1. Generate optimized code for all of the primitive types except `boolean`
1. Generate code using `ArrayBuilder` or `ArrayBuffer`
1. Leave only a generic path in the interpreted path

## How was this patch tested?

Existing tests

Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>

Closes #21937 from kiszk/SPARK-23914-follow.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4446a0b0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4446a0b0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4446a0b0

Branch: refs/heads/master
Commit: 4446a0b0d9bd830f0e903d6780dedac4db572b5a
Parents: 51bee7a
Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>
Authored: Tue Aug 7 12:07:56 2018 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Tue Aug 7 12:07:56 2018 +0900

----------------------------------------------------------------------
 .../expressions/collectionOperations.scala      | 325 +++++++------------
 .../CollectionExpressionsSuite.scala            |  21 +-
 .../spark/sql/DataFrameFunctionsSuite.scala     |  24 +-
 3 files changed, 153 insertions(+), 217 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4446a0b0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index e385c2d..fbb1826 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -3767,230 +3767,159 @@ object ArraySetLike {
   """,
   since = "2.4.0")
 case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike
-    with ComplexTypeMergingExpression {
-  var hsInt: OpenHashSet[Int] = _
-  var hsLong: OpenHashSet[Long] = _
-
-  def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): 
Boolean = {
-    val elem = array.getInt(idx)
-    if (!hsInt.contains(elem)) {
-      if (resultArray != null) {
-        resultArray.setInt(pos, elem)
-      }
-      hsInt.add(elem)
-      true
-    } else {
-      false
-    }
-  }
-
-  def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: 
Int): Boolean = {
-    val elem = array.getLong(idx)
-    if (!hsLong.contains(elem)) {
-      if (resultArray != null) {
-        resultArray.setLong(pos, elem)
-      }
-      hsLong.add(elem)
-      true
-    } else {
-      false
-    }
-  }
+  with ComplexTypeMergingExpression {
 
-  def evalIntLongPrimitiveType(
-      array1: ArrayData,
-      array2: ArrayData,
-      resultArray: ArrayData,
-      isLongType: Boolean): Int = {
-    // store elements into resultArray
-    var nullElementSize = 0
-    var pos = 0
-    Seq(array1, array2).foreach { array =>
-      var i = 0
-      while (i < array.numElements()) {
-        val size = if (!isLongType) hsInt.size else hsLong.size
-        if (size + nullElementSize > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-          ArraySetLike.throwUnionLengthOverflowException(size)
-        }
-        if (array.isNullAt(i)) {
-          if (nullElementSize == 0) {
-            if (resultArray != null) {
-              resultArray.setNullAt(pos)
+  @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = {
+    if (elementTypeSupportEquals) {
+      (array1, array2) =>
+        val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+        val hs = new OpenHashSet[Any]
+        var foundNullElement = false
+        Seq(array1, array2).foreach { array =>
+          var i = 0
+          while (i < array.numElements()) {
+            if (array.isNullAt(i)) {
+              if (!foundNullElement) {
+                arrayBuffer += null
+                foundNullElement = true
+              }
+            } else {
+              val elem = array.get(i, elementType)
+              if (!hs.contains(elem)) {
+                if (arrayBuffer.size > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+                  
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
+                }
+                arrayBuffer += elem
+                hs.add(elem)
+              }
             }
-            pos += 1
-            nullElementSize = 1
+            i += 1
           }
-        } else {
-          val assigned = if (!isLongType) {
-            assignInt(array, i, resultArray, pos)
+        }
+        new GenericArrayData(arrayBuffer)
+    } else {
+      (array1, array2) =>
+        val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+        var alreadyIncludeNull = false
+        Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => {
+          var found = false
+          if (elem == null) {
+            if (alreadyIncludeNull) {
+              found = true
+            } else {
+              alreadyIncludeNull = true
+            }
           } else {
-            assignLong(array, i, resultArray, pos)
+            // check elem is already stored in arrayBuffer or not?
+            var j = 0
+            while (!found && j < arrayBuffer.size) {
+              val va = arrayBuffer(j)
+              if (va != null && ordering.equiv(va, elem)) {
+                found = true
+              }
+              j = j + 1
+            }
           }
-          if (assigned) {
-            pos += 1
+          if (!found) {
+            if (arrayBuffer.length > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+              
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length)
+            }
+            arrayBuffer += elem
           }
-        }
-        i += 1
-      }
+        }))
+        new GenericArrayData(arrayBuffer)
     }
-    pos
   }
 
   override def nullSafeEval(input1: Any, input2: Any): Any = {
     val array1 = input1.asInstanceOf[ArrayData]
     val array2 = input2.asInstanceOf[ArrayData]
 
-    if (elementTypeSupportEquals) {
-      elementType match {
-        case IntegerType =>
-          // avoid boxing of primitive int array elements
-          // calculate result array size
-          hsInt = new OpenHashSet[Int]
-          val elements = evalIntLongPrimitiveType(array1, array2, null, false)
-          hsInt = new OpenHashSet[Int]
-          val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
-            IntegerType.defaultSize, elements)) {
-            new GenericArrayData(new Array[Any](elements))
-          } else {
-            UnsafeArrayData.forPrimitiveArray(
-              Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
-          }
-          evalIntLongPrimitiveType(array1, array2, resultArray, false)
-          resultArray
-        case LongType =>
-          // avoid boxing of primitive long array elements
-          // calculate result array size
-          hsLong = new OpenHashSet[Long]
-          val elements = evalIntLongPrimitiveType(array1, array2, null, true)
-          hsLong = new OpenHashSet[Long]
-          val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData(
-            LongType.defaultSize, elements)) {
-            new GenericArrayData(new Array[Any](elements))
-          } else {
-            UnsafeArrayData.forPrimitiveArray(
-              Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
-          }
-          evalIntLongPrimitiveType(array1, array2, resultArray, true)
-          resultArray
-        case _ =>
-          val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
-          val hs = new OpenHashSet[Any]
-          var foundNullElement = false
-          Seq(array1, array2).foreach { array =>
-            var i = 0
-            while (i < array.numElements()) {
-              if (array.isNullAt(i)) {
-                if (!foundNullElement) {
-                  arrayBuffer += null
-                  foundNullElement = true
-                }
-              } else {
-                val elem = array.get(i, elementType)
-                if (!hs.contains(elem)) {
-                  if (arrayBuffer.size > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-                    
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
-                  }
-                  arrayBuffer += elem
-                  hs.add(elem)
-                }
-              }
-              i += 1
-            }
-          }
-          new GenericArrayData(arrayBuffer)
-      }
-    } else {
-      ArrayUnion.unionOrdering(array1, array2, elementType, ordering)
-    }
+    evalUnion(array1, array2)
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val i = ctx.freshName("i")
-    val pos = ctx.freshName("pos")
     val value = ctx.freshName("value")
     val size = ctx.freshName("size")
-    val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, 
arrayBuilder) =
-      if (elementTypeSupportEquals) {
-        elementType match {
-          case ByteType | ShortType | IntegerType | LongType =>
-            val ptName = CodeGenerator.primitiveTypeName(elementType)
-            val unsafeArray = ctx.freshName("unsafeArray")
-            (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp",
-              if (elementType == LongType) "Long" else "Int",
-              s"get$ptName($i)", s"set$ptName($pos, $value)", 
CodeGenerator.javaType(elementType),
-              if (elementType == LongType) "(long)" else "(int)",
-              s"""
-                 |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" 
$prettyName failed.")}
-                 |${ev.value} = $unsafeArray;
-               """.stripMargin)
-          case _ =>
-            val genericArrayData = classOf[GenericArrayData].getName
-            val et = ctx.addReferenceObj("elementType", elementType)
-            ("", "Object",
-              s"get($i, $et)", s"update($pos, $value)", "Object", "",
-              s"${ev.value} = new $genericArrayData(new Object[$size]);")
-        }
-      } else {
-        ("", "", "", "", "", "", "")
-      }
+    if (canUseSpecializedHashSet) {
+      val jt = CodeGenerator.javaType(elementType)
+      val ptName = CodeGenerator.primitiveTypeName(jt)
 
-    nullSafeCodeGen(ctx, ev, (array1, array2) => {
-      if (openHashElementType != "") {
-        // Here, we ensure elementTypeSupportEquals is true
+      nullSafeCodeGen(ctx, ev, (array1, array2) => {
         val foundNullElement = ctx.freshName("foundNullElement")
-        val openHashSet = classOf[OpenHashSet[_]].getName
-        val classTag = 
s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
-        val hs = ctx.freshName("hs")
-        val arrayData = classOf[ArrayData].getName
-        val arrays = ctx.freshName("arrays")
+        val nullElementIndex = ctx.freshName("nullElementIndex")
+        val builder = ctx.freshName("builder")
         val array = ctx.freshName("array")
+        val arrays = ctx.freshName("arrays")
         val arrayDataIdx = ctx.freshName("arrayDataIdx")
+        val openHashSet = classOf[OpenHashSet[_]].getName
+        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
+        val hashSet = ctx.freshName("hashSet")
+        val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
+        val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
+
+        def withArrayNullAssignment(body: String) =
+          if (dataType.asInstanceOf[ArrayType].containsNull) {
+            s"""
+               |if ($array.isNullAt($i)) {
+               |  if (!$foundNullElement) {
+               |    $nullElementIndex = $size;
+               |    $foundNullElement = true;
+               |    $size++;
+               |    $builder.$$plus$$eq($nullValueHolder);
+               |  }
+               |} else {
+               |  $body
+               |}
+             """.stripMargin
+          } else {
+            body
+          }
+
+        val processArray = withArrayNullAssignment(
+          s"""
+             |$jt $value = ${genGetValue(array, i)};
+             |if (!$hashSet.contains($hsValueCast$value)) {
+             |  if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+             |    break;
+             |  }
+             |  $hashSet.add$hsPostFix($hsValueCast$value);
+             |  $builder.$$plus$$eq($value);
+             |}
+           """.stripMargin)
+
+        // Only need to track null element index when result array's element 
is nullable.
+        val declareNullTrackVariables = if 
(dataType.asInstanceOf[ArrayType].containsNull) {
+          s"""
+             |boolean $foundNullElement = false;
+             |int $nullElementIndex = -1;
+           """.stripMargin
+        } else {
+          ""
+        }
+
         s"""
-           |$openHashSet $hs = new $openHashSet$postFix($classTag);
-           |boolean $foundNullElement = false;
-           |$arrayData[] $arrays = new $arrayData[]{$array1, $array2};
-           |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
-           |  $arrayData $array = $arrays[$arrayDataIdx];
-           |  for (int $i = 0; $i < $array.numElements(); $i++) {
-           |    if ($array.isNullAt($i)) {
-           |      $foundNullElement = true;
-           |    } else {
-           |      $hs.add$postFix($array.$getter);
-           |    }
-           |  }
-           |}
-           |int $size = $hs.size() + ($foundNullElement ? 1 : 0);
-           |$arrayBuilder
-           |$hs = new $openHashSet$postFix($classTag);
-           |$foundNullElement = false;
-           |int $pos = 0;
+           |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag);
+           |$declareNullTrackVariables
+           |int $size = 0;
+           |$arrayBuilderClass $builder = new $arrayBuilderClass();
+           |ArrayData[] $arrays = new ArrayData[]{$array1, $array2};
            |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) {
-           |  $arrayData $array = $arrays[$arrayDataIdx];
+           |  ArrayData $array = $arrays[$arrayDataIdx];
            |  for (int $i = 0; $i < $array.numElements(); $i++) {
-           |    if ($array.isNullAt($i)) {
-           |      if (!$foundNullElement) {
-           |        ${ev.value}.setNullAt($pos++);
-           |        $foundNullElement = true;
-           |      }
-           |    } else {
-           |      $javaTypeName $value = $array.$getter;
-           |      if (!$hs.contains($castOp $value)) {
-           |        $hs.add$postFix($value);
-           |        ${ev.value}.$setter;
-           |        $pos++;
-           |      }
-           |    }
+           |    $processArray
            |  }
            |}
+           |${buildResultArray(builder, ev.value, size, nullElementIndex)}
          """.stripMargin
-      } else {
-        val arrayUnion = classOf[ArrayUnion].getName
-        val et = ctx.addReferenceObj("elementTypeUnion", elementType)
-        val order = ctx.addReferenceObj("orderingUnion", ordering)
-        val method = "unionOrdering"
-        s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, 
$order);"
-      }
-    })
+      })
+    } else {
+      nullSafeCodeGen(ctx, ev, (array1, array2) => {
+        val expr = ctx.addReferenceObj("arrayUnionExpr", this)
+        s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);"
+      })
+    }
   }
 
   override def prettyName: String = "array_union"
@@ -4154,7 +4083,6 @@ case class ArrayIntersect(left: Expression, right: 
Expression) extends ArraySetL
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val arrayData = classOf[ArrayData].getName
     val i = ctx.freshName("i")
     val value = ctx.freshName("value")
     val size = ctx.freshName("size")
@@ -4268,7 +4196,7 @@ case class ArrayIntersect(left: Expression, right: 
Expression) extends ArraySetL
     } else {
       nullSafeCodeGen(ctx, ev, (array1, array2) => {
         val expr = ctx.addReferenceObj("arrayIntersectExpr", this)
-        s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);"
+        s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);"
       })
     }
   }
@@ -4387,7 +4315,6 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArraySetLike
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val arrayData = classOf[ArrayData].getName
     val i = ctx.freshName("i")
     val value = ctx.freshName("value")
     val size = ctx.freshName("size")
@@ -4490,7 +4417,7 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArraySetLike
     } else {
       nullSafeCodeGen(ctx, ev, (array1, array2) => {
         val expr = ctx.addReferenceObj("arrayExceptExpr", this)
-        s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);"
+        s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);"
       })
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/4446a0b0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 4daa113..c6b3f95 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -1362,10 +1362,16 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, 
containsNull = true))
     val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, 
containsNull = false))
     val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, 
containsNull = false))
-    val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, 
containsNull = false))
-    val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull 
= false))
-    val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType, 
containsNull = false))
-    val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType, 
containsNull = false))
+    val abl0 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, 
false))
+    val abl1 = Literal.create(Seq[Boolean](false, false), 
ArrayType(BooleanType, false))
+    val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, false))
+    val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, false))
+    val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, 
false))
+    val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, false))
+    val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), 
ArrayType(FloatType, false))
+    val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), 
ArrayType(FloatType, false))
+    val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), 
ArrayType(DoubleType, false))
+    val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, 
false))
 
     val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull 
= false))
     val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = 
false))
@@ -1384,8 +1390,11 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3))
     checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5))
     checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5))
-    checkEvaluation(ArrayUnion(a05, a06), Seq[Byte](1, 2, 3, 4))
-    checkEvaluation(ArrayUnion(a07, a08), Seq[Short](1, 2, 3, 4))
+    checkEvaluation(ArrayUnion(abl0, abl1), Seq[Boolean](true, false))
+    checkEvaluation(ArrayUnion(ab0, ab1), Seq[Byte](1, 2, 3, 4))
+    checkEvaluation(ArrayUnion(as0, as1), Seq[Short](1, 2, 3, 4))
+    checkEvaluation(ArrayUnion(af0, af1), Seq[Float](1.1F, 2.2F, 3.3F, 4.4F))
+    checkEvaluation(ArrayUnion(ad0, ad1), Seq[Double](1.1, 2.2, 3.3, 4.4))
 
     checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L))
     checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, 
-1L))

http://git-wip-us.apache.org/repos/asf/spark/blob/4446a0b0/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 3c5831f..c04780d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1148,28 +1148,28 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     checkAnswer(df5.selectExpr("array_union(a, b)"), ans5)
 
     val df6 = Seq((null, Array("a"))).toDF("a", "b")
-    intercept[AnalysisException] {
+    assert(intercept[AnalysisException] {
       df6.select(array_union($"a", $"b"))
-    }
-    intercept[AnalysisException] {
+    }.getMessage.contains("data type mismatch"))
+    assert(intercept[AnalysisException] {
       df6.selectExpr("array_union(a, b)")
-    }
+    }.getMessage.contains("data type mismatch"))
 
     val df7 = Seq((null, null)).toDF("a", "b")
-    intercept[AnalysisException] {
+    assert(intercept[AnalysisException] {
       df7.select(array_union($"a", $"b"))
-    }
-    intercept[AnalysisException] {
+    }.getMessage.contains("data type mismatch"))
+    assert(intercept[AnalysisException] {
       df7.selectExpr("array_union(a, b)")
-    }
+    }.getMessage.contains("data type mismatch"))
 
     val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b")
-    intercept[AnalysisException] {
+    assert(intercept[AnalysisException] {
       df8.select(array_union($"a", $"b"))
-    }
-    intercept[AnalysisException] {
+    }.getMessage.contains("data type mismatch"))
+    assert(intercept[AnalysisException] {
       df8.selectExpr("array_union(a, b)")
-    }
+    }.getMessage.contains("data type mismatch"))
   }
 
   test("concat function - arrays") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to