This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new a93b4a1f9d7 [SPARK-42944][SS][PYTHON][CONNECT][FOLLOWUP] Streaming
ForeachBatch in Python followups
a93b4a1f9d7 is described below
commit a93b4a1f9d7023695e83d8369fe2a229185c4127
Author: Wei Liu <[email protected]>
AuthorDate: Mon Jul 24 15:28:33 2023 -0700
[SPARK-42944][SS][PYTHON][CONNECT][FOLLOWUP] Streaming ForeachBatch in
Python followups
### What changes were proposed in this pull request?
Followup of https://github.com/apache/spark/pull/42035, address comments
### Why are the changes needed?
Code quality, doc change
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit tests
Closes #42096 from WweiL/spark-42944-followup.
Authored-by: Wei Liu <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
(cherry picked from commit cfc279c6c0fdaf278f0e9f1444e6cf99e3848be7)
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../spark/sql/streaming/StreamingQuerySuite.scala | 2 +-
.../planner/StreamingForeachBatchHelper.scala | 13 +++++-----
.../spark/api/python/StreamingPythonRunner.scala | 8 +++----
dev/sparktestsupport/modules.py | 1 +
python/pyspark/sql/connect/session.py | 2 +-
python/pyspark/sql/streaming/readwriter.py | 9 +++++++
...foreachBatch.py => test_parity_foreachBatch.py} | 2 +-
python/pyspark/streaming_worker.py | 28 +++++++++++-----------
8 files changed, 38 insertions(+), 27 deletions(-)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 91d744b9e48..62770ae383e 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -346,7 +346,7 @@ class StreamingQuerySuite extends QueryTest with SQLHelper
with Logging {
.collect()
.toSeq
assert(rows.size > 0)
- log.info(s"Rows in $tableName: $rows")
+ logInfo(s"Rows in $tableName: $rows")
}
q.stop()
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index 31481393777..9770ac4cee5 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -44,16 +44,16 @@ object StreamingForeachBatchHelper extends Logging {
sessionHolder: SessionHolder): ForeachBatchFnType = { (df: DataFrame,
batchId: Long) =>
{
val dfId = UUID.randomUUID().toString
- log.info(s"Caching DataFrame with id $dfId") // TODO: Add query id to
the log.
+ logInfo(s"Caching DataFrame with id $dfId") // TODO: Add query id to the
log.
- // TODO: Sanity check there is no other active DataFrame for this query.
The query id
- // needs to be saved in the cache for this check.
+ // TODO(SPARK-44462): Sanity check there is no other active DataFrame
for this query.
+ // The query id needs to be saved in the cache for this check.
sessionHolder.cacheDataFrameById(dfId, df)
try {
fn(FnArgsWithId(dfId, df, batchId))
} finally {
- log.info(s"Removing DataFrame with id $dfId from the cache")
+ logInfo(s"Removing DataFrame with id $dfId from the cache")
sessionHolder.removeCachedDataFrame(dfId)
}
}
@@ -69,7 +69,8 @@ object StreamingForeachBatchHelper extends Logging {
def scalaForeachBatchWrapper(
fn: ForeachBatchFnType,
sessionHolder: SessionHolder): ForeachBatchFnType = {
- // TODO: Set up Spark Connect session. Do we actually need this for the
first version?
+ // TODO(SPARK-44462): Set up Spark Connect session.
+ // Do we actually need this for the first version?
dataFrameCachingWrapper(
(args: FnArgsWithId) => {
fn(args.df, args.batchId) // dfId is not used, see hack comment above.
@@ -104,7 +105,7 @@ object StreamingForeachBatchHelper extends Logging {
dataOut.flush()
val ret = dataIn.readInt()
- log.info(s"Python foreach batch for dfId ${args.dfId} completed (ret:
$ret)")
+ logInfo(s"Python foreach batch for dfId ${args.dfId} completed (ret:
$ret)")
}
dataFrameCachingWrapper(foreachBatchRunnerFn, sessionHolder)
diff --git
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index 77dc88e0cfa..faf462a1990 100644
---
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -49,7 +49,7 @@ private[spark] class StreamingPythonRunner(func:
PythonFunction, connectUrl: Str
* to be used with the functions.
*/
def init(sessionId: String): (DataOutputStream, DataInputStream) = {
- log.info(s"Initializing Python runner (session: $sessionId ,pythonExec:
$pythonExec")
+ logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec:
$pythonExec")
val env = SparkEnv.get
@@ -67,12 +67,12 @@ private[spark] class StreamingPythonRunner(func:
PythonFunction, connectUrl: Str
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
- // TODO: verify python version
+ // TODO(SPARK-44461): verify python version
// Send sessionId
PythonRDD.writeUTF(sessionId, dataOut)
- // send the user function to python process
+ // Send the user function to python process
val command = func.command
dataOut.writeInt(command.length)
dataOut.write(command.toArray)
@@ -81,7 +81,7 @@ private[spark] class StreamingPythonRunner(func:
PythonFunction, connectUrl: Str
val dataIn = new DataInputStream(new
BufferedInputStream(worker.getInputStream, bufferSize))
val resFromPython = dataIn.readInt()
- log.info(s"Runner initialization returned $resFromPython")
+ logInfo(s"Runner initialization returned $resFromPython")
(dataOut, dataIn)
}
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 6382e9a5369..9d0ba219e79 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -863,6 +863,7 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.client.test_client",
"pyspark.sql.tests.connect.streaming.test_parity_streaming",
"pyspark.sql.tests.connect.streaming.test_parity_foreach",
+ "pyspark.sql.tests.connect.streaming.test_parity_foreachBatch",
"pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state",
"pyspark.sql.tests.connect.test_parity_pandas_udf_scalar",
"pyspark.sql.tests.connect.test_parity_pandas_udf_grouped_agg",
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 37a5bdd9f9f..a49e4cdd0f4 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -671,7 +671,7 @@ class SparkSession:
copyFromLocalToFs.__doc__ = PySparkSession.copyFromLocalToFs.__doc__
- def _createRemoteDataFrame(self, remote_id: str) -> "DataFrame":
+ def _create_remote_dataframe(self, remote_id: str) -> "DataFrame":
"""
In internal API to reference a runtime DataFrame on the server side.
This is used in ForeachBatch() runner, where the remote DataFrame
refers to the
diff --git a/python/pyspark/sql/streaming/readwriter.py
b/python/pyspark/sql/streaming/readwriter.py
index 08d7396ba86..2026651ce12 100644
--- a/python/pyspark/sql/streaming/readwriter.py
+++ b/python/pyspark/sql/streaming/readwriter.py
@@ -1404,20 +1404,29 @@ class DataStreamWriter:
.. versionadded:: 2.4.0
+ .. versionchanged:: 3.5.0
+ Supports Spark Connect.
+
Notes
-----
This API is evolving.
+ This function behaves differently in Spark Connect mode. See examples.
+ In Connect, the provided function doesn't have access to variables
defined outside of it.
Examples
--------
>>> import time
>>> df = spark.readStream.format("rate").load()
+ >>> my_value = -1
>>> def func(batch_df, batch_id):
+ ... global my_value
+ ... my_value = 100
... batch_df.collect()
...
>>> q = df.writeStream.foreachBatch(func).start()
>>> time.sleep(3)
>>> q.stop()
+ >>> # if in Spark Connect, my_value = -1, else my_value = 100
"""
from pyspark.java_gateway import ensure_callback_server_started
diff --git
a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py
b/python/pyspark/sql/tests/connect/streaming/test_parity_foreachBatch.py
similarity index 94%
rename from
python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py
rename to python/pyspark/sql/tests/connect/streaming/test_parity_foreachBatch.py
index c4aa936a43e..01108c95391 100644
---
a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_foreachBatch.py
@@ -33,7 +33,7 @@ class
StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedCo
if __name__ == "__main__":
import unittest
- from
pyspark.sql.tests.connect.streaming.test_parity_streaming_foreachBatch import *
# noqa: F401,E501
+ from pyspark.sql.tests.connect.streaming.test_parity_foreachBatch import *
# noqa: F401,E501
try:
import xmlrunner # type: ignore[import]
diff --git a/python/pyspark/streaming_worker.py
b/python/pyspark/streaming_worker.py
index 490bae44d99..a818880a984 100644
--- a/python/pyspark/streaming_worker.py
+++ b/python/pyspark/streaming_worker.py
@@ -30,38 +30,38 @@ from pyspark.serializers import (
from pyspark import worker
from pyspark.sql import SparkSession
-pickleSer = CPickleSerializer()
+pickle_ser = CPickleSerializer()
utf8_deserializer = UTF8Deserializer()
def main(infile, outfile): # type: ignore[no-untyped-def]
log_name = "Streaming ForeachBatch worker"
connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
- sessionId = utf8_deserializer.loads(infile)
+ session_id = utf8_deserializer.loads(infile)
- print(f"{log_name} is starting with url {connect_url} and sessionId
{sessionId}.")
+ print(f"{log_name} is starting with url {connect_url} and sessionId
{session_id}.")
- sparkConnectSession =
SparkSession.builder.remote(connect_url).getOrCreate()
- sparkConnectSession._client._session_id = sessionId
+ spark_connect_session =
SparkSession.builder.remote(connect_url).getOrCreate()
+ spark_connect_session._client._session_id = session_id
# TODO(SPARK-44460): Pass credentials.
# TODO(SPARK-44461): Enable Process Isolation
- func = worker.read_command(pickleSer, infile)
+ func = worker.read_command(pickle_ser, infile)
write_int(0, outfile) # Indicate successful initialization
outfile.flush()
- def process(dfId, batchId): # type: ignore[no-untyped-def]
- print(f"{log_name} Started batch {batchId} with DF id {dfId}")
- batchDf = sparkConnectSession._createRemoteDataFrame(dfId)
- func(batchDf, batchId)
- print(f"{log_name} Completed batch {batchId} with DF id {dfId}")
+ def process(df_id, batch_id): # type: ignore[no-untyped-def]
+ print(f"{log_name} Started batch {batch_id} with DF id {df_id}")
+ batch_df = spark_connect_session._create_remote_dataframe(df_id)
+ func(batch_df, batch_id)
+ print(f"{log_name} Completed batch {batch_id} with DF id {df_id}")
while True:
- dfRefId = utf8_deserializer.loads(infile)
- batchId = read_long(infile)
- process(dfRefId, int(batchId)) # TODO(SPARK-44463): Propagate error
to the user.
+ df_ref_id = utf8_deserializer.loads(infile)
+ batch_id = read_long(infile)
+ process(df_ref_id, int(batch_id)) # TODO(SPARK-44463): Propagate
error to the user.
write_int(0, outfile)
outfile.flush()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]