Repository: spark Updated Branches: refs/heads/branch-1.0 3aa52be39 -> b459aa77f
[SPARK-2327] [SQL] Fix nullabilities of Join/Generate/Aggregate. Fix nullabilities of `Join`/`Generate`/`Aggregate` because: - Output attributes of opposite side of `OuterJoin` should be nullable. - Output attributes of generater side of `Generate` should be nullable if `join` is `true` and `outer` is `true`. - `AttributeReference` of `computedAggregates` of `Aggregate` should be the same as `aggregateExpression`'s. Author: Takuya UESHIN <[email protected]> Closes #1266 from ueshin/issues/SPARK-2327 and squashes the following commits: 3ace83a [Takuya UESHIN] Add withNullability to Attribute and use it to change nullabilities. df1ae53 [Takuya UESHIN] Modify nullabilize to leave attribute if not resolved. 799ce56 [Takuya UESHIN] Add nullabilization to Generate of SparkPlan. a0fc9bc [Takuya UESHIN] Fix scalastyle errors. 0e31e37 [Takuya UESHIN] Fix Aggregate resultAttribute nullabilities. 09532ec [Takuya UESHIN] Fix Generate output nullabilities. f20f196 [Takuya UESHIN] Fix Join output nullabilities. (cherry picked from commit 9d5ecf8205b924dc8a3c13fed68beb78cc5c7553) Signed-off-by: Michael Armbrust <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b459aa77 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b459aa77 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b459aa77 Branch: refs/heads/branch-1.0 Commit: b459aa77f63d0a469dc20e0ef555cf94382f41ca Parents: 3aa52be Author: Takuya UESHIN <[email protected]> Authored: Sat Jul 5 11:51:48 2014 -0700 Committer: Michael Armbrust <[email protected]> Committed: Sat Jul 5 11:52:11 2014 -0700 ---------------------------------------------------------------------- .../sql/catalyst/analysis/unresolved.scala | 2 ++ .../catalyst/expressions/BoundAttribute.scala | 16 +++++----- .../catalyst/expressions/namedExpressions.scala | 3 +- .../catalyst/plans/logical/basicOperators.scala | 31 +++++++++++++++----- .../apache/spark/sql/execution/Aggregate.scala | 4 +-- .../apache/spark/sql/execution/Generate.scala | 12 ++++++-- .../org/apache/spark/sql/execution/joins.scala | 13 +++++++- 7 files changed, 60 insertions(+), 21 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b459aa77/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index d629172..7abeb03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -52,6 +52,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override lazy val resolved = false override def newInstance = this + override def withNullability(newNullability: Boolean) = this override def withQualifiers(newQualifiers: Seq[String]) = this // Unresolved attributes are transient at compile time and don't get evaluated during execution. @@ -95,6 +96,7 @@ case class Star( override lazy val resolved = false override def newInstance = this + override def withNullability(newNullability: Boolean) = this override def withQualifiers(newQualifiers: Seq[String]) = this def expand(input: Seq[Attribute]): Seq[NamedExpression] = { http://git-wip-us.apache.org/repos/asf/spark/blob/b459aa77/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 655d4a0..9ce1f01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -33,14 +33,16 @@ case class BoundReference(ordinal: Int, baseReference: Attribute) type EvaluatedType = Any - def nullable = baseReference.nullable - def dataType = baseReference.dataType - def exprId = baseReference.exprId - def qualifiers = baseReference.qualifiers - def name = baseReference.name + override def nullable = baseReference.nullable + override def dataType = baseReference.dataType + override def exprId = baseReference.exprId + override def qualifiers = baseReference.qualifiers + override def name = baseReference.name - def newInstance = BoundReference(ordinal, baseReference.newInstance) - def withQualifiers(newQualifiers: Seq[String]) = + override def newInstance = BoundReference(ordinal, baseReference.newInstance) + override def withNullability(newNullability: Boolean) = + BoundReference(ordinal, baseReference.withNullability(newNullability)) + override def withQualifiers(newQualifiers: Seq[String]) = BoundReference(ordinal, baseReference.withQualifiers(newQualifiers)) override def toString = s"$baseReference:$ordinal" http://git-wip-us.apache.org/repos/asf/spark/blob/b459aa77/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 66ae22e..934bad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -57,6 +57,7 @@ abstract class NamedExpression extends Expression { abstract class Attribute extends NamedExpression { self: Product => + def withNullability(newNullability: Boolean): Attribute def withQualifiers(newQualifiers: Seq[String]): Attribute def toAttribute = this @@ -133,7 +134,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea /** * Returns a copy of this [[AttributeReference]] with changed nullability. */ - def withNullability(newNullability: Boolean) = { + override def withNullability(newNullability: Boolean) = { if (nullable == newNullability) { this } else { http://git-wip-us.apache.org/repos/asf/spark/blob/b459aa77/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 3e06398..b51a02d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.types._ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { @@ -46,10 +46,16 @@ case class Generate( child: LogicalPlan) extends UnaryNode { - protected def generatorOutput: Seq[Attribute] = - alias + protected def generatorOutput: Seq[Attribute] = { + val output = alias .map(a => generator.output.map(_.withQualifiers(a :: Nil))) .getOrElse(generator.output) + if (join && outer) { + output.map(_.withNullability(true)) + } else { + output + } + } override def output = if (join) child.output ++ generatorOutput else generatorOutput @@ -81,11 +87,20 @@ case class Join( condition: Option[Expression]) extends BinaryNode { override def references = condition.map(_.references).getOrElse(Set.empty) - override def output = joinType match { - case LeftSemi => - left.output - case _ => - left.output ++ right.output + + override def output = { + joinType match { + case LeftSemi => + left.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } } } http://git-wip-us.apache.org/repos/asf/spark/blob/b459aa77/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index d85d2d7..c1ced8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -83,8 +83,8 @@ case class Aggregate( case a: AggregateExpression => ComputedAggregate( a, - BindReferences.bindReference(a, childOutput).asInstanceOf[AggregateExpression], - AttributeReference(s"aggResult:$a", a.dataType, nullable = true)()) + BindReferences.bindReference(a, childOutput), + AttributeReference(s"aggResult:$a", a.dataType, a.nullable)()) } }.toArray http://git-wip-us.apache.org/repos/asf/spark/blob/b459aa77/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index da1e08b..47b3d00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{Generator, JoinedRow, Literal, Projection} +import org.apache.spark.sql.catalyst.expressions._ /** * :: DeveloperApi :: @@ -39,8 +39,16 @@ case class Generate( child: SparkPlan) extends UnaryNode { + protected def generatorOutput: Seq[Attribute] = { + if (join && outer) { + generator.output.map(_.withNullability(true)) + } else { + generator.output + } + } + override def output = - if (join) child.output ++ generator.output else generator.output + if (join) child.output ++ generatorOutput else generatorOutput override def execute() = { if (join) { http://git-wip-us.apache.org/repos/asf/spark/blob/b459aa77/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 84bdde3..4797cd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -271,7 +271,18 @@ case class BroadcastNestedLoopJoin( override def otherCopyArgs = sqlContext :: Nil - def output = left.output ++ right.output + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } + } /** The Streamed Relation */ def left = streamed
