Repository: spark
Updated Branches:
  refs/heads/master a8a1ac01c -> 6b8fbbfb1


[SPARK-25141][SQL][TEST] Modify tests for higher-order functions to check bind 
method.

## What changes were proposed in this pull request?

We should also check `HigherOrderFunction.bind` method passes expected 
parameters.
This pr modifies tests for higher-order functions to check `bind` method.

## How was this patch tested?

Modified tests.

Closes #22131 from ueshin/issues/SPARK-25141/bind_test.

Authored-by: Takuya UESHIN <ues...@databricks.com>
Signed-off-by: Takuya UESHIN <ues...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6b8fbbfb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6b8fbbfb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6b8fbbfb

Branch: refs/heads/master
Commit: 6b8fbbfb110601ffc3343b08113d13267baf27bf
Parents: a8a1ac0
Author: Takuya UESHIN <ues...@databricks.com>
Authored: Sun Aug 19 09:18:47 2018 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Sun Aug 19 09:18:47 2018 +0900

----------------------------------------------------------------------
 .../expressions/HigherOrderFunctionsSuite.scala | 49 +++++++++++++-------
 1 file changed, 32 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6b8fbbfb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
index ea85c21..e13f4d9 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala
@@ -60,24 +60,37 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
     LambdaFunction(function, Seq(lv1, lv2, lv3))
   }
 
+  private def validateBinding(
+      e: Expression,
+      argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match {
+    case f: LambdaFunction =>
+      assert(f.arguments.size === argInfo.size)
+      f.arguments.zip(argInfo).foreach {
+        case (arg, (dataType, nullable)) =>
+          assert(arg.dataType === dataType)
+          assert(arg.nullable === nullable)
+      }
+      f
+  }
+
   def transform(expr: Expression, f: Expression => Expression): Expression = {
-    val at = expr.dataType.asInstanceOf[ArrayType]
-    ArrayTransform(expr, createLambda(at.elementType, at.containsNull, f))
+    val ArrayType(et, cn) = expr.dataType
+    ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding)
   }
 
   def transform(expr: Expression, f: (Expression, Expression) => Expression): 
Expression = {
-    val at = expr.dataType.asInstanceOf[ArrayType]
-    ArrayTransform(expr, createLambda(at.elementType, at.containsNull, 
IntegerType, false, f))
+    val ArrayType(et, cn) = expr.dataType
+    ArrayTransform(expr, createLambda(et, cn, IntegerType, false, 
f)).bind(validateBinding)
   }
 
   def filter(expr: Expression, f: Expression => Expression): Expression = {
-    val at = expr.dataType.asInstanceOf[ArrayType]
-    ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f))
+    val ArrayType(et, cn) = expr.dataType
+    ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding)
   }
 
   def transformKeys(expr: Expression, f: (Expression, Expression) => 
Expression): Expression = {
-    val map = expr.dataType.asInstanceOf[MapType]
-    TransformKeys(expr, createLambda(map.keyType, false, map.valueType, 
map.valueContainsNull, f))
+    val MapType(kt, vt, vcn) = expr.dataType
+    TransformKeys(expr, createLambda(kt, false, vt, vcn, 
f)).bind(validateBinding)
   }
 
   def aggregate(
@@ -85,13 +98,14 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
       zero: Expression,
       merge: (Expression, Expression) => Expression,
       finish: Expression => Expression): Expression = {
-    val at = expr.dataType.asInstanceOf[ArrayType]
+    val ArrayType(et, cn) = expr.dataType
     val zeroType = zero.dataType
     ArrayAggregate(
       expr,
       zero,
-      createLambda(zeroType, true, at.elementType, at.containsNull, merge),
+      createLambda(zeroType, true, et, cn, merge),
       createLambda(zeroType, true, finish))
+      .bind(validateBinding)
   }
 
   def aggregate(
@@ -102,8 +116,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
   }
 
   def transformValues(expr: Expression, f: (Expression, Expression) => 
Expression): Expression = {
-    val map = expr.dataType.asInstanceOf[MapType]
-    TransformValues(expr, createLambda(map.keyType, false, map.valueType, 
map.valueContainsNull, f))
+    val MapType(kt, vt, vcn) = expr.dataType
+    TransformValues(expr, createLambda(kt, false, vt, vcn, 
f)).bind(validateBinding)
   }
 
   test("ArrayTransform") {
@@ -149,8 +163,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
 
   test("MapFilter") {
     def mapFilter(expr: Expression, f: (Expression, Expression) => 
Expression): Expression = {
-      val mt = expr.dataType.asInstanceOf[MapType]
-      MapFilter(expr, createLambda(mt.keyType, false, mt.valueType, 
mt.valueContainsNull, f))
+      val MapType(kt, vt, vcn) = expr.dataType
+      MapFilter(expr, createLambda(kt, false, vt, vcn, 
f)).bind(validateBinding)
     }
     val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1),
       MapType(IntegerType, IntegerType, valueContainsNull = false))
@@ -230,8 +244,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
 
   test("ArrayExists") {
     def exists(expr: Expression, f: Expression => Expression): Expression = {
-      val at = expr.dataType.asInstanceOf[ArrayType]
-      ArrayExists(expr, createLambda(at.elementType, at.containsNull, f))
+      val ArrayType(et, cn) = expr.dataType
+      ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding)
     }
 
     val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull 
= false))
@@ -439,6 +453,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
       val MapType(kt, vt1, _) = left.dataType
       val MapType(_, vt2, _) = right.dataType
       MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f))
+        .bind(validateBinding)
     }
 
     val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30),
@@ -556,7 +571,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper
         f: (Expression, Expression) => Expression): Expression = {
       val ArrayType(leftT, _) = left.dataType
       val ArrayType(rightT, _) = right.dataType
-      ZipWith(left, right, createLambda(leftT, true, rightT, true, f))
+      ZipWith(left, right, createLambda(leftT, true, rightT, true, 
f)).bind(validateBinding)
     }
 
     val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull 
= false))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to