This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new b52086b55b2 [SPARK-44461][3.5][SS][PYTHON][CONNECT] Verify Python
Version for spark connect streaming workers
b52086b55b2 is described below
commit b52086b55b278818755f403ae55e80eda79bd250
Author: Wei Liu <[email protected]>
AuthorDate: Sat Aug 12 08:48:41 2023 +0900
[SPARK-44461][3.5][SS][PYTHON][CONNECT] Verify Python Version for spark
connect streaming workers
### What changes were proposed in this pull request?
Backport https://github.com/apache/spark/pull/42421 into 3.5. Also
partially backport https://github.com/apache/spark/pull/42135
Add python version check for spark connect streaming `foreachBatch_worker`
and `listener_worker`
### Why are the changes needed?
Necessary check
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
I believe it can be skipped here
Closes #42443 from WweiL/SPARK-44461-3.5-backport-1.
Authored-by: Wei Liu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../spark/api/python/PythonWorkerUtils.scala | 46 ++++++++++++++++++++++
.../spark/api/python/StreamingPythonRunner.scala | 2 +-
python/pyspark/sql/connect/streaming/query.py | 3 +-
python/pyspark/sql/connect/streaming/readwriter.py | 3 +-
.../streaming/worker/foreachBatch_worker.py | 3 ++
.../connect/streaming/worker/listener_worker.py | 3 ++
python/pyspark/worker_util.py | 46 ++++++++++++++++++++++
7 files changed, 103 insertions(+), 3 deletions(-)
diff --git
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
new file mode 100644
index 00000000000..d55465f135e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import java.io.DataOutputStream
+import java.nio.charset.StandardCharsets
+
+import org.apache.spark.internal.Logging
+
+private[spark] object PythonWorkerUtils extends Logging {
+
+ /**
+ * Write a string in UTF-8 charset.
+ *
+ * It will be read by `UTF8Deserializer.loads` in Python.
+ */
+ def writeUTF(str: String, dataOut: DataOutputStream): Unit = {
+ val bytes = str.getBytes(StandardCharsets.UTF_8)
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+
+ /**
+ * Write a Python version to check if the Python version is expected.
+ *
+ * It will be read and checked by `worker_util.check_python_version`.
+ */
+ def writePythonVersion(pythonVer: String, dataOut: DataOutputStream): Unit =
{
+ writeUTF(pythonVer, dataOut)
+ }
+}
diff --git
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index 1a75965eb92..cddda6fb7a7 100644
---
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -83,7 +83,7 @@ private[spark] class StreamingPythonRunner(
val stream = new BufferedOutputStream(pythonWorker.get.getOutputStream,
bufferSize)
val dataOut = new DataOutputStream(stream)
- // TODO(SPARK-44461): verify python version
+ PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
// Send sessionId
PythonRDD.writeUTF(sessionId, dataOut)
diff --git a/python/pyspark/sql/connect/streaming/query.py
b/python/pyspark/sql/connect/streaming/query.py
index 59e98e7bc30..021d27e939d 100644
--- a/python/pyspark/sql/connect/streaming/query.py
+++ b/python/pyspark/sql/connect/streaming/query.py
@@ -23,6 +23,7 @@ from pyspark.errors import StreamingQueryException,
PySparkValueError
import pyspark.sql.connect.proto as pb2
from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.connect import proto
+from pyspark.sql.connect.utils import get_python_ver
from pyspark.sql.streaming import StreamingQueryListener
from pyspark.sql.streaming.query import (
StreamingQuery as PySparkStreamingQuery,
@@ -237,7 +238,7 @@ class StreamingQueryManager:
cmd = pb2.StreamingQueryManagerCommand()
expr = proto.PythonUDF()
expr.command = CloudPickleSerializer().dumps(listener)
- expr.python_ver = "%d.%d" % sys.version_info[:2]
+ expr.python_ver = get_python_ver()
cmd.add_listener.python_listener_payload.CopyFrom(expr)
cmd.add_listener.id = listener._id
self._execute_streaming_query_manager_cmd(cmd)
diff --git a/python/pyspark/sql/connect/streaming/readwriter.py
b/python/pyspark/sql/connect/streaming/readwriter.py
index c8cd408404f..89097fcf43a 100644
--- a/python/pyspark/sql/connect/streaming/readwriter.py
+++ b/python/pyspark/sql/connect/streaming/readwriter.py
@@ -31,6 +31,7 @@ from pyspark.sql.streaming.readwriter import (
DataStreamReader as PySparkDataStreamReader,
DataStreamWriter as PySparkDataStreamWriter,
)
+from pyspark.sql.connect.utils import get_python_ver
from pyspark.sql.types import Row, StructType
from pyspark.errors import PySparkTypeError, PySparkValueError
@@ -499,7 +500,7 @@ class DataStreamWriter:
self._write_proto.foreach_batch.python_function.command =
CloudPickleSerializer().dumps(
func
)
- self._write_proto.foreach_batch.python_function.python_ver = "%d.%d" %
sys.version_info[:2]
+ self._write_proto.foreach_batch.python_function.python_ver =
get_python_ver()
return self
foreachBatch.__doc__ = PySparkDataStreamWriter.foreachBatch.__doc__
diff --git a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
index 48a9848de40..cf61463cd68 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py
@@ -31,12 +31,15 @@ from pyspark.serializers import (
from pyspark import worker
from pyspark.sql import SparkSession
from typing import IO
+from pyspark.worker_util import check_python_version
pickle_ser = CPickleSerializer()
utf8_deserializer = UTF8Deserializer()
def main(infile: IO, outfile: IO) -> None:
+ check_python_version(infile)
+
connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
session_id = utf8_deserializer.loads(infile)
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index 7aef911426d..e1f4678e42f 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -39,12 +39,15 @@ from pyspark.sql.streaming.listener import (
QueryTerminatedEvent,
QueryIdleEvent,
)
+from pyspark.worker_util import check_python_version
pickle_ser = CPickleSerializer()
utf8_deserializer = UTF8Deserializer()
def main(infile: IO, outfile: IO) -> None:
+ check_python_version(infile)
+
connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
session_id = utf8_deserializer.loads(infile)
diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py
new file mode 100644
index 00000000000..0c119a5efaf
--- /dev/null
+++ b/python/pyspark/worker_util.py
@@ -0,0 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Util functions for workers.
+"""
+import sys
+from typing import IO
+
+from pyspark.errors import PySparkRuntimeError
+from pyspark.serializers import (
+ UTF8Deserializer,
+ CPickleSerializer,
+)
+
+pickleSer = CPickleSerializer()
+utf8_deserializer = UTF8Deserializer()
+
+
+def check_python_version(infile: IO) -> None:
+ """
+ Check the Python version between the running process and the one used to
serialize the command.
+ """
+ version = utf8_deserializer.loads(infile)
+ if version != "%d.%d" % sys.version_info[:2]:
+ raise PySparkRuntimeError(
+ error_class="PYTHON_VERSION_MISMATCH",
+ message_parameters={
+ "worker_version": str(sys.version_info[:2]),
+ "driver_version": str(version),
+ },
+ )
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]