This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 511839b6eac9 [SPARK-47137][PYTHON][CONNECT] Add getAll to spark.conf 
for feature parity with Scala
511839b6eac9 is described below

commit 511839b6eac974351410a1713f5a90329e49abe9
Author: Takuya UESHIN <[email protected]>
AuthorDate: Thu Feb 22 20:22:43 2024 -0800

    [SPARK-47137][PYTHON][CONNECT] Add getAll to spark.conf for feature parity 
with Scala
    
    ### What changes were proposed in this pull request?
    
    Adds `getAll` to `spark.conf` for feature parity with Scala.
    
    ```py
    >>> spark.conf.getAll
    {'spark.sql.warehouse.dir': ...}
    ```
    
    ### Why are the changes needed?
    
    Scala API provides `spark.conf.getAll`; whereas Python doesn't.
    
    ```scala
    scala> spark.conf.getAll
    val res0: Map[String,String] = HashMap(spark.sql.warehouse.dir -> ...
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, `spark.conf.getAll` will be available in PySpark.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #45222 from ueshin/issues/SPARK-47137/getAll.
    
    Authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/sql/conf.py                         | 16 +++++-
 python/pyspark/sql/connect/conf.py                 | 15 +++++-
 python/pyspark/sql/tests/test_conf.py              | 63 ++++++++++++++--------
 .../scala/org/apache/spark/sql/RuntimeConfig.scala |  6 +++
 4 files changed, 75 insertions(+), 25 deletions(-)

diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py
index e77039565dd1..dd43991b0706 100644
--- a/python/pyspark/sql/conf.py
+++ b/python/pyspark/sql/conf.py
@@ -16,7 +16,7 @@
 #
 
 import sys
-from typing import Any, Optional, Union
+from typing import Any, Dict, Optional, Union
 
 from py4j.java_gateway import JavaObject
 
@@ -93,6 +93,20 @@ class RuntimeConfig:
                 self._check_type(default, "default")
             return self._jconf.get(key, default)
 
+    @property
+    def getAll(self) -> Dict[str, str]:
+        """
+        Returns all properties set in this conf.
+
+        .. versionadded:: 4.0.0
+
+        Returns
+        -------
+        dict
+            A dictionary containing all properties set in this conf.
+        """
+        return dict(self._jconf.getAllAsJava())
+
     def unset(self, key: str) -> None:
         """
         Resets the configuration property for the given key.
diff --git a/python/pyspark/sql/connect/conf.py 
b/python/pyspark/sql/connect/conf.py
index 3548a31fef03..57a669aca889 100644
--- a/python/pyspark/sql/connect/conf.py
+++ b/python/pyspark/sql/connect/conf.py
@@ -19,7 +19,7 @@ from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
 
-from typing import Any, Optional, Union, cast
+from typing import Any, Dict, Optional, Union, cast
 import warnings
 
 from pyspark import _NoValue
@@ -68,6 +68,19 @@ class RuntimeConf:
 
     get.__doc__ = PySparkRuntimeConfig.get.__doc__
 
+    @property
+    def getAll(self) -> Dict[str, str]:
+        op_get_all = proto.ConfigRequest.GetAll()
+        operation = proto.ConfigRequest.Operation(get_all=op_get_all)
+        result = self._client.config(operation)
+        confs: Dict[str, str] = dict()
+        for key, value in result.pairs:
+            assert value is not None
+            confs[key] = value
+        return confs
+
+    getAll.__doc__ = PySparkRuntimeConfig.getAll.__doc__
+
     def unset(self, key: str) -> None:
         op_unset = proto.ConfigRequest.Unset(keys=[key])
         operation = proto.ConfigRequest.Operation(unset=op_unset)
diff --git a/python/pyspark/sql/tests/test_conf.py 
b/python/pyspark/sql/tests/test_conf.py
index 9b939205b1d1..68b147f09746 100644
--- a/python/pyspark/sql/tests/test_conf.py
+++ b/python/pyspark/sql/tests/test_conf.py
@@ -50,32 +50,49 @@ class ConfTestsMixin:
     def test_conf_with_python_objects(self):
         spark = self.spark
 
-        for value, expected in [(True, "true"), (False, "false")]:
-            spark.conf.set("foo", value)
-            self.assertEqual(spark.conf.get("foo"), expected)
-
-        spark.conf.set("foo", 1)
-        self.assertEqual(spark.conf.get("foo"), "1")
-
-        with self.assertRaises(IllegalArgumentException):
-            spark.conf.set("foo", None)
-
-        with self.assertRaises(Exception):
-            spark.conf.set("foo", Decimal(1))
+        try:
+            for value, expected in [(True, "true"), (False, "false")]:
+                spark.conf.set("foo", value)
+                self.assertEqual(spark.conf.get("foo"), expected)
+
+            spark.conf.set("foo", 1)
+            self.assertEqual(spark.conf.get("foo"), "1")
+
+            with self.assertRaises(IllegalArgumentException):
+                spark.conf.set("foo", None)
+
+            with self.assertRaises(Exception):
+                spark.conf.set("foo", Decimal(1))
+
+            with self.assertRaises(PySparkTypeError) as pe:
+                spark.conf.get(123)
+
+            self.check_error(
+                exception=pe.exception,
+                error_class="NOT_STR",
+                message_parameters={
+                    "arg_name": "key",
+                    "arg_type": "int",
+                },
+            )
+        finally:
+            spark.conf.unset("foo")
+
+    def test_get_all(self):
+        spark = self.spark
+        all_confs = spark.conf.getAll
 
-        with self.assertRaises(PySparkTypeError) as pe:
-            spark.conf.get(123)
+        self.assertTrue(len(all_confs) > 0)
+        self.assertNotIn("foo", all_confs)
 
-        self.check_error(
-            exception=pe.exception,
-            error_class="NOT_STR",
-            message_parameters={
-                "arg_name": "key",
-                "arg_type": "int",
-            },
-        )
+        try:
+            spark.conf.set("foo", "bar")
+            updated = spark.conf.getAll
 
-        spark.conf.unset("foo")
+            self.assertEquals(len(updated), len(all_confs) + 1)
+            self.assertIn("foo", updated)
+        finally:
+            spark.conf.unset("foo")
 
 
 class ConfTests(ConfTestsMixin, ReusedSQLTestCase):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
index 936d40f5c387..ed8cf4f121f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import scala.jdk.CollectionConverters._
+
 import org.apache.spark.SPARK_DOC_ROOT
 import org.apache.spark.annotation.Stable
 import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry}
@@ -118,6 +120,10 @@ class RuntimeConfig private[sql](val sqlConf: SQLConf = 
new SQLConf) {
     sqlConf.getAllConfs
   }
 
+  private[sql] def getAllAsJava: java.util.Map[String, String] = {
+    getAll.asJava
+  }
+
   /**
    * Returns the value of Spark runtime configuration property for the given 
key.
    *


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to