This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7021588 [SPARK-28306][SQL] Make NormalizeFloatingNumbers rule
idempotent
7021588 is described below
commit 7021588ba8365b2018f9f7b5d3b05285886cd0be
Author: Yesheng Ma <[email protected]>
AuthorDate: Thu Jul 11 10:22:00 2019 +0800
[SPARK-28306][SQL] Make NormalizeFloatingNumbers rule idempotent
## What changes were proposed in this pull request?
The optimizer rule `NormalizeFloatingNumbers` is not idempotent. It will
generate multiple `NormalizeNaNAndZero` and `ArrayTransform` expression nodes
for multiple runs. This patch fixed this non-idempotence by adding a marking
tag above normalized expressions. It also adds missing UTs for
`NormalizeFloatingNumbers`.
## How was this patch tested?
New UTs.
Closes #25080 from yeshengm/spark-28306.
Authored-by: Yesheng Ma <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../expressions/constraintExpressions.scala | 18 +++--
.../optimizer/NormalizeFloatingNumbers.scala | 20 ++++--
.../NormalizeFloatingPointNumbersSuite.scala | 82 ++++++++++++++++++++++
3 files changed, 108 insertions(+), 12 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
index 2917b0b..5bfae7b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
@@ -21,15 +21,21 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode, FalseLiteral}
import org.apache.spark.sql.types.DataType
-case class KnownNotNull(child: Expression) extends UnaryExpression {
- override def nullable: Boolean = false
+trait TaggingExpression extends UnaryExpression {
+ override def nullable: Boolean = child.nullable
override def dataType: DataType = child.dataType
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = child.genCode(ctx)
+
+ override def eval(input: InternalRow): Any = child.eval(input)
+}
+
+case class KnownNotNull(child: Expression) extends TaggingExpression {
+ override def nullable: Boolean = false
+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
child.genCode(ctx).copy(isNull = FalseLiteral)
}
-
- override def eval(input: InternalRow): Any = {
- child.eval(input)
- }
}
+
+case class KnownFloatingPointNormalized(child: Expression) extends
TaggingExpression
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index a5921eb..b036092 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform,
CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe,
CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField,
LambdaFunction, NamedLambdaVariable, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform,
CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe,
CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField,
KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable,
UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery,
Window}
@@ -61,7 +61,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case _: Subquery => plan
case _ => plan transform {
- case w: Window if w.partitionSpec.exists(p => needNormalize(p.dataType))
=>
+ case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
// Although the `windowExpressions` may refer to `partitionSpec`
expressions, we don't need
// to normalize the `windowExpressions`, as they are executed per
input row and should take
// the input row as it is.
@@ -73,7 +73,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _)
// The analyzer guarantees left and right joins keys are of the same
data type. Here we
// only need to check join keys of one side.
- if leftKeys.exists(k => needNormalize(k.dataType)) =>
+ if leftKeys.exists(k => needNormalize(k)) =>
val newLeftJoinKeys = leftKeys.map(normalize)
val newRightJoinKeys = rightKeys.map(normalize)
val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
@@ -87,6 +87,14 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
}
}
+ /**
+ * Short circuit if the underlying expression is already normalized
+ */
+ private def needNormalize(expr: Expression): Boolean = expr match {
+ case KnownFloatingPointNormalized(_) => false
+ case _ => needNormalize(expr.dataType)
+ }
+
private def needNormalize(dt: DataType): Boolean = dt match {
case FloatType | DoubleType => true
case StructType(fields) => fields.exists(f => needNormalize(f.dataType))
@@ -98,7 +106,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
}
private[sql] def normalize(expr: Expression): Expression = expr match {
- case _ if !needNormalize(expr.dataType) => expr
+ case _ if !needNormalize(expr) => expr
case a: Alias =>
a.withNewChildren(Seq(normalize(a.child)))
@@ -116,7 +124,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
CreateMap(children.map(normalize))
case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
- NormalizeNaNAndZero(expr)
+ KnownFloatingPointNormalized(NormalizeNaNAndZero(expr))
case _ if expr.dataType.isInstanceOf[StructType] =>
val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map {
i =>
@@ -128,7 +136,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
val ArrayType(et, containsNull) = expr.dataType
val lv = NamedLambdaVariable("arg", et, containsNull)
val function = normalize(lv)
- ArrayTransform(expr, LambdaFunction(function, Seq(lv)))
+ KnownFloatingPointNormalized(ArrayTransform(expr,
LambdaFunction(function, Seq(lv))))
case _ => throw new IllegalStateException(s"fail to normalize $expr")
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
new file mode 100644
index 0000000..5f616da
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class NormalizeFloatingPointNumbersSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = Batch("NormalizeFloatingPointNumbers", Once,
NormalizeFloatingNumbers) :: Nil
+ }
+
+ val testRelation1 = LocalRelation('a.double)
+ val a = testRelation1.output(0)
+ val testRelation2 = LocalRelation('a.double)
+ val b = testRelation2.output(0)
+
+ test("normalize floating points in window function expressions") {
+ val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc))
+
+ val optimized = Optimize.execute(query)
+ val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")),
+ Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc))
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("normalize floating points in window function expressions -
idempotence") {
+ val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc))
+
+ val optimized = Optimize.execute(query)
+ val doubleOptimized = Optimize.execute(optimized)
+ val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")),
+ Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc))
+
+ comparePlans(doubleOptimized, correctAnswer)
+ }
+
+ test("normalize floating points in join keys") {
+ val query = testRelation1.join(testRelation2, condition = Some(a === b))
+
+ val optimized = Optimize.execute(query)
+ val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))
+ === KnownFloatingPointNormalized(NormalizeNaNAndZero(b)))
+ val correctAnswer = testRelation1.join(testRelation2, condition = joinCond)
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("normalize floating points in join keys - idempotence") {
+ val query = testRelation1.join(testRelation2, condition = Some(a === b))
+
+ val optimized = Optimize.execute(query)
+ val doubleOptimized = Optimize.execute(optimized)
+ val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))
+ === KnownFloatingPointNormalized(NormalizeNaNAndZero(b)))
+ val correctAnswer = testRelation1.join(testRelation2, condition = joinCond)
+
+ comparePlans(doubleOptimized, correctAnswer)
+ }
+}
+
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]