This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new edd6076699c [SPARK-40670][SS][PYTHON] Fix NPE in 
applyInPandasWithState when the input schema has "non-nullable" column(s)
edd6076699c is described below

commit edd6076699c36a94c1bc1b9ca853f05e55ba9f2c
Author: Jungtaek Lim <kabhwan.opensou...@gmail.com>
AuthorDate: Thu Oct 6 15:17:58 2022 +0900

    [SPARK-40670][SS][PYTHON] Fix NPE in applyInPandasWithState when the input 
schema has "non-nullable" column(s)
    
    ### What changes were proposed in this pull request?
    
    This PR fixes a bug which occurs NPE when the input schema of 
applyInPandasWithState has "non-nullable" column(s).
    This PR also leaves a code comment explaining the fix. Quoting:
    
    ```
      // See processTimedOutState: we create a row which contains the actual 
values for grouping key,
      // but all nulls for value side by intention. This technically changes 
the schema of input to
      // be "nullable", hence the schema information and the internal 
projection of row should take
      // this into consideration. Strictly saying, it's not applied to the part 
of grouping key, but
      // it doesn't hurt much even if we apply the same for grouping key as 
well.
    ```
    
    ### Why are the changes needed?
    
    There's a bug which we didn't take the non-null columns into account. This 
PR fixes the bug.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New UT. The new test case failed with NPE without the fix, and succeeded 
with the fix.
    
    Closes #38115 from HeartSaVioR/SPARK-40670.
    
    Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../FlatMapGroupsInPandasWithStateExec.scala       | 16 +++-
 .../FlatMapGroupsInPandasWithStateSuite.scala      | 87 +++++++++++++++++++++-
 2 files changed, 99 insertions(+), 4 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
index 159f805f734..09123344c2e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
@@ -83,7 +83,17 @@ case class FlatMapGroupsInPandasWithStateExec(
   private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
   private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets(
     groupingAttributes ++ child.output, groupingAttributes)
-  private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, 
child.output)
+
+  // See processTimedOutState: we create a row which contains the actual 
values for grouping key,
+  // but all nulls for value side by intention. This technically changes the 
schema of input to
+  // be "nullable", hence the schema information and the internal projection 
of row should take
+  // this into consideration. Strictly saying, it's not applied to the part of 
grouping key, but
+  // it doesn't hurt much even if we apply the same for grouping key as well.
+  private lazy val dedupAttributesWithNull =
+    dedupAttributes.map(_.withNullability(newNullability = true))
+  private lazy val childOutputWithNull = 
child.output.map(_.withNullability(newNullability = true))
+  private lazy val unsafeProj = 
UnsafeProjection.create(dedupAttributesWithNull,
+    childOutputWithNull)
 
   override def requiredChildDistribution: Seq[Distribution] =
     StatefulOperatorPartitioning.getCompatibleDistribution(
@@ -134,7 +144,7 @@ case class FlatMapGroupsInPandasWithStateExec(
           val joinedKeyRow = unsafeProj(
             new JoinedRow(
               stateData.keyRow,
-              new GenericInternalRow(Array.fill(dedupAttributes.length)(null: 
Any))))
+              new 
GenericInternalRow(Array.fill(dedupAttributesWithNull.length)(null: Any))))
 
           (stateData.keyRow, stateData, Iterator.single(joinedKeyRow))
         }
