This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new e718d0d4dc5 [SPARK-44464][SS] Implement applyInPandasWithState in 
PySpark
e718d0d4dc5 is described below

commit e718d0d4dc57f9e0ecdb7067ee9778250200fe83
Author: Siying Dong <siying.d...@databricks.com>
AuthorDate: Wed Jul 19 07:39:11 2023 +0900

    [SPARK-44464][SS] Implement applyInPandasWithState in PySpark
    
    ### What changes were proposed in this pull request?
    Change the serialization format for group-by-with-state outputs: include an 
explicit hidden column indicating how many data and state records there are.
    
    ### Why are the changes needed?
    The current implementation of ApplyInPandasWithStatePythonRunner cannot 
deal with outputs where the first column of the row is null, as it cannot 
distinguish the case where the column is null, or the field is filled as the 
number of data records are smaller than state records. It causes incorrect 
results for the former case.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Add unit tests that cover null cases and different other scenarios.
    
    Closes #42046 from siying/pypanda.
    
    Authored-by: Siying Dong <siying.d...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    (cherry picked from commit 6cebd97f4670e1e998b0064ddf4db11050fe52dd)
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 python/pyspark/sql/pandas/serializers.py           |  43 ++++++--
 .../test_parity_pandas_grouped_map_with_state.py   |  20 ++++
 .../pandas/test_pandas_grouped_map_with_state.py   | 114 ++++++++++++++++++---
 .../ApplyInPandasWithStatePythonRunner.scala       |  84 +++++++++------
 4 files changed, 208 insertions(+), 53 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index f835ea57b77..f22a73cbbef 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -27,7 +27,15 @@ from pyspark.sql.pandas.types import (
     _create_converter_from_pandas,
     _create_converter_to_pandas,
 )
-from pyspark.sql.types import DataType, StringType, StructType, BinaryType, 
StructField, LongType
+from pyspark.sql.types import (
+    DataType,
+    StringType,
+    StructType,
+    BinaryType,
+    StructField,
+    LongType,
+    IntegerType,
+)
 
 
 class SpecialLengths:
@@ -603,6 +611,15 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
         self.utf8_deserializer = UTF8Deserializer()
         self.state_object_schema = state_object_schema
 
