This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new a93b4a1f9d7 [SPARK-42944][SS][PYTHON][CONNECT][FOLLOWUP] Streaming ForeachBatch in Python followups a93b4a1f9d7 is described below commit a93b4a1f9d7023695e83d8369fe2a229185c4127 Author: Wei Liu <wei....@databricks.com> AuthorDate: Mon Jul 24 15:28:33 2023 -0700 [SPARK-42944][SS][PYTHON][CONNECT][FOLLOWUP] Streaming ForeachBatch in Python followups ### What changes were proposed in this pull request? Followup of https://github.com/apache/spark/pull/42035, address comments ### Why are the changes needed? Code quality, doc change ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests Closes #42096 from WweiL/spark-42944-followup. Authored-by: Wei Liu <wei....@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> (cherry picked from commit cfc279c6c0fdaf278f0e9f1444e6cf99e3848be7) Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../spark/sql/streaming/StreamingQuerySuite.scala | 2 +- .../planner/StreamingForeachBatchHelper.scala | 13 +++++----- .../spark/api/python/StreamingPythonRunner.scala | 8 +++---- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/connect/session.py | 2 +- python/pyspark/sql/streaming/readwriter.py | 9 +++++++ ...foreachBatch.py => test_parity_foreachBatch.py} | 2 +- python/pyspark/streaming_worker.py | 28 +++++++++++----------- 8 files changed, 38 insertions(+), 27 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 91d744b9e48..62770ae383e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -346,7 +346,7 @@ class StreamingQuerySuite extends QueryTest with SQLHelper with Logging { .collect() .toSeq assert(rows.size > 0) - log.info(s"Rows in $tableName: $rows") + logInfo(s"Rows in $tableName: $rows") } q.stop() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index 31481393777..9770ac4cee5 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -44,16 +44,16 @@ object StreamingForeachBatchHelper extends Logging { sessionHolder: SessionHolder): ForeachBatchFnType = { (df: DataFrame, batchId: Long) => { val dfId = UUID.randomUUID().toString - log.info(s"Caching DataFrame with id $dfId") // TODO: Add query id to the log. + logInfo(s"Caching DataFrame with id $dfId") // TODO: Add query id to the log. - // TODO: Sanity check there is no other active DataFrame for this query. The query id - // needs to be saved in the cache for this check. + // TODO(SPARK-44462): Sanity check there is no other active DataFrame for this query. + // The query id needs to be saved in the cache for this check. sessionHolder.cacheDataFrameById(dfId, df) try { fn(FnArgsWithId(dfId, df, batchId)) } finally { - log.info(s"Removing DataFrame with id $dfId from the cache") + logInfo(s"Removing DataFrame with id $dfId from the cache") sessionHolder.removeCachedDataFrame(dfId) } } @@ -69,7 +69,8 @@ object StreamingForeachBatchHelper extends Logging { def scalaForeachBatchWrapper( fn: ForeachBatchFnType, sessionHolder: SessionHolder): ForeachBatchFnType = { - // TODO: Set up Spark Connect session. Do we actually need this for the first version? + // TODO(SPARK-44462): Set up Spark Connect session. + // Do we actually need this for the first version? dataFrameCachingWrapper( (args: FnArgsWithId) => { fn(args.df, args.batchId) // dfId is not used, see hack comment above. @@ -104,7 +105,7 @@ object StreamingForeachBatchHelper extends Logging { dataOut.flush() val ret = dataIn.readInt() - log.info(s"Python foreach batch for dfId ${args.dfId} completed (ret: $ret)") + logInfo(s"Python foreach batch for dfId ${args.dfId} completed (ret: $ret)") } dataFrameCachingWrapper(foreachBatchRunnerFn, sessionHolder) diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala index 77dc88e0cfa..faf462a1990 100644 --- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -49,7 +49,7 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str * to be used with the functions. */ def init(sessionId: String): (DataOutputStream, DataInputStream) = { - log.info(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec") + logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec") val env = SparkEnv.get @@ -67,12 +67,12 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) - // TODO: verify python version + // TODO(SPARK-44461): verify python version // Send sessionId PythonRDD.writeUTF(sessionId, dataOut) - // send the user function to python process + // Send the user function to python process val command = func.command dataOut.writeInt(command.length) dataOut.write(command.toArray) @@ -81,7 +81,7 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) val resFromPython = dataIn.readInt() - log.info(s"Runner initialization returned $resFromPython") + logInfo(s"Runner initialization returned $resFromPython") (dataOut, dataIn) } diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 6382e9a5369..9d0ba219e79 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -863,6 +863,7 @@ pyspark_connect = Module( "pyspark.sql.tests.connect.client.test_client", "pyspark.sql.tests.connect.streaming.test_parity_streaming", "pyspark.sql.tests.connect.streaming.test_parity_foreach", + "pyspark.sql.tests.connect.streaming.test_parity_foreachBatch", "pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state", "pyspark.sql.tests.connect.test_parity_pandas_udf_scalar", "pyspark.sql.tests.connect.test_parity_pandas_udf_grouped_agg", diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 37a5bdd9f9f..a49e4cdd0f4 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -671,7 +671,7 @@ class SparkSession: copyFromLocalToFs.__doc__ = PySparkSession.copyFromLocalToFs.__doc__ - def _createRemoteDataFrame(self, remote_id: str) -> "DataFrame": + def _create_remote_dataframe(self, remote_id: str) -> "DataFrame": """ In internal API to reference a runtime DataFrame on the server side. This is used in ForeachBatch() runner, where the remote DataFrame refers to the diff --git a/python/pyspark/sql/streaming/readwriter.py b/python/pyspark/sql/streaming/readwriter.py index 08d7396ba86..2026651ce12 100644 --- a/python/pyspark/sql/streaming/readwriter.py +++ b/python/pyspark/sql/streaming/readwriter.py @@ -1404,20 +1404,29 @@ class DataStreamWriter: .. versionadded:: 2.4.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. + This function behaves differently in Spark Connect mode. See examples. + In Connect, the provided function doesn't have access to variables defined outside of it. Examples -------- >>> import time >>> df = spark.readStream.format("rate").load() + >>> my_value = -1 >>> def func(batch_df, batch_id): + ... global my_value + ... my_value = 100 ... batch_df.collect() ... >>> q = df.writeStream.foreachBatch(func).start() >>> time.sleep(3) >>> q.stop() + >>> # if in Spark Connect, my_value = -1, else my_value = 100 """ from pyspark.java_gateway import ensure_callback_server_started diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py b/python/pyspark/sql/tests/connect/streaming/test_parity_foreachBatch.py similarity index 94% rename from python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py rename to python/pyspark/sql/tests/connect/streaming/test_parity_foreachBatch.py index c4aa936a43e..01108c95391 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_foreachBatch.py @@ -33,7 +33,7 @@ class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedCo if __name__ == "__main__": import unittest - from pyspark.sql.tests.connect.streaming.test_parity_streaming_foreachBatch import * # noqa: F401,E501 + from pyspark.sql.tests.connect.streaming.test_parity_foreachBatch import * # noqa: F401,E501 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/streaming_worker.py b/python/pyspark/streaming_worker.py index 490bae44d99..a818880a984 100644 --- a/python/pyspark/streaming_worker.py +++ b/python/pyspark/streaming_worker.py @@ -30,38 +30,38 @@ from pyspark.serializers import ( from pyspark import worker from pyspark.sql import SparkSession -pickleSer = CPickleSerializer() +pickle_ser = CPickleSerializer() utf8_deserializer = UTF8Deserializer() def main(infile, outfile): # type: ignore[no-untyped-def] log_name = "Streaming ForeachBatch worker" connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"] - sessionId = utf8_deserializer.loads(infile) + session_id = utf8_deserializer.loads(infile) - print(f"{log_name} is starting with url {connect_url} and sessionId {sessionId}.") + print(f"{log_name} is starting with url {connect_url} and sessionId {session_id}.") - sparkConnectSession = SparkSession.builder.remote(connect_url).getOrCreate() - sparkConnectSession._client._session_id = sessionId + spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate() + spark_connect_session._client._session_id = session_id # TODO(SPARK-44460): Pass credentials. # TODO(SPARK-44461): Enable Process Isolation - func = worker.read_command(pickleSer, infile) + func = worker.read_command(pickle_ser, infile) write_int(0, outfile) # Indicate successful initialization outfile.flush() - def process(dfId, batchId): # type: ignore[no-untyped-def] - print(f"{log_name} Started batch {batchId} with DF id {dfId}") - batchDf = sparkConnectSession._createRemoteDataFrame(dfId) - func(batchDf, batchId) - print(f"{log_name} Completed batch {batchId} with DF id {dfId}") + def process(df_id, batch_id): # type: ignore[no-untyped-def] + print(f"{log_name} Started batch {batch_id} with DF id {df_id}") + batch_df = spark_connect_session._create_remote_dataframe(df_id) + func(batch_df, batch_id) + print(f"{log_name} Completed batch {batch_id} with DF id {df_id}") while True: - dfRefId = utf8_deserializer.loads(infile) - batchId = read_long(infile) - process(dfRefId, int(batchId)) # TODO(SPARK-44463): Propagate error to the user. + df_ref_id = utf8_deserializer.loads(infile) + batch_id = read_long(infile) + process(df_ref_id, int(batch_id)) # TODO(SPARK-44463): Propagate error to the user. write_int(0, outfile) outfile.flush() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org