Repository: spark Updated Branches: refs/heads/master 2e3abdff2 -> b804ca577
[SPARK-23908][SQL][FOLLOW-UP] Rename inputs to arguments, and add argument type check. ## What changes were proposed in this pull request? This is a follow-up pr of #21954 to address comments. - Rename ambiguous name `inputs` to `arguments`. - Add argument type check and remove hacky workaround. - Address other small comments. ## How was this patch tested? Existing tests and some additional tests. Closes #22075 from ueshin/issues/SPARK-23908/fup1. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b804ca57 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b804ca57 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b804ca57 Branch: refs/heads/master Commit: b804ca57718ad1568458d8185c8c30118be8275f Parents: 2e3abdf Author: Takuya UESHIN <ues...@databricks.com> Authored: Mon Aug 13 20:58:29 2018 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Mon Aug 13 20:58:29 2018 +0800 ---------------------------------------------------------------------- .../sql/catalyst/analysis/CheckAnalysis.scala | 14 ++ .../analysis/higherOrderFunctions.scala | 12 +- .../expressions/ExpectsInputTypes.scala | 16 +- .../expressions/higherOrderFunctions.scala | 181 ++++++++++--------- .../spark/sql/catalyst/plans/PlanTest.scala | 2 +- .../spark/sql/DataFrameFunctionsSuite.scala | 25 +++ 6 files changed, 152 insertions(+), 98 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b804ca57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4addc83..6a91d55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -90,6 +90,20 @@ trait CheckAnalysis extends PredicateHelper { u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}") case operator: LogicalPlan => + // Check argument data types of higher-order functions downwards first. + // If the arguments of the higher-order functions are resolved but the type check fails, + // the argument functions will not get resolved, but we should report the argument type + // check failure instead of claiming the argument functions are unresolved. + operator transformExpressionsDown { + case hof: HigherOrderFunction + if hof.argumentsResolved && hof.checkArgumentDataTypes().isFailure => + hof.checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + hof.failAnalysis( + s"cannot resolve '${hof.sql}' due to argument data type mismatch: $message") + } + } + operator transformExpressionsUp { case a: Attribute if !a.resolved => val from = operator.inputSet.map(_.qualifiedName).mkString(", ") http://git-wip-us.apache.org/repos/asf/spark/blob/b804ca57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala index 5e2029c..dd08190 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -95,15 +95,15 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { */ private def createLambda( e: Expression, - partialArguments: Seq[(DataType, Boolean)]): LambdaFunction = e match { + argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match { case f: LambdaFunction if f.bound => f case LambdaFunction(function, names, _) => - if (names.size != partialArguments.size) { + if (names.size != argInfo.size) { e.failAnalysis( s"The number of lambda function arguments '${names.size}' does not " + "match the number of arguments expected by the higher order function " + - s"'${partialArguments.size}'.") + s"'${argInfo.size}'.") } if (names.map(a => canonicalizer(a.name)).distinct.size < names.size) { @@ -111,7 +111,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { "Lambda function arguments should not have names that are semantically the same.") } - val arguments = partialArguments.zip(names).map { + val arguments = argInfo.zip(names).map { case ((dataType, nullable), ne) => NamedLambdaVariable(ne.name, dataType, nullable) } @@ -122,7 +122,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { // create a lambda function with default parameters because this is expected by the higher // order function. Note that we hide the lambda variables produced by this function in order // to prevent accidental naming collisions. - val arguments = partialArguments.zipWithIndex.map { + val arguments = argInfo.zipWithIndex.map { case ((dataType, nullable), i) => NamedLambdaVariable(s"col$i", dataType, nullable) } @@ -135,7 +135,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { private def resolve(e: Expression, parentLambdaMap: LambdaVariableMap): Expression = e match { case _ if e.resolved => e - case h: HigherOrderFunction if h.inputResolved => + case h: HigherOrderFunction if h.argumentsResolved && h.checkArgumentDataTypes().isSuccess => h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap)) case l: LambdaFunction if !l.bound => http://git-wip-us.apache.org/repos/asf/spark/blob/b804ca57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index d8f046c..981ce0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -41,10 +41,19 @@ trait ExpectsInputTypes extends Expression { def inputTypes: Seq[AbstractDataType] override def checkInputDataTypes(): TypeCheckResult = { - val mismatches = children.zip(inputTypes).zipWithIndex.collect { - case ((child, expected), idx) if !expected.acceptsType(child.dataType) => + ExpectsInputTypes.checkInputDataTypes(children, inputTypes) + } +} + +object ExpectsInputTypes { + + def checkInputDataTypes( + inputs: Seq[Expression], + inputTypes: Seq[AbstractDataType]): TypeCheckResult = { + val mismatches = inputs.zip(inputTypes).zipWithIndex.collect { + case ((input, expected), idx) if !expected.acceptsType(input.dataType) => s"argument ${idx + 1} requires ${expected.simpleString} type, " + - s"however, '${child.sql}' is of ${child.dataType.catalogString} type." + s"however, '${input.sql}' is of ${input.dataType.catalogString} type." } if (mismatches.isEmpty) { @@ -55,7 +64,6 @@ trait ExpectsInputTypes extends Expression { } } - /** * A mixin for the analyzer to perform implicit type casting using * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts]]. http://git-wip-us.apache.org/repos/asf/spark/blob/b804ca57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 7f8203a..5d1b8c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -35,8 +35,8 @@ case class NamedLambdaVariable( name: String, dataType: DataType, nullable: Boolean, - value: AtomicReference[Any] = new AtomicReference(), - exprId: ExprId = NamedExpression.newExprId) + exprId: ExprId = NamedExpression.newExprId, + value: AtomicReference[Any] = new AtomicReference()) extends LeafExpression with NamedExpression with CodegenFallback { @@ -44,7 +44,7 @@ case class NamedLambdaVariable( override def qualifier: Seq[String] = Seq.empty override def newInstance(): NamedExpression = - copy(value = new AtomicReference(), exprId = NamedExpression.newExprId) + copy(exprId = NamedExpression.newExprId, value = new AtomicReference()) override def toAttribute: Attribute = { AttributeReference(name, dataType, nullable, Metadata.empty)(exprId, Seq.empty) @@ -88,30 +88,45 @@ object LambdaFunction { * A higher order function takes one or more (lambda) functions and applies these to some objects. * The function produces a number of variables which can be consumed by some lambda function. */ -trait HigherOrderFunction extends Expression { +trait HigherOrderFunction extends Expression with ExpectsInputTypes { - override def children: Seq[Expression] = inputs ++ functions + override def children: Seq[Expression] = arguments ++ functions /** - * Inputs to the higher ordered function. + * Arguments of the higher ordered function. */ - def inputs: Seq[Expression] + def arguments: Seq[Expression] + + def argumentTypes: Seq[AbstractDataType] /** - * All inputs have been resolved. This means that the types and nullabilty of (most of) the + * All arguments have been resolved. This means that the types and nullabilty of (most of) the * lambda function arguments is known, and that we can start binding the lambda functions. */ - lazy val inputResolved: Boolean = inputs.forall(_.resolved) + lazy val argumentsResolved: Boolean = arguments.forall(_.resolved) + + /** + * Checks the argument data types, returns `TypeCheckResult.success` if it's valid, + * or returns a `TypeCheckResult` with an error message if invalid. + * Note: it's not valid to call this method until `argumentsResolved == true`. + */ + def checkArgumentDataTypes(): TypeCheckResult = { + ExpectsInputTypes.checkInputDataTypes(arguments, argumentTypes) + } /** * Functions applied by the higher order function. */ def functions: Seq[Expression] + def functionTypes: Seq[AbstractDataType] + + override def inputTypes: Seq[AbstractDataType] = argumentTypes ++ functionTypes + /** * All inputs must be resolved and all functions must be resolved lambda functions. */ - override lazy val resolved: Boolean = inputResolved && functions.forall { + override lazy val resolved: Boolean = argumentsResolved && functions.forall { case l: LambdaFunction => l.resolved case _ => false } @@ -123,6 +138,8 @@ trait HigherOrderFunction extends Expression { */ def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): HigherOrderFunction + // Make sure the lambda variables refer the same instances as of arguments for case that the + // variables in instantiated separately during serialization or for some reason. @transient lazy val functionsForEval: Seq[Expression] = functions.map { case LambdaFunction(function, arguments, hidden) => val argumentMap = arguments.map { arg => arg.exprId -> arg }.toMap @@ -133,51 +150,38 @@ trait HigherOrderFunction extends Expression { } } -object HigherOrderFunction { - - def arrayArgumentType(dt: DataType): (DataType, Boolean) = { - dt match { - case ArrayType(elementType, containsNull) => (elementType, containsNull) - case _ => - val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType - (elementType, containsNull) - } - } - - def mapKeyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = dt match { - case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull) - case _ => - val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType - (kType, vType, vContainsNull) - } -} - /** * Trait for functions having as input one argument and one function. */ -trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { +trait SimpleHigherOrderFunction extends HigherOrderFunction { + + def argument: Expression - def input: Expression + override def arguments: Seq[Expression] = argument :: Nil - override def inputs: Seq[Expression] = input :: Nil + def argumentType: AbstractDataType + + override def argumentTypes(): Seq[AbstractDataType] = argumentType :: Nil def function: Expression override def functions: Seq[Expression] = function :: Nil - def expectingFunctionType: AbstractDataType = AnyDataType + def functionType: AbstractDataType = AnyDataType + + override def functionTypes: Seq[AbstractDataType] = functionType :: Nil - @transient lazy val functionForEval: Expression = functionsForEval.head + def functionForEval: Expression = functionsForEval.head /** * Called by [[eval]]. If a subclass keeps the default nullability, it can override this method * in order to save null-check code. */ - protected def nullSafeEval(inputRow: InternalRow, input: Any): Any = + protected def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval") override def eval(inputRow: InternalRow): Any = { - val value = input.eval(inputRow) + val value = argument.eval(inputRow) if (value == null) { null } else { @@ -187,11 +191,11 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTyp } trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { - override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType) + override def argumentType: AbstractDataType = ArrayType } trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { - override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) + override def argumentType: AbstractDataType = MapType } /** @@ -209,21 +213,21 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction { """, since = "2.4.0") case class ArrayTransform( - input: Expression, + argument: Expression, function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { - val elem = HigherOrderFunction.arrayArgumentType(input.dataType) + val ArrayType(elementType, containsNull) = argument.dataType function match { case LambdaFunction(_, arguments, _) if arguments.size == 2 => - copy(function = f(function, elem :: (IntegerType, false) :: Nil)) + copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) case _ => - copy(function = f(function, elem :: Nil)) + copy(function = f(function, (elementType, containsNull) :: Nil)) } } @@ -237,8 +241,8 @@ case class ArrayTransform( (elementVar, indexVar) } - override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = { - val arr = inputValue.asInstanceOf[ArrayData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval val result = new GenericArrayData(new Array[Any](arr.numElements)) var i = 0 @@ -268,7 +272,7 @@ examples = """ """, since = "2.4.0") case class MapFilter( - input: Expression, + argument: Expression, function: Expression) extends MapBasedSimpleHigherOrderFunction with CodegenFallback { @@ -277,17 +281,16 @@ case class MapFilter( (args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable]) } - @transient val (keyType, valueType, valueContainsNull) = - HigherOrderFunction.mapKeyValueArgumentType(input.dataType) + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable - override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { - val m = value.asInstanceOf[MapData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val m = argumentValue.asInstanceOf[MapData] val f = functionForEval val retKeys = new mutable.ListBuffer[Any] val retValues = new mutable.ListBuffer[Any] @@ -302,9 +305,9 @@ case class MapFilter( ArrayBasedMapData(retKeys.toArray, retValues.toArray) } - override def dataType: DataType = input.dataType + override def dataType: DataType = argument.dataType - override def expectingFunctionType: AbstractDataType = BooleanType + override def functionType: AbstractDataType = BooleanType override def prettyName: String = "map_filter" } @@ -321,25 +324,25 @@ case class MapFilter( """, since = "2.4.0") case class ArrayFilter( - input: Expression, + argument: Expression, function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable - override def dataType: DataType = input.dataType + override def dataType: DataType = argument.dataType - override def expectingFunctionType: AbstractDataType = BooleanType + override def functionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { - val elem = HigherOrderFunction.arrayArgumentType(input.dataType) - copy(function = f(function, elem :: Nil)) + val ArrayType(elementType, containsNull) = argument.dataType + copy(function = f(function, (elementType, containsNull) :: Nil)) } @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function - override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { - val arr = value.asInstanceOf[ArrayData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval val buffer = new mutable.ArrayBuffer[Any](arr.numElements) var i = 0 @@ -368,25 +371,25 @@ case class ArrayFilter( """, since = "2.4.0") case class ArrayExists( - input: Expression, + argument: Expression, function: Expression) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable override def dataType: DataType = BooleanType - override def expectingFunctionType: AbstractDataType = BooleanType + override def functionType: AbstractDataType = BooleanType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayExists = { - val elem = HigherOrderFunction.arrayArgumentType(input.dataType) - copy(function = f(function, elem :: Nil)) + val ArrayType(elementType, containsNull) = argument.dataType + copy(function = f(function, (elementType, containsNull) :: Nil)) } @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function - override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { - val arr = value.asInstanceOf[ArrayData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val arr = argumentValue.asInstanceOf[ArrayData] val f = functionForEval var exists = false var i = 0 @@ -422,45 +425,49 @@ case class ArrayExists( """, since = "2.4.0") case class ArrayAggregate( - input: Expression, + argument: Expression, zero: Expression, merge: Expression, finish: Expression) extends HigherOrderFunction with CodegenFallback { - def this(input: Expression, zero: Expression, merge: Expression) = { - this(input, zero, merge, LambdaFunction.identity) + def this(argument: Expression, zero: Expression, merge: Expression) = { + this(argument, zero, merge, LambdaFunction.identity) } - override def inputs: Seq[Expression] = input :: zero :: Nil + override def arguments: Seq[Expression] = argument :: zero :: Nil + + override def argumentTypes: Seq[AbstractDataType] = ArrayType :: AnyDataType :: Nil override def functions: Seq[Expression] = merge :: finish :: Nil - override def nullable: Boolean = input.nullable || finish.nullable + override def functionTypes: Seq[AbstractDataType] = zero.dataType :: AnyDataType :: Nil + + override def nullable: Boolean = argument.nullable || finish.nullable override def dataType: DataType = finish.dataType override def checkInputDataTypes(): TypeCheckResult = { - if (!ArrayType.acceptsType(input.dataType)) { - TypeCheckResult.TypeCheckFailure( - s"argument 1 requires ${ArrayType.simpleString} type, " + - s"however, '${input.sql}' is of ${input.dataType.catalogString} type.") - } else if (!DataType.equalsStructurally( - zero.dataType, merge.dataType, ignoreNullability = true)) { - TypeCheckResult.TypeCheckFailure( - s"argument 3 requires ${zero.dataType.simpleString} type, " + - s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.") - } else { - TypeCheckResult.TypeCheckSuccess + checkArgumentDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (!DataType.equalsStructurally( + zero.dataType, merge.dataType, ignoreNullability = true)) { + TypeCheckResult.TypeCheckFailure( + s"argument 3 requires ${zero.dataType.simpleString} type, " + + s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.") + } else { + TypeCheckResult.TypeCheckSuccess + } + case failure => failure } } override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = { // Be very conservative with nullable. We cannot be sure that the accumulator does not // evaluate to null. So we always set nullable to true here. - val elem = HigherOrderFunction.arrayArgumentType(input.dataType) + val ArrayType(elementType, containsNull) = argument.dataType val acc = zero.dataType -> true - val newMerge = f(merge, acc :: elem :: Nil) + val newMerge = f(merge, acc :: (elementType, containsNull) :: Nil) val newFinish = f(finish, acc :: Nil) copy(merge = newMerge, finish = newFinish) } @@ -470,7 +477,7 @@ case class ArrayAggregate( @transient lazy val LambdaFunction(_, Seq(accForFinishVar: NamedLambdaVariable), _) = finish override def eval(input: InternalRow): Any = { - val arr = this.input.eval(input).asInstanceOf[ArrayData] + val arr = argument.eval(input).asInstanceOf[ArrayData] if (arr == null) { null } else { http://git-wip-us.apache.org/repos/asf/spark/blob/b804ca57/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 9e95b19..67740c3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -81,7 +81,7 @@ trait PlanTestBase extends PredicateHelper { self: Suite => case ae: AggregateExpression => ae.copy(resultId = ExprId(0)) case lv: NamedLambdaVariable => - lv.copy(value = null, exprId = ExprId(0)) + lv.copy(exprId = ExprId(0), value = null) } } http://git-wip-us.apache.org/repos/asf/spark/blob/b804ca57/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 2c4238e..6401e3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1852,6 +1852,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("transform(i, x -> x)") } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("transform(a, x -> x)") + } + assert(ex3.getMessage.contains("cannot resolve '`a`'")) } test("map_filter") { @@ -1898,6 +1903,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("map_filter(i, (k, v) -> k > v)") } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("map_filter(a, (k, v) -> k > v)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) } test("filter function - array for primitive type not containing null") { @@ -1994,6 +2004,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("filter(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("filter(a, x -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) } test("exists function - array for primitive type not containing null") { @@ -2090,6 +2105,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("exists(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + + val ex4 = intercept[AnalysisException] { + df.selectExpr("exists(a, x -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) } test("aggregate function - array for primitive type not containing null") { @@ -2211,6 +2231,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("aggregate(s, 0, (acc, x) -> x)") } assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) + + val ex5 = intercept[AnalysisException] { + df.selectExpr("aggregate(a, 0, (acc, x) -> x)") + } + assert(ex5.getMessage.contains("cannot resolve '`a`'")) } private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org