rithwik-db commented on code in PR #40045:
URL: https://github.com/apache/spark/pull/40045#discussion_r1107891507


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -187,11 +187,9 @@ def _get_num_tasks(self) -> int:
                 return math.ceil(self.num_processes / task_gpu_amount)
             else:
                 key = "spark.driver.resource.gpu.amount"
-                if "gpu" not in self.sc.resources:

Review Comment:
   @HyukjinKwon, this is the stacktrace:
   ```
   File /databricks/spark/python/pyspark/ml/torch/distributor.py:344, in 
TorchDistributor.__init__(self, num_processes, local_mode, use_gpu)
       314 def __init__(
       315     self,
       316     num_processes: int = 1,
       317     local_mode: bool = True,
       318     use_gpu: bool = True,
       319 ):
       320     """Initializes the distributor.
       321 
       322     Parameters
      (...)
       342         If an active SparkSession is unavailable.
       343     """
   --> 344     super().__init__(num_processes, local_mode, use_gpu)
       345     self.ssl_conf = "pytorch.spark.distributor.ignoreSsl"  # type: 
ignore
       346     self._validate_input_params()
   File /databricks/spark/python/pyspark/ml/torch/distributor.py:156, in 
Distributor.__init__(self, num_processes, local_mode, use_gpu)
       154     raise RuntimeError("An active SparkSession is required for the 
distributor.")
       155 self.sc = self.spark.sparkContext
   --> 156 self.num_tasks = self._get_num_tasks()
       157 self.ssl_conf = None
   File /databricks/spark/python/pyspark/ml/torch/distributor.py:190, in 
Distributor._get_num_tasks(self)
       188 else:
       189     key = "spark.driver.resource.gpu.amount"
   --> 190     if "gpu" not in self.sc.resources:
       191         raise RuntimeError("GPUs were unable to be found on the 
driver.")
       192     num_available_gpus = int(self.sc.getConf().get(key, "0"))
   File /databricks/spark/python/pyspark/context.py:2545, in 
SparkContext.resources(self)
      2543 resources = {}
      2544 jresources = self._jsc.resources()
   -> 2545 for x in jresources:
      2546     name = jresources[x].name()
      2547     jaddresses = jresources[x].addresses()
   File 
/databricks/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_collections.py:62, 
in JavaIterator.next(self)
        58     self._methods[self._next_name] = JavaMember(
        59         self._next_name, self,
        60         self._target_id, self._gateway_client)
        61 try:
   ---> 62     return self._methods[self._next_name]()
        63 except Py4JError:
        64     raise StopIteration()
   File 
/databricks/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py:1322, 
in JavaMember.__call__(self, *args)
      1316 command = proto.CALL_COMMAND_NAME +\
      1317     self.command_header +\
      1318     args_command +\
      1319     proto.END_COMMAND_PART
      1321 answer = self.gateway_client.send_command(command)
   -> 1322 return_value = get_return_value(
      1323     answer, self.gateway_client, self.target_id, self.name)
      1325 for temp_arg in temp_args:
      1326     if hasattr(temp_arg, "_detach"):
   File /databricks/spark/python/pyspark/errors/exceptions.py:230, in 
capture_sql_exception.<locals>.deco(*a, **kw)
       228     return f(*a, **kw)
       229 except Py4JJavaError as e:
   --> 230     converted = convert_exception(e.java_exception)
       231     if not isinstance(converted, UnknownException):
       232         # Hide where the exception came from that shows a 
non-Pythonic
       233         # JVM exception message.
       234         raise converted from None
   File /databricks/spark/python/pyspark/errors/exceptions.py:198, in 
convert_exception(e)
       194     return SparkUpgradeException(origin=e)
       196 # BEGIN-EDGE
       197 # For Delta exception improvement.
   --> 198 from delta.exceptions import _convert_delta_exception
       200 delta_exception = _convert_delta_exception(e)
       201 if delta_exception is not None:
   File 
/databricks/python_shell/dbruntime/PythonPackageImportsInstrumentation/__init__.py:171,
 in _create_import_patch.<locals>.import_patch(name, globals, locals, fromlist, 
level)
       166 thread_local._nest_level += 1
       168 try:
       169     # Import the desired module. If you’re seeing this while 
debugging a failed import,
       170     # look at preceding stack frames for relevant error information.
   --> 171     original_result = python_builtin_import(name, globals, locals, 
fromlist, level)
       173     is_root_import = thread_local._nest_level == 1
       174     # `level` represents the number of leading dots in a relative 
import statement.
       175     # If it's zero, then this is an absolute import.
   File /databricks/spark/python/delta/__init__.py:17
         1 #
         2 # Copyright (2021) The Delta Lake Project Authors.
         3 #
      (...)
        14 # limitations under the License.
        15 #
   ---> 17 from delta.tables import DeltaTable
        18 from delta.pip_utils import configure_spark_with_delta_pip
        20 __all__ = ['DeltaTable', 'configure_spark_with_delta_pip']
   File 
/databricks/python_shell/dbruntime/PythonPackageImportsInstrumentation/__init__.py:171,
 in _create_import_patch.<locals>.import_patch(name, globals, locals, fromlist, 
level)
       166 thread_local._nest_level += 1
       168 try:
       169     # Import the desired module. If you’re seeing this while 
debugging a failed import,
       170     # look at preceding stack frames for relevant error information.
   --> 171     original_result = python_builtin_import(name, globals, locals, 
fromlist, level)
       173     is_root_import = thread_local._nest_level == 1
       174     # `level` represents the number of leading dots in a relative 
import statement.
       175     # If it's zero, then this is an absolute import.
   File /databricks/spark/python/delta/tables.py:30
        28 from pyspark.sql.column import _to_seq  # type: ignore[attr-defined]
        29 from pyspark.sql.types import DataType, StructType, StructField
   ---> 30 from pyspark.sql.connect.session import SparkSession as 
RemoteSparkSession # Edge
        33 if TYPE_CHECKING:
        34     from py4j.java_gateway import JavaObject, JVMView  # type: 
ignore[import]
   File 
/databricks/python_shell/dbruntime/PythonPackageImportsInstrumentation/__init__.py:171,
 in _create_import_patch.<locals>.import_patch(name, globals, locals, fromlist, 
level)
       166 thread_local._nest_level += 1
       168 try:
       169     # Import the desired module. If you’re seeing this while 
debugging a failed import,
       170     # look at preceding stack frames for relevant error information.
   --> 171     original_result = python_builtin_import(name, globals, locals, 
fromlist, level)
       173     is_root_import = thread_local._nest_level == 1
       174     # `level` represents the number of leading dots in a relative 
import statement.
       175     # If it's zero, then this is an absolute import.
   File /databricks/spark/python/pyspark/sql/connect/session.py:19
         1 #
         2 # Licensed to the Apache Software Foundation (ASF) under one or more
         3 # contributor license agreements.  See the NOTICE file distributed 
with
      (...)
        15 # limitations under the License.
        16 #
        17 from pyspark.sql.connect import check_dependencies
   ---> 19 check_dependencies(__name__, __file__)
        21 import os
        22 import warnings
   File /databricks/spark/python/pyspark/sql/connect/__init__.py:42, in 
check_dependencies(mod_name, file_name)
        40 require_minimum_pandas_version()
        41 require_minimum_pyarrow_version()
   ---> 42 require_minimum_grpc_version()
   File /databricks/spark/python/pyspark/sql/pandas/utils.py:89, in 
require_minimum_grpc_version()
        85     raise ImportError(
        86         "grpc >= %s must be installed; however, " "it was not 
found." % minimum_grpc_version
        87     ) from error
        88 if LooseVersion(grpc.__version__) < 
LooseVersion(minimum_grpc_version):
   ---> 89     raise ImportError(
        90         "gRPC >= %s must be installed; however, "
        91         "your version was %s." % (minimum_grpc_version, 
grpc.__version__)
        92     )
   ImportError: gRPC >= 1.48.1 must be installed; however, your version was 
1.42.0.)))
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to