This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7021588  [SPARK-28306][SQL] Make NormalizeFloatingNumbers rule 
idempotent
7021588 is described below

commit 7021588ba8365b2018f9f7b5d3b05285886cd0be
Author: Yesheng Ma <[email protected]>
AuthorDate: Thu Jul 11 10:22:00 2019 +0800

    [SPARK-28306][SQL] Make NormalizeFloatingNumbers rule idempotent
    
    ## What changes were proposed in this pull request?
    The optimizer rule `NormalizeFloatingNumbers` is not idempotent. It will 
generate multiple `NormalizeNaNAndZero` and `ArrayTransform` expression nodes 
for multiple runs. This patch fixed this non-idempotence by adding a marking 
tag above normalized expressions. It also adds missing UTs for 
`NormalizeFloatingNumbers`.
    
    ## How was this patch tested?
    New UTs.
    
    Closes #25080 from yeshengm/spark-28306.
    
    Authored-by: Yesheng Ma <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../expressions/constraintExpressions.scala        | 18 +++--
 .../optimizer/NormalizeFloatingNumbers.scala       | 20 ++++--
 .../NormalizeFloatingPointNumbersSuite.scala       | 82 ++++++++++++++++++++++
 3 files changed, 108 insertions(+), 12 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
index 2917b0b..5bfae7b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
@@ -21,15 +21,21 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode, FalseLiteral}
 import org.apache.spark.sql.types.DataType
 
-case class KnownNotNull(child: Expression) extends UnaryExpression {
-  override def nullable: Boolean = false
+trait TaggingExpression extends UnaryExpression {
+  override def nullable: Boolean = child.nullable
   override def dataType: DataType = child.dataType
 
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = child.genCode(ctx)
+
+  override def eval(input: InternalRow): Any = child.eval(input)
+}
+
+case class KnownNotNull(child: Expression) extends TaggingExpression {
+  override def nullable: Boolean = false
+
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
     child.genCode(ctx).copy(isNull = FalseLiteral)
   }
-
-  override def eval(input: InternalRow): Any = {
-    child.eval(input)
-  }
 }
+
+case class KnownFloatingPointNormalized(child: Expression) extends 
TaggingExpression
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index a5921eb..b036092 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
-import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, 
CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, 
CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, 
LambdaFunction, NamedLambdaVariable, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, 
CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, 
CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, 
KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, 
UnaryExpression}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, 
Window}
@@ -61,7 +61,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
     case _: Subquery => plan
 
     case _ => plan transform {
-      case w: Window if w.partitionSpec.exists(p => needNormalize(p.dataType)) 
=>
+      case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
         // Although the `windowExpressions` may refer to `partitionSpec` 
expressions, we don't need
         // to normalize the `windowExpressions`, as they are executed per 
input row and should take
         // the input row as it is.
@@ -73,7 +73,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
       case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _)
           // The analyzer guarantees left and right joins keys are of the same 
data type. Here we
           // only need to check join keys of one side.
-          if leftKeys.exists(k => needNormalize(k.dataType)) =>
+          if leftKeys.exists(k => needNormalize(k)) =>
         val newLeftJoinKeys = leftKeys.map(normalize)
         val newRightJoinKeys = rightKeys.map(normalize)
         val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
@@ -87,6 +87,14 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
     }
   }
 
+  /**
+   * Short circuit if the underlying expression is already normalized
+   */
+  private def needNormalize(expr: Expression): Boolean = expr match {
+    case KnownFloatingPointNormalized(_) => false
+    case _ => needNormalize(expr.dataType)
+  }
+
   private def needNormalize(dt: DataType): Boolean = dt match {
     case FloatType | DoubleType => true
     case StructType(fields) => fields.exists(f => needNormalize(f.dataType))
@@ -98,7 +106,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
   }
 
   private[sql] def normalize(expr: Expression): Expression = expr match {
-    case _ if !needNormalize(expr.dataType) => expr
+    case _ if !needNormalize(expr) => expr
 
     case a: Alias =>
       a.withNewChildren(Seq(normalize(a.child)))
@@ -116,7 +124,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
       CreateMap(children.map(normalize))
 
     case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
-      NormalizeNaNAndZero(expr)
+      KnownFloatingPointNormalized(NormalizeNaNAndZero(expr))
 
     case _ if expr.dataType.isInstanceOf[StructType] =>
       val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { 
i =>
@@ -128,7 +136,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
       val ArrayType(et, containsNull) = expr.dataType
       val lv = NamedLambdaVariable("arg", et, containsNull)
       val function = normalize(lv)
-      ArrayTransform(expr, LambdaFunction(function, Seq(lv)))
+      KnownFloatingPointNormalized(ArrayTransform(expr, 
LambdaFunction(function, Seq(lv))))
 
     case _ => throw new IllegalStateException(s"fail to normalize $expr")
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
new file mode 100644
index 0000000..5f616da
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class NormalizeFloatingPointNumbersSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches = Batch("NormalizeFloatingPointNumbers", Once, 
NormalizeFloatingNumbers) :: Nil
+  }
+
+  val testRelation1 = LocalRelation('a.double)
+  val a = testRelation1.output(0)
+  val testRelation2 = LocalRelation('a.double)
+  val b = testRelation2.output(0)
+
+  test("normalize floating points in window function expressions") {
+    val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc))
+
+    val optimized = Optimize.execute(query)
+    val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")),
+      Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc))
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("normalize floating points in window function expressions - 
idempotence") {
+    val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc))
+
+    val optimized = Optimize.execute(query)
+    val doubleOptimized = Optimize.execute(optimized)
+    val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")),
+      Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc))
+
+    comparePlans(doubleOptimized, correctAnswer)
+  }
+
+  test("normalize floating points in join keys") {
+    val query = testRelation1.join(testRelation2, condition = Some(a === b))
+
+    val optimized = Optimize.execute(query)
+    val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))
+        === KnownFloatingPointNormalized(NormalizeNaNAndZero(b)))
+    val correctAnswer = testRelation1.join(testRelation2, condition = joinCond)
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("normalize floating points in join keys - idempotence") {
+    val query = testRelation1.join(testRelation2, condition = Some(a === b))
+
+    val optimized = Optimize.execute(query)
+    val doubleOptimized = Optimize.execute(optimized)
+    val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))
+      === KnownFloatingPointNormalized(NormalizeNaNAndZero(b)))
+    val correctAnswer = testRelation1.join(testRelation2, condition = joinCond)
+
+    comparePlans(doubleOptimized, correctAnswer)
+  }
+}
+


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to