WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155430021


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -32,54 +32,65 @@
 
 from pyspark import cloudpickle
 from pyspark.sql import SparkSession
+from pyspark.taskcontext import BarrierTaskContext
 from pyspark.ml.torch.log_communication import (  # type: ignore
-    get_driver_host,
     LogStreamingClient,
     LogStreamingServer,
 )
-from pyspark.context import SparkContext
-from pyspark.taskcontext import BarrierTaskContext
 
 
-# TODO(SPARK-41589): will move the functions and tests to an external file
-#       once we are in agreement about which functions should be in utils.py
-def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
-    """Get the conf "key" from the given spark context,
+def _get_active_session() -> SparkSession:
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        spark = SparkSession.getActiveSession()
+    else:
+        from pyspark.sql.connect.session import _active_spark_session
+
+        spark = _active_spark_session  # type: ignore[assignment]
+
+    if spark is None:
+        raise RuntimeError("An active SparkSession is required for the 
distributor.")
+    return spark
+
+
+def _get_conf(spark: SparkSession, key: str, default_value: str) -> str:
+    """Get the conf "key" from the given spark session,
     or return the default value if the conf is not set.
-    This expects the conf value to be a boolean or string;
-    if the value is a string, this checks for all capitalization
-    patterns of "true" and "false" to match Scala.
+    If this session is a remote connect session, the SparkConf
+    code path always fail, and fallback to the RuntimeConf.
 
     Parameters
     ----------
-    sc : :class:`SparkContext`
-        The :class:`SparkContext` for the distributor.
+    spark : :class:`SparkSession`
+        The :class:`SparkSession` for the distributor.
     key : str
         string for conf name
     default_value : str
         default value for the conf value for the given key
 
     Returns
     -------
-    bool
-        Returns the boolean value that corresponds to the conf
-
-    Raises
-    ------
-    ValueError
-        Thrown when the conf value is not a valid boolean
+    str
+        Returns the string value that corresponds to the conf
     """
-    val = sc.getConf().get(key, default_value)
-    lowercase_val = val.lower()
-    if lowercase_val == "true":
-        return True
-    if lowercase_val == "false":
-        return False
-    raise ValueError(
-        f"The conf value for '{key}' was expected to be a boolean "
-        f"value but found value of type {type(val)} "
-        f"with value: {val}"
-    )
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        value = spark.sparkContext.getConf().get(key, default_value)

Review Comment:
   Q: Why not using `spark.conf.get` too ? So that we don't need 2 branches 
here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to