This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 10b448402b9 [SPARK-41533][CONNECT] Proper Error Handling for Spark 
Connect Server / Client
10b448402b9 is described below

commit 10b448402b9e142db4a8d8c7989478a0c5d04315
Author: Martin Grund <martin.gr...@databricks.com>
AuthorDate: Thu Dec 29 09:04:21 2022 +0900

    [SPARK-41533][CONNECT] Proper Error Handling for Spark Connect Server / 
Client
    
    ### What changes were proposed in this pull request?
    This PR improves the error handling on the Spark Connect server and client 
side. First, this patch moves the error handling logic on the server into a 
common error handler partial function that differentiates between the internal 
Spark errors and other runtime errors.
    
    For custom Spark exceptions, the actual internal error is wrapped into a 
Google RPC Status and sent as trailing metadata to the client.
    
    On the client side, similarly, the error handling is moved into a common 
function. All GRPC errors are wrapped into custom exceptions to avoid 
presenting the user with confusing GRPC errors. If available the attached RPC 
status is extracted and added to the exception.
    
    Lastly, this patch adds basic logging functionality that can be enabled 
using the environment variable `SPARK_CONNECT_LOG_LEVEL` and can be set to 
`info`, `warn`, `error`, and `debug`.
    
    ### Why are the changes needed?
    Usability
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    Closes #39212 from grundprinzip/SPARK-41533.
    
    Authored-by: Martin Grund <martin.gr...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .github/workflows/build_and_test.yml               |   2 +-
 connector/connect/README.md                        |   2 +-
 .../sql/connect/service/SparkConnectService.scala  |  87 ++++++--
 .../service/SparkConnectStreamHandler.scala        |   9 +-
 .../connect/planner/SparkConnectServiceSuite.scala |   4 +-
 dev/create-release/spark-rm/Dockerfile             |   2 +-
 dev/infra/Dockerfile                               |   2 +-
 dev/lint-python                                    |   2 +
 dev/requirements.txt                               |   4 +
 python/docs/source/getting_started/install.rst     |  20 +-
 python/mypy.ini                                    |   1 +
 python/pyspark/sql/connect/client.py               | 234 ++++++++++++++++++---
 python/pyspark/sql/connect/dataframe.py            |   2 +-
 .../sql/tests/connect/test_connect_basic.py        |  29 ++-
 .../sql/tests/connect/test_connect_function.py     |  20 +-
 python/pyspark/testing/connectutils.py             |  25 ++-
 python/setup.py                                    |   5 +-
 17 files changed, 364 insertions(+), 86 deletions(-)

diff --git a/.github/workflows/build_and_test.yml 
b/.github/workflows/build_and_test.yml
index 5bd2fef9b0c..443fbf47942 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -589,7 +589,7 @@ jobs:
         #   See also https://issues.apache.org/jira/browse/SPARK-38279.
         python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme 
ipython nbsphinx numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' 'pyzmq<24.0.0'
         python3.9 -m pip install ipython_genutils # See SPARK-38517
-        python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' 
pyarrow pandas 'plotly>=4.8' 'grpcio==1.48.1' 'protobuf==3.19.5' 
'mypy-protobuf==3.3.0'
+        python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' 
pyarrow pandas 'plotly>=4.8' 'grpcio==1.48.1' 'protobuf==3.19.5' 
'mypy-protobuf==3.3.0' 'grpc-stubs==1.24.11' 
'googleapis-common-protos-stubs==2.2.0'
         python3.9 -m pip install 'docutils<0.18.0' # See SPARK-39421
         apt-get update -y
         apt-get install -y ruby ruby-dev
diff --git a/connector/connect/README.md b/connector/connect/README.md
index d30f65ffb5d..d5cc767c744 100644
--- a/connector/connect/README.md
+++ b/connector/connect/README.md
@@ -90,7 +90,7 @@ To use the release version of Spark Connect:
 
 ### Generate proto generated files for the Python client
 1. Install `buf version 1.11.0`: https://docs.buf.build/installation
-2. Run `pip install grpcio==1.48.1 protobuf==3.19.5 mypy-protobuf==3.3.0`
+2. Run `pip install grpcio==1.48.1 protobuf==3.19.5 mypy-protobuf==3.3.0 
googleapis-common-protos==1.56.4 grpcio-status==1.48.1`
 3. Run `./connector/connect/dev/generate_protos.sh`
 4. Optional Check `./dev/check-codegen-python.py`
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index bfcea3d2252..61f035630f7 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -20,19 +20,23 @@ package org.apache.spark.sql.connect.service
 import java.util.concurrent.TimeUnit
 
 import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
 
 import com.google.common.base.Ticker
 import com.google.common.cache.CacheBuilder
+import com.google.protobuf.{Any => ProtoAny}
+import com.google.rpc.{Code => RPCCode, ErrorInfo, Status => RPCStatus}
 import io.grpc.{Server, Status}
 import io.grpc.netty.NettyServerBuilder
