jingz-db commented on code in PR #47878:
URL: https://github.com/apache/spark/pull/47878#discussion_r1777566285
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -251,4 +336,64 @@ class TransformWithStateInPandasStateServer(
sendResponse(1, s"state $stateName already exists")
}
}
+
+ /** Utils object for sending response to Python client. */
+ private object PythonResponseWriterUtils {
+ def sendResponse(
+ status: Int,
+ errorMessage: String = null,
+ byteString: ByteString = null): Unit = {
+ val responseMessageBuilder =
StateResponse.newBuilder().setStatusCode(status)
+ if (status != 0 && errorMessage != null) {
+ responseMessageBuilder.setErrorMessage(errorMessage)
+ }
+ if (byteString != null) {
+ responseMessageBuilder.setValue(byteString)
+ }
+ val responseMessage = responseMessageBuilder.build()
+ val responseMessageBytes = responseMessage.toByteArray
+ val byteLength = responseMessageBytes.length
+ outputStream.writeInt(byteLength)
+ outputStream.write(responseMessageBytes)
+ }
+
+ def sendResponseWithLongVal(
+ status: Int,
+ errorMessage: String = null,
+ longVal: Long): Unit = {
+ val responseMessageBuilder =
StateResponseWithLongTypeVal.newBuilder().setStatusCode(status)
+ if (status != 0 && errorMessage != null) {
+ responseMessageBuilder.setErrorMessage(errorMessage)
+ }
+ responseMessageBuilder.setValue(longVal)
+ val responseMessage = responseMessageBuilder.build()
+ val responseMessageBytes = responseMessage.toByteArray
+ val byteLength = responseMessageBytes.length
+ outputStream.writeInt(byteLength)
+ outputStream.write(responseMessageBytes)
+ }
+
+ def sendIteratorAsArrowBatches[T](
+ iter: Iterator[T], outputSchema: StructType)(func: T => InternalRow):
Unit = {
+ outputStream.flush()
+ val arrowSchema = ArrowUtils.toArrowSchema(outputSchema, timeZoneId,
+ errorOnDuplicatedFieldNames, largeVarTypes)
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+ s"stdout writer for transformWithStateInPandas state socket", 0,
Long.MaxValue)
+ val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ val writer = new ArrowStreamWriter(root, null, outputStream)
+
+ val arrowStreamWriter = new BaseStreamingArrowWriter(root, writer,
+ arrowTransformWithStateInPandasMaxRecordsPerBatch)
+ while (iter.hasNext) {
+ val data = iter.next()
+ val internalRow = func(data)
+ arrowStreamWriter.writeRow(internalRow)
+ }
+ arrowStreamWriter.finalizeCurrentArrowBatch()
+ writer.end()
Review Comment:
Done!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]