This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 59fcecb5a59 [SPARK-44464][SS] Fix applyInPandasWithStatePythonRunner 
to output rows that have Null as first column value
59fcecb5a59 is described below

commit 59fcecb5a59df54ecb3c675d4f3722fc72c1466e
Author: Siying Dong <[email protected]>
AuthorDate: Fri Jul 21 14:11:06 2023 +0900

    [SPARK-44464][SS] Fix applyInPandasWithStatePythonRunner to output rows 
that have Null as first column value
    
    Ports back #42046 to 3.4.
    
    ### What changes were proposed in this pull request?
    Change the serialization format for group-by-with-state outputs: include an 
explicit hidden column indicating how many data and state records there are.
    
    ### Why are the changes needed?
    The current implementation of ApplyInPandasWithStatePythonRunner cannot 
deal with outputs where the first column of the row is null, as it cannot 
distinguish the case where the column is null, or the field is filled as the 
number of data records are smaller than state records. It causes incorrect 
results for the former case.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Add unit tests that cover null cases and different other scenarios.
    
    Closes #42074 from siying/pandas34.
    
    Authored-by: Siying Dong <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/pandas/serializers.py           |  35 +++++--
 .../pandas/test_pandas_grouped_map_with_state.py   | 114 ++++++++++++++++++---
 .../ApplyInPandasWithStatePythonRunner.scala       |  84 +++++++++------
 3 files changed, 180 insertions(+), 53 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index ca249c75ea5..dd46dc85ab1 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -21,7 +21,7 @@ Serializers for PyArrow and pandas conversions. See 
`pyspark.serializers` for mo
 
 from pyspark.serializers import Serializer, read_int, write_int, 
UTF8Deserializer, CPickleSerializer
 from pyspark.sql.pandas.types import to_arrow_type
-from pyspark.sql.types import StringType, StructType, BinaryType, StructField, 
LongType
+from pyspark.sql.types import StringType, StructType, BinaryType, StructField, 
LongType, IntegerType
 
 
 class SpecialLengths:
@@ -408,6 +408,15 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
         self.utf8_deserializer = UTF8Deserializer()
         self.state_object_schema = state_object_schema
 
+        self.result_count_df_type = StructType(
+            [
+                StructField("dataCount", IntegerType()),
+                StructField("stateCount", IntegerType()),
+            ]
+        )
+
+        self.result_count_pdf_arrow_type = 
to_arrow_type(self.result_count_df_type)
+
         self.result_state_df_type = StructType(
             [
                 StructField("properties", StringType()),
@@ -604,16 +613,26 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
         def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, 
state_data_cnt):
             """
             Construct a new Arrow RecordBatch based on output pandas 
DataFrames and states. Each
-            one matches to the single struct field for Arrow schema, hence the 
return value of
-            Arrow RecordBatch will have schema with two fields, in `data`, 
`state` order.
+            one matches to the single struct field for Arrow schema. We also 
need an extra one to
+            indicate array length for data and state, so the return value of 
Arrow RecordBatch will
+            have schema with three fields, in `count`, `data`, `state` order.
             (Readers are expected to access the field via position rather than 
the name. We do
             not guarantee the name of the field.)
 
             Note that Arrow RecordBatch requires all columns to have all same 
number of rows,
-            hence this function inserts empty data for state/data with less 
elements to compensate.
+            hence this function inserts empty data for count/state/data with 
less elements to
+            compensate.
             """
 
-            max_data_cnt = max(pdf_data_cnt, state_data_cnt)
+            max_data_cnt = max(1, max(pdf_data_cnt, state_data_cnt))
+
+            # We only use the first row in the count column, and fill other 
rows to be the same
+            # value, hoping it is more friendly for compression, in case it is 
needed.
+            count_dict = {
+                "dataCount": [pdf_data_cnt] * max_data_cnt,
+                "stateCount": [state_data_cnt] * max_data_cnt,
+            }
+            count_pdf = pd.DataFrame.from_dict(count_dict)
 
             empty_row_cnt_in_data = max_data_cnt - pdf_data_cnt
             empty_row_cnt_in_state = max_data_cnt - state_data_cnt
@@ -634,7 +653,11 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
             merged_state_pdf = pd.concat(state_pdfs, ignore_index=True)
 
             return self._create_batch(
-                [(merged_pdf, pdf_schema), (merged_state_pdf, 
self.result_state_pdf_arrow_type)]
+                [
+                    (count_pdf, self.result_count_pdf_arrow_type),
+                    (merged_pdf, pdf_schema),
+                    (merged_state_pdf, self.result_state_pdf_arrow_type),
+                ]
             )
 
         def serialize_batches():
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
index 655f0bf151d..e4c2e229387 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
@@ -60,7 +60,7 @@ class GroupedMapInPandasWithStateTests(ReusedSQLTestCase):
         cfg.set("spark.sql.shuffle.partitions", "5")
         return cfg
 
-    def test_apply_in_pandas_with_state_basic(self):
+    def _test_apply_in_pandas_with_state_basic(self, func, check_results):
         input_path = tempfile.mkdtemp()
 
         def prepare_test_resource():
@@ -81,6 +81,22 @@ class GroupedMapInPandasWithStateTests(ReusedSQLTestCase):
         )
         state_type = StructType([StructField("c", LongType())])
 