+        self.result_count_df_type = StructType(
+            [
+                StructField("dataCount", IntegerType()),
+                StructField("stateCount", IntegerType()),
+            ]
+        )
+
+        self.result_count_pdf_arrow_type = 
to_arrow_type(self.result_count_df_type)
+
         self.result_state_df_type = StructType(
             [
                 StructField("properties", StringType()),
@@ -799,16 +816,26 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
         def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, 
state_data_cnt):
             """
             Construct a new Arrow RecordBatch based on output pandas 
DataFrames and states. Each
-            one matches to the single struct field for Arrow schema, hence the 
return value of
-            Arrow RecordBatch will have schema with two fields, in `data`, 
`state` order.
+            one matches to the single struct field for Arrow schema. We also 
need an extra one to
+            indicate array length for data and state, so the return value of 
Arrow RecordBatch will
+            have schema with three fields, in `count`, `data`, `state` order.
             (Readers are expected to access the field via position rather than 
the name. We do
             not guarantee the name of the field.)
 
             Note that Arrow RecordBatch requires all columns to have all same 
number of rows,
-            hence this function inserts empty data for state/data with less 
elements to compensate.
+            hence this function inserts empty data for count/state/data with 
less elements to
+            compensate.
             """
 
-            max_data_cnt = max(pdf_data_cnt, state_data_cnt)
+            max_data_cnt = max(1, max(pdf_data_cnt, state_data_cnt))
+
+            # We only use the first row in the count column, and fill other 
rows to be the same
+            # value, hoping it is more friendly for compression, in case it is 
needed.
+            count_dict = {
+                "dataCount": [pdf_data_cnt] * max_data_cnt,
+                "stateCount": [state_data_cnt] * max_data_cnt,
+            }
+            count_pdf = pd.DataFrame.from_dict(count_dict)
 
             empty_row_cnt_in_data = max_data_cnt - pdf_data_cnt
             empty_row_cnt_in_state = max_data_cnt - state_data_cnt
@@ -829,7 +856,11 @@ class 
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
             merged_state_pdf = pd.concat(state_pdfs, ignore_index=True)
 
             return self._create_batch(
-                [(merged_pdf, pdf_schema), (merged_state_pdf, 
self.result_state_pdf_arrow_type)]
+                [
+                    (count_pdf, self.result_count_pdf_arrow_type),
+                    (merged_pdf, pdf_schema),
+                    (merged_state_pdf, self.result_state_pdf_arrow_type),
+                ]
             )
 
         def serialize_batches():
diff --git 
a/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py 
b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py
index 3a38cd17406..dc3bdf28f81 100644
--- 
a/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py
+++ 
b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map_with_state.py
@@ -29,6 +29,26 @@ class GroupedApplyInPandasWithStateTests(
     def test_apply_in_pandas_with_state_basic(self):
         super().test_apply_in_pandas_with_state_basic()
 
+    @unittest.skip("foreachBatch will be supported in SPARK-42944.")
+    def test_apply_in_pandas_with_state_basic_no_state(self):
+        super().test_apply_in_pandas_with_state_basic()
+
+    @unittest.skip("foreachBatch will be supported in SPARK-42944.")
+    def test_apply_in_pandas_with_state_basic_no_state_no_data(self):
+        super().test_apply_in_pandas_with_state_basic()
+
+    @unittest.skip("foreachBatch will be supported in SPARK-42944.")
+    def test_apply_in_pandas_with_state_basic_more_data(self):
+        super().test_apply_in_pandas_with_state_basic()
+
+    @unittest.skip("foreachBatch will be supported in SPARK-42944.")
+    def test_apply_in_pandas_with_state_basic_fewer_data(self):
+        super().test_apply_in_pandas_with_state_basic()
+
+    @unittest.skip("foreachBatch will be supported in SPARK-42944.")
+    def test_apply_in_pandas_with_state_basic_with_null(self):
+        super().test_apply_in_pandas_with_state_basic()
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state 
import *  # noqa: F401
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
index a2a6544faa0..e1ec97928f7 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
@@ -60,7 +60,7 @@ class GroupedApplyInPandasWithStateTestsMixin:
         cfg.set("spark.sql.shuffle.partitions", "5")
         return cfg
 
-    def test_apply_in_pandas_with_state_basic(self):
+    def _test_apply_in_pandas_with_state_basic(self, func, check_results):
         input_path = tempfile.mkdtemp()
 
         def prepare_test_resource():
@@ -81,6 +81,22 @@ class GroupedApplyInPandasWithStateTestsMixin:
         )
         state_type = StructType([StructField("c", LongType())])
 
+        q = (
+            df.groupBy(df["value"])
+            .applyInPandasWithState(
+                func, output_type, state_type, "Update", 
GroupStateTimeout.NoTimeout
+            )
+            .writeStream.queryName("this_query")
+            .foreachBatch(check_results)
+            .outputMode("update")
+            .start()
+        )
+
+        self.assertEqual(q.name, "this_query")
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+
+    def test_apply_in_pandas_with_state_basic(self):
         def func(key, pdf_iter, state):
             assert isinstance(state, GroupState)
 
@@ -98,20 +114,92 @@ class GroupedApplyInPandasWithStateTestsMixin:
                 {Row(key="hello", countAsString="1"), Row(key="this", 
countAsString="1")},
             )
 
