bogao007 commented on code in PR #47933:
URL: https://github.com/apache/spark/pull/47933#discussion_r1769332212
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala:
##########
@@ -127,24 +151,79 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
}
test("value state clear") {
- val message = ValueStateCall.newBuilder().setStateName(valueStateName)
+ val message = ValueStateCall.newBuilder().setStateName(stateName)
.setClear(Clear.newBuilder().build()).build()
stateServer.handleValueStateRequest(message)
verify(valueState).clear()
verify(outputStream).writeInt(0)
}
test("value state update") {
- // Below byte array is a serialized row with a single integer value 1.
- val byteArray: Array[Byte] = Array(0x80.toByte, 0x05.toByte, 0x95.toByte,
0x05.toByte,
- 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte,
0x00.toByte, 0x00.toByte,
- 'K'.toByte, 0x01.toByte, 0x85.toByte, 0x94.toByte, '.'.toByte
- )
val byteString: ByteString = ByteString.copyFrom(byteArray)
- val message = ValueStateCall.newBuilder().setStateName(valueStateName)
+ val message = ValueStateCall.newBuilder().setStateName(stateName)
.setValueStateUpdate(ValueStateUpdate.newBuilder().setValue(byteString).build()).build()
stateServer.handleValueStateRequest(message)
verify(valueState).update(any[Row])
verify(outputStream).writeInt(0)
}
+
+ test("list state exists") {
+ val message = ListStateCall.newBuilder().setStateName(stateName)
+ .setExists(Exists.newBuilder().build()).build()
+ stateServer.handleListStateRequest(message)
+ verify(listState).exists()
+ }
+
+ test("list state get - iterator in map") {
+ val message = ListStateCall.newBuilder().setStateName(stateName)
+
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
+ stateServer.handleListStateRequest(message)
+ verify(listState, times(0)).get()
+ verify(arrowStreamWriter).writeRow(any)
+ verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+ }
+
+ test("list state get - iterator not in map") {
+ val maxRecordsPerBatch = 2
+ val message = ListStateCall.newBuilder().setStateName(stateName)
+
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
+ val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
+ stateServer = new TransformWithStateInPandasStateServer(serverSocket,
+ statefulProcessorHandle, groupingKeySchema, "", false, false,
+ maxRecordsPerBatch, outputStream, valueStateMap,
+ transformWithStateInPandasDeserializer, arrowStreamWriter, listStateMap,
iteratorMap)
+ when(listState.get()).thenReturn(Iterator(new
GenericRowWithSchema(Array(1), stateSchema),
+ new GenericRowWithSchema(Array(2), stateSchema),
+ new GenericRowWithSchema(Array(3), stateSchema)))
+ stateServer.handleListStateRequest(message)
+ verify(listState).get()
+ // Verify that only maxRecordsPerBatch (2) rows are written to the output
stream while still
+ // having 1 row left in the iterator.
+ verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any)
+ verify(arrowStreamWriter).finalizeCurrentArrowBatch()
Review Comment:
added a separate test case
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]