+import io.grpc.protobuf.StatusProto
 import io.grpc.protobuf.services.ProtoReflectionService
 import io.grpc.stub.StreamObserver
 
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkEnv, SparkThrowable}
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.{AnalyzePlanRequest, 
AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, 
SparkConnectServiceGrpc}
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{Dataset, SparkSession}
+import org.apache.spark.sql.{AnalysisException, Dataset, SparkSession}
 import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT
 import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, 
SparkConnectPlanner}
 import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExplainMode, 
ExtendedMode, FormattedMode, SimpleMode}
@@ -49,6 +53,71 @@ class SparkConnectService(debug: Boolean)
     extends SparkConnectServiceGrpc.SparkConnectServiceImplBase
     with Logging {
 
+  private def buildStatusFromThrowable[A <: Throwable with SparkThrowable](st: 
A): RPCStatus = {
+    val t = Option(st.getCause).getOrElse(st)
+    RPCStatus
+      .newBuilder()
+      .setCode(RPCCode.INTERNAL_VALUE)
+      .addDetails(
+        ProtoAny.pack(
+          ErrorInfo
+            .newBuilder()
+            .setReason(t.getClass.getName)
+            .setDomain("org.apache.spark")
+            .build()))
+      .setMessage(t.getLocalizedMessage)
+      .build()
+  }
+
+  /**
+   * Common exception handling function for the Analysis and Execution 
methods. Closes the stream
+   * after the error has been sent.
+   *
+   * @param opType
+   *   String value indicating the operation type (analysis, execution)
+   * @param observer
+   *   The GRPC response observer.
+   * @tparam V
+   * @return
+   */
+  private def handleError[V](
+      opType: String,
+      observer: StreamObserver[V]): PartialFunction[Throwable, Unit] = {
+    case ae: AnalysisException =>
+      logError(s"Error during: $opType", ae)
+      val status = RPCStatus
+        .newBuilder()
+        .setCode(RPCCode.INTERNAL_VALUE)
+        .addDetails(
+          ProtoAny.pack(
+            ErrorInfo
+              .newBuilder()
+              .setReason(ae.getClass.getName)
+              .setDomain("org.apache.spark")
+              .putMetadata("message", ae.getSimpleMessage)
+              .putMetadata("plan", Option(ae.plan).flatten.map(p => 
s"$p").getOrElse(""))
+              .build()))
+        .setMessage(ae.getLocalizedMessage)
+        .build()
+      observer.onError(StatusProto.toStatusRuntimeException(status))
+    case st: SparkThrowable =>
+      logError(s"Error during: $opType", st)
+      val status = buildStatusFromThrowable(st)
+      observer.onError(StatusProto.toStatusRuntimeException(status))
+    case NonFatal(nf) =>
+      logError(s"Error during: $opType", nf)
+      val status = RPCStatus
+        .newBuilder()
+        .setCode(RPCCode.INTERNAL_VALUE)
+        .setMessage(nf.getLocalizedMessage)
+        .build()
+      observer.onError(StatusProto.toStatusRuntimeException(status))
+    case e: Throwable =>
+      logError(s"Error during: $opType", e)
+      observer.onError(
+        
Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException())
+  }
+
   /**
    * This is the main entry method for Spark Connect and all calls to execute 
a plan.
    *
@@ -64,12 +133,7 @@ class SparkConnectService(debug: Boolean)
       responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
     try {
       new SparkConnectStreamHandler(responseObserver).handle(request)
-    } catch {
-      case e: Throwable =>
-        log.error("Error executing plan.", e)
-        responseObserver.onError(
-          
Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException())
-    }
+    } catch handleError("execute", observer = responseObserver)
   }
 
   /**
@@ -114,12 +178,7 @@ class SparkConnectService(debug: Boolean)
       response.setClientId(request.getClientId)
       responseObserver.onNext(response.build())
       responseObserver.onCompleted()
-    } catch {
-      case e: Throwable =>
-        log.error("Error analyzing plan.", e)
-        responseObserver.onError(
-          
Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException())
-    }
+    } catch handleError("analyze", observer = responseObserver)
   }
 
   def handleAnalyzePlanRequest(
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 9631b93f6e9..9c1a8ca4dc4 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.connect.service
 
 import scala.collection.JavaConverters._
-import scala.util.control.NonFatal
 
 import com.google.protobuf.ByteString
 import io.grpc.stub.StreamObserver
@@ -128,12 +127,8 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[ExecutePlanResp
             }
             partitions(currentPartitionId) = null
 
-            error.foreach {
-              case NonFatal(e) =>
-                responseObserver.onError(e)
-                logError("Error while processing query.", e)
-                return
-              case other => throw other
+            error.foreach { case other =>
+              throw other
             }
             part
           }
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 6dcce0926dc..9c5df253aea 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -18,12 +18,12 @@ package org.apache.spark.sql.connect.planner
 
 import scala.collection.mutable
 
+import io.grpc.StatusRuntimeException
 import io.grpc.stub.StreamObserver
 import org.apache.arrow.memory.RootAllocator
 import org.apache.arrow.vector.{BigIntVector, Float8Vector}
 import org.apache.arrow.vector.ipc.ArrowStreamReader
 
-import org.apache.spark.SparkException
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.connect.dsl.MockRemoteSession
 import org.apache.spark.sql.connect.dsl.plans._
@@ -185,7 +185,7 @@ class SparkConnectServiceSuite extends SharedSparkSession {
         }
 
         override def onError(throwable: Throwable): Unit = {
-          assert(throwable.isInstanceOf[SparkException])
+          assert(throwable.isInstanceOf[StatusRuntimeException])
         }
 
         override def onCompleted(): Unit = {
diff --git a/dev/create-release/spark-rm/Dockerfile 
b/dev/create-release/spark-rm/Dockerfile
index c65a0e1c759..38c64601882 100644
--- a/dev/create-release/spark-rm/Dockerfile
+++ b/dev/create-release/spark-rm/Dockerfile
@@ -42,7 +42,7 @@ ARG APT_INSTALL="apt-get install --no-install-recommends -y"
 #   We should use the latest Sphinx version once this is fixed.
 # TODO(SPARK-35375): Jinja2 3.0.0+ causes error when building with Sphinx.
 #   See also https://issues.apache.org/jira/browse/SPARK-35375.
-ARG PIP_PKGS="sphinx==3.0.4 mkdocs==1.1.2 numpy==1.19.4 
pydata_sphinx_theme==0.4.1 ipython==7.19.0 nbsphinx==0.8.0 numpydoc==1.1.0 
jinja2==2.11.3 twine==3.4.1 sphinx-plotly-directive==0.1.3 pandas==1.1.5 
pyarrow==3.0.0 plotly==5.4.0 markupsafe==2.0.1 docutils<0.17 grpcio==1.48.1 
protobuf==4.21.6"
+ARG PIP_PKGS="sphinx==3.0.4 mkdocs==1.1.2 numpy==1.19.4 
pydata_sphinx_theme==0.4.1 ipython==7.19.0 nbsphinx==0.8.0 numpydoc==1.1.0 
jinja2==2.11.3 twine==3.4.1 sphinx-plotly-directive==0.1.3 pandas==1.1.5 
pyarrow==3.0.0 plotly==5.4.0 markupsafe==2.0.1 docutils<0.17 grpcio==1.48.1 
protobuf==4.21.6 grpcio-status==1.48.1 googleapis-common-protos==1.56.4"
 ARG GEM_PKGS="bundler:2.2.9"
 
 # Install extra needed repos and refresh.
diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile
index 1c326d437c7..92cb75360c1 100644
--- a/dev/infra/Dockerfile
+++ b/dev/infra/Dockerfile
@@ -68,4 +68,4 @@ RUN pypy3 -m pip install numpy 'pandas<=1.5.2' scipy coverage 
matplotlib
 RUN python3.9 -m pip install numpy pyarrow 'pandas<=1.5.2' scipy 
unittest-xml-reporting plotly>=4.8 sklearn 'mlflow>=1.0' coverage matplotlib 
openpyxl 'memory-profiler==0.60.0'
 
 # Add Python deps for Spark Connect.
-RUN python3.9 -m pip install grpcio protobuf
+RUN python3.9 -m pip install grpcio protobuf googleapis-common-protos 
grpcio-status
diff --git a/dev/lint-python b/dev/lint-python
index 806b7572dc6..59ce71980d9 100755
--- a/dev/lint-python
+++ b/dev/lint-python
@@ -69,6 +69,7 @@ function mypy_annotation_test {
 
     echo "starting mypy annotations test..."
     MYPY_REPORT=$( ($MYPY_BUILD \
+      --namespace-packages \
       --config-file python/mypy.ini \
       --cache-dir /tmp/.mypy_cache/ \
       python/pyspark) 2>&1)
@@ -127,6 +128,7 @@ function mypy_examples_test {
     echo "starting mypy examples test..."
 
     MYPY_REPORT=$( (MYPYPATH=python $MYPY_BUILD \
+      --namespace-packages \
       --config-file python/mypy.ini \
       --exclude "mllib/*" \
       examples/src/main/python/) 2>&1)
diff --git a/dev/requirements.txt b/dev/requirements.txt
index f91e2fed713..c3911b57eb9 100644
--- a/dev/requirements.txt
+++ b/dev/requirements.txt
@@ -50,7 +50,11 @@ black==22.6.0
 
 # Spark Connect (required)
 grpcio==1.48.1
+grpcio-status==1.48.1
 protobuf==3.19.5
+googleapis-common-protos==1.56.4
 
 # Spark Connect python proto generation plugin (optional)
 mypy-protobuf==3.3.0
+googleapis-common-protos-stubs==2.2.0
+grpc-stubs==1.24.11
diff --git a/python/docs/source/getting_started/install.rst 
b/python/docs/source/getting_started/install.rst
index d3b24be3d49..eddee8e30e1 100644
--- a/python/docs/source/getting_started/install.rst
+++ b/python/docs/source/getting_started/install.rst
@@ -153,15 +153,17 @@ To install PySpark from source, refer to 
|building_spark|_.
 
 Dependencies
 ------------
-============= ========================= 
======================================================================================
-Package       Minimum supported version Note
-============= ========================= 
======================================================================================
-`py4j`        0.10.9.7                  Required
-`pandas`      1.0.5                     Required for pandas API on Spark and 
Spark Connect; Optional for Spark SQL
-`pyarrow`     1.0.0                     Required for pandas API on Spark and 
Spark Connect; Optional for Spark SQL
-`numpy`       1.15                      Required for pandas API on Spark and 
MLLib DataFrame-based API; Optional for Spark SQL
-`grpc`        1.48.1                    Required for Spark Connect
-============= ========================= 
======================================================================================
+========================== ========================= 
======================================================================================
+Package                    Minimum supported version Note
+========================== ========================= 
======================================================================================
+`py4j`                     0.10.9.7                  Required
+`pandas`                   1.0.5                     Required for pandas API 
on Spark and Spark Connect; Optional for Spark SQL
+`pyarrow`                  1.0.0                     Required for pandas API 
on Spark and Spark Connect; Optional for Spark SQL
+`numpy`                    1.15                      Required for pandas API 
on Spark and MLLib DataFrame-based API; Optional for Spark SQL
+`grpc`                     1.48.1                    Required for Spark Connect
+`grpcio-status`            1.48.1                    Required for Spark Connect
+`googleapis-common-protos` 1.56.4                    Required for Spark Connect
+========================== ========================= 
======================================================================================
 
 Note that PySpark requires Java 8 or later with ``JAVA_HOME`` properly set.  
 If using JDK 11, set ``-Dio.netty.tryReflectionSetAccessible=true`` for Arrow 
related features and refer
diff --git a/python/mypy.ini b/python/mypy.ini
index 603647bd3cd..dd1c1cd4875 100644
--- a/python/mypy.ini
+++ b/python/mypy.ini
@@ -22,6 +22,7 @@ disallow_untyped_defs = True
 show_error_codes = True
 warn_unused_ignores = True
 warn_redundant_casts = True
+namespace_packages = True
 
 [mypy-pyspark.sql.connect.proto.*]
 ignore_errors = True
diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index e258dbd92b4..e78c4de0f70 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -15,14 +15,19 @@
 # limitations under the License.
 #
 
+import logging
 import os
 import urllib.parse
 import uuid
-from typing import Iterable, Optional, Any, Union, List, Tuple, Dict
+from typing import Iterable, Optional, Any, Union, List, Tuple, Dict, 
NoReturn, cast
 
+import google.protobuf.message
+from grpc_status import rpc_status
 import grpc
 import pandas
+from google.protobuf import text_format
 import pyarrow as pa
+from google.rpc import error_details_pb2
 
 import pyspark.sql.connect.proto as pb2
 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
@@ -36,6 +41,50 @@ from pyspark.sql.types import (
 )
 
 
+def _configure_logging() -> logging.Logger:
+    """Configure logging for the Spark Connect clients."""
+    logger = logging.getLogger(__name__)
+    handler = logging.StreamHandler()
+    handler.setFormatter(
+        logging.Formatter(fmt="%(asctime)s %(process)d %(levelname)s 
%(funcName)s %(message)s")
+    )
+    logger.addHandler(handler)
+
+    # Check the environment variables for log levels:
+    if "SPARK_CONNECT_LOG_LEVEL" in os.environ:
+        logger.setLevel(os.getenv("SPARK_CONNECT_LOG_LEVEL", "error").upper())
+    else:
+        logger.disabled = True
+    return logger
+
+
+# Instantiate the logger based on the environment configuration.
+logger = _configure_logging()
+
+
+class SparkConnectException(Exception):
+    def __init__(self, message: str, reason: Optional[str] = None) -> None:
+        super(SparkConnectException, self).__init__(message)
+        self._reason = reason
+        self._message = message
+
+    def __str__(self) -> str:
+        if self._reason is not None:
+            return f"({self._reason}) {self._message}"
+        else:
+            return self._message
+
+
+class SparkConnectAnalysisException(SparkConnectException):
+    def __init__(self, reason: str, message: str, plan: str) -> None:
+        self._reason = reason
+        self._message = message
+        self._plan = plan
+
+    def __str__(self) -> str:
+        return f"{self._message}\nPlan: {self._plan}"
+
+
 class ChannelBuilder:
     """
     This is a helper class that is used to create a GRPC channel based on the 
given
@@ -60,7 +109,18 @@ class ChannelBuilder:
 
     DEFAULT_PORT = 15002
 
-    def __init__(self, url: str) -> None:
+    def __init__(self, url: str, channelOptions: Optional[List[Tuple[str, 
Any]]] = None) -> None:
+        """
+        Constructs a new channel builder. This is used to create the proper 
GRPC channel from
+        the connection string.
+
+        Parameters
+        ----------
+        url : str
+            Spark Connect connection string
+        channelOptions: list of tuple, optional
+            Additional options that can be passed to the GRPC channel 
construction.
+        """
         # Explicitly check the scheme of the URL.
         if url[:5] != "sc://":
             raise AttributeError("URL scheme must be set to `sc`.")
@@ -74,6 +134,7 @@ class ChannelBuilder:
                 f"Path component for connection URI must be empty: 
{self.url.path}"
             )
         self._extract_attributes()
+        self._channel_options = channelOptions
 
     def _extract_attributes(self) -> None:
         if len(self.url.params) > 0:
@@ -159,7 +220,8 @@ class ChannelBuilder:
     def toChannel(self) -> grpc.Channel:
         """
         Applies the parameters of the connection string and creates a new
-        GRPC channel according to the configuration.
+        GRPC channel according to the configuration. Passes optional channel 
options to
+        construct the channel.
 
         Returns
         -------
@@ -176,7 +238,7 @@ class ChannelBuilder:
             use_secure = False
 
         if not use_secure:
-            return grpc.insecure_channel(destination)
+            return grpc.insecure_channel(destination, 
options=self._channel_options)
         else:
             # Default SSL Credentials.
             opt_token = self.params.get(ChannelBuilder.PARAM_TOKEN, None)
@@ -186,9 +248,15 @@ class ChannelBuilder:
                 composite_creds = grpc.composite_channel_credentials(
                     ssl_creds, grpc.access_token_call_credentials(opt_token)
                 )
-                return grpc.secure_channel(destination, 
credentials=composite_creds)
+                return grpc.secure_channel(
+                    destination, credentials=composite_creds, 
options=self._channel_options
+                )
             else:
-                return grpc.secure_channel(destination, 
credentials=grpc.ssl_channel_credentials())
+                return grpc.secure_channel(
+                    destination,
+                    credentials=grpc.ssl_channel_credentials(),
+                    options=self._channel_options,
+                )
 
 
 class MetricValue:
@@ -272,7 +340,12 @@ class AnalyzeResult:
 class SparkConnectClient(object):
     """Conceptually the remote spark session that communicates with the 
server"""
 
-    def __init__(self, connectionString: str, userId: Optional[str] = None):
+    def __init__(
+        self,
+        connectionString: str,
+        userId: Optional[str] = None,
+        channelOptions: Optional[List[Tuple[str, Any]]] = None,
+    ):
         """
         Creates a new SparkSession for the Spark Connect interface.
 
@@ -288,7 +361,7 @@ class SparkConnectClient(object):
             takes precedence.
         """
         # Parse the connection string.
-        self._builder = ChannelBuilder(connectionString)
+        self._builder = ChannelBuilder(connectionString, channelOptions)
         self._user_id = None
         # Generate a unique session ID for this client. This UUID must be 
unique to allow
         # concurrent Spark sessions of the same user. If the channel is 
closed, creating
@@ -303,6 +376,7 @@ class SparkConnectClient(object):
 
         self._channel = self._builder.toChannel()
         self._stub = grpc_lib.SparkConnectServiceStub(self._channel)
+        # Configure logging for the SparkConnect client.
 
     def register_udf(
         self, function: Any, return_type: Union[str, 
pyspark.sql.types.DataType]
@@ -312,6 +386,7 @@ class SparkConnectClient(object):
         name = f"fun_{uuid.uuid4().hex}"
         fun = pb2.CreateScalarFunction()
         fun.parts.append(name)
+        logger.info(f"Registering UDF: {self._proto_to_string(fun)}")
         fun.serialized_function = cloudpickle.dumps((function, return_type))
 
         req = self._execute_plan_request_with_metadata()
@@ -331,7 +406,8 @@ class SparkConnectClient(object):
             for x in metrics.metrics
         ]
 
-    def _to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame":
+    def to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame":
+        logger.info(f"Executing plan {self._proto_to_string(plan)}")
         req = self._execute_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
         return self._execute_and_fetch(req)
@@ -339,7 +415,22 @@ class SparkConnectClient(object):
     def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> 
DataType:
         return types.proto_schema_to_pyspark_data_type(schema)
 
+    def _proto_to_string(self, p: google.protobuf.message.Message) -> str:
+        """
+        Helper method to generate a one line string representation of the plan.
+        Parameters
+        ----------
+        p : google.protobuf.message.Message
+            Generic Message type
+
+        Returns
+        -------
+        Single line string of the serialized proto message.
+        """
+        return text_format.MessageToString(p, as_one_line=True)
+
     def schema(self, plan: pb2.Plan) -> StructType:
+        logger.info(f"Schema for plan: {self._proto_to_string(plan)}")
         proto_schema = self._analyze(plan).schema
         # Server side should populate the struct field which is the schema.
         assert proto_schema.HasField("struct")
@@ -355,10 +446,12 @@ class SparkConnectClient(object):
         return StructType(fields)
 
     def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") 
-> str:
+        logger.info(f"Explain (mode={explain_mode}) for plan 
{self._proto_to_string(plan)}")
         result = self._analyze(plan, explain_mode)
         return result.explain_string
 
     def execute_command(self, command: pb2.Command) -> None:
+        logger.info(f"Execute command for command 
{self._proto_to_string(command)}")
         req = self._execute_plan_request_with_metadata()
         if self._user_id:
             req.user_context.user_id = self._user_id
@@ -386,6 +479,20 @@ class SparkConnectClient(object):
         return req
 
     def _analyze(self, plan: pb2.Plan, explain_mode: str = "extended") -> 
AnalyzeResult:
+        """
+        Call the analyze RPC of Spark Connect.
+
+        Parameters
+        ----------
+        plan : :class:`pyspark.sql.connect.proto.Plan`
+           Proto representation of the plan.
+        explain_mode : str
+           Explain mode
+
+        Returns
+        -------
+        The result of the analyze call.
+        """
         req = self._analyze_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
         if explain_mode not in ["simple", "extended", "codegen", "cost", 
"formatted"]:
@@ -406,36 +513,64 @@ class SparkConnectClient(object):
         else:  # formatted
             req.explain.explain_mode = pb2.Explain.ExplainMode.FORMATTED
 
-        resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata())
-        if resp.client_id != self._session_id:
-            raise ValueError("Received incorrect session identifier for 
request.")
-        return AnalyzeResult.fromProto(resp)
+        try:
+            resp = self._stub.AnalyzePlan(req, 
metadata=self._builder.metadata())
+            if resp.client_id != self._session_id:
+                raise SparkConnectException("Received incorrect session 
identifier for request.")
+            return AnalyzeResult.fromProto(resp)
+        except grpc.RpcError as rpc_error:
+            self._handle_error(rpc_error)
 
     def _process_batch(self, arrow_batch: pb2.ExecutePlanResponse.ArrowBatch) 
-> "pandas.DataFrame":
         with pa.ipc.open_stream(arrow_batch.data) as rd:
             return rd.read_pandas()
 
     def _execute(self, req: pb2.ExecutePlanRequest) -> None:
-        for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
-            if b.client_id != self._session_id:
-                raise ValueError("Received incorrect session identifier for 
request.")
-            continue
-        return
+        """
+        Execute the passed request `req` and drop all results.
+
+        Parameters
+        ----------
+        req : pb2.ExecutePlanRequest
+            Proto representation of the plan.
+
+        """
+        logger.info("Execute")
+        try:
+            for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+                if b.client_id != self._session_id:
+                    raise SparkConnectException(
+                        "Received incorrect session identifier for request."
+                    )
+                continue
+        except grpc.RpcError as rpc_error:
+            self._handle_error(rpc_error)
 
     def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> 
"pandas.DataFrame":
+        logger.info("ExecuteAndFetch")
         import pandas as pd
 
         m: Optional[pb2.ExecutePlanResponse.Metrics] = None
         result_dfs = []
 
-        for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
-            if b.client_id != self._session_id:
-                raise ValueError("Received incorrect session identifier for 
request.")
-            if b.metrics is not None:
-                m = b.metrics
-            if b.HasField("arrow_batch"):
-                pb = self._process_batch(b.arrow_batch)
-                result_dfs.append(pb)
+        try:
+            for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+                if b.client_id != self._session_id:
+                    raise SparkConnectException(
+                        "Received incorrect session identifier for request."
+                    )
+                if b.metrics is not None:
+                    logger.debug("Received metric batch.")
+                    m = b.metrics
+                if b.HasField("arrow_batch"):
+                    logger.debug(
+                        f"Received arrow batch rows={b.arrow_batch.row_count} "
+                        f"size={len(b.arrow_batch.data)}"
+                    )
+                    pb = self._process_batch(b.arrow_batch)
+                    result_dfs.append(pb)
+        except grpc.RpcError as rpc_error:
+            self._handle_error(rpc_error)
 
         assert len(result_dfs) > 0
 
@@ -451,3 +586,50 @@ class SparkConnectClient(object):
         if m is not None:
             df.attrs["metrics"] = self._build_metrics(m)
         return df
+
+    def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn:
+        """
+        Error handling helper for dealing with GRPC Errors. On the server 
side, certain
+        exceptions are enriched with additional RPC Status information. These 
are
+        unpacked in this function and put into the exception.
+
+        To avoid overloading the user with GRPC errors, this message explicitly
+        swallows the error context from the call. This GRPC Error is logged 
however,
+        and can be enabled.
+
+        Parameters
+        ----------
+        rpc_error : grpc.RpcError
+           RPC Error containing the details of the exception.
+
+        Returns
+        -------
+        Throws the appropriate internal Python exception.
+        """
+        logger.exception("GRPC Error received")
+        # We have to cast the value here because, a RpcError is a Call as well.
+        # 
https://grpc.github.io/grpc/python/grpc.html#grpc.UnaryUnaryMultiCallable.__call__
+        status = rpc_status.from_call(cast(grpc.Call, rpc_error))
+        if status:
+            for d in status.details:
+                if d.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
+                    info = error_details_pb2.ErrorInfo()
+                    d.Unpack(info)
+                    if info.reason == "org.apache.spark.sql.AnalysisException":
+                        raise SparkConnectAnalysisException(
+                            info.reason, info.metadata["message"], 
info.metadata["plan"]
+                        ) from None
+                    else:
+                        raise SparkConnectException(status.message, 
info.reason) from None
+
+            raise SparkConnectException(status.message) from None
+        else:
+            raise SparkConnectException(str(rpc_error)) from None
+
+
+__all__ = [
+    "ChannelBuilder",
+    "SparkConnectClient",
+    "SparkConnectException",
+    "SparkConnectAnalysisException",
+]
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 313d46ceb2f..08db6b61871 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -907,7 +907,7 @@ class DataFrame:
         if self._session is None:
             raise Exception("Cannot collect on empty session.")
         query = self._plan.to_proto(self._session.client)
-        return self._session.client._to_pandas(query)
+        return self._session.client.to_pandas(query)
 
     toPandas.__doc__ = PySparkDataFrame.toPandas.__doc__
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 302a5a96010..da6d5afd1cd 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -33,6 +33,7 @@ import pyspark.sql.functions
 from pyspark.testing.utils import ReusedPySparkTestCase
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.sql.connect.client import SparkConnectException, 
SparkConnectAnalysisException
 
 if should_test_connect:
     import grpc
@@ -116,6 +117,12 @@ class SparkConnectSQLTestCase(PandasOnSparkTestCase, 
ReusedPySparkTestCase, SQLT
 
 
 class SparkConnectTests(SparkConnectSQLTestCase):
+    def test_error_handling(self):
+        # SPARK-41533 Proper error handling for Spark Connect
+        df = self.connect.range(10).select("id2")
+        with self.assertRaises(SparkConnectAnalysisException):
+            df.collect()
+
     def test_simple_read(self):
         df = self.connect.read.table(self.tbl_name)
         data = df.limit(10).toPandas()
@@ -262,12 +269,12 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         ):
             self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
 
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(SparkConnectException):
             self.connect.createDataFrame(
                 data, "col1 magic_type, col2 int, col3 int, col4 int"
             ).show()
 
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(SparkConnectException):
             self.connect.createDataFrame(data, "col1 int, col2 int, col3 
int").show()
 
     def test_with_local_list(self):
@@ -299,12 +306,12 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         ):
             self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
 
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(SparkConnectException):
             self.connect.createDataFrame(
                 data, "col1 magic_type, col2 int, col3 int, col4 int"
             ).show()
 
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(SparkConnectException):
             self.connect.createDataFrame(data, "col1 int, col2 int, col3 
int").show()
 
     def test_with_atom_type(self):
@@ -457,7 +464,7 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             ]
         )
 
-        with self.assertRaises(grpc.RpcError) as context:
+        with self.assertRaises(SparkConnectException) as context:
             self.connect.read.table(self.tbl_name).to(schema).toPandas()
             self.assertIn(
                 """Column or field `name` is of type "STRING" while it's 
required to be "INT".""",
@@ -679,7 +686,7 @@ class SparkConnectTests(SparkConnectSQLTestCase):
 
             # Test when creating a view which is already exists but
             
self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1"))
-            with self.assertRaises(grpc.RpcError):
+            with self.assertRaises(SparkConnectException):
                 self.connect.sql("SELECT 1 AS X LIMIT 
0").createGlobalTempView("view_1")
 
     def test_create_session_local_temp_view(self):
@@ -691,7 +698,7 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             self.assertEqual(self.connect.sql("SELECT * FROM 
view_local_temp").count(), 0)
 
             # Test when creating a view which is already exists but
-            with self.assertRaises(grpc.RpcError):
+            with self.assertRaises(SparkConnectException):
                 self.connect.sql("SELECT 1 AS X LIMIT 
0").createTempView("view_local_temp")
 
     def test_to_pandas(self):
@@ -876,7 +883,7 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             self.connect.sql(query).replace({None: 1}, subset="a").toPandas()
             self.assertTrue("Mixed type replacements are not supported" in 
str(context.exception))
 
-        with self.assertRaises(grpc.RpcError) as context:
+        with self.assertRaises(SparkConnectException) as context:
             self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", 
"x")).toPandas()
             self.assertIn(
                 """Cannot resolve column name "x" among (a, b, c)""", 
str(context.exception)
@@ -957,7 +964,7 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         )
 
         # Hint with unsupported parameter values
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(SparkConnectException):
             self.connect.read.table(self.tbl_name).hint("REPARTITION", 
"id+1").toPandas()
 
         # Hint with unsupported parameter types
@@ -965,7 +972,7 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             self.connect.read.table(self.tbl_name).hint("REPARTITION", 
1.1).toPandas()
 
         # Hint with wrong combination
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(SparkConnectException):
             self.connect.read.table(self.tbl_name).hint("REPARTITION", "id", 
3).toPandas()
 
     def test_empty_dataset(self):
@@ -1084,7 +1091,7 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         )
         self.assertEqual("name", col0)
 
-        with self.assertRaises(grpc.RpcError) as exc:
+        with self.assertRaises(SparkConnectException) as exc:
             self.connect.range(1, 10).select(col("id").alias("this", "is", 
"not")).collect()
         self.assertIn("(this, is, not)", str(exc.exception))
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index c9d770f1399..edf58947712 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -22,9 +22,9 @@ from pyspark.testing.pandasutils import PandasOnSparkTestCase
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 from pyspark.testing.utils import ReusedPySparkTestCase
 from pyspark.testing.sqlutils import SQLTestUtils
+from pyspark.sql.connect.client import SparkConnectException
 
 if should_test_connect:
-    import grpc
     from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
 
 
@@ -818,7 +818,7 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase):
             cdf.select(CF.rank().over(cdf.a))
 
         # invalid window function
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(SparkConnectException):
             cdf.select(cdf.b.over(CW.orderBy("b"))).show()
 
         # invalid window frame
@@ -832,34 +832,34 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase):
             CF.lead("c", 1),
             CF.ntile(1),
         ]:
-            with self.assertRaises(grpc.RpcError):
+            with self.assertRaises(SparkConnectException):
                 cdf.select(
                     ccol.over(CW.orderBy("b").rowsBetween(CW.currentRow, 
CW.currentRow + 123))
                 ).show()
 
-            with self.assertRaises(grpc.RpcError):
+            with self.assertRaises(SparkConnectException):
                 cdf.select(
                     ccol.over(CW.orderBy("b").rangeBetween(CW.currentRow, 
CW.currentRow + 123))
                 ).show()
 
-            with self.assertRaises(grpc.RpcError):
+            with self.assertRaises(SparkConnectException):
                 cdf.select(
                     
ccol.over(CW.orderBy("b").rangeBetween(CW.unboundedPreceding, CW.currentRow))
                 ).show()
 
         # Function 'cume_dist' requires Windowframe(RangeFrame, 
UnboundedPreceding, CurrentRow)
         ccol = CF.cume_dist()
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(SparkConnectException):
             cdf.select(
                 ccol.over(CW.orderBy("b").rangeBetween(CW.currentRow, 
CW.currentRow + 123))
             ).show()
 
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(SparkConnectException):
             cdf.select(
                 ccol.over(CW.orderBy("b").rowsBetween(CW.currentRow, 
CW.currentRow + 123))
             ).show()
 
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(SparkConnectException):
             cdf.select(
                 ccol.over(CW.orderBy("b").rowsBetween(CW.unboundedPreceding, 
CW.currentRow))
             ).show()
@@ -1964,11 +1964,11 @@ class 
SparkConnectFunctionTests(SparkConnectFuncTestCase):
         sdf = self.spark.sql(query)
 
         # test assert_true
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(SparkConnectException):
             cdf.select(CF.assert_true(cdf.a > 0, "a should be 
positive!")).show()
 
         # test raise_error
-        with self.assertRaises(grpc.RpcError):
+        with self.assertRaises(SparkConnectException):
             cdf.select(CF.raise_error("a should be positive!")).show()
 
         # test crc32
diff --git a/python/pyspark/testing/connectutils.py 
b/python/pyspark/testing/connectutils.py
index dcbc09f2210..bec116b5f79 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -34,8 +34,29 @@ except ImportError as e:
     grpc_requirement_message = str(e)
 have_grpc = grpc_requirement_message is None
 
+
+grpc_status_requirement_message = None
+try:
+    import grpc_status
+except ImportError as e:
+    grpc_status_requirement_message = str(e)
+have_grpc_status = grpc_status_requirement_message is None
+
+googleapis_common_protos_requirement_message = None
+try:
+    from google.rpc import error_details_pb2
+except ImportError as e:
+    googleapis_common_protos_requirement_message = str(e)
+have_googleapis_common_protos = googleapis_common_protos_requirement_message 
is None
+
 connect_not_compiled_message = None
-if have_pandas and have_pyarrow and have_grpc:
+if (
+    have_pandas
+    and have_pyarrow
+    and have_grpc
+    and have_grpc_status
+    and have_googleapis_common_protos
+):
     from pyspark.sql.connect import DataFrame
     from pyspark.sql.connect.plan import Read, Range, SQL
     from pyspark.testing.utils import search_jar
@@ -62,6 +83,8 @@ connect_requirement_message = (
     or pyarrow_requirement_message
     or grpc_requirement_message
     or connect_not_compiled_message
+    or googleapis_common_protos_requirement_message
+    or grpc_status_requirement_message
 )
 should_test_connect: str = typing.cast(str, connect_requirement_message is 
None)
 
diff --git a/python/setup.py b/python/setup.py
index 4ba2740246a..54115359a60 100755
--- a/python/setup.py
+++ b/python/setup.py
@@ -114,6 +114,7 @@ if (in_spark):
 _minimum_pandas_version = "1.0.5"
 _minimum_pyarrow_version = "1.0.0"
 _minimum_grpc_version = "1.48.1"
+_minimum_googleapis_common_protos_version = "1.56.4"
 
 
 class InstallCommand(install):
@@ -280,7 +281,9 @@ try:
             'connect': [
                 'pandas>=%s' % _minimum_pandas_version,
                 'pyarrow>=%s' % _minimum_pyarrow_version,
-                'grpc>=%s' % _minimum_grpc_version,
+                'grpcio>=%s' % _minimum_grpc_version,
+                'grpcio-status>=%s' % _minimum_grpc_version,
+                'googleapis-common-protos>=%s' % 
_minimum_googleapis_common_protos_version,
                 'numpy>=1.15',
             ],
         },


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to