This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 7c3887c1ed2 [SPARK-40874][PYTHON] Fix broadcasts in Python UDFs when 
encryption enabled
7c3887c1ed2 is described below

commit 7c3887c1ed2e23bd0010d3e79a847bad18818461
Author: Peter Toth <peter.t...@gmail.com>
AuthorDate: Sat Oct 22 10:39:32 2022 +0900

    [SPARK-40874][PYTHON] Fix broadcasts in Python UDFs when encryption enabled
    
    This PR fixes a bug in broadcast handling `PythonRunner` when encryption is 
enabed. Due to this bug the following pyspark script:
    ```
    bin/pyspark --conf spark.io.encryption.enabled=true
    
    ...
    
    bar = {"a": "aa", "b": "bb"}
    foo = spark.sparkContext.broadcast(bar)
    spark.udf.register("MYUDF", lambda x: foo.value[x] if x else "")
    spark.sql("SELECT MYUDF('a') AS a, MYUDF('b') AS b").collect()
    ```
    fails with:
    ```
    22/10/21 17:14:32 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 
0)/ 1]
    org.apache.spark.api.python.PythonException: Traceback (most recent call 
last):
      File 
"/Users/petertoth/git/apache/spark/python/lib/pyspark.zip/pyspark/worker.py", 
line 811, in main
        func, profiler, deserializer, serializer = read_command(pickleSer, 
infile)
      File 
"/Users/petertoth/git/apache/spark/python/lib/pyspark.zip/pyspark/worker.py", 
line 87, in read_command
        command = serializer._read_with_length(file)
      File 
"/Users/petertoth/git/apache/spark/python/lib/pyspark.zip/pyspark/serializers.py",
 line 173, in _read_with_length
        return self.loads(obj)
      File 
"/Users/petertoth/git/apache/spark/python/lib/pyspark.zip/pyspark/serializers.py",
 line 471, in loads
        return cloudpickle.loads(obj, encoding=encoding)
    EOFError: Ran out of input
    ```
    The reason for this failure is that we have multiple Python UDF referencing 
the same broadcast and in the current code:
    
https://github.com/apache/spark/blob/748fa2792e488a6b923b32e2898d9bb6e16fb4ca/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala#L385-L420
    the number of broadcasts (`cnt`) is correct (1) but the broadcast id is 
serialized 2 times from JVM to Python ruining the next item that Python expects 
from JVM side.
    
    Please note that the example above works in Spark 3.3 without this fix. 
That is because https://github.com/apache/spark/pull/36121 in Spark 3.4 
modified `ExpressionSet` and so `udfs` in `ExtractPythonUDFs`:
    
https://github.com/apache/spark/blob/748fa2792e488a6b923b32e2898d9bb6e16fb4ca/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala#L239-L242
    changed from `Stream` to `Vector`. When `broadcastVars` (and so 
`idsAndFiles`) is a `Stream` the example accidentaly works as the broadcast id 
is written to `dataOut` once (`oldBids.add(id)` in `idsAndFiles.foreach` is 
called before the 2nd item is calculated in `broadcastVars.flatMap`). But that 
doesn't mean that https://github.com/apache/spark/pull/36121 introduced the 
regression as `EncryptedPythonBroadcastServer` shouldn't serve the broadcast 
data 2 times (which `EncryptedPythonBr [...]
    
    To fix a bug.
    
    No.
    
    Added new UT.
    
    Closes #38334 from peter-toth/SPARK-40874-fix-broadcasts-in-python-udf.
    
    Authored-by: Peter Toth <peter.t...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit 8a96f69bb536729eaa59fae55160f8a6747efbe3)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../scala/org/apache/spark/api/python/PythonRunner.scala   |  2 +-
 python/pyspark/tests/test_broadcast.py                     | 14 ++++++++++++++
 2 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 8d9f2be6218..60689858628 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -360,6 +360,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
           // the decrypted data to python
           val idsAndFiles = broadcastVars.flatMap { broadcast =>
             if (!oldBids.contains(broadcast.id)) {
+              oldBids.add(broadcast.id)
               Some((broadcast.id, broadcast.value.path))
             } else {
               None
@@ -373,7 +374,6 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
           idsAndFiles.foreach { case (id, _) =>
             // send new broadcast
             dataOut.writeLong(id)
-            oldBids.add(id)
           }
           dataOut.flush()
           logTrace("waiting for python to read decrypted broadcast data from 
server")
diff --git a/python/pyspark/tests/test_broadcast.py 
b/python/pyspark/tests/test_broadcast.py
index c35c5a68e49..61e798c33c8 100644
--- a/python/pyspark/tests/test_broadcast.py
+++ b/python/pyspark/tests/test_broadcast.py
@@ -23,6 +23,7 @@ import unittest
 from pyspark import SparkConf, SparkContext
 from pyspark.java_gateway import launch_gateway
 from pyspark.serializers import ChunkedStream
+from pyspark.sql import SparkSession, Row
 
 
 class BroadcastTest(unittest.TestCase):
@@ -100,6 +101,19 @@ class BroadcastTest(unittest.TestCase):
         finally:
             b.destroy()
 
+    def test_broadcast_in_udfs_with_encryption(self):
+        conf = SparkConf()
+        conf.set("spark.io.encryption.enabled", "true")
+        conf.setMaster("local-cluster[2,1,1024]")
+        self.sc = SparkContext(conf=conf)
+        bar = {"a": "aa", "b": "bb"}
+        foo = self.sc.broadcast(bar)
+        spark = SparkSession(self.sc)
+        spark.udf.register("MYUDF", lambda x: foo.value[x] if x else "")
+        sel = spark.sql("SELECT MYUDF('a') AS a, MYUDF('b') AS b")
+        self.assertEqual(sel.collect(), [Row(a="aa", b="bb")])
+        spark.stop()
+
 
 class BroadcastFrameProtocolTest(unittest.TestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to