HeartSaVioR commented on code in PR #47933:
URL: https://github.com/apache/spark/pull/47933#discussion_r1768129788


##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -77,6 +78,58 @@ def clear(self) -> None:
         self._value_state_client.clear(self._state_name)
 
 
+class ListState:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
capture single value
+    state.
+
+    .. versionadded:: 4.0.0
+    """
+
+    def __init__(
+        self, list_state_client: ListStateClient, state_name: str, schema: 
Union[StructType, str]
+    ) -> None:
+        self._list_state_client = list_state_client
+        self._state_name = state_name
+        self.schema = schema
+
+    def exists(self) -> bool:
+        """
+        Whether list state exists or not.
+        """
+        return self._list_state_client.exists(self._state_name)
+
+    def get(self) -> Iterator[Row]:
+        """
+        Get list state with an iterator.
+        """
+        return ListStateIterator(self._list_state_client, self._state_name)
+
+    def put(self, new_state: List[Any]) -> None:

Review Comment:
   I understand the type is just a hint, but I wonder the reason we use `Any` 
and `Row` here. Even further we use `Tuple` in other side. What is the actual 
type of this? Do we allow multiple types here?



##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -168,3 +207,43 @@ def _serialize_to_bytes(self, schema: StructType, data: 
Tuple) -> bytes:
 
     def _deserialize_from_bytes(self, value: bytes) -> Any:
         return self.pickleSer.loads(value)
+
+    def _send_arrow_state(self, schema: StructType, state: List[Tuple]) -> 
None:

Review Comment:
   Is there any utility class/function we can reuse in PySpark? I feel like 
this is the functionality we should have done before to support Pandas API on 
Spark. Among `_send_arrow_state` and `_convert`, `_to_numpy_type`.
   
   cc. @HyukjinKwon 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -125,9 +149,13 @@ class TransformWithStateInPandasStateServer(
         // The key row is serialized as a byte array, we need to convert it 
back to a Row
         val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, 
keyRowDeserializer)
         ImplicitGroupingKeyTracker.setImplicitKey(keyRow)
+        // Reset the list state iterators for a new grouping key.
+        listStateIterators = new mutable.HashMap[String, Iterator[Row]]()

Review Comment:
   How do we handle the case of given param `listStateIteratorMapForTest`? Are 
we assuming the test should only deal with a single grouping key? If then let's 
explicitly mention in somewhere, e.g. class doc.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala:
##########
@@ -127,24 +151,79 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
   }
 
   test("value state clear") {
-    val message = ValueStateCall.newBuilder().setStateName(valueStateName)
+    val message = ValueStateCall.newBuilder().setStateName(stateName)
       .setClear(Clear.newBuilder().build()).build()
     stateServer.handleValueStateRequest(message)
     verify(valueState).clear()
     verify(outputStream).writeInt(0)
   }
 
   test("value state update") {
-    // Below byte array is a serialized row with a single integer value 1.
-    val byteArray: Array[Byte] = Array(0x80.toByte, 0x05.toByte, 0x95.toByte, 
0x05.toByte,
-      0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 
0x00.toByte, 0x00.toByte,
-      'K'.toByte, 0x01.toByte, 0x85.toByte, 0x94.toByte, '.'.toByte
-    )
     val byteString: ByteString = ByteString.copyFrom(byteArray)
-    val message = ValueStateCall.newBuilder().setStateName(valueStateName)
+    val message = ValueStateCall.newBuilder().setStateName(stateName)
       
.setValueStateUpdate(ValueStateUpdate.newBuilder().setValue(byteString).build()).build()
     stateServer.handleValueStateRequest(message)
     verify(valueState).update(any[Row])
     verify(outputStream).writeInt(0)
   }
+
+  test("list state exists") {
+    val message = ListStateCall.newBuilder().setStateName(stateName)
+      .setExists(Exists.newBuilder().build()).build()
+    stateServer.handleListStateRequest(message)
+    verify(listState).exists()
+  }
+
+  test("list state get - iterator in map") {
+    val message = ListStateCall.newBuilder().setStateName(stateName)
+      
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
+    stateServer.handleListStateRequest(message)
+    verify(listState, times(0)).get()
+    verify(arrowStreamWriter).writeRow(any)
+    verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+  }
+
+  test("list state get - iterator not in map") {
+    val maxRecordsPerBatch = 2
+    val message = ListStateCall.newBuilder().setStateName(stateName)
+      
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
+    val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
+    stateServer = new TransformWithStateInPandasStateServer(serverSocket,
+      statefulProcessorHandle, groupingKeySchema, "", false, false,
+      maxRecordsPerBatch, outputStream, valueStateMap,
+      transformWithStateInPandasDeserializer, arrowStreamWriter, listStateMap, 
iteratorMap)
+    when(listState.get()).thenReturn(Iterator(new 
GenericRowWithSchema(Array(1), stateSchema),
+      new GenericRowWithSchema(Array(2), stateSchema),
+      new GenericRowWithSchema(Array(3), stateSchema)))
+    stateServer.handleListStateRequest(message)
+    verify(listState).get()
+    // Verify that only maxRecordsPerBatch (2) rows are written to the output 
stream while still
+    // having 1 row left in the iterator.
+    verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any)
+    verify(arrowStreamWriter).finalizeCurrentArrowBatch()

Review Comment:
   Would we like to also test that the remaining part will be presented if we 
request the same iterator ID? Probably good to have a separate test to fully 
consume the iterator which length is multiple of maxRecordsPerBatch.



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -77,6 +78,58 @@ def clear(self) -> None:
         self._value_state_client.clear(self._state_name)
 
 
+class ListState:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
capture single value

Review Comment:
   nit: It's not for single value state. Probably could be an issue of copy & 
paste without fix.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -157,7 +185,11 @@ class TransformWithStateInPandasStateServer(
         val ttlDurationMs = if (message.getGetValueState.hasTtl) {
           Some(message.getGetValueState.getTtl.getDurationMs)
         } else None
-        initializeValueState(stateName, schema, ttlDurationMs)
+        initializeStateVariable(stateName, schema, "valueState", ttlDurationMs)
+      case StatefulProcessorCall.MethodCase.GETLISTSTATE =>
+        val stateName = message.getGetListState.getStateName
+        val schema = message.getGetListState.getSchema
+        initializeStateVariable(stateName, schema, "listState", None)

Review Comment:
   Is it TODO to support TTL for ListState ? Probably better to have a JIRA 
ticket.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -232,23 +349,53 @@ class TransformWithStateInPandasStateServer(
     outputStream.write(responseMessageBytes)
   }
 
-  private def initializeValueState(
-      stateName: String,
-      schemaString: String,
-      ttlDurationMs: Option[Int]): Unit = {
-    if (!valueStates.contains(stateName)) {
-      val schema = StructType.fromString(schemaString)
-      val state = if (ttlDurationMs.isEmpty) {
-        statefulProcessorHandle.getValueState[Row](stateName, 
Encoders.row(schema))
-      } else {
-        statefulProcessorHandle.getValueState(
-          stateName, Encoders.row(schema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
-      }
-      val valueRowDeserializer = 
ExpressionEncoder(schema).resolveAndBind().createDeserializer()
-      valueStates.put(stateName, (state, schema, valueRowDeserializer))
-      sendResponse(0)
-    } else {
-      sendResponse(1, s"state $stateName already exists")
+  private def initializeStateVariable(
+    stateName: String,
+    schemaString: String,
+    stateVariable: String,

Review Comment:
   nit: state`Type`? And would you mind using enum for type?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -201,13 +236,95 @@ class TransformWithStateInPandasStateServer(
         }
       case ValueStateCall.MethodCase.VALUESTATEUPDATE =>
         val byteArray = message.getValueStateUpdate.getValue.toByteArray
-        val valueStateTuple = valueStates(stateName)
         // The value row is serialized as a byte array, we need to convert it 
back to a Row
-        val valueRow = PythonSQLUtils.toJVMRow(byteArray, valueStateTuple._2, 
valueStateTuple._3)
-        valueStateTuple._1.update(valueRow)
+        val valueRow = PythonSQLUtils.toJVMRow(byteArray, 
valueStateInfo.schema,
+          valueStateInfo.deserializer)
+        valueStateInfo.valueState.update(valueRow)
         sendResponse(0)
       case ValueStateCall.MethodCase.CLEAR =>
-        valueStates(stateName)._1.clear()
+        valueStateInfo.valueState.clear()
+        sendResponse(0)
+      case _ =>
+        throw new IllegalArgumentException("Invalid method call")
+    }
+  }
+
+  private[sql] def handleListStateRequest(message: ListStateCall): Unit = {
+    val stateName = message.getStateName
+    if (!listStates.contains(stateName)) {
+      logWarning(log"List state ${MDC(LogKeys.STATE_NAME, stateName)} is not 
initialized.")
+      sendResponse(1, s"List state $stateName is not initialized.")
+      return
+    }
+    val listStateInfo = listStates(stateName)
+    val deserializer = if (deserializerForTest != null) {
+      deserializerForTest
+    } else {
+      new TransformWithStateInPandasDeserializer(listStateInfo.deserializer)
+    }
+    message.getMethodCase match {
+      case ListStateCall.MethodCase.EXISTS =>
+        if (listStateInfo.listState.exists()) {
+          sendResponse(0)
+        } else {
+          // Send status code 2 to indicate that the list state doesn't have a 
value yet.
+          sendResponse(2, s"state $stateName doesn't exist")
+        }
+      case ListStateCall.MethodCase.LISTSTATEPUT =>
+        val rows = deserializer.readArrowBatches(inputStream)
+        listStateInfo.listState.put(rows.toArray)
+        sendResponse(0)
+      case ListStateCall.MethodCase.LISTSTATEGET =>
+        val iteratorId = message.getListStateGet.getIteratorId
+        var iteratorOption = listStateIterators.get(iteratorId)
+        if (iteratorOption.isEmpty) {
+          iteratorOption = Some(listStateInfo.listState.get())
+          listStateIterators.put(iteratorId, iteratorOption.get)
+        }
+        if (!iteratorOption.get.hasNext) {
+          sendResponse(2, s"List state $stateName doesn't contain any value.")
+          return
+        } else {
+          sendResponse(0)
+        }
+        outputStream.flush()
+        val arrowStreamWriter = if (arrowStreamWriterForTest != null) {
+          arrowStreamWriterForTest
+        } else {
+          val arrowSchema = ArrowUtils.toArrowSchema(listStateInfo.schema, 
timeZoneId,
+            errorOnDuplicatedFieldNames, largeVarTypes)
+          val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+          s"stdout writer for transformWithStateInPandas state socket", 0, 
Long.MaxValue)
+          val root = VectorSchemaRoot.create(arrowSchema, allocator)
+          new BaseStreamingArrowWriter(root, new ArrowStreamWriter(root, null, 
outputStream),
+            arrowTransformWithStateInPandasMaxRecordsPerBatch)
+        }
+        val listRowSerializer = listStateInfo.serializer
+        // Only write a single batch in each GET request. Stops writing row if 
rowCount reaches
+        // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This 
is to avoid a case

Review Comment:
   Is this to avoid the case, or to support the case? From the code 
implementation I feel like it's to support the case. Otherwise do we have the 
logic to block such a request?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -232,23 +349,53 @@ class TransformWithStateInPandasStateServer(
     outputStream.write(responseMessageBytes)
   }
 
-  private def initializeValueState(
-      stateName: String,
-      schemaString: String,
-      ttlDurationMs: Option[Int]): Unit = {
-    if (!valueStates.contains(stateName)) {
-      val schema = StructType.fromString(schemaString)
-      val state = if (ttlDurationMs.isEmpty) {
-        statefulProcessorHandle.getValueState[Row](stateName, 
Encoders.row(schema))
-      } else {
-        statefulProcessorHandle.getValueState(
-          stateName, Encoders.row(schema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
-      }
-      val valueRowDeserializer = 
ExpressionEncoder(schema).resolveAndBind().createDeserializer()
-      valueStates.put(stateName, (state, schema, valueRowDeserializer))
-      sendResponse(0)
-    } else {
-      sendResponse(1, s"state $stateName already exists")
+  private def initializeStateVariable(
+    stateName: String,

Review Comment:
   nit: two more spaces for params



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -232,23 +349,53 @@ class TransformWithStateInPandasStateServer(
     outputStream.write(responseMessageBytes)
   }
 
-  private def initializeValueState(
-      stateName: String,
-      schemaString: String,
-      ttlDurationMs: Option[Int]): Unit = {
-    if (!valueStates.contains(stateName)) {
-      val schema = StructType.fromString(schemaString)
-      val state = if (ttlDurationMs.isEmpty) {
-        statefulProcessorHandle.getValueState[Row](stateName, 
Encoders.row(schema))
-      } else {
-        statefulProcessorHandle.getValueState(
-          stateName, Encoders.row(schema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
-      }
-      val valueRowDeserializer = 
ExpressionEncoder(schema).resolveAndBind().createDeserializer()
-      valueStates.put(stateName, (state, schema, valueRowDeserializer))
-      sendResponse(0)
-    } else {
-      sendResponse(1, s"state $stateName already exists")
+  private def initializeStateVariable(
+    stateName: String,
+    schemaString: String,
+    stateVariable: String,
+    ttlDurationMs: Option[Int]): Unit = {
+    val schema = StructType.fromString(schemaString)
+    val expressionEncoder = ExpressionEncoder(schema).resolveAndBind()
+    stateVariable match {
+        case "valueState" => if (!valueStates.contains(stateName)) {
+          val state = if (ttlDurationMs.isEmpty) {
+            statefulProcessorHandle.getValueState[Row](stateName, 
Encoders.row(schema))
+          } else {
+            statefulProcessorHandle.getValueState(
+              stateName, Encoders.row(schema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
+          }
+          valueStates.put(stateName,
+            ValueStateInfo(state, schema, 
expressionEncoder.createDeserializer()))
+          sendResponse(0)
+        } else {
+          sendResponse(1, s"Value state $stateName already exists")
+        }
+        case "listState" => if (!listStates.contains(stateName)) {

Review Comment:
   nit: let's either assert ttlDurationMs is empty or leave code comment for 
TODO



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to