This is an automated email from the ASF dual-hosted git repository.

yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 345e7da1cb0b [SPARK-53530][SS] Clean up the useless code related to 
`TransformWithStateInPySparkStateServer`
345e7da1cb0b is described below

commit 345e7da1cb0b764d2eba9ae0620714a103670696
Author: yangjie01 <[email protected]>
AuthorDate: Wed Sep 10 12:07:33 2025 +0800

    [SPARK-53530][SS] Clean up the useless code related to 
`TransformWithStateInPySparkStateServer`
    
    ### What changes were proposed in this pull request?
    This PR performs the following cleanup on the code related to 
`TransformWithStateInPySparkStateServer`:
    
    - Removed the `private` function `sendIteratorForListState` from 
`TransformWithStateInPySparkStateServer`, as it is no longer used after 
SPARK-51891.
    - Removed the function `sendIteratorAsArrowBatches` from 
`TransformWithStateInPySparkStateServer`, as it is no longer used after 
SPARK-52333.
    - Removed the input parameters `timeZoneId`, `errorOnDuplicatedFieldNames`, 
`largeVarTypes`, and `arrowStreamWriterForTest` from the constructor of 
`TransformWithStateInPySparkStateServer`, as they are no longer used after the 
cleanup of `sendIteratorAsArrowBatches`.
    - Removed the input parameter `timeZoneId` from the constructor of 
`TransformWithStateInPySparkPythonPreInitRunner`, as it was only used for 
constructing `TransformWithStateInPySparkStateServer`.
    
    ### Why are the changes needed?
    Code cleanup.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Pass Github Actions
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #52279 from LuciferYang/TransformWithStateInPySparkStateServer.
    
    Lead-authored-by: yangjie01 <[email protected]>
    Co-authored-by: YangJie <[email protected]>
    Signed-off-by: yangjie01 <[email protected]>
---
 .../TransformWithStateInPySparkExec.scala          |  1 -
 .../TransformWithStateInPySparkPythonRunner.scala  |  6 +-
 .../TransformWithStateInPySparkStateServer.scala   | 70 ----------------------
 ...arkTransformWithStateInPySparkStateServer.scala |  3 -
 ...ansformWithStateInPySparkStateServerSuite.scala | 33 +++++-----
 5 files changed, 18 insertions(+), 95 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
