This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e2796d2  [SPARK-38227][SQL][SS] Apply strict nullability of nested 
column in time window / session window
e2796d2 is described below

commit e2796d2d2f9069119b273fa9d7a777eca87fa015
Author: Jungtaek Lim <kabhwan.opensou...@gmail.com>
AuthorDate: Sun Feb 20 21:44:58 2022 -0800

    [SPARK-38227][SQL][SS] Apply strict nullability of nested column in time 
window / session window
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to apply strict nullability of nested column in window 
struct for both time window and session window, which respects the dataType of 
TimeWindow and SessionWindow.
    
    ### Why are the changes needed?
    
    The implementation of rule TimeWindowing and SessionWindowing have been 
exposed the possible risks of inconsistency between the dataType of 
TimeWindow/SessionWindow and the replacement. For the replacement, it is 
possible that analyzer/optimizer may decide the value expressions to be 
non-nullable.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New tests added.
    
    Closes #35543 from HeartSaVioR/SPARK-38227.
    
    Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 14 +++--
 .../apache/spark/sql/catalyst/dsl/package.scala    |  8 +++
 .../expressions/constraintExpressions.scala        | 11 ++++
 .../spark/sql/DataFrameSessionWindowingSuite.scala | 61 ++++++++++++++++++++++
 .../spark/sql/DataFrameTimeWindowingSuite.scala    | 48 +++++++++++++++++
 5 files changed, 138 insertions(+), 4 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c560062..5cb5f21 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3893,11 +3893,13 @@ object TimeWindowing extends Rule[LogicalPlan] {
           val windowStart = lastStart - i * window.slideDuration
           val windowEnd = windowStart + window.windowDuration
 
+          // We make sure value fields are nullable since the dataType of 
TimeWindow defines them
+          // as nullable.
           CreateNamedStruct(
             Literal(WINDOW_START) ::
-              PreciseTimestampConversion(windowStart, LongType, dataType) ::
+              PreciseTimestampConversion(windowStart, LongType, 
dataType).castNullable() ::
               Literal(WINDOW_END) ::
-              PreciseTimestampConversion(windowEnd, LongType, dataType) ::
+              PreciseTimestampConversion(windowEnd, LongType, 
dataType).castNullable() ::
               Nil)
         }
 
@@ -4012,11 +4014,15 @@ object SessionWindowing extends Rule[LogicalPlan] {
         val sessionEnd = PreciseTimestampConversion(session.timeColumn + 
gapDuration,
           session.timeColumn.dataType, LongType)
 
+        // We make sure value fields are nullable since the dataType of 
SessionWindow defines them
+        // as nullable.
         val literalSessionStruct = CreateNamedStruct(
           Literal(SESSION_START) ::
-            PreciseTimestampConversion(sessionStart, LongType, 
session.timeColumn.dataType) ::
+            PreciseTimestampConversion(sessionStart, LongType, 
session.timeColumn.dataType)
+              .castNullable() ::
             Literal(SESSION_END) ::
-            PreciseTimestampConversion(sessionEnd, LongType, 
session.timeColumn.dataType) ::
+            PreciseTimestampConversion(sessionEnd, LongType, 
session.timeColumn.dataType)
+              .castNullable() ::
             Nil)
 
         val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index dda0d19..0988bef 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -138,6 +138,14 @@ package object dsl {
       }
     }
 
+    def castNullable(): Expression = {
+      if (expr.resolved && expr.nullable) {
+        expr
+      } else {
+        KnownNullable(expr)
+      }
+    }
+
     def asc: SortOrder = SortOrder(expr, Ascending)
     def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, 
Seq.empty)
     def desc: SortOrder = SortOrder(expr, Descending)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
index 8feaf52..75d9126 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala
@@ -30,6 +30,17 @@ trait TaggingExpression extends UnaryExpression {
   override def eval(input: InternalRow): Any = child.eval(input)
 }
 
+case class KnownNullable(child: Expression) extends TaggingExpression {
+  override def nullable: Boolean = true
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
+    child.genCode(ctx)
+  }
+
+  override protected def withNewChildInternal(newChild: Expression): 
KnownNullable =
+    copy(child = newChild)
+}
+
 case class KnownNotNull(child: Expression) extends TaggingExpression {
   override def nullable: Boolean = false
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
index b3d2127..076b64c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala
@@ -21,6 +21,7 @@ import java.time.LocalDateTime
 
 import org.scalatest.BeforeAndAfterEach
 
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand}
 import org.apache.spark.sql.functions._
@@ -406,4 +407,64 @@ class DataFrameSessionWindowingSuite extends QueryTest 
with SharedSparkSession
 
     checkAnswer(aggDF, Seq(Row("2016-03-27 19:39:25", "2016-03-27 19:39:40", 
2)))
   }
