jingz-db commented on code in PR #47933:
URL: https://github.com/apache/spark/pull/47933#discussion_r1755201977
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -197,13 +232,92 @@ class TransformWithStateInPandasStateServer(
}
case ValueStateCall.MethodCase.VALUESTATEUPDATE =>
val byteArray = message.getValueStateUpdate.getValue.toByteArray
- val valueStateTuple = valueStates(stateName)
// The value row is serialized as a byte array, we need to convert it
back to a Row
- val valueRow = PythonSQLUtils.toJVMRow(byteArray, valueStateTuple._2,
valueStateTuple._3)
- valueStateTuple._1.update(valueRow)
+ val valueRow = PythonSQLUtils.toJVMRow(byteArray,
valueStateInfo.schema,
+ valueStateInfo.deserializer)
+ valueStateInfo.valueState.update(valueRow)
sendResponse(0)
case ValueStateCall.MethodCase.CLEAR =>
- valueStates(stateName)._1.clear()
+ valueStateInfo.valueState.clear()
+ sendResponse(0)
+ case _ =>
+ throw new IllegalArgumentException("Invalid method call")
+ }
+ }
+
+ private[sql] def handleListStateRequest(message: ListStateCall): Unit = {
+ val stateName = message.getStateName
+ if (!listStates.contains(stateName)) {
+ logWarning(log"List state ${MDC(LogKeys.STATE_NAME, stateName)} is not
initialized.")
+ sendResponse(1, s"List state $stateName is not initialized.")
+ return
+ }
+ val listStateInfo = listStates(stateName)
+ val deserializer = if (deserializerForTest != null) {
+ deserializerForTest
+ } else {
+ new TransformWithStateInPandasDeserializer(listStateInfo.deserializer)
+ }
+ message.getMethodCase match {
+ case ListStateCall.MethodCase.EXISTS =>
+ if (listStateInfo.listState.exists()) {
+ sendResponse(0)
+ } else {
+ // Send status code 2 to indicate that the list state doesn't have a
value yet.
+ sendResponse(2, s"state $stateName doesn't exist")
+ }
+ case ListStateCall.MethodCase.LISTSTATEPUT =>
+ val rows = deserializer.readArrowBatches(inputStream)
+ listStateInfo.listState.put(rows.toArray)
+ sendResponse(0)
+ case ListStateCall.MethodCase.GET =>
+ var iteratorOption = listStateIterators.get(stateName)
+ if (iteratorOption.isEmpty) {
+ iteratorOption = Some(listStateInfo.listState.get())
+ listStateIterators.put(stateName, iteratorOption.get)
+ }
+ if (!iteratorOption.get.hasNext) {
+ sendResponse(2, s"List state $stateName doesn't contain any value.")
+ return
+ } else {
+ sendResponse(0)
+ }
+ outputStream.flush()
+ val arrowStreamWriter = if (arrowStreamWriterForTest != null) {
+ arrowStreamWriterForTest
+ } else {
+ val arrowSchema = ArrowUtils.toArrowSchema(listStateInfo.schema,
timeZoneId,
+ errorOnDuplicatedFieldNames, largeVarTypes)
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+ s"stdout writer for transformWithStateInPandas state socket", 0,
Long.MaxValue)
+ val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ new BaseStreamingArrowWriter(root, new ArrowStreamWriter(root, null,
outputStream),
+ arrowTransformWithStateInPandasMaxRecordsPerBatch)
+ }
+ val listRowSerializer = listStateInfo.serializer
+ // Only write a single batch in each GET request. Stops writing row if
rowCount reaches
Review Comment:
nits: If the number of row exceed the `rowCount`, IIUC, we will just send
the following data in another batch? Shall we log out some info here when the
number of rows exceeds `rowCount`? Also is it possible to write some tests for
the above scenario to make sure the newly added config takes effect?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]