HyukjinKwon closed pull request #23337: [SPARK-26019][PYSPARK] Allow insecure 
py4j gateways
URL: https://github.com/apache/spark/pull/23337
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala
index 9ddc4a4910180..17c65f6170d67 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala
@@ -43,12 +43,17 @@ private[spark] object PythonGatewayServer extends Logging {
     // with the same secret, in case the app needs callbacks from the JVM to 
the underlying
     // python processes.
     val localhost = InetAddress.getLoopbackAddress()
-    val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
-      .authToken(secret)
+    val builder = new GatewayServer.GatewayServerBuilder()
       .javaPort(0)
       .javaAddress(localhost)
       .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
-      .build()
+    if (sys.env.getOrElse("_PYSPARK_CREATE_INSECURE_GATEWAY", "0") != "1") {
+      builder.authToken(secret)
+    } else {
+      assert(sys.env.getOrElse("SPARK_TESTING", "0") == "1",
+        "Creating insecure Java gateways only allowed for testing")
+    }
+    val gatewayServer: GatewayServer = builder.build()
 
     gatewayServer.start()
     val boundPort: Int = gatewayServer.getListeningPort
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 5ed5070558af7..81494b167af50 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -616,8 +616,10 @@ private[spark] class PythonAccumulatorV2(
     if (socket == null || socket.isClosed) {
       socket = new Socket(serverHost, serverPort)
       logInfo(s"Connected to AccumulatorServer at host: $serverHost port: 
$serverPort")
-      // send the secret just for the initial authentication when opening a 
new connection
-      
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
+      if (secretToken != null) {
+        // send the secret just for the initial authentication when opening a 
new connection
+        
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
+      }
     }
     socket
   }
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 00ec094e7e3b4..855d8fb4a859f 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -262,9 +262,10 @@ def authenticate_and_accum_updates():
                 raise Exception(
                     "The value of the provided token to the AccumulatorServer 
is not correct.")
 
-        # first we keep polling till we've received the authentication token
-        poll(authenticate_and_accum_updates)
-        # now we've authenticated, don't need to check for the token anymore
+        if auth_token is not None:
+            # first we keep polling till we've received the authentication 
token
+            poll(authenticate_and_accum_updates)
+        # now we've authenticated if needed, don't need to check for the token 
anymore
         poll(accum_updates)
 
 
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 0924d3d95f044..6d99e9823f001 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -112,6 +112,20 @@ def __init__(self, master=None, appName=None, 
sparkHome=None, pyFiles=None,
         ValueError:...
         """
         self._callsite = first_spark_call() or CallSite(None, None, None)
+        if gateway is not None and gateway.gateway_parameters.auth_token is 
None:
+            allow_insecure_env = 
os.environ.get("PYSPARK_ALLOW_INSECURE_GATEWAY", "0")
+            if allow_insecure_env == "1" or allow_insecure_env.lower() == 
"true":
+                warnings.warn(
+                    "You are passing in an insecure Py4j gateway.  This "
+                    "presents a security risk, and will be completely 
forbidden in Spark 3.0")
+            else:
+                raise ValueError(
+                    "You are trying to pass an insecure Py4j gateway to Spark. 
This"
+                    " presents a security risk.  If you are sure you 
understand and accept this"
+                    " risk, you can set the environment variable"
+                    " 'PYSPARK_ALLOW_INSECURE_GATEWAY=1', but"
+                    " note this option will be removed in Spark 3.0")
+
         SparkContext._ensure_initialized(self, gateway=gateway, conf=conf)
         try:
             self._do_init(master, appName, sparkHome, pyFiles, environment, 
batchSize, serializer,
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index c8c5f801f89bb..feb6b7bd6aa3d 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -41,8 +41,20 @@ def launch_gateway(conf=None):
     """
     launch jvm gateway
     :param conf: spark configuration passed to spark-submit
-    :return:
+    :return: a JVM gateway
     """
+    return _launch_gateway(conf)
+
+
+def _launch_gateway(conf=None, insecure=False):
+    """
+    launch jvm gateway
+    :param conf: spark configuration passed to spark-submit
+    :param insecure: True to create an insecure gateway; only for testing
+    :return: a JVM gateway
+    """
+    if insecure and os.environ.get("SPARK_TESTING", "0") != "1":
+        raise ValueError("creating insecure gateways is only for testing")
     if "PYSPARK_GATEWAY_PORT" in os.environ:
         gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
         gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
@@ -74,6 +86,8 @@ def launch_gateway(conf=None):
 
             env = dict(os.environ)
             env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file
+            if insecure:
+                env["_PYSPARK_CREATE_INSECURE_GATEWAY"] = "1"
 
             # Launch the Java gateway.
             # We open a pipe to stdin so that the Java gateway can die when 
the pipe is broken
@@ -116,9 +130,10 @@ def killChild():
             atexit.register(killChild)
 
     # Connect to the gateway
-    gateway = JavaGateway(
-        gateway_parameters=GatewayParameters(port=gateway_port, 
auth_token=gateway_secret,
-                                             auto_convert=True))
+    gateway_params = GatewayParameters(port=gateway_port, auto_convert=True)
+    if not insecure:
+        gateway_params.auth_token = gateway_secret
+    gateway = JavaGateway(gateway_parameters=gateway_params)
 
     # Import the classes used by PySpark
     java_import(gateway.jvm, "org.apache.spark.SparkConf")
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 131c51e108cad..a2d825ba36256 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -61,6 +61,7 @@
 from pyspark import keyword_only
 from pyspark.conf import SparkConf
 from pyspark.context import SparkContext
+from pyspark.java_gateway import _launch_gateway
 from pyspark.rdd import RDD
 from pyspark.files import SparkFiles
 from pyspark.serializers import read_int, BatchedSerializer, 
MarshalSerializer, PickleSerializer, \
@@ -2381,6 +2382,37 @@ def test_startTime(self):
         with SparkContext() as sc:
             self.assertGreater(sc.startTime, 0)
 
+    def test_forbid_insecure_gateway(self):
+        # By default, we fail immediately if you try to create a SparkContext
+        # with an insecure gateway
+        gateway = _launch_gateway(insecure=True)
+        log4j = gateway.jvm.org.apache.log4j
+        old_level = log4j.LogManager.getRootLogger().getLevel()
+        try:
+            log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL)
+            with self.assertRaises(Exception) as context:
+                SparkContext(gateway=gateway)
+            self.assertIn("insecure Py4j gateway", str(context.exception))
+            self.assertIn("PYSPARK_ALLOW_INSECURE_GATEWAY", 
str(context.exception))
+            self.assertIn("removed in Spark 3.0", str(context.exception))
+        finally:
+            log4j.LogManager.getRootLogger().setLevel(old_level)
+
+    def test_allow_insecure_gateway_with_conf(self):
+        with SparkContext._lock:
+            SparkContext._gateway = None
+            SparkContext._jvm = None
+        gateway = _launch_gateway(insecure=True)
+        try:
+            os.environ["PYSPARK_ALLOW_INSECURE_GATEWAY"] = "1"
+            with SparkContext(gateway=gateway) as sc:
+                a = sc.accumulator(1)
+                rdd = sc.parallelize([1, 2, 3])
+                rdd.foreach(lambda x: a.add(x))
+                self.assertEqual(7, a.value)
+        finally:
+            os.environ.pop("PYSPARK_ALLOW_INSECURE_GATEWAY", None)
+
 
 class ConfTests(unittest.TestCase):
     def test_memory_conf(self):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to