+
+  test("SPARK-38227: 'start' and 'end' fields should be nullable") {
+    // We expect the fields in window struct as nullable since the dataType of 
SessionWindow
+    // defines them as nullable. The rule 'SessionWindowing' should respect 
the dataType.
+    val df1 = Seq(
+      ("hello", "2016-03-27 09:00:05", 1),
+      ("structured", "2016-03-27 09:00:32", 2)).toDF("id", "time", "value")
+    val df2 = Seq(
+      ("world", LocalDateTime.parse("2016-03-27T09:00:05"), 1),
+      ("spark", LocalDateTime.parse("2016-03-27T09:00:32"), 2)).toDF("id", 
"time", "value")
+
+    val udf = spark.udf.register("gapDuration", (s: String) => {
+      if (s == "hello") {
+        "1 second"
+      } else if (s == "structured") {
+        // zero gap duration will be filtered out from aggregation
+        "0 second"
+      } else if (s == "world") {
+        // negative gap duration will be filtered out from aggregation
+        "-10 seconds"
+      } else {
+        "10 seconds"
+      }
+    })
+
+    def validateWindowColumnInSchema(schema: StructType, colName: String): 
Unit = {
+      schema.find(_.name == colName) match {
+        case Some(StructField(_, st: StructType, _, _)) =>
+          assertFieldInWindowStruct(st, "start")
+          assertFieldInWindowStruct(st, "end")
+
+        case _ => fail("Failed to find suitable window column from DataFrame!")
+      }
+    }
+
+    def assertFieldInWindowStruct(windowType: StructType, fieldName: String): 
Unit = {
+      val field = windowType.fields.find(_.name == fieldName)
+      assert(field.isDefined, s"'$fieldName' field should exist in window 
struct")
+      assert(field.get.nullable, s"'$fieldName' field should be nullable")
+    }
+
+    for {
+      df <- Seq(df1, df2)
+      nullable <- Seq(true, false)
+    } {
+      val dfWithDesiredNullability = new DataFrame(df.queryExecution, 
RowEncoder(
+        StructType(df.schema.fields.map(_.copy(nullable = nullable)))))
+      // session window without dynamic gap
+      val windowedProject = dfWithDesiredNullability
+        .select(session_window($"time", "10 seconds").as("session"), $"value")
+      val schema = windowedProject.queryExecution.optimizedPlan.schema
+      validateWindowColumnInSchema(schema, "session")
+
+      // session window with dynamic gap
+      val windowedProject2 = dfWithDesiredNullability
+        .select(session_window($"time", udf($"id")).as("session"), $"value")
+      val schema2 = windowedProject2.queryExecution.optimizedPlan.schema
+      validateWindowColumnInSchema(schema2, "session")
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
index e9a145c..bd39453 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import java.time.LocalDateTime
 
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, Filter}
 import org.apache.spark.sql.functions._
@@ -527,4 +528,51 @@ class DataFrameTimeWindowingSuite extends QueryTest with 
SharedSparkSession {
         "when windowDuration is multiple of slideDuration")
     }
   }
+
+  test("SPARK-38227: 'start' and 'end' fields should be nullable") {
+    // We expect the fields in window struct as nullable since the dataType of 
TimeWindow defines
+    // them as nullable. The rule 'TimeWindowing' should respect the dataType.
+    val df1 = Seq(
+      ("2016-03-27 09:00:05", 1),
+      ("2016-03-27 09:00:32", 2)).toDF("time", "value")
+    val df2 = Seq(
+      (LocalDateTime.parse("2016-03-27T09:00:05"), 1),
+      (LocalDateTime.parse("2016-03-27T09:00:32"), 2)).toDF("time", "value")
+
+    def validateWindowColumnInSchema(schema: StructType, colName: String): 
Unit = {
+      schema.find(_.name == colName) match {
+        case Some(StructField(_, st: StructType, _, _)) =>
+          assertFieldInWindowStruct(st, "start")
+          assertFieldInWindowStruct(st, "end")
+
+        case _ => fail("Failed to find suitable window column from DataFrame!")
+      }
+    }
+
+    def assertFieldInWindowStruct(windowType: StructType, fieldName: String): 
Unit = {
+      val field = windowType.fields.find(_.name == fieldName)
+      assert(field.isDefined, s"'$fieldName' field should exist in window 
struct")
+      assert(field.get.nullable, s"'$fieldName' field should be nullable")
+    }
+
+    for {
+      df <- Seq(df1, df2)
+      nullable <- Seq(true, false)
+    } {
+      val dfWithDesiredNullability = new DataFrame(df.queryExecution, 
RowEncoder(
+        StructType(df.schema.fields.map(_.copy(nullable = nullable)))))
+      // tumbling windows
+      val windowedProject = dfWithDesiredNullability
+        .select(window($"time", "10 seconds").as("window"), $"value")
+      val schema = windowedProject.queryExecution.optimizedPlan.schema
+      validateWindowColumnInSchema(schema, "window")
+
+      // sliding windows
+      val windowedProject2 = dfWithDesiredNullability
+        .select(window($"time", "10 seconds", "3 seconds").as("window"),
+        $"value")
+      val schema2 = windowedProject2.queryExecution.optimizedPlan.schema
+      validateWindowColumnInSchema(schema2, "window")
+    }
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to