+        q = (
+            df.groupBy(df["value"])
+            .applyInPandasWithState(
+                func, output_type, state_type, "Update", 
GroupStateTimeout.NoTimeout
+            )
+            .writeStream.queryName("this_query")
+            .foreachBatch(check_results)
+            .outputMode("update")
+            .start()
+        )
+
+        self.assertEqual(q.name, "this_query")
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+
+    def test_apply_in_pandas_with_state_basic(self):
         def func(key, pdf_iter, state):
             assert isinstance(state, GroupState)
 
@@ -98,20 +114,92 @@ class GroupedMapInPandasWithStateTests(ReusedSQLTestCase):
                 {Row(key="hello", countAsString="1"), Row(key="this", 
countAsString="1")},
             )
 
-        q = (
-            df.groupBy(df["value"])
-            .applyInPandasWithState(
-                func, output_type, state_type, "Update", 
GroupStateTimeout.NoTimeout
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
+
+    def test_apply_in_pandas_with_state_basic_no_state(self):
+        def func(key, pdf_iter, state):
+            assert isinstance(state, GroupState)
+            # 2 data rows
+            yield pd.DataFrame({"key": [key[0], "foo"], "countAsString": 
["100", "222"]})
+
+        def check_results(batch_df, _):
+            self.assertEqual(
+                set(batch_df.sort("key").collect()),
+                {
+                    Row(key="hello", countAsString="100"),
+                    Row(key="this", countAsString="100"),
+                    Row(key="foo", countAsString="222"),
+                },
             )
-            .writeStream.queryName("this_query")
-            .foreachBatch(check_results)
-            .outputMode("update")
-            .start()
-        )
 
-        self.assertEqual(q.name, "this_query")
-        self.assertTrue(q.isActive)
-        q.processAllAvailable()
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
+
+    def test_apply_in_pandas_with_state_basic_no_state_no_data(self):
+        def func(key, pdf_iter, state):
+            assert isinstance(state, GroupState)
+            # 2 data rows
+            yield pd.DataFrame({"key": [], "countAsString": []})
+
+        def check_results(batch_df, _):
+            self.assertTrue(len(set(batch_df.sort("key").collect())) == 0)
+
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
+
+    def test_apply_in_pandas_with_state_basic_more_data(self):
+        # Test data rows returned are more or fewer than state.
+        def func(key, pdf_iter, state):
+            state.update((1,))
+            assert isinstance(state, GroupState)
+            # 3 rows
+            yield pd.DataFrame(
+                {"key": [key[0], "foo", key[0] + "_2"], "countAsString": ["1", 
"666", "2"]}
+            )
+
+        def check_results(batch_df, _):
+            self.assertEqual(
+                set(batch_df.sort("key").collect()),
+                {
+                    Row(key="hello", countAsString="1"),
+                    Row(key="foo", countAsString="666"),
+                    Row(key="hello_2", countAsString="2"),
+                    Row(key="this", countAsString="1"),
+                    Row(key="this_2", countAsString="2"),
+                },
+            )
+
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
+
+    def test_apply_in_pandas_with_state_basic_fewer_data(self):
+        # Test data rows returned are more or fewer than state.
+        def func(key, pdf_iter, state):
+            state.update((1,))
+            assert isinstance(state, GroupState)
+            yield pd.DataFrame({"key": [], "countAsString": []})
+
+        def check_results(batch_df, _):
+            self.assertTrue(len(set(batch_df.sort("key").collect())) == 0)
+
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
+
+    def test_apply_in_pandas_with_state_basic_with_null(self):
+        def func(key, pdf_iter, state):
+            assert isinstance(state, GroupState)
+
+            total_len = 0
+            for pdf in pdf_iter:
+                total_len += len(pdf)
+
+            state.update((total_len,))
+            assert state.get[0] == 1
+            yield pd.DataFrame({"key": [None], "countAsString": 
[str(total_len)]})
+
+        def check_results(batch_df, _):
+            self.assertEqual(
+                set(batch_df.sort("key").collect()),
+                {Row(key=None, countAsString="1")},
+            )
+
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
 
     def test_apply_in_pandas_with_state_python_worker_random_failure(self):
         input_path = tempfile.mkdtemp()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index f3531668c8e..ea1c64c0919 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.execution.metric.SQLMetric
-import 
org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType,
 OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import 
org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER,
 InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
 import 
org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
 import org.apache.spark.sql.execution.streaming.GroupStateImpl
 import org.apache.spark.sql.internal.SQLConf
@@ -145,7 +145,24 @@ class ApplyInPandasWithStatePythonRunner(
     // data and state metadata have same number of rows, which is required by 
Arrow record
     // batch.
     assert(batch.numRows() > 0)
-    assert(schema.length == 2)
+    assert(schema.length == 3)
+
+    def getValueFromCountColumn(batch: ColumnarBatch): (Int, Int) = {
+      //  UDF returns a StructType column in ColumnarBatch, select the 
children here
+      val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
+      val dataType = schema(0).dataType.asInstanceOf[StructType]
+      assert(
+        dataType.sameType(COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER),
+        s"Schema equality check failure! type from Arrow: $dataType, " +
+        s"expected type: $COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER"
+      )
+
+      // NOTE: See 
ApplyInPandasWithStatePythonRunner.COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER
+      // for the schema.
+      val dataCount = structVector.getChild(0).getInt(0)
+      val stateCount = structVector.getChild(1).getInt(0)
+      (dataCount, stateCount)
+    }
 
     def getColumnarBatchForStructTypeColumn(
         batch: ColumnarBatch,
@@ -164,51 +181,43 @@ class ApplyInPandasWithStatePythonRunner(
       flattenedBatch
     }
 
-    def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = {
-      val dataBatch = getColumnarBatchForStructTypeColumn(batch, 0, 
outputSchema)
-      dataBatch.rowIterator.asScala.flatMap { row =>
-        if (row.isNullAt(0)) {
-          // The entire row in record batch seems to be for state metadata.
-          None
-        } else {
-          Some(row)
-        }
+
+    def constructIterForData(batch: ColumnarBatch, numRows: Int): 
Iterator[InternalRow] = {
+      val dataBatch = getColumnarBatchForStructTypeColumn(batch, 1, 
outputSchema)
+      dataBatch.rowIterator.asScala.take(numRows).flatMap { row =>
+        Some(row)
       }
     }
 
-    def constructIterForState(batch: ColumnarBatch): Iterator[OutTypeForState] 
= {
-      val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 1,
+    def constructIterForState(batch: ColumnarBatch, numRows: Int): 
Iterator[OutTypeForState] = {
+      val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 2,
         STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER)
 
-      stateMetadataBatch.rowIterator().asScala.flatMap { row =>
+      stateMetadataBatch.rowIterator().asScala.take(numRows).flatMap { row =>
         implicit val formats = org.json4s.DefaultFormats
 
-        if (row.isNullAt(0)) {
-          // The entire row in record batch seems to be for data.
+        // NOTE: See 
ApplyInPandasWithStatePythonRunner.STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER
+        // for the schema.
+        val propertiesAsJson = parse(row.getUTF8String(0).toString)
+        val keyRowAsUnsafeAsBinary = row.getBinary(1)
+        val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length)
+        keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, 
keyRowAsUnsafeAsBinary.length)
+        val maybeObjectRow = if (row.isNullAt(2)) {
           None
         } else {
-          // NOTE: See 
ApplyInPandasWithStatePythonRunner.STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER
-          // for the schema.
-          val propertiesAsJson = parse(row.getUTF8String(0).toString)
-          val keyRowAsUnsafeAsBinary = row.getBinary(1)
-          val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length)
-          keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, 
keyRowAsUnsafeAsBinary.length)
-          val maybeObjectRow = if (row.isNullAt(2)) {
-            None
-          } else {
-            val pickledStateValue = row.getBinary(2)
-            Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema,
-              stateRowDeserializer))
-          }
-          val oldTimeoutTimestamp = row.getLong(3)
-
-          Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, 
propertiesAsJson),
-            oldTimeoutTimestamp))
+          val pickledStateValue = row.getBinary(2)
+          Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema,
+            stateRowDeserializer))
         }
+        val oldTimeoutTimestamp = row.getLong(3)
+
+        Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, 
propertiesAsJson),
+          oldTimeoutTimestamp))
       }
     }
 
-    (constructIterForState(batch), constructIterForData(batch))
+    val (dataCount, stateCount) = getValueFromCountColumn(batch)
+    (constructIterForState(batch, stateCount), constructIterForData(batch, 
dataCount))
   }
 }
 
@@ -225,4 +234,11 @@ object ApplyInPandasWithStatePythonRunner {
       StructField("oldTimeoutTimestamp", LongType)
     )
   )
+  val COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER: StructType = StructType(
+    Array(
+      StructField("dataCount", IntegerType),
+      StructField("stateCount", IntegerType)
+    )
+  )
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to