HeartSaVioR commented on code in PR #49156:
URL: https://github.com/apache/spark/pull/49156#discussion_r1893860278


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala:
##########
@@ -109,6 +112,21 @@ case class TransformWithStateInPandasExec(
   // Each state variable has its own schema, this is a dummy one.
   protected val schemaForValueRow: StructType = new StructType().add("value", 
BinaryType)
 
+  override def operatorStateMetadataVersion: Int = 2
+
+  override def getColFamilySchemas(): Map[String, StateStoreColFamilySchema] = 
{
+    driverProcessorHandle.getColumnFamilySchemas
+  }
+
+  override def getStateVariableInfos(): Map[String, 
TransformWithStateVariableInfo] = {
+    driverProcessorHandle.getStateVariableInfos
+  }
+
+  /** Metadata of this stateful operator and its states stores.
+   * Written during IncrementalExecution */

Review Comment:
   nit: Let's explicitly mention which method will initialize this - looks like 
driverProcessorHandle would be useful "after validateAndMaybeEvolveStateSchema 
is called".



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -238,7 +224,137 @@ abstract class 
TransformWithStateInPandasPythonBaseRunner[I](
     super.compute(inputIterator, partitionIndex, context)
   }
 
-  private def closeServerSocketChannelSilently(stateServerSocket: 
ServerSocket): Unit = {
+  override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, None)
+  }
+}
+
+/**
+ * TransformWithStateInPandas driver side Python runner. Similar as executor 
side runner,
+ * will start a new daemon thread on the Python runner to run state server.
+ */
+class TransformWithStateInPandasPythonPreInitRunner(

Review Comment:
   nit: This class has a lot of duplications - I'm OK to leave it as it is as 
of now, but let's consider it as tech debt.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -238,7 +224,137 @@ abstract class 
TransformWithStateInPandasPythonBaseRunner[I](
     super.compute(inputIterator, partitionIndex, context)
   }
 
-  private def closeServerSocketChannelSilently(stateServerSocket: 
ServerSocket): Unit = {
+  override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, None)
+  }
+}
+
+/**
+ * TransformWithStateInPandas driver side Python runner. Similar as executor 
side runner,
+ * will start a new daemon thread on the Python runner to run state server.
+ */
+class TransformWithStateInPandasPythonPreInitRunner(
+    func: PythonFunction,
+    workerModule: String,
+    timeZoneId: String,
+    groupingKeySchema: StructType,
+    processorHandleImpl: DriverStatefulProcessorHandleImpl)
+  extends StreamingPythonRunner(func, "", "", workerModule)
+  with TransformWithStateInPandasPythonRunnerUtils
+  with Logging {
+  protected val sqlConf = SQLConf.get
+
+  private var dataIn: DataInputStream = _
+  private var dataOut: DataOutputStream = _
+
+  private var daemonThread: Thread = _
+
+  override def init(): (DataOutputStream, DataInputStream) = {
+    val env = SparkEnv.get
+
+    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
+    envVars.put("SPARK_LOCAL_DIRS", localdir)
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    val workerFactory =
+      new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap, 
false)
+    val (worker: PythonWorker, _) = 
workerFactory.createSimpleWorker(blockingMode = true)
+    pythonWorker = Some(worker)
+    pythonWorkerFactory = Some(workerFactory)
+
+    val stream = new BufferedOutputStream(
+      pythonWorker.get.channel.socket().getOutputStream, bufferSize)
+    dataOut = new DataOutputStream(stream)
+
+    PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
+
+    // Send the user function to python process
+    PythonWorkerUtils.writePythonFunction(func, dataOut)
+    dataOut.flush()
+
+    dataIn = new DataInputStream(
+      new 
BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, 
bufferSize))
+
+    val resFromPython = dataIn.readInt()
+    if (resFromPython != 0) {
+      val errMessage = PythonWorkerUtils.readUTF(dataIn)
+      throw streamingPythonRunnerInitializationFailure(resFromPython, 
errMessage)
+    }
+    logInfo("Driver Python Runner initialization succeeded.")
+
+    // start state server, update socket port
+    startStateServer()
+    (dataOut, dataIn)
+  }
+
+  def process(): Unit = {
+    // Also write the port number for state server
+    dataOut.writeInt(stateServerSocketPort)
+    PythonWorkerUtils.writeUTF(groupingKeySchema.json, dataOut)
+    dataOut.flush()
+
+    val resFromPython = dataIn.readInt()
+    if (resFromPython != 0) {
+      val errMessage = PythonWorkerUtils.readUTF(dataIn)
+      throw streamingPythonRunnerInitializationFailure(resFromPython, 
errMessage)
+    }
+  }
+
+  override def stop(): Unit = {
+    super.stop()
+    closeServerSocketChannelSilently(stateServerSocket)
+    daemonThread.stop()
+  }
+
+  private def startStateServer(): Unit = {
+    initStateServer()
+
+    daemonThread = new Thread {
+      override def run(): Unit = {
+        try {
+          new TransformWithStateInPandasStateServer(stateServerSocket, 
processorHandleImpl,
+            groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames = true,
+            largeVarTypes = sqlConf.arrowUseLargeVarTypes,
+            sqlConf.arrowTransformWithStateInPandasMaxRecordsPerBatch).run()
+        } catch {
+          case e: Exception =>
+            throw new SparkException("TransformWithStateInPandas state server 
" +
+              "daemon thread exited unexpectedly (crashed)", e)
+        }
+      }
+    }
+    daemonThread.setDaemon(true)
+    daemonThread.setName("stateConnectionListenerThread")
+    daemonThread.start()
+  }
+}
+
+/**
+ * TransformWithStateInPandas Python runner utils functions for handling a 
state server
+ * in a new daemon thread.
+ */
+trait TransformWithStateInPandasPythonRunnerUtils extends Logging{

Review Comment:
   nit: space between `g` and `{`



##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -358,7 +369,7 @@ def check_results(batch_df, batch_id):
                 for q in self.spark.streams.active:
                     q.stop()
             if batch_id == 0 or batch_id == 1:
-                time.sleep(6)
+                time.sleep(4)

Review Comment:
   nit: Is it relevant change? I am happy with reducing test duration, just to 
see the rationale.



##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -852,13 +872,6 @@ def check_results(batch_df, batch_id):
             StatefulProcessorWithInitialStateTimers(), check_results, 
"processingTime"
         )
 
