Repository: spark Updated Branches: refs/heads/master a8a1ac01c -> 6b8fbbfb1
[SPARK-25141][SQL][TEST] Modify tests for higher-order functions to check bind method. ## What changes were proposed in this pull request? We should also check `HigherOrderFunction.bind` method passes expected parameters. This pr modifies tests for higher-order functions to check `bind` method. ## How was this patch tested? Modified tests. Closes #22131 from ueshin/issues/SPARK-25141/bind_test. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6b8fbbfb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6b8fbbfb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6b8fbbfb Branch: refs/heads/master Commit: 6b8fbbfb110601ffc3343b08113d13267baf27bf Parents: a8a1ac0 Author: Takuya UESHIN <ues...@databricks.com> Authored: Sun Aug 19 09:18:47 2018 +0900 Committer: Takuya UESHIN <ues...@databricks.com> Committed: Sun Aug 19 09:18:47 2018 +0900 ---------------------------------------------------------------------- .../expressions/HigherOrderFunctionsSuite.scala | 49 +++++++++++++------- 1 file changed, 32 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6b8fbbfb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index ea85c21..e13f4d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -60,24 +60,37 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper LambdaFunction(function, Seq(lv1, lv2, lv3)) } + private def validateBinding( + e: Expression, + argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { + case f: LambdaFunction => + assert(f.arguments.size === argInfo.size) + f.arguments.zip(argInfo).foreach { + case (arg, (dataType, nullable)) => + assert(arg.dataType === dataType) + assert(arg.nullable === nullable) + } + f + } + def transform(expr: Expression, f: Expression => Expression): Expression = { - val at = expr.dataType.asInstanceOf[ArrayType] - ArrayTransform(expr, createLambda(at.elementType, at.containsNull, f)) + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding) } def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val at = expr.dataType.asInstanceOf[ArrayType] - ArrayTransform(expr, createLambda(at.elementType, at.containsNull, IntegerType, false, f)) + val ArrayType(et, cn) = expr.dataType + ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding) } def filter(expr: Expression, f: Expression => Expression): Expression = { - val at = expr.dataType.asInstanceOf[ArrayType] - ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f)) + val ArrayType(et, cn) = expr.dataType + ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding) } def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val map = expr.dataType.asInstanceOf[MapType] - TransformKeys(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f)) + val MapType(kt, vt, vcn) = expr.dataType + TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) } def aggregate( @@ -85,13 +98,14 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper zero: Expression, merge: (Expression, Expression) => Expression, finish: Expression => Expression): Expression = { - val at = expr.dataType.asInstanceOf[ArrayType] + val ArrayType(et, cn) = expr.dataType val zeroType = zero.dataType ArrayAggregate( expr, zero, - createLambda(zeroType, true, at.elementType, at.containsNull, merge), + createLambda(zeroType, true, et, cn, merge), createLambda(zeroType, true, finish)) + .bind(validateBinding) } def aggregate( @@ -102,8 +116,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val map = expr.dataType.asInstanceOf[MapType] - TransformValues(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f)) + val MapType(kt, vt, vcn) = expr.dataType + TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) } test("ArrayTransform") { @@ -149,8 +163,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper test("MapFilter") { def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val mt = expr.dataType.asInstanceOf[MapType] - MapFilter(expr, createLambda(mt.keyType, false, mt.valueType, mt.valueContainsNull, f)) + val MapType(kt, vt, vcn) = expr.dataType + MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding) } val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1), MapType(IntegerType, IntegerType, valueContainsNull = false)) @@ -230,8 +244,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper test("ArrayExists") { def exists(expr: Expression, f: Expression => Expression): Expression = { - val at = expr.dataType.asInstanceOf[ArrayType] - ArrayExists(expr, createLambda(at.elementType, at.containsNull, f)) + val ArrayType(et, cn) = expr.dataType + ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding) } val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) @@ -439,6 +453,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val MapType(kt, vt1, _) = left.dataType val MapType(_, vt2, _) = right.dataType MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f)) + .bind(validateBinding) } val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30), @@ -556,7 +571,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper f: (Expression, Expression) => Expression): Expression = { val ArrayType(leftT, _) = left.dataType val ArrayType(rightT, _) = right.dataType - ZipWith(left, right, createLambda(leftT, true, rightT, true, f)) + ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding) } val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org