This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new a93b4a1f9d7 [SPARK-42944][SS][PYTHON][CONNECT][FOLLOWUP] Streaming 
ForeachBatch in Python followups
a93b4a1f9d7 is described below

commit a93b4a1f9d7023695e83d8369fe2a229185c4127
Author: Wei Liu <wei....@databricks.com>
AuthorDate: Mon Jul 24 15:28:33 2023 -0700

    [SPARK-42944][SS][PYTHON][CONNECT][FOLLOWUP] Streaming ForeachBatch in 
Python followups
    
    ### What changes were proposed in this pull request?
    
    Followup of https://github.com/apache/spark/pull/42035, address comments
    
    ### Why are the changes needed?
    
    Code quality, doc change
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests
    
    Closes #42096 from WweiL/spark-42944-followup.
    
    Authored-by: Wei Liu <wei....@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    (cherry picked from commit cfc279c6c0fdaf278f0e9f1444e6cf99e3848be7)
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../spark/sql/streaming/StreamingQuerySuite.scala  |  2 +-
 .../planner/StreamingForeachBatchHelper.scala      | 13 +++++-----
 .../spark/api/python/StreamingPythonRunner.scala   |  8 +++----
 dev/sparktestsupport/modules.py                    |  1 +
 python/pyspark/sql/connect/session.py              |  2 +-
 python/pyspark/sql/streaming/readwriter.py         |  9 +++++++
 ...foreachBatch.py => test_parity_foreachBatch.py} |  2 +-
 python/pyspark/streaming_worker.py                 | 28 +++++++++++-----------
 8 files changed, 38 insertions(+), 27 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 91d744b9e48..62770ae383e 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -346,7 +346,7 @@ class StreamingQuerySuite extends QueryTest with SQLHelper 
