This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push:
new 476ce566c41 [SPARK-40874][PYTHON] Fix broadcasts in Python UDFs when
encryption enabled
476ce566c41 is described below
commit 476ce566c412437c0dde6b4006d3685548370784
Author: Peter Toth <[email protected]>
AuthorDate: Sat Oct 22 10:39:32 2022 +0900
[SPARK-40874][PYTHON] Fix broadcasts in Python UDFs when encryption enabled
This PR fixes a bug in broadcast handling `PythonRunner` when encryption is
enabed. Due to this bug the following pyspark script:
```
bin/pyspark --conf spark.io.encryption.enabled=true
...
bar = {"a": "aa", "b": "bb"}
foo = spark.sparkContext.broadcast(bar)
spark.udf.register("MYUDF", lambda x: foo.value[x] if x else "")
spark.sql("SELECT MYUDF('a') AS a, MYUDF('b') AS b").collect()
```
fails with:
```
22/10/21 17:14:32 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID
0)/ 1]
org.apache.spark.api.python.PythonException: Traceback (most recent call
last):
File
"/Users/petertoth/git/apache/spark/python/lib/pyspark.zip/pyspark/worker.py",
line 811, in main
func, profiler, deserializer, serializer = read_command(pickleSer,
infile)
File
"/Users/petertoth/git/apache/spark/python/lib/pyspark.zip/pyspark/worker.py",
line 87, in read_command
command = serializer._read_with_length(file)
File
"/Users/petertoth/git/apache/spark/python/lib/pyspark.zip/pyspark/serializers.py",
line 173, in _read_with_length
return self.loads(obj)
File
"/Users/petertoth/git/apache/spark/python/lib/pyspark.zip/pyspark/serializers.py",
line 471, in loads
return cloudpickle.loads(obj, encoding=encoding)
EOFError: Ran out of input
```
The reason for this failure is that we have multiple Python UDF referencing
the same broadcast and in the current code:
https://github.com/apache/spark/blob/748fa2792e488a6b923b32e2898d9bb6e16fb4ca/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala#L385-L420
the number of broadcasts (`cnt`) is correct (1) but the broadcast id is
serialized 2 times from JVM to Python ruining the next item that Python expects
from JVM side.
Please note that the example above works in Spark 3.3 without this fix.
That is because https://github.com/apache/spark/pull/36121 in Spark 3.4
modified `ExpressionSet` and so `udfs` in `ExtractPythonUDFs`:
https://github.com/apache/spark/blob/748fa2792e488a6b923b32e2898d9bb6e16fb4ca/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala#L239-L242
changed from `Stream` to `Vector`. When `broadcastVars` (and so
`idsAndFiles`) is a `Stream` the example accidentaly works as the broadcast id
is written to `dataOut` once (`oldBids.add(id)` in `idsAndFiles.foreach` is
called before the 2nd item is calculated in `broadcastVars.flatMap`). But that
doesn't mean that https://github.com/apache/spark/pull/36121 introduced the
regression as `EncryptedPythonBroadcastServer` shouldn't serve the broadcast
data 2 times (which `EncryptedPythonBr [...]
To fix a bug.
No.
Added new UT.
Closes #38334 from peter-toth/SPARK-40874-fix-broadcasts-in-python-udf.
Authored-by: Peter Toth <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../scala/org/apache/spark/api/python/PythonRunner.scala | 2 +-
python/pyspark/tests/test_broadcast.py | 14 ++++++++++++++
2 files changed, 15 insertions(+), 1 deletion(-)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 15707ab9157..f32c80f3ef5 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -401,6 +401,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
// the decrypted data to python
val idsAndFiles = broadcastVars.flatMap { broadcast =>
if (!oldBids.contains(broadcast.id)) {
+ oldBids.add(broadcast.id)
Some((broadcast.id, broadcast.value.path))
} else {
None
@@ -414,7 +415,6 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
idsAndFiles.foreach { case (id, _) =>
// send new broadcast
dataOut.writeLong(id)
- oldBids.add(id)
}
dataOut.flush()
logTrace("waiting for python to read decrypted broadcast data from
server")
diff --git a/python/pyspark/tests/test_broadcast.py
b/python/pyspark/tests/test_broadcast.py
index 56763e8d80a..6dce34c4ca5 100644
--- a/python/pyspark/tests/test_broadcast.py
+++ b/python/pyspark/tests/test_broadcast.py
@@ -23,6 +23,7 @@ import unittest
from pyspark import SparkConf, SparkContext
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import ChunkedStream
+from pyspark.sql import SparkSession, Row
class BroadcastTest(unittest.TestCase):
@@ -99,6 +100,19 @@ class BroadcastTest(unittest.TestCase):
finally:
b.destroy()
+ def test_broadcast_in_udfs_with_encryption(self):
+ conf = SparkConf()
+ conf.set("spark.io.encryption.enabled", "true")
+ conf.setMaster("local-cluster[2,1,1024]")
+ self.sc = SparkContext(conf=conf)
+ bar = {"a": "aa", "b": "bb"}
+ foo = self.sc.broadcast(bar)
+ spark = SparkSession(self.sc)
+ spark.udf.register("MYUDF", lambda x: foo.value[x] if x else "")
+ sel = spark.sql("SELECT MYUDF('a') AS a, MYUDF('b') AS b")
+ self.assertEqual(sel.collect(), [Row(a="aa", b="bb")])
+ spark.stop()
+
class BroadcastFrameProtocolTest(unittest.TestCase):
@classmethod
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]