-        q = (
-            df.groupBy(df["value"])
-            .applyInPandasWithState(
-                func, output_type, state_type, "Update", 
GroupStateTimeout.NoTimeout
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
+
+    def test_apply_in_pandas_with_state_basic_no_state(self):
+        def func(key, pdf_iter, state):
+            assert isinstance(state, GroupState)
+            # 2 data rows
+            yield pd.DataFrame({"key": [key[0], "foo"], "countAsString": 
["100", "222"]})
+
+        def check_results(batch_df, _):
+            self.assertEqual(
+                set(batch_df.sort("key").collect()),
+                {
+                    Row(key="hello", countAsString="100"),
+                    Row(key="this", countAsString="100"),
+                    Row(key="foo", countAsString="222"),
+                },
             )
-            .writeStream.queryName("this_query")
-            .foreachBatch(check_results)
-            .outputMode("update")
-            .start()
-        )
 
-        self.assertEqual(q.name, "this_query")
-        self.assertTrue(q.isActive)
-        q.processAllAvailable()
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
+
+    def test_apply_in_pandas_with_state_basic_no_state_no_data(self):
+        def func(key, pdf_iter, state):
+            assert isinstance(state, GroupState)
+            # 2 data rows
+            yield pd.DataFrame({"key": [], "countAsString": []})
+
+        def check_results(batch_df, _):
+            self.assertTrue(len(set(batch_df.sort("key").collect())) == 0)
+
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
+
+    def test_apply_in_pandas_with_state_basic_more_data(self):
+        # Test data rows returned are more or fewer than state.
+        def func(key, pdf_iter, state):
+            state.update((1,))
+            assert isinstance(state, GroupState)
+            # 3 rows
+            yield pd.DataFrame(
+                {"key": [key[0], "foo", key[0] + "_2"], "countAsString": ["1", 
"666", "2"]}
+            )
+
+        def check_results(batch_df, _):
+            self.assertEqual(
+                set(batch_df.sort("key").collect()),
+                {
+                    Row(key="hello", countAsString="1"),
+                    Row(key="foo", countAsString="666"),
+                    Row(key="hello_2", countAsString="2"),
+                    Row(key="this", countAsString="1"),
+                    Row(key="this_2", countAsString="2"),
+                },
+            )
+
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
+
+    def test_apply_in_pandas_with_state_basic_fewer_data(self):
+        # Test data rows returned are more or fewer than state.
+        def func(key, pdf_iter, state):
+            state.update((1,))
+            assert isinstance(state, GroupState)
+            yield pd.DataFrame({"key": [], "countAsString": []})
+
+        def check_results(batch_df, _):
+            self.assertTrue(len(set(batch_df.sort("key").collect())) == 0)
+
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
+
+    def test_apply_in_pandas_with_state_basic_with_null(self):
+        def func(key, pdf_iter, state):
+            assert isinstance(state, GroupState)
+
+            total_len = 0
+            for pdf in pdf_iter:
+                total_len += len(pdf)
+
+            state.update((total_len,))
+            assert state.get[0] == 1
+            yield pd.DataFrame({"key": [None], "countAsString": 
[str(total_len)]})
+
+        def check_results(batch_df, _):
+            self.assertEqual(
+                set(batch_df.sort("key").collect()),
+                {Row(key=None, countAsString="1")},
+            )
+
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
 
     def test_apply_in_pandas_with_state_python_worker_random_failure(self):
         input_path = tempfile.mkdtemp()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index 9fc6ae04e94..d4c535fe76a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -34,7 +34,7 @@ import 
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.execution.metric.SQLMetric
-import 
org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType,
 OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import 
org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER,
 InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
 import 
org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
 import org.apache.spark.sql.execution.streaming.GroupStateImpl
 import org.apache.spark.sql.internal.SQLConf
@@ -155,7 +155,24 @@ class ApplyInPandasWithStatePythonRunner(
     // data and state metadata have same number of rows, which is required by 
Arrow record
     // batch.
     assert(batch.numRows() > 0)
-    assert(schema.length == 2)
+    assert(schema.length == 3)
+
+    def getValueFromCountColumn(batch: ColumnarBatch): (Int, Int) = {
+      //  UDF returns a StructType column in ColumnarBatch, select the 
children here
+      val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
+      val dataType = schema(0).dataType.asInstanceOf[StructType]
+      assert(
+        DataTypeUtils.sameType(dataType, 
COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER),
+        s"Schema equality check failure! type from Arrow: $dataType, " +
+        s"expected type: $COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER"
+      )
+
+      // NOTE: See 
ApplyInPandasWithStatePythonRunner.COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER
+      // for the schema.
+      val dataCount = structVector.getChild(0).getInt(0)
+      val stateCount = structVector.getChild(1).getInt(0)
+      (dataCount, stateCount)
+    }
 
     def getColumnarBatchForStructTypeColumn(
         batch: ColumnarBatch,
@@ -174,51 +191,43 @@ class ApplyInPandasWithStatePythonRunner(
       flattenedBatch
     }
 
-    def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = {
-      val dataBatch = getColumnarBatchForStructTypeColumn(batch, 0, 
outputSchema)
-      dataBatch.rowIterator.asScala.flatMap { row =>
-        if (row.isNullAt(0)) {
-          // The entire row in record batch seems to be for state metadata.
-          None
-        } else {
-          Some(row)
-        }
+
+    def constructIterForData(batch: ColumnarBatch, numRows: Int): 
Iterator[InternalRow] = {
+      val dataBatch = getColumnarBatchForStructTypeColumn(batch, 1, 
outputSchema)
+      dataBatch.rowIterator.asScala.take(numRows).flatMap { row =>
+        Some(row)
       }
     }
 
-    def constructIterForState(batch: ColumnarBatch): Iterator[OutTypeForState] 
= {
-      val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 1,
+    def constructIterForState(batch: ColumnarBatch, numRows: Int): 
Iterator[OutTypeForState] = {
+      val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 2,
         STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER)
 
-      stateMetadataBatch.rowIterator().asScala.flatMap { row =>
+      stateMetadataBatch.rowIterator().asScala.take(numRows).flatMap { row =>
         implicit val formats = org.json4s.DefaultFormats
 
-        if (row.isNullAt(0)) {
-          // The entire row in record batch seems to be for data.
+        // NOTE: See 
ApplyInPandasWithStatePythonRunner.STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER
+        // for the schema.
+        val propertiesAsJson = parse(row.getUTF8String(0).toString)
+        val keyRowAsUnsafeAsBinary = row.getBinary(1)
+        val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length)
+        keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, 
keyRowAsUnsafeAsBinary.length)
+        val maybeObjectRow = if (row.isNullAt(2)) {
           None
         } else {
-          // NOTE: See 
ApplyInPandasWithStatePythonRunner.STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER
-          // for the schema.
-          val propertiesAsJson = parse(row.getUTF8String(0).toString)
-          val keyRowAsUnsafeAsBinary = row.getBinary(1)
-          val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length)
-          keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, 
keyRowAsUnsafeAsBinary.length)
-          val maybeObjectRow = if (row.isNullAt(2)) {
-            None
-          } else {
-            val pickledStateValue = row.getBinary(2)
-            Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema,
-              stateRowDeserializer))
-          }
-          val oldTimeoutTimestamp = row.getLong(3)
-
-          Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, 
propertiesAsJson),
-            oldTimeoutTimestamp))
+          val pickledStateValue = row.getBinary(2)
+          Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema,
+            stateRowDeserializer))
         }
+        val oldTimeoutTimestamp = row.getLong(3)
+
+        Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, 
propertiesAsJson),
+          oldTimeoutTimestamp))
       }
     }
 
-    (constructIterForState(batch), constructIterForData(batch))
+    val (dataCount, stateCount) = getValueFromCountColumn(batch)
+    (constructIterForState(batch, stateCount), constructIterForData(batch, 
dataCount))
   }
 }
 
@@ -235,4 +244,11 @@ object ApplyInPandasWithStatePythonRunner {
       StructField("oldTimeoutTimestamp", LongType)
     )
   )
+  val COUNT_COLUMN_SCHEMA_FROM_PYTHON_WORKER: StructType = StructType(
+    Array(
+      StructField("dataCount", IntegerType),
+      StructField("stateCount", IntegerType)
+    )
+  )
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to