bogao007 commented on code in PR #47933:
URL: https://github.com/apache/spark/pull/47933#discussion_r1769253275


##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -77,6 +78,58 @@ def clear(self) -> None:
         self._value_state_client.clear(self._state_name)
 
 
+class ListState:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
capture single value

Review Comment:
   Good catch, updated, thanks!



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -77,6 +78,58 @@ def clear(self) -> None:
         self._value_state_client.clear(self._state_name)
 
 
+class ListState:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
capture single value
+    state.
+
+    .. versionadded:: 4.0.0
+    """
+
+    def __init__(
+        self, list_state_client: ListStateClient, state_name: str, schema: 
Union[StructType, str]
+    ) -> None:
+        self._list_state_client = list_state_client
+        self._state_name = state_name
+        self.schema = schema
+
+    def exists(self) -> bool:
+        """
+        Whether list state exists or not.
+        """
+        return self._list_state_client.exists(self._state_name)
+
+    def get(self) -> Iterator[Row]:
+        """
+        Get list state with an iterator.
+        """
+        return ListStateIterator(self._list_state_client, self._state_name)
+
+    def put(self, new_state: List[Any]) -> None:

Review Comment:
   Makes sense. Should we use `Tuple` for all the cases here? I see we are 
doing so in `GroupState` 
https://github.com/apache/spark/blob/f76a9b1135e748649bdb9a2104360f0dc533cc1f/python/pyspark/sql/streaming/state.py#L95



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -201,13 +236,95 @@ class TransformWithStateInPandasStateServer(
         }
       case ValueStateCall.MethodCase.VALUESTATEUPDATE =>
         val byteArray = message.getValueStateUpdate.getValue.toByteArray
-        val valueStateTuple = valueStates(stateName)
         // The value row is serialized as a byte array, we need to convert it 
back to a Row
-        val valueRow = PythonSQLUtils.toJVMRow(byteArray, valueStateTuple._2, 
valueStateTuple._3)
-        valueStateTuple._1.update(valueRow)
+        val valueRow = PythonSQLUtils.toJVMRow(byteArray, 
valueStateInfo.schema,
+          valueStateInfo.deserializer)
+        valueStateInfo.valueState.update(valueRow)
         sendResponse(0)
       case ValueStateCall.MethodCase.CLEAR =>
-        valueStates(stateName)._1.clear()
+        valueStateInfo.valueState.clear()
+        sendResponse(0)
+      case _ =>
+        throw new IllegalArgumentException("Invalid method call")
+    }
+  }
+
+  private[sql] def handleListStateRequest(message: ListStateCall): Unit = {
+    val stateName = message.getStateName
+    if (!listStates.contains(stateName)) {
+      logWarning(log"List state ${MDC(LogKeys.STATE_NAME, stateName)} is not 
initialized.")
+      sendResponse(1, s"List state $stateName is not initialized.")
+      return
+    }
+    val listStateInfo = listStates(stateName)
+    val deserializer = if (deserializerForTest != null) {
+      deserializerForTest
+    } else {
+      new TransformWithStateInPandasDeserializer(listStateInfo.deserializer)
+    }
+    message.getMethodCase match {
+      case ListStateCall.MethodCase.EXISTS =>
+        if (listStateInfo.listState.exists()) {
+          sendResponse(0)
+        } else {
+          // Send status code 2 to indicate that the list state doesn't have a 
value yet.
+          sendResponse(2, s"state $stateName doesn't exist")
+        }
+      case ListStateCall.MethodCase.LISTSTATEPUT =>
+        val rows = deserializer.readArrowBatches(inputStream)
+        listStateInfo.listState.put(rows.toArray)
+        sendResponse(0)
+      case ListStateCall.MethodCase.LISTSTATEGET =>
+        val iteratorId = message.getListStateGet.getIteratorId
+        var iteratorOption = listStateIterators.get(iteratorId)
+        if (iteratorOption.isEmpty) {
+          iteratorOption = Some(listStateInfo.listState.get())
+          listStateIterators.put(iteratorId, iteratorOption.get)
+        }
+        if (!iteratorOption.get.hasNext) {
+          sendResponse(2, s"List state $stateName doesn't contain any value.")
+          return
+        } else {
+          sendResponse(0)
+        }
+        outputStream.flush()
+        val arrowStreamWriter = if (arrowStreamWriterForTest != null) {
+          arrowStreamWriterForTest
+        } else {
+          val arrowSchema = ArrowUtils.toArrowSchema(listStateInfo.schema, 
timeZoneId,
+            errorOnDuplicatedFieldNames, largeVarTypes)
+          val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+          s"stdout writer for transformWithStateInPandas state socket", 0, 
Long.MaxValue)
+          val root = VectorSchemaRoot.create(arrowSchema, allocator)
+          new BaseStreamingArrowWriter(root, new ArrowStreamWriter(root, null, 
outputStream),
+            arrowTransformWithStateInPandasMaxRecordsPerBatch)
+        }
+        val listRowSerializer = listStateInfo.serializer
+        // Only write a single batch in each GET request. Stops writing row if 
rowCount reaches
+        // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This 
is to avoid a case

Review Comment:
   Sorry about the misleading wording, I meant to handle/support such a case, 
updated.



##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -168,3 +207,43 @@ def _serialize_to_bytes(self, schema: StructType, data: 
Tuple) -> bytes:
 
     def _deserialize_from_bytes(self, value: bytes) -> Any:
         return self.pickleSer.loads(value)
+
+    def _send_arrow_state(self, schema: StructType, state: List[Tuple]) -> 
None:

Review Comment:
   We have something similar in `types.py`: `_from_numpy_type` 
https://github.com/apache/spark/blob/f76a9b1135e748649bdb9a2104360f0dc533cc1f/python/pyspark/sql/types.py#L2166
 which is opposite to the use case in this PR. Should we move these utility 
functions there?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -157,7 +185,11 @@ class TransformWithStateInPandasStateServer(
         val ttlDurationMs = if (message.getGetValueState.hasTtl) {
           Some(message.getGetValueState.getTtl.getDurationMs)
         } else None
-        initializeValueState(stateName, schema, ttlDurationMs)
+        initializeStateVariable(stateName, schema, "valueState", ttlDurationMs)
+      case StatefulProcessorCall.MethodCase.GETLISTSTATE =>
+        val stateName = message.getGetListState.getStateName
+        val schema = message.getGetListState.getSchema
+        initializeStateVariable(stateName, schema, "listState", None)

Review Comment:
   Added. Btw, I have the TTL change in my local as well, do you prefer to 
combine it together with the current PR? I think separate PRs makes more sense 
but let me know if you have other ideas.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -125,9 +149,13 @@ class TransformWithStateInPandasStateServer(
         // The key row is serialized as a byte array, we need to convert it 
back to a Row
         val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, 
keyRowDeserializer)
         ImplicitGroupingKeyTracker.setImplicitKey(keyRow)
+        // Reset the list state iterators for a new grouping key.
+        listStateIterators = new mutable.HashMap[String, Iterator[Row]]()

Review Comment:
   In unit tests, we only access the iterator map when calling 
`handleListStateRequest` which doesn't contain logic to reset the map. I think 
as long as we don't explicitly combine `ImplicitGroupingKeyRequest` and 
`ListStateRequest` together in unit test, we should be good. I added a note in 
the unit test.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to