@@ -150,7 +160,7 @@ case class FlatMapGroupsInPandasWithStateExec(
         chainedFunc,
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
         Array(argOffsets),
-        StructType.fromAttributes(dedupAttributes),
+        StructType.fromAttributes(dedupAttributesWithNull),
         sessionLocalTimeZone,
         pythonRunnerConf,
         stateEncoder.asInstanceOf[ExpressionEncoder[Row]],
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
index d8f7aeb5ac8..4d62ccd1423 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
@@ -24,7 +24,7 @@ import 
org.apache.spark.sql.catalyst.plans.logical.{NoTimeout, ProcessingTimeTim
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Complete, 
Update}
 import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
 import org.apache.spark.sql.execution.streaming.MemoryStream
-import org.apache.spark.sql.functions.timestamp_seconds
+import org.apache.spark.sql.functions.{lit, timestamp_seconds}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.util.StreamManualClock
 import org.apache.spark.sql.types._
@@ -738,4 +738,89 @@ class FlatMapGroupsInPandasWithStateSuite extends 
StateStoreMetricsTest {
       }
     }
   }
+
+  test("SPARK-40670: applyInPandasWithState - streaming having non-null 
columns") {
+    // scalastyle:off assume
+    assume(shouldTestPandasUDFs)
+    // scalastyle:on assume
+
+    // Function to maintain the count as state and set the proc. time timeout 
delay of 10 seconds.
+    // It returns the count if changed, or -1 if the state was removed by 
timeout.
+    val pythonScript =
+    """
+      |import pandas as pd
+      |from pyspark.sql.types import StructType, StructField, StringType
+      |
+      |tpe = StructType([
+      |    StructField("key1", StringType()),
+      |    StructField("key2", StringType()),
+      |    StructField("countAsStr", StringType())])
+      |
+      |def func(key, pdf_iter, state):
+      |    ret = None
+      |    if state.hasTimedOut:
+      |        state.remove()
+      |        yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 
'countAsStr': [str(-1)]})
+      |    else:
+      |        count = state.getOption
+      |        if count is None:
+      |            count = 0
+      |        else:
+      |            count = count[0]
+      |
+      |        for pdf in pdf_iter:
+      |            count += len(pdf)
+      |
+      |        state.update((count,))
+      |        state.setTimeoutDuration(10000)
+      |        yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 
'countAsStr': [str(count)]})
+      |""".stripMargin
+    val pythonFunc = TestGroupedMapPandasUDFWithState(
+      name = "pandas_grouped_map_with_state", pythonScript = pythonScript)
+
+    val clock = new StreamManualClock
+    val inputData = MemoryStream[String]
+    val inputDataDS = inputData.toDS
+      .withColumnRenamed("value", "key1")
+      // the type of columns with string literal will be non-nullable
+      .withColumn("key2", lit("__FAKE__"))
+      .withColumn("val1", lit("__FAKE__"))
+      .withColumn("val2", lit("__FAKE__"))
+    val outputStructType = StructType(
+      Seq(
+        StructField("key1", StringType),
+        StructField("key2", StringType),
+        StructField("countAsStr", StringType)))
+    val stateStructType = StructType(Seq(StructField("count", LongType)))
+    val result =
+      inputDataDS
+        .groupBy("key1", "key2")
+        .applyInPandasWithState(
+          pythonFunc(
+            inputDataDS("key1"), inputDataDS("key2"), inputDataDS("val1"), 
inputDataDS("val2")
+          ).expr.asInstanceOf[PythonUDF],
+          outputStructType,
+          stateStructType,
+          "Update",
+          "ProcessingTimeTimeout")
+
+    testStream(result, Update)(
+      StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
+      AddData(inputData, "a"),
+      AdvanceManualClock(1 * 1000),
+      CheckNewAnswer(("a", "__FAKE__", "1")),
+      assertNumStateRows(total = 1, updated = 1),
+
+      AddData(inputData, "b"),
+      AdvanceManualClock(1 * 1000),
+      CheckNewAnswer(("b", "__FAKE__", "1")),
+      assertNumStateRows(total = 2, updated = 1),
+
+      AddData(inputData, "b"),
+      AdvanceManualClock(10 * 1000),
+      CheckNewAnswer(("a", "__FAKE__", "-1"), ("b", "__FAKE__", "2")),
+      assertNumStateRows(
+        total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed 
= Some(Seq(1)))
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to