index f8390b7d878f..c10d21933c2f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
@@ -154,7 +154,6 @@ case class TransformWithStateInPySparkExec(
     val runner = new TransformWithStateInPySparkPythonPreInitRunner(
       pythonFunction,
       "pyspark.sql.streaming.transform_with_state_driver_worker",
-      sessionLocalTimeZone,
       groupingKeySchema,
       driverProcessorHandle
     )
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index 51dc179c901a..329bd4335265 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -220,7 +220,7 @@ abstract class 
TransformWithStateInPySparkPythonBaseRunner[I](
 
     executionContext.execute(
       new TransformWithStateInPySparkStateServer(stateServerSocket, 
processorHandle,
-        groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames, 
largeVarTypes,
+        groupingKeySchema,
         sqlConf.arrowTransformWithStateInPySparkMaxStateRecordsPerBatch,
         batchTimestampMs, eventTimeWatermarkForEviction))
 
@@ -245,7 +245,6 @@ abstract class 
TransformWithStateInPySparkPythonBaseRunner[I](
 class TransformWithStateInPySparkPythonPreInitRunner(
     func: PythonFunction,
     workerModule: String,
-    timeZoneId: String,
     groupingKeySchema: StructType,
     processorHandleImpl: DriverStatefulProcessorHandleImpl)
   extends StreamingPythonRunner(func, "", "", workerModule)
@@ -299,8 +298,7 @@ class TransformWithStateInPySparkPythonPreInitRunner(
       override def run(): Unit = {
         try {
           new TransformWithStateInPySparkStateServer(stateServerSocket, 
processorHandleImpl,
-            groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames = true,
-            largeVarTypes = sqlConf.arrowUseLargeVarTypes,
+            groupingKeySchema,
             
sqlConf.arrowTransformWithStateInPySparkMaxStateRecordsPerBatch).run()
         } catch {
           case e: Exception =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
index 4edeae132b47..59acf434035e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala
@@ -25,15 +25,12 @@ import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 
 import com.google.protobuf.ByteString
-import org.apache.arrow.vector.VectorSchemaRoot
-import org.apache.arrow.vector.ipc.ArrowStreamWriter
 
 import org.apache.spark.SparkEnv
 import org.apache.spark.internal.{Logging, LogKeys}
 import 
org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
 import org.apache.spark.sql.{Encoders, Row}
 import org.apache.spark.sql.api.python.PythonSQLUtils
-import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.StateVariableType
@@ -43,8 +40,6 @@ import 
org.apache.spark.sql.execution.streaming.state.StateMessage.KeyAndValuePa
 import 
org.apache.spark.sql.execution.streaming.state.StateMessage.StateResponseWithListGet
 import org.apache.spark.sql.streaming.{ListState, MapState, TTLConfig, 
ValueState}
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.ArrowUtils
-import org.apache.spark.util.Utils
 
 /**
  * This class is used to handle the state requests from the Python side. It 
runs on a separate
@@ -60,16 +55,12 @@ class TransformWithStateInPySparkStateServer(
     stateServerSocket: ServerSocketChannel,
     statefulProcessorHandle: StatefulProcessorHandleImplBase,
     groupingKeySchema: StructType,
-    timeZoneId: String,
-    errorOnDuplicatedFieldNames: Boolean,
-    largeVarTypes: Boolean,
     arrowTransformWithStateInPySparkMaxRecordsPerBatch: Int,
     batchTimestampMs: Option[Long] = None,
     eventTimeWatermarkForEviction: Option[Long] = None,
     outputStreamForTest: DataOutputStream = null,
     valueStateMapForTest: mutable.HashMap[String, ValueStateInfo] = null,
     deserializerForTest: TransformWithStateInPySparkDeserializer = null,
-    arrowStreamWriterForTest: BaseStreamingArrowWriter = null,
     listStatesMapForTest : mutable.HashMap[String, ListStateInfo] = null,
     iteratorMapForTest: mutable.HashMap[String, Iterator[Row]] = null,
     mapStatesMapForTest : mutable.HashMap[String, MapStateInfo] = null,
@@ -533,28 +524,6 @@ class TransformWithStateInPySparkStateServer(
     }
   }
 
-  private def sendIteratorForListState(iter: Iterator[Row]): Unit = {
-    // Only write a single batch in each GET request. Stops writing row if 
rowCount reaches
-    // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to 
handle a case
-    // when there are multiple state variables, user tries to access a 
different state variable
-    // while the current state variable is not exhausted yet.
-    var rowCount = 0
-    while (iter.hasNext && rowCount < 
arrowTransformWithStateInPySparkMaxRecordsPerBatch) {
-      val data = iter.next()
-
-      // Serialize the value row as a byte array
-      val valueBytes = PythonSQLUtils.toPyRow(data)
-      val lenBytes = valueBytes.length
-
-      outputStream.writeInt(lenBytes)
-      outputStream.write(valueBytes)
-
-      rowCount += 1
-    }
-    outputStream.writeInt(-1)
-    outputStream.flush()
-  }
-
   private[sql] def handleMapStateRequest(message: MapStateCall): Unit = {
     val stateName = message.getStateName
     if (!mapStates.contains(stateName)) {
@@ -939,45 +908,6 @@ class TransformWithStateInPySparkStateServer(
       outputStream.write(responseMessageBytes)
     }
 
-    def sendIteratorAsArrowBatches[T](
-        iter: Iterator[T],
-        outputSchema: StructType,
-        arrowStreamWriterForTest: BaseStreamingArrowWriter = null)(func: T => 
InternalRow): Unit = {
-      outputStream.flush()
-      val arrowSchema = ArrowUtils.toArrowSchema(outputSchema, timeZoneId,
-        errorOnDuplicatedFieldNames, largeVarTypes)
-      val allocator = ArrowUtils.rootAllocator.newChildAllocator(
-        s"stdout writer for transformWithStateInPySpark state socket", 0, 
Long.MaxValue)
-      val root = VectorSchemaRoot.create(arrowSchema, allocator)
-      val writer = new ArrowStreamWriter(root, null, outputStream)
-      val arrowStreamWriter = if (arrowStreamWriterForTest != null) {
-        arrowStreamWriterForTest
-      } else {
-        new BaseStreamingArrowWriter(root, writer,
-          arrowTransformWithStateInPySparkMaxRecordsPerBatch)
-      }
-      // Only write a single batch in each GET request. Stops writing row if 
rowCount reaches
-      // the arrowTransformWithStateInPySparkMaxRecordsPerBatch limit. This is 
to handle a case
-      // when there are multiple state variables, user tries to access a 
different state variable
-      // while the current state variable is not exhausted yet.
-      var rowCount = 0
-      while (iter.hasNext && rowCount < 
arrowTransformWithStateInPySparkMaxRecordsPerBatch) {
-        val data = iter.next()
-        val internalRow = func(data)
-        arrowStreamWriter.writeRow(internalRow)
-        rowCount += 1
-      }
-      arrowStreamWriter.finalizeCurrentArrowBatch()
-      Utils.tryWithSafeFinally {
-        // end writes footer to the output stream and doesn't clean any 
resources.
-        // It could throw exception if the output stream is closed, so it 
should be
-        // in the try block.
-        writer.end()
-      } {
-        root.close()
-        allocator.close()
-      }
-    }
   }
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/benchmark/BenchmarkTransformWithStateInPySparkStateServer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/benchmark/BenchmarkTransformWithStateInPySparkStateServer.scala
index 5dc7d9733dcd..91162c7b02f9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/benchmark/BenchmarkTransformWithStateInPySparkStateServer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/benchmark/BenchmarkTransformWithStateInPySparkStateServer.scala
@@ -351,9 +351,6 @@ object BenchmarkTransformWithStateInPySparkStateServer 
extends App {
     serverSocketChannel,
     stateHandleImpl,
     groupingKeySchema,
-    timeZoneId,
-    errorOnDuplicatedFieldNames,
-    largeVarTypes,
     arrowTransformWithStateInPySparkMaxRecordsPerBatch
   )
   // scalastyle:off println
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala
index ff99b4ee280d..013aa375c308 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServerSuite.scala
@@ -100,9 +100,9 @@ class TransformWithStateInPySparkStateServerSuite extends 
SparkFunSuite with Bef
     batchTimestampMs = mock(classOf[Option[Long]])
     eventTimeWatermarkForEviction = mock(classOf[Option[Long]])
     stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
-      statefulProcessorHandle, groupingKeySchema, "", false, false, 2,
+      statefulProcessorHandle, groupingKeySchema, 2,
       batchTimestampMs, eventTimeWatermarkForEviction,
-      outputStream, valueStateMap, transformWithStateInPySparkDeserializer, 
arrowStreamWriter,
+      outputStream, valueStateMap, transformWithStateInPySparkDeserializer,
       listStateMap, iteratorMap, mapStateMap, keyValueIteratorMap, 
expiryTimerIter, listTimerMap)
     when(transformWithStateInPySparkDeserializer.readArrowBatches(any))
       .thenReturn(Seq(getIntegerRow(1)))
@@ -278,9 +278,9 @@ class TransformWithStateInPySparkStateServerSuite extends 
SparkFunSuite with Bef
     val iteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId ->
       Iterator(getIntegerRow(1), getIntegerRow(2), getIntegerRow(3), 
getIntegerRow(4)))
     stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
-      statefulProcessorHandle, groupingKeySchema, "", false, false,
+      statefulProcessorHandle, groupingKeySchema,
       maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, 
outputStream,
-      valueStateMap, transformWithStateInPySparkDeserializer, 
arrowStreamWriter,
+      valueStateMap, transformWithStateInPySparkDeserializer,
       listStateMap, iteratorMap)
     // First call should send 2 records.
     stateServer.handleListStateRequest(message)
@@ -307,9 +307,9 @@ class TransformWithStateInPySparkStateServerSuite extends 
SparkFunSuite with Bef
       
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
     val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
     stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
-      statefulProcessorHandle, groupingKeySchema, "", false, false,
+      statefulProcessorHandle, groupingKeySchema,
       maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, 
outputStream,
-      valueStateMap, transformWithStateInPySparkDeserializer, 
arrowStreamWriter,
+      valueStateMap, transformWithStateInPySparkDeserializer,
       listStateMap, iteratorMap)
     when(listState.get()).thenReturn(Iterator(getIntegerRow(1), 
getIntegerRow(2), getIntegerRow(3)))
     stateServer.handleListStateRequest(message)
@@ -419,9 +419,9 @@ class TransformWithStateInPySparkStateServerSuite extends 
SparkFunSuite with Bef
       Iterator((getIntegerRow(1), getIntegerRow(1)), (getIntegerRow(2), 
getIntegerRow(2)),
         (getIntegerRow(3), getIntegerRow(3)), (getIntegerRow(4), 
getIntegerRow(4))))
     stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
-      statefulProcessorHandle, groupingKeySchema, "", false, false,
+      statefulProcessorHandle, groupingKeySchema,
       maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, 
outputStream,
-      valueStateMap, transformWithStateInPySparkDeserializer, 
arrowStreamWriter,
+      valueStateMap, transformWithStateInPySparkDeserializer,
       listStateMap, null, mapStateMap, keyValueIteratorMap)
     // First call should send 2 records.
     stateServer.handleMapStateRequest(message)
@@ -448,10 +448,10 @@ class TransformWithStateInPySparkStateServerSuite extends 
SparkFunSuite with Bef
       
.setIterator(StateMessage.Iterator.newBuilder().setIteratorId(iteratorId).build()).build()
     val keyValueIteratorMap: mutable.HashMap[String, Iterator[(Row, Row)]] = 
mutable.HashMap()
     stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
-      statefulProcessorHandle, groupingKeySchema, "", false, false,
+      statefulProcessorHandle, groupingKeySchema,
       maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
       outputStream, valueStateMap, transformWithStateInPySparkDeserializer,
-      arrowStreamWriter, listStateMap, null, mapStateMap, keyValueIteratorMap)
+      listStateMap, null, mapStateMap, keyValueIteratorMap)
     when(mapState.iterator()).thenReturn(Iterator((getIntegerRow(1), 
getIntegerRow(1)),
       (getIntegerRow(2), getIntegerRow(2)), (getIntegerRow(3), 
getIntegerRow(3))))
     stateServer.handleMapStateRequest(message)
@@ -481,10 +481,10 @@ class TransformWithStateInPySparkStateServerSuite extends 
SparkFunSuite with Bef
       .setKeys(Keys.newBuilder().setIteratorId(iteratorId).build()).build()
     val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
     stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
-      statefulProcessorHandle, groupingKeySchema, "", false, false,
+      statefulProcessorHandle, groupingKeySchema,
       maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
       outputStream, valueStateMap, transformWithStateInPySparkDeserializer,
-      arrowStreamWriter, listStateMap, iteratorMap, mapStateMap)
+      listStateMap, iteratorMap, mapStateMap)
     when(mapState.keys()).thenReturn(Iterator(getIntegerRow(1), 
getIntegerRow(2), getIntegerRow(3)))
     stateServer.handleMapStateRequest(message)
     verify(mapState).keys()
@@ -513,10 +513,10 @@ class TransformWithStateInPySparkStateServerSuite extends 
SparkFunSuite with Bef
       .setValues(Values.newBuilder().setIteratorId(iteratorId).build()).build()
     val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
     stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
-      statefulProcessorHandle, groupingKeySchema, "", false, false,
+      statefulProcessorHandle, groupingKeySchema,
       maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, 
outputStream,
       valueStateMap, transformWithStateInPySparkDeserializer,
-      arrowStreamWriter, listStateMap, iteratorMap, mapStateMap)
+      listStateMap, iteratorMap, mapStateMap)
     when(mapState.values()).thenReturn(Iterator(getIntegerRow(1), 
getIntegerRow(2),
       getIntegerRow(3)))
     stateServer.handleMapStateRequest(message)
@@ -611,10 +611,9 @@ class TransformWithStateInPySparkStateServerSuite extends 
SparkFunSuite with Bef
         .build()
     ).build()
     stateServer = new TransformWithStateInPySparkStateServer(serverSocket,
-      statefulProcessorHandle, groupingKeySchema, "", false, false,
+      statefulProcessorHandle, groupingKeySchema,
       2, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
-      valueStateMap, transformWithStateInPySparkDeserializer,
-      arrowStreamWriter, listStateMap, null, mapStateMap, null,
+      valueStateMap, transformWithStateInPySparkDeserializer, listStateMap, 
null, mapStateMap, null,
       null, listTimerMap)
     when(statefulProcessorHandle.listTimers()).thenReturn(Iterator(1))
     stateServer.handleStatefulProcessorCall(message)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to