-    # run the same test suites again but with single shuffle partition

Review Comment:
   Isn't this something we intentionally added because of some observed 
failure? Mind explaining why we are going to remove this?



##########
python/pyspark/sql/streaming/transform_with_state_driver_worker.py:
##########
@@ -0,0 +1,103 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import json
+from typing import Any, Iterator, TYPE_CHECKING
+
+from pyspark.util import local_connect_and_auth
+from pyspark.serializers import (
+    write_int,
+    read_int,
+    UTF8Deserializer,
+    CPickleSerializer,
+)
+from pyspark import worker
+from pyspark.util import handle_worker_exception
+from typing import IO
+from pyspark.worker_util import check_python_version
+from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
+from pyspark.sql.streaming.stateful_processor_util import 
TransformWithStateInPandasFuncMode
+from pyspark.sql.types import StructType
+
+if TYPE_CHECKING:
+    from pyspark.sql.pandas._typing import (
+        DataFrameLike as PandasDataFrameLike,
+    )
+
+pickle_ser = CPickleSerializer()
+utf8_deserializer = UTF8Deserializer()
+
+
+def main(infile: IO, outfile: IO) -> None:
+    check_python_version(infile)
+
+    log_name = "Streaming TransformWithStateInPandas Python worker"
+    print(f"Starting {log_name}.\n")
+
+    def process(
+        processor: StatefulProcessorApiClient,
+        mode: TransformWithStateInPandasFuncMode,
+        key: Any,
+        input: Iterator["PandasDataFrameLike"],
+    ) -> None:
+        print(f"{log_name} Starting execution of UDF: {func}.\n")
+        func(processor, mode, key, input)
+        print(f"{log_name} Completed execution of UDF: {func}.\n")
+
+    try:
+        func, return_type = worker.read_command(pickle_ser, infile)
+        print(
+            f"{log_name} finish init stage of Python runner. Received UDF from 
JVM: {func}, "
+            f"received return type of UDF: {return_type}.\n"
+        )
+        # send signal for getting args
+        write_int(0, outfile)
+        outfile.flush()
+
+        while True:

Review Comment:
   Is this intentional? Is this worker process reused by multiple queries in 
the same driver? Otherwise how this would be used?



##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -907,6 +920,359 @@ def 
test_transform_with_state_in_pandas_batch_query_initial_state(self):
             Row(id="1", value=str(146 + 346)),
         }
 