with Logging {
           .collect()
           .toSeq
         assert(rows.size > 0)
-        log.info(s"Rows in $tableName: $rows")
+        logInfo(s"Rows in $tableName: $rows")
       }
 
       q.stop()
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index 31481393777..9770ac4cee5 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -44,16 +44,16 @@ object StreamingForeachBatchHelper extends Logging {
       sessionHolder: SessionHolder): ForeachBatchFnType = { (df: DataFrame, 
batchId: Long) =>
     {
       val dfId = UUID.randomUUID().toString
-      log.info(s"Caching DataFrame with id $dfId") // TODO: Add query id to 
the log.
+      logInfo(s"Caching DataFrame with id $dfId") // TODO: Add query id to the 
log.
 
-      // TODO: Sanity check there is no other active DataFrame for this query. 
The query id
-      //       needs to be saved in the cache for this check.
+      // TODO(SPARK-44462): Sanity check there is no other active DataFrame 
for this query.
+      //  The query id needs to be saved in the cache for this check.
 
       sessionHolder.cacheDataFrameById(dfId, df)
       try {
         fn(FnArgsWithId(dfId, df, batchId))
       } finally {
-        log.info(s"Removing DataFrame with id $dfId from the cache")
+        logInfo(s"Removing DataFrame with id $dfId from the cache")
         sessionHolder.removeCachedDataFrame(dfId)
       }
     }
@@ -69,7 +69,8 @@ object StreamingForeachBatchHelper extends Logging {
   def scalaForeachBatchWrapper(
       fn: ForeachBatchFnType,
       sessionHolder: SessionHolder): ForeachBatchFnType = {
-    // TODO: Set up Spark Connect session. Do we actually need this for the 
first version?
+    // TODO(SPARK-44462): Set up Spark Connect session.
+    // Do we actually need this for the first version?
     dataFrameCachingWrapper(
       (args: FnArgsWithId) => {
         fn(args.df, args.batchId) // dfId is not used, see hack comment above.
@@ -104,7 +105,7 @@ object StreamingForeachBatchHelper extends Logging {
       dataOut.flush()
 
       val ret = dataIn.readInt()
-      log.info(s"Python foreach batch for dfId ${args.dfId} completed (ret: 
$ret)")
+      logInfo(s"Python foreach batch for dfId ${args.dfId} completed (ret: 
$ret)")
     }
 
     dataFrameCachingWrapper(foreachBatchRunnerFn, sessionHolder)
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index 77dc88e0cfa..faf462a1990 100644
--- 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++ 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -49,7 +49,7 @@ private[spark] class StreamingPythonRunner(func: 
PythonFunction, connectUrl: Str
    * to be used with the functions.
    */
   def init(sessionId: String): (DataOutputStream, DataInputStream) = {
-    log.info(s"Initializing Python runner (session: $sessionId ,pythonExec: 
$pythonExec")
+    logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: 
$pythonExec")
 
     val env = SparkEnv.get
 
@@ -67,12 +67,12 @@ private[spark] class StreamingPythonRunner(func: 
PythonFunction, connectUrl: Str
     val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
     val dataOut = new DataOutputStream(stream)
 
-    // TODO: verify python version
+    // TODO(SPARK-44461): verify python version
 
     // Send sessionId
     PythonRDD.writeUTF(sessionId, dataOut)
 
-    // send the user function to python process
+    // Send the user function to python process
     val command = func.command
     dataOut.writeInt(command.length)
     dataOut.write(command.toArray)
@@ -81,7 +81,7 @@ private[spark] class StreamingPythonRunner(func: 
PythonFunction, connectUrl: Str
     val dataIn = new DataInputStream(new 
BufferedInputStream(worker.getInputStream, bufferSize))
 
     val resFromPython = dataIn.readInt()
-    log.info(s"Runner initialization returned $resFromPython")
+    logInfo(s"Runner initialization returned $resFromPython")
 
     (dataOut, dataIn)
   }
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 6382e9a5369..9d0ba219e79 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -863,6 +863,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.client.test_client",
         "pyspark.sql.tests.connect.streaming.test_parity_streaming",
         "pyspark.sql.tests.connect.streaming.test_parity_foreach",
+        "pyspark.sql.tests.connect.streaming.test_parity_foreachBatch",
         "pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state",
         "pyspark.sql.tests.connect.test_parity_pandas_udf_scalar",
         "pyspark.sql.tests.connect.test_parity_pandas_udf_grouped_agg",
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 37a5bdd9f9f..a49e4cdd0f4 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -671,7 +671,7 @@ class SparkSession:
 
     copyFromLocalToFs.__doc__ = PySparkSession.copyFromLocalToFs.__doc__
 
-    def _createRemoteDataFrame(self, remote_id: str) -> "DataFrame":
+    def _create_remote_dataframe(self, remote_id: str) -> "DataFrame":
         """
         In internal API to reference a runtime DataFrame on the server side.
         This is used in ForeachBatch() runner, where the remote DataFrame 
refers to the
diff --git a/python/pyspark/sql/streaming/readwriter.py 
b/python/pyspark/sql/streaming/readwriter.py
index 08d7396ba86..2026651ce12 100644
--- a/python/pyspark/sql/streaming/readwriter.py
+++ b/python/pyspark/sql/streaming/readwriter.py
@@ -1404,20 +1404,29 @@ class DataStreamWriter:
 
         .. versionadded:: 2.4.0
 
+        .. versionchanged:: 3.5.0
+            Supports Spark Connect.
+
         Notes
         -----
         This API is evolving.
+        This function behaves differently in Spark Connect mode. See examples.
+        In Connect, the provided function doesn't have access to variables 
defined outside of it.
 
         Examples
         --------
         >>> import time
         >>> df = spark.readStream.format("rate").load()
+        >>> my_value = -1
         >>> def func(batch_df, batch_id):
+        ...     global my_value
+        ...     my_value = 100
         ...     batch_df.collect()
         ...
         >>> q = df.writeStream.foreachBatch(func).start()
         >>> time.sleep(3)
         >>> q.stop()
+        >>> # if in Spark Connect, my_value = -1, else my_value = 100
         """
 
         from pyspark.java_gateway import ensure_callback_server_started
diff --git 
a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py
 b/python/pyspark/sql/tests/connect/streaming/test_parity_foreachBatch.py
similarity index 94%
rename from 
python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py
rename to python/pyspark/sql/tests/connect/streaming/test_parity_foreachBatch.py
index c4aa936a43e..01108c95391 100644
--- 
a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_foreachBatch.py
@@ -33,7 +33,7 @@ class 
StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedCo
 
 if __name__ == "__main__":
     import unittest
-    from 
pyspark.sql.tests.connect.streaming.test_parity_streaming_foreachBatch import * 
 # noqa: F401,E501
+    from pyspark.sql.tests.connect.streaming.test_parity_foreachBatch import * 
 # noqa: F401,E501
 
     try:
         import xmlrunner  # type: ignore[import]
diff --git a/python/pyspark/streaming_worker.py 
b/python/pyspark/streaming_worker.py
index 490bae44d99..a818880a984 100644
--- a/python/pyspark/streaming_worker.py
+++ b/python/pyspark/streaming_worker.py
@@ -30,38 +30,38 @@ from pyspark.serializers import (
 from pyspark import worker
 from pyspark.sql import SparkSession
 
-pickleSer = CPickleSerializer()
+pickle_ser = CPickleSerializer()
 utf8_deserializer = UTF8Deserializer()
 
 
 def main(infile, outfile):  # type: ignore[no-untyped-def]
     log_name = "Streaming ForeachBatch worker"
     connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
-    sessionId = utf8_deserializer.loads(infile)
+    session_id = utf8_deserializer.loads(infile)
 
-    print(f"{log_name} is starting with url {connect_url} and sessionId 
{sessionId}.")
+    print(f"{log_name} is starting with url {connect_url} and sessionId 
{session_id}.")
 
-    sparkConnectSession = 
SparkSession.builder.remote(connect_url).getOrCreate()
-    sparkConnectSession._client._session_id = sessionId
+    spark_connect_session = 
SparkSession.builder.remote(connect_url).getOrCreate()
+    spark_connect_session._client._session_id = session_id
 
     # TODO(SPARK-44460): Pass credentials.
     # TODO(SPARK-44461): Enable Process Isolation
 
-    func = worker.read_command(pickleSer, infile)
+    func = worker.read_command(pickle_ser, infile)
     write_int(0, outfile)  # Indicate successful initialization
 
     outfile.flush()
 
-    def process(dfId, batchId):  # type: ignore[no-untyped-def]
-        print(f"{log_name} Started batch {batchId} with DF id {dfId}")
-        batchDf = sparkConnectSession._createRemoteDataFrame(dfId)
-        func(batchDf, batchId)
-        print(f"{log_name} Completed batch {batchId} with DF id {dfId}")
+    def process(df_id, batch_id):  # type: ignore[no-untyped-def]
+        print(f"{log_name} Started batch {batch_id} with DF id {df_id}")
+        batch_df = spark_connect_session._create_remote_dataframe(df_id)
+        func(batch_df, batch_id)
+        print(f"{log_name} Completed batch {batch_id} with DF id {df_id}")
 
     while True:
-        dfRefId = utf8_deserializer.loads(infile)
-        batchId = read_long(infile)
-        process(dfRefId, int(batchId))  # TODO(SPARK-44463): Propagate error 
to the user.
+        df_ref_id = utf8_deserializer.loads(infile)
+        batch_id = read_long(infile)
+        process(df_ref_id, int(batch_id))  # TODO(SPARK-44463): Propagate 
error to the user.
         write_int(0, outfile)
         outfile.flush()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to