HeartSaVioR commented on code in PR #47933:
URL: https://github.com/apache/spark/pull/47933#discussion_r1768129788
##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -77,6 +78,58 @@ def clear(self) -> None:
self._value_state_client.clear(self._state_name)
+class ListState:
+ """
+ Class used for arbitrary stateful operations with transformWithState to
capture single value
+ state.
+
+ .. versionadded:: 4.0.0
+ """
+
+ def __init__(
+ self, list_state_client: ListStateClient, state_name: str, schema:
Union[StructType, str]
+ ) -> None:
+ self._list_state_client = list_state_client
+ self._state_name = state_name
+ self.schema = schema
+
+ def exists(self) -> bool:
+ """
+ Whether list state exists or not.
+ """
+ return self._list_state_client.exists(self._state_name)
+
+ def get(self) -> Iterator[Row]:
+ """
+ Get list state with an iterator.
+ """
+ return ListStateIterator(self._list_state_client, self._state_name)
+
+ def put(self, new_state: List[Any]) -> None:
Review Comment:
I understand the type is just a hint, but I wonder the reason we use `Any`
and `Row` here. Even further we use `Tuple` in other side. What is the actual
type of this? Do we allow multiple types here?
##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -168,3 +207,43 @@ def _serialize_to_bytes(self, schema: StructType, data:
Tuple) -> bytes:
def _deserialize_from_bytes(self, value: bytes) -> Any:
return self.pickleSer.loads(value)
+
+ def _send_arrow_state(self, schema: StructType, state: List[Tuple]) ->
None:
Review Comment:
Is there any utility class/function we can reuse in PySpark? I feel like
this is the functionality we should have done before to support Pandas API on
Spark. Among `_send_arrow_state` and `_convert`, `_to_numpy_type`.
cc. @HyukjinKwon
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -125,9 +149,13 @@ class TransformWithStateInPandasStateServer(
// The key row is serialized as a byte array, we need to convert it
back to a Row
val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema,
keyRowDeserializer)
ImplicitGroupingKeyTracker.setImplicitKey(keyRow)
+ // Reset the list state iterators for a new grouping key.
+ listStateIterators = new mutable.HashMap[String, Iterator[Row]]()
Review Comment:
How do we handle the case of given param `listStateIteratorMapForTest`? Are
we assuming the test should only deal with a single grouping key? If then let's
explicitly mention in somewhere, e.g. class doc.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala:
##########
@@ -127,24 +151,79 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
}
test("value state clear") {
- val message = ValueStateCall.newBuilder().setStateName(valueStateName)
+ val message = ValueStateCall.newBuilder().setStateName(stateName)
.setClear(Clear.newBuilder().build()).build()
stateServer.handleValueStateRequest(message)
verify(valueState).clear()
verify(outputStream).writeInt(0)
}
test("value state update") {
- // Below byte array is a serialized row with a single integer value 1.
- val byteArray: Array[Byte] = Array(0x80.toByte, 0x05.toByte, 0x95.toByte,
0x05.toByte,
- 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte,
0x00.toByte, 0x00.toByte,
- 'K'.toByte, 0x01.toByte, 0x85.toByte, 0x94.toByte, '.'.toByte
- )
val byteString: ByteString = ByteString.copyFrom(byteArray)
- val message = ValueStateCall.newBuilder().setStateName(valueStateName)
+ val message = ValueStateCall.newBuilder().setStateName(stateName)
.setValueStateUpdate(ValueStateUpdate.newBuilder().setValue(byteString).build()).build()
stateServer.handleValueStateRequest(message)
verify(valueState).update(any[Row])
verify(outputStream).writeInt(0)
}
+
+ test("list state exists") {
+ val message = ListStateCall.newBuilder().setStateName(stateName)
+ .setExists(Exists.newBuilder().build()).build()
+ stateServer.handleListStateRequest(message)
+ verify(listState).exists()
+ }
+
+ test("list state get - iterator in map") {
+ val message = ListStateCall.newBuilder().setStateName(stateName)
+
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
+ stateServer.handleListStateRequest(message)
+ verify(listState, times(0)).get()
+ verify(arrowStreamWriter).writeRow(any)
+ verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+ }
+
+ test("list state get - iterator not in map") {
+ val maxRecordsPerBatch = 2
+ val message = ListStateCall.newBuilder().setStateName(stateName)
+
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
+ val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
+ stateServer = new TransformWithStateInPandasStateServer(serverSocket,
+ statefulProcessorHandle, groupingKeySchema, "", false, false,
+ maxRecordsPerBatch, outputStream, valueStateMap,
+ transformWithStateInPandasDeserializer, arrowStreamWriter, listStateMap,
iteratorMap)
+ when(listState.get()).thenReturn(Iterator(new
GenericRowWithSchema(Array(1), stateSchema),
+ new GenericRowWithSchema(Array(2), stateSchema),
+ new GenericRowWithSchema(Array(3), stateSchema)))
+ stateServer.handleListStateRequest(message)
+ verify(listState).get()
+ // Verify that only maxRecordsPerBatch (2) rows are written to the output
stream while still
+ // having 1 row left in the iterator.
+ verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any)
+ verify(arrowStreamWriter).finalizeCurrentArrowBatch()
Review Comment:
Would we like to also test that the remaining part will be presented if we
request the same iterator ID? Probably good to have a separate test to fully
consume the iterator which length is multiple of maxRecordsPerBatch.
##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -77,6 +78,58 @@ def clear(self) -> None:
self._value_state_client.clear(self._state_name)
+class ListState:
+ """
+ Class used for arbitrary stateful operations with transformWithState to
capture single value
Review Comment:
nit: It's not for single value state. Probably could be an issue of copy &
paste without fix.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -157,7 +185,11 @@ class TransformWithStateInPandasStateServer(
val ttlDurationMs = if (message.getGetValueState.hasTtl) {
Some(message.getGetValueState.getTtl.getDurationMs)
} else None
- initializeValueState(stateName, schema, ttlDurationMs)
+ initializeStateVariable(stateName, schema, "valueState", ttlDurationMs)
+ case StatefulProcessorCall.MethodCase.GETLISTSTATE =>
+ val stateName = message.getGetListState.getStateName
+ val schema = message.getGetListState.getSchema
+ initializeStateVariable(stateName, schema, "listState", None)
Review Comment:
Is it TODO to support TTL for ListState ? Probably better to have a JIRA
ticket.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -232,23 +349,53 @@ class TransformWithStateInPandasStateServer(
outputStream.write(responseMessageBytes)
}
- private def initializeValueState(
- stateName: String,
- schemaString: String,
- ttlDurationMs: Option[Int]): Unit = {
- if (!valueStates.contains(stateName)) {
- val schema = StructType.fromString(schemaString)
- val state = if (ttlDurationMs.isEmpty) {
- statefulProcessorHandle.getValueState[Row](stateName,
Encoders.row(schema))
- } else {
- statefulProcessorHandle.getValueState(
- stateName, Encoders.row(schema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
- }
- val valueRowDeserializer =
ExpressionEncoder(schema).resolveAndBind().createDeserializer()
- valueStates.put(stateName, (state, schema, valueRowDeserializer))
- sendResponse(0)
- } else {
- sendResponse(1, s"state $stateName already exists")
+ private def initializeStateVariable(
+ stateName: String,
+ schemaString: String,
+ stateVariable: String,
Review Comment:
nit: state`Type`? And would you mind using enum for type?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -201,13 +236,95 @@ class TransformWithStateInPandasStateServer(
}
case ValueStateCall.MethodCase.VALUESTATEUPDATE =>
val byteArray = message.getValueStateUpdate.getValue.toByteArray
- val valueStateTuple = valueStates(stateName)
// The value row is serialized as a byte array, we need to convert it
back to a Row
- val valueRow = PythonSQLUtils.toJVMRow(byteArray, valueStateTuple._2,
valueStateTuple._3)
- valueStateTuple._1.update(valueRow)
+ val valueRow = PythonSQLUtils.toJVMRow(byteArray,
valueStateInfo.schema,
+ valueStateInfo.deserializer)
+ valueStateInfo.valueState.update(valueRow)
sendResponse(0)
case ValueStateCall.MethodCase.CLEAR =>
- valueStates(stateName)._1.clear()
+ valueStateInfo.valueState.clear()
+ sendResponse(0)
+ case _ =>
+ throw new IllegalArgumentException("Invalid method call")
+ }
+ }
+
+ private[sql] def handleListStateRequest(message: ListStateCall): Unit = {
+ val stateName = message.getStateName
+ if (!listStates.contains(stateName)) {
+ logWarning(log"List state ${MDC(LogKeys.STATE_NAME, stateName)} is not
initialized.")
+ sendResponse(1, s"List state $stateName is not initialized.")
+ return
+ }
+ val listStateInfo = listStates(stateName)
+ val deserializer = if (deserializerForTest != null) {
+ deserializerForTest
+ } else {
+ new TransformWithStateInPandasDeserializer(listStateInfo.deserializer)
+ }
+ message.getMethodCase match {
+ case ListStateCall.MethodCase.EXISTS =>
+ if (listStateInfo.listState.exists()) {
+ sendResponse(0)
+ } else {
+ // Send status code 2 to indicate that the list state doesn't have a
value yet.
+ sendResponse(2, s"state $stateName doesn't exist")
+ }
+ case ListStateCall.MethodCase.LISTSTATEPUT =>
+ val rows = deserializer.readArrowBatches(inputStream)
+ listStateInfo.listState.put(rows.toArray)
+ sendResponse(0)
+ case ListStateCall.MethodCase.LISTSTATEGET =>
+ val iteratorId = message.getListStateGet.getIteratorId
+ var iteratorOption = listStateIterators.get(iteratorId)
+ if (iteratorOption.isEmpty) {
+ iteratorOption = Some(listStateInfo.listState.get())
+ listStateIterators.put(iteratorId, iteratorOption.get)
+ }
+ if (!iteratorOption.get.hasNext) {
+ sendResponse(2, s"List state $stateName doesn't contain any value.")
+ return
+ } else {
+ sendResponse(0)
+ }
+ outputStream.flush()
+ val arrowStreamWriter = if (arrowStreamWriterForTest != null) {
+ arrowStreamWriterForTest
+ } else {
+ val arrowSchema = ArrowUtils.toArrowSchema(listStateInfo.schema,
timeZoneId,
+ errorOnDuplicatedFieldNames, largeVarTypes)
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+ s"stdout writer for transformWithStateInPandas state socket", 0,
Long.MaxValue)
+ val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ new BaseStreamingArrowWriter(root, new ArrowStreamWriter(root, null,
outputStream),
+ arrowTransformWithStateInPandasMaxRecordsPerBatch)
+ }
+ val listRowSerializer = listStateInfo.serializer
+ // Only write a single batch in each GET request. Stops writing row if
rowCount reaches
+ // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This
is to avoid a case
Review Comment:
Is this to avoid the case, or to support the case? From the code
implementation I feel like it's to support the case. Otherwise do we have the
logic to block such a request?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -232,23 +349,53 @@ class TransformWithStateInPandasStateServer(
outputStream.write(responseMessageBytes)
}
- private def initializeValueState(
- stateName: String,
- schemaString: String,
- ttlDurationMs: Option[Int]): Unit = {
- if (!valueStates.contains(stateName)) {
- val schema = StructType.fromString(schemaString)
- val state = if (ttlDurationMs.isEmpty) {
- statefulProcessorHandle.getValueState[Row](stateName,
Encoders.row(schema))
- } else {
- statefulProcessorHandle.getValueState(
- stateName, Encoders.row(schema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
- }
- val valueRowDeserializer =
ExpressionEncoder(schema).resolveAndBind().createDeserializer()
- valueStates.put(stateName, (state, schema, valueRowDeserializer))
- sendResponse(0)
- } else {
- sendResponse(1, s"state $stateName already exists")
+ private def initializeStateVariable(
+ stateName: String,
Review Comment:
nit: two more spaces for params
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -232,23 +349,53 @@ class TransformWithStateInPandasStateServer(
outputStream.write(responseMessageBytes)
}
- private def initializeValueState(
- stateName: String,
- schemaString: String,
- ttlDurationMs: Option[Int]): Unit = {
- if (!valueStates.contains(stateName)) {
- val schema = StructType.fromString(schemaString)
- val state = if (ttlDurationMs.isEmpty) {
- statefulProcessorHandle.getValueState[Row](stateName,
Encoders.row(schema))
- } else {
- statefulProcessorHandle.getValueState(
- stateName, Encoders.row(schema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
- }
- val valueRowDeserializer =
ExpressionEncoder(schema).resolveAndBind().createDeserializer()
- valueStates.put(stateName, (state, schema, valueRowDeserializer))
- sendResponse(0)
- } else {
- sendResponse(1, s"state $stateName already exists")
+ private def initializeStateVariable(
+ stateName: String,
+ schemaString: String,
+ stateVariable: String,
+ ttlDurationMs: Option[Int]): Unit = {
+ val schema = StructType.fromString(schemaString)
+ val expressionEncoder = ExpressionEncoder(schema).resolveAndBind()
+ stateVariable match {
+ case "valueState" => if (!valueStates.contains(stateName)) {
+ val state = if (ttlDurationMs.isEmpty) {
+ statefulProcessorHandle.getValueState[Row](stateName,
Encoders.row(schema))
+ } else {
+ statefulProcessorHandle.getValueState(
+ stateName, Encoders.row(schema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
+ }
+ valueStates.put(stateName,
+ ValueStateInfo(state, schema,
expressionEncoder.createDeserializer()))
+ sendResponse(0)
+ } else {
+ sendResponse(1, s"Value state $stateName already exists")
+ }
+ case "listState" => if (!listStates.contains(stateName)) {
Review Comment:
nit: let's either assert ttlDurationMs is empty or leave code comment for
TODO
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]