+    # This test covers mapState with TTL, an empty state variable
+    # and additional test against initial state python runner
+    def test_transform_with_map_state_metadata(self):
+        checkpoint_path = tempfile.mktemp()
+
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="0", countAsString="2"),
+                    Row(id="1", countAsString="2"),
+                }
+            else:
+                # check for state metadata source
+                metadata_df = 
self.spark.read.format("state-metadata").load(checkpoint_path)
+                assert set(
+                    metadata_df.select(
+                        "operatorId",
+                        "operatorName",
+                        "stateStoreName",
+                        "numPartitions",
+                        "minBatchId",
+                        "maxBatchId",
+                    ).collect()
+                ) == {
+                    Row(
+                        operatorId=0,
+                        operatorName="transformWithStateInPandasExec",
+                        stateStoreName="default",
+                        numPartitions=5,
+                        minBatchId=0,
+                        maxBatchId=0,
+                    )
+                }
+                operator_properties_json_obj = json.loads(
+                    metadata_df.select("operatorProperties").collect()[0][0]
+                )
+                assert operator_properties_json_obj["timeMode"] == 
"ProcessingTime"
+                assert operator_properties_json_obj["outputMode"] == "Update"
+
+                state_var_list = operator_properties_json_obj["stateVariables"]
+                assert len(state_var_list) == 3
+                for state_var in state_var_list:
+                    if state_var["stateName"] == "mapState":
+                        assert state_var["stateVariableType"] == "MapState"
+                        assert state_var["ttlEnabled"]
+                    elif state_var["stateName"] == "listState":
+                        assert state_var["stateVariableType"] == "ListState"
+                        assert not state_var["ttlEnabled"]
+                    else:
+                        assert state_var["stateName"] == 
"$procTimers_keyToTimestamp"
+                        assert state_var["stateVariableType"] == "TimerState"
+
+                # check for state data source
+                map_state_df = (
+                    self.spark.read.format("statestore")
+                    .option("path", checkpoint_path)
+                    .option("stateVarName", "mapState")
+                    .load()
+                )
+                assert set(
+                    map_state_df.selectExpr(
+                        "key.id AS groupingKey",
+                        "user_map_key.name AS mapKey",
+                        "user_map_value.value.count AS mapValue",
+                    )
+                    .sort("groupingKey")
+                    .collect()
+                ) == {
+                    Row(groupingKey="0", mapKey="key2", mapValue=2),
+                    Row(groupingKey="1", mapKey="key2", mapValue=2),
+                }
+
+                ttl_df = map_state_df.selectExpr(
+                    "user_map_value.ttlExpirationMs AS TTLVal"
+                ).collect()
+                # check if there are two rows containing TTL value in map 
state dataframe
+                assert len(ttl_df) == 2
+                # check if two rows are of the same TTL value
+                assert len(set(ttl_df)) == 1
+
+                list_state_df = (
+                    self.spark.read.format("statestore")
+                    .option("path", checkpoint_path)
+                    .option("stateVarName", "listState")
+                    .load()
+                )
+                assert list_state_df.isEmpty()
+
+                for q in self.spark.streams.active:
+                    q.stop()
+
+        self._test_transform_with_state_in_pandas_basic(
+            MapStateLargeTTLProcessor(),
+            check_results,
+            True,
+            "processingTime",
+            checkpoint_path=checkpoint_path,
+            initial_state=None,
+        )
+
+        # run the same test suite again but with no-op initial state
+        # TWS with initial state is using a different python runner
+        init_data = [("0", 789), ("3", 987)]
+        initial_state = self.spark.createDataFrame(init_data, "id string, 
temperature int").groupBy(
+            "id"
+        )
+        self._test_transform_with_state_in_pandas_basic(
+            MapStateLargeTTLProcessor(),
+            check_results,
+            True,
+            "processingTime",
+            checkpoint_path=checkpoint_path,
+            initial_state=initial_state,
+        )
+
+    # This test covers multiple list state variables and flatten option
+    def test_transform_with_list_state_metadata(self):
+        checkpoint_path = tempfile.mktemp()
+
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="0", countAsString="2"),
+                    Row(id="1", countAsString="2"),
+                }
+            else:
+                # check for state metadata source
+                metadata_df = 
self.spark.read.format("state-metadata").load(checkpoint_path)
+                operator_properties_json_obj = json.loads(
+                    metadata_df.select("operatorProperties").collect()[0][0]
+                )
+                state_var_list = operator_properties_json_obj["stateVariables"]
+                assert len(state_var_list) == 3
+                for state_var in state_var_list:
+                    if state_var["stateName"] in ["listState1", "listState2"]:
+                        state_var["stateVariableType"] == "ListState"
+                    else:
+                        assert state_var["stateName"] == 
"$procTimers_keyToTimestamp"
+                        assert state_var["stateVariableType"] == "TimerState"
+
+                # check for state data source and flatten option
+                list_state_1_df = (
+                    self.spark.read.format("statestore")
+                    .option("path", checkpoint_path)
+                    .option("stateVarName", "listState1")
+                    .option("flattenCollectionTypes", True)
+                    .load()
+                )
+                assert list_state_1_df.selectExpr(
+                    "key.id AS groupingKey",
+                    "list_element.temperature AS listElement",
+                ).sort("groupingKey", "listElement").collect() == [
+                    Row(groupingKey="0", listElement=20),
+                    Row(groupingKey="0", listElement=20),
+                    Row(groupingKey="0", listElement=111),
+                    Row(groupingKey="0", listElement=120),
+                    Row(groupingKey="0", listElement=120),
+                    Row(groupingKey="1", listElement=20),
+                    Row(groupingKey="1", listElement=20),
+                    Row(groupingKey="1", listElement=111),
+                    Row(groupingKey="1", listElement=120),
+                    Row(groupingKey="1", listElement=120),
+                ]
+
+                list_state_2_df = (
+                    self.spark.read.format("statestore")
+                    .option("path", checkpoint_path)
+                    .option("stateVarName", "listState2")
+                    .option("flattenCollectionTypes", False)
+                    .load()
+                )
+                assert list_state_2_df.selectExpr(
+                    "key.id AS groupingKey", "list_value.temperature AS 
valueList"
+                ).sort("groupingKey").withColumn(
+                    "valueSortedList", array_sort(col("valueList"))
+                ).select(
+                    "groupingKey", "valueSortedList"
+                ).collect() == [
+                    Row(groupingKey="0", valueSortedList=[20, 20, 120, 120, 
222]),
+                    Row(groupingKey="1", valueSortedList=[20, 20, 120, 120, 
222]),
+                ]
+
+                for q in self.spark.streams.active:
+                    q.stop()
+
+        self._test_transform_with_state_in_pandas_basic(
+            ListStateProcessor(),
+            check_results,
+            True,
+            "processingTime",
+            checkpoint_path=checkpoint_path,
+            initial_state=None,
+        )
+
+    # This test covers value state variable and read change feed,
+    # snapshotStartBatchId related options
+    def test_transform_with_value_state_metadata(self):
+        checkpoint_path = tempfile.mktemp()
+
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="0", countAsString="2"),
+                    Row(id="1", countAsString="2"),
+                }
+            else:
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="0", countAsString="3"),
+                    Row(id="1", countAsString="2"),
+                }
+
+                # check for state metadata source
+                metadata_df = 
self.spark.read.format("state-metadata").load(checkpoint_path)
+                operator_properties_json_obj = json.loads(
+                    metadata_df.select("operatorProperties").collect()[0][0]
+                )
+                state_var_list = operator_properties_json_obj["stateVariables"]
+
+                assert len(state_var_list) == 3
+                for state_var in state_var_list:
+                    if state_var["stateName"] in ["numViolations", 
"tempState"]:
+                        state_var["stateVariableType"] == "ValueState"
+                    else:
+                        assert state_var["stateName"] == 
"$procTimers_keyToTimestamp"
+                        assert state_var["stateVariableType"] == "TimerState"
+
+                # check for state data source and readChangeFeed
+                value_state_df = (
+                    self.spark.read.format("statestore")
+                    .option("path", checkpoint_path)
+                    .option("stateVarName", "numViolations")
+                    .option("readChangeFeed", True)
+                    .option("changeStartBatchId", 0)
+                    .load()
+                ).selectExpr(
+                    "change_type", "key.id AS groupingKey", "value.value AS 
value", "partition_id"
+                )
+
+                assert value_state_df.select("change_type", "groupingKey", 
"value").sort(
+                    "groupingKey"
+                ).collect() == [
+                    Row(change_type="update", groupingKey="0", value=1),
+                    Row(change_type="update", groupingKey="1", value=2),
+                ]
+
+                partition_id_list = [
+                    row["partition_id"] for row in 
value_state_df.select("partition_id").collect()
+                ]
+
+                for partition_id in partition_id_list:
+                    # check for state data source and snapshotStartBatchId 
options
+                    state_snapshot_df = (
+                        self.spark.read.format("statestore")
+                        .option("path", checkpoint_path)
+                        .option("stateVarName", "numViolations")
+                        .option("snapshotPartitionId", partition_id)
+                        .option("snapshotStartBatchId", 0)
+                        .load()
+                    )
+
+                    assert (
+                        value_state_df.select("partition_id", "groupingKey", 
"value")
+                        .filter(value_state_df["partition_id"] == partition_id)
+                        .sort("groupingKey")
+                        .collect()
+                        == state_snapshot_df.selectExpr(
+                            "partition_id", "key.id AS groupingKey", 
"value.value AS value"
+                        )
+                        .sort("groupingKey")
+                        .collect()
+                    )
+
+                for q in self.spark.streams.active:
+                    q.stop()
+
+        with self.sql_conf(
+            
{"spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled": 
"true"}
+        ):
+            self._test_transform_with_state_in_pandas_basic(
+                SimpleStatefulProcessor(),
+                check_results,
+                False,
+                "processingTime",
+                checkpoint_path=checkpoint_path,
+            )
+
+    """

Review Comment:
   nit: Is this intentional? Tests are commented out.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to