This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new edd6076699c [SPARK-40670][SS][PYTHON] Fix NPE in applyInPandasWithState when the input schema has "non-nullable" column(s) edd6076699c is described below commit edd6076699c36a94c1bc1b9ca853f05e55ba9f2c Author: Jungtaek Lim <kabhwan.opensou...@gmail.com> AuthorDate: Thu Oct 6 15:17:58 2022 +0900 [SPARK-40670][SS][PYTHON] Fix NPE in applyInPandasWithState when the input schema has "non-nullable" column(s) ### What changes were proposed in this pull request? This PR fixes a bug which occurs NPE when the input schema of applyInPandasWithState has "non-nullable" column(s). This PR also leaves a code comment explaining the fix. Quoting: ``` // See processTimedOutState: we create a row which contains the actual values for grouping key, // but all nulls for value side by intention. This technically changes the schema of input to // be "nullable", hence the schema information and the internal projection of row should take // this into consideration. Strictly saying, it's not applied to the part of grouping key, but // it doesn't hurt much even if we apply the same for grouping key as well. ``` ### Why are the changes needed? There's a bug which we didn't take the non-null columns into account. This PR fixes the bug. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New UT. The new test case failed with NPE without the fix, and succeeded with the fix. Closes #38115 from HeartSaVioR/SPARK-40670. Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../FlatMapGroupsInPandasWithStateExec.scala | 16 +++- .../FlatMapGroupsInPandasWithStateSuite.scala | 87 +++++++++++++++++++++- 2 files changed, 99 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 159f805f734..09123344c2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -83,7 +83,17 @@ case class FlatMapGroupsInPandasWithStateExec( private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets( groupingAttributes ++ child.output, groupingAttributes) - private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, child.output) + + // See processTimedOutState: we create a row which contains the actual values for grouping key, + // but all nulls for value side by intention. This technically changes the schema of input to + // be "nullable", hence the schema information and the internal projection of row should take + // this into consideration. Strictly saying, it's not applied to the part of grouping key, but + // it doesn't hurt much even if we apply the same for grouping key as well. + private lazy val dedupAttributesWithNull = + dedupAttributes.map(_.withNullability(newNullability = true)) + private lazy val childOutputWithNull = child.output.map(_.withNullability(newNullability = true)) + private lazy val unsafeProj = UnsafeProjection.create(dedupAttributesWithNull, + childOutputWithNull) override def requiredChildDistribution: Seq[Distribution] = StatefulOperatorPartitioning.getCompatibleDistribution( @@ -134,7 +144,7 @@ case class FlatMapGroupsInPandasWithStateExec( val joinedKeyRow = unsafeProj( new JoinedRow( stateData.keyRow, - new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any)))) + new GenericInternalRow(Array.fill(dedupAttributesWithNull.length)(null: Any)))) (stateData.keyRow, stateData, Iterator.single(joinedKeyRow)) } @@ -150,7 +160,7 @@ case class FlatMapGroupsInPandasWithStateExec( chainedFunc, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, Array(argOffsets), - StructType.fromAttributes(dedupAttributes), + StructType.fromAttributes(dedupAttributesWithNull), sessionLocalTimeZone, pythonRunnerConf, stateEncoder.asInstanceOf[ExpressionEncoder[Row]], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala index d8f7aeb5ac8..4d62ccd1423 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{NoTimeout, ProcessingTimeTim import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Complete, Update} import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.functions.timestamp_seconds +import org.apache.spark.sql.functions.{lit, timestamp_seconds} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types._ @@ -738,4 +738,89 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest { } } } + + test("SPARK-40670: applyInPandasWithState - streaming having non-null columns") { + // scalastyle:off assume + assume(shouldTestPandasUDFs) + // scalastyle:on assume + + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. + val pythonScript = + """ + |import pandas as pd + |from pyspark.sql.types import StructType, StructField, StringType + | + |tpe = StructType([ + | StructField("key1", StringType()), + | StructField("key2", StringType()), + | StructField("countAsStr", StringType())]) + | + |def func(key, pdf_iter, state): + | ret = None + | if state.hasTimedOut: + | state.remove() + | yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'countAsStr': [str(-1)]}) + | else: + | count = state.getOption + | if count is None: + | count = 0 + | else: + | count = count[0] + | + | for pdf in pdf_iter: + | count += len(pdf) + | + | state.update((count,)) + | state.setTimeoutDuration(10000) + | yield pd.DataFrame({'key1': [key[0]], 'key2': [key[1]], 'countAsStr': [str(count)]}) + |""".stripMargin + val pythonFunc = TestGroupedMapPandasUDFWithState( + name = "pandas_grouped_map_with_state", pythonScript = pythonScript) + + val clock = new StreamManualClock + val inputData = MemoryStream[String] + val inputDataDS = inputData.toDS + .withColumnRenamed("value", "key1") + // the type of columns with string literal will be non-nullable + .withColumn("key2", lit("__FAKE__")) + .withColumn("val1", lit("__FAKE__")) + .withColumn("val2", lit("__FAKE__")) + val outputStructType = StructType( + Seq( + StructField("key1", StringType), + StructField("key2", StringType), + StructField("countAsStr", StringType))) + val stateStructType = StructType(Seq(StructField("count", LongType))) + val result = + inputDataDS + .groupBy("key1", "key2") + .applyInPandasWithState( + pythonFunc( + inputDataDS("key1"), inputDataDS("key2"), inputDataDS("val1"), inputDataDS("val2") + ).expr.asInstanceOf[PythonUDF], + outputStructType, + stateStructType, + "Update", + "ProcessingTimeTimeout") + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "__FAKE__", "1")), + assertNumStateRows(total = 1, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("b", "__FAKE__", "1")), + assertNumStateRows(total = 2, updated = 1), + + AddData(inputData, "b"), + AdvanceManualClock(10 * 1000), + CheckNewAnswer(("a", "__FAKE__", "-1"), ("b", "__FAKE__", "2")), + assertNumStateRows( + total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1))) + ) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org