This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7de85bad5916 [SPARK-55314][CONNECT] Propagate observed metrics errors
to client
7de85bad5916 is described below
commit 7de85bad591675822177ab4b0e2695fb9f801227
Author: Yihong He <[email protected]>
AuthorDate: Thu Feb 19 08:26:53 2026 +0900
[SPARK-55314][CONNECT] Propagate observed metrics errors to client
### What changes were proposed in this pull request?
Propagate observation metric collection errors to the client in Spark
Connect instead of silently returning empty metrics.
- **Proto:** Add optional `root_error_idx` and repeated `errors` to
`ExecutePlanResponse.ObservedMetrics` so the server can send observation
failures.
- **Python:** Add `convert_observation_errors()` and refactor exception
conversion to support it; in the client, when observed metrics have
`root_error_idx` set, convert and store the error on the Observation; in
`Observation.get`, raise the stored error if present.
- **Scala/server:** Use `Try[Row]` / `Try[Seq[...]]` for observed metrics
end-to-end; on failure, serialize the throwable via
`ErrorUtils.throwableToProtoErrors` and set `root_error_idx`/`errors` on
ObservedMetrics; in Observation, rethrow the cause from `getRow` so the
original failure is exposed.
- **Tests:** New Python test and updated Scala Connect E2E and DatasetSuite
tests for the new behavior.
### Why are the changes needed?
Previously, when an error occurred during observation metric collection
(SPARK-55150), the error was silently ignored and an empty result was returned.
This was confusing for users since they would get empty metrics without knowing
an error occurred. With this change, the actual error is propagated to the
client so users can understand why their observation failed.
### Does this PR introduce _any_ user-facing change?
Yes. When an observation fails during metric collection, `observation.get`
now raises the underlying exception (e.g. `PySparkException` in Python,
`SparkRuntimeException` in Scala) instead of returning an empty map.
### How was this patch tested?
New unit test in Python (`test_observation_errors_propagated_to_client`);
updated Scala Connect E2E test and DatasetSuite test to expect the exception
with message containing `"test error"` instead of empty metrics.
### Was this patch authored or co-authored using generative AI tooling?
Yes
Closes #54094 from heyihong/SPARK-55314.
Authored-by: Yihong He <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/errors/exceptions/connect.py | 107 +++++++---
python/pyspark/sql/connect/client/core.py | 50 +++--
python/pyspark/sql/connect/observation.py | 7 +
python/pyspark/sql/connect/proto/base_pb2.py | 234 ++++++++++-----------
python/pyspark/sql/connect/proto/base_pb2.pyi | 41 +++-
python/pyspark/sql/tests/test_observation.py | 19 ++
.../scala/org/apache/spark/sql/Observation.scala | 19 +-
.../spark/sql/connect/ClientE2ETestSuite.scala | 13 +-
.../src/main/protobuf/spark/connect/base.proto | 6 +
.../apache/spark/sql/connect/SparkSession.scala | 4 +-
.../connect/client/GrpcExceptionConverter.scala | 2 +-
.../spark/sql/connect/client/SparkResult.scala | 36 ++--
.../connect/execution/ExecuteThreadRunner.scala | 15 +-
.../execution/SparkConnectPlanExecution.scala | 38 ++--
.../spark/sql/connect/utils/ErrorUtils.scala | 30 ++-
.../scala/org/apache/spark/sql/DatasetSuite.scala | 11 +-
16 files changed, 395 insertions(+), 237 deletions(-)
diff --git a/python/pyspark/errors/exceptions/connect.py
b/python/pyspark/errors/exceptions/connect.py
index d89705f47d24..88acc1a9d5b9 100644
--- a/python/pyspark/errors/exceptions/connect.py
+++ b/python/pyspark/errors/exceptions/connect.py
@@ -61,43 +61,89 @@ def convert_exception(
display_server_stacktrace: bool = False,
grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN,
) -> SparkConnectException:
+ raw_classes = info.metadata.get("classes")
+ classes: List[str] = json.loads(raw_classes) if raw_classes else []
+ raw_message_parameters = info.metadata.get("messageParameters")
+ message_parameters: Dict[str, str] = (
+ json.loads(raw_message_parameters) if raw_message_parameters else {}
+ )
+ root_error_idx = (
+ resp.root_error_idx if resp is not None and
resp.HasField("root_error_idx") else None
+ )
converted = _convert_exception(
- info, truncated_message, resp, display_server_stacktrace,
grpc_status_code
+ classes=classes,
+ sql_state=info.metadata.get("sqlState"),
+ error_class=info.metadata.get("errorClass"),
+ reason=info.reason,
+ root_error_idx=root_error_idx,
+ errors=list(resp.errors) if resp is not None else None,
+ truncated_message=truncated_message,
+ truncated_message_parameters=message_parameters,
+ truncated_stacktrace=info.metadata.get("stackTrace"),
+ display_server_stacktrace=display_server_stacktrace,
+ grpc_status_code=grpc_status_code,
)
return recover_python_exception(converted)
+def convert_observation_errors(
+ root_error_idx: int,
+ errors: List["pb2.FetchErrorDetailsResponse.Error"],
+) -> SparkConnectException:
+ """
+ Convert observation error payload (root_error_idx + list of Error from
ObservedMetrics)
+ to a SparkConnectException.
+ """
+ if root_error_idx < 0 or root_error_idx >= len(errors):
+ return SparkConnectException("Observation error: invalid
root_error_idx")
+
+ if len(errors) == 0:
+ return SparkConnectException("Observation error: no errors")
+
+ root_error = errors[root_error_idx]
+
+ return _convert_exception(
+ classes=list(root_error.error_type_hierarchy),
+ sql_state=root_error.spark_throwable.sql_state
+ if root_error.spark_throwable.HasField("sql_state")
+ else None,
+ error_class=root_error.spark_throwable.error_class
+ if root_error.spark_throwable.HasField("error_class")
+ else None,
+ reason=None,
+ root_error_idx=root_error_idx,
+ errors=errors,
+ truncated_message="",
+ truncated_message_parameters=None,
+ truncated_stacktrace=None,
+ )
+
+
def _convert_exception(
- info: "ErrorInfo",
+ classes: List[str],
+ sql_state: Optional[str],
+ error_class: Optional[str],
+ reason: Optional[str],
+ root_error_idx: Optional[int],
+ errors: Optional[List["pb2.FetchErrorDetailsResponse.Error"]],
truncated_message: str,
- resp: Optional["pb2.FetchErrorDetailsResponse"],
+ truncated_message_parameters: Optional[Dict[str, str]],
+ truncated_stacktrace: Optional[str],
display_server_stacktrace: bool = False,
grpc_status_code: grpc.StatusCode = StatusCode.UNKNOWN,
) -> SparkConnectException:
import pyspark.sql.connect.proto as pb2
- raw_classes = info.metadata.get("classes")
- classes: List[str] = json.loads(raw_classes) if raw_classes else []
- sql_state = info.metadata.get("sqlState")
- error_class = info.metadata.get("errorClass")
- raw_message_parameters = info.metadata.get("messageParameters")
- message_parameters: Dict[str, str] = (
- json.loads(raw_message_parameters) if raw_message_parameters else {}
- )
- stacktrace: Optional[str] = None
-
- if resp is not None and resp.HasField("root_error_idx"):
- message = resp.errors[resp.root_error_idx].message
- stacktrace = _extract_jvm_stacktrace(resp)
- else:
- message = truncated_message
- stacktrace = info.metadata.get("stackTrace")
- display_server_stacktrace = display_server_stacktrace if stacktrace
else False
-
+ message = truncated_message
+ stacktrace = truncated_stacktrace
+ message_parameters = truncated_message_parameters
contexts = None
breaking_change_info = None
- if resp and resp.HasField("root_error_idx"):
- root_error = resp.errors[resp.root_error_idx]
+
+ if root_error_idx is not None and errors is not None:
+ root_error = errors[root_error_idx]
+ message = root_error.message
+ stacktrace = _extract_jvm_stacktrace(root_error_idx, errors)
if hasattr(root_error, "spark_throwable"):
# Extract errorClass from FetchErrorDetailsResponse if not in
metadata
if error_class is None and
root_error.spark_throwable.HasField("error_class"):
@@ -123,6 +169,8 @@ def _convert_exception(
"key": bci.mitigation_config.key,
"value": bci.mitigation_config.value,
}
+ else:
+ display_server_stacktrace = display_server_stacktrace if stacktrace
else False
if "org.apache.spark.api.python.PythonException" in classes:
return PythonException(
@@ -158,7 +206,7 @@ def _convert_exception(
# Return UnknownException if there is no matched exception class
return UnknownException(
message,
- reason=info.reason,
+ reason=reason,
messageParameters=message_parameters,
errorClass=error_class,
sql_state=sql_state,
@@ -170,10 +218,9 @@ def _convert_exception(
)
-def _extract_jvm_stacktrace(resp: "pb2.FetchErrorDetailsResponse") -> str:
- if len(resp.errors[resp.root_error_idx].stack_trace) == 0:
- return ""
-
+def _extract_jvm_stacktrace(
+ root_error_idx: int, errors: List["pb2.FetchErrorDetailsResponse.Error"]
+) -> str:
lines: List[str] = []
def format_stacktrace(error: "pb2.FetchErrorDetailsResponse.Error") ->
None:
@@ -190,9 +237,9 @@ def _extract_jvm_stacktrace(resp:
"pb2.FetchErrorDetailsResponse") -> str:
# If this error has a cause, format that recursively
if error.HasField("cause_idx"):
- format_stacktrace(resp.errors[error.cause_idx])
+ format_stacktrace(errors[error.cause_idx])
- format_stacktrace(resp.errors[resp.root_error_idx])
+ format_stacktrace(errors[root_error_idx])
return "\n".join(lines)
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 58cbd22a36b7..b906aee1b7d0 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -55,7 +55,6 @@ from typing import (
cast,
TYPE_CHECKING,
Type,
- Sequence,
)
import pandas as pd
@@ -88,6 +87,7 @@ import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
import pyspark.sql.connect.types as types
from pyspark.errors.exceptions.connect import (
convert_exception,
+ convert_observation_errors,
SparkConnectException,
SparkConnectGrpcException,
)
@@ -1003,11 +1003,6 @@ class SparkConnectClient(object):
resources = properties["get_resources_command_result"]
return resources
- def _build_observed_metrics(
- self, metrics: Sequence["pb2.ExecutePlanResponse.ObservedMetrics"]
- ) -> Iterator[PlanObservedMetrics]:
- return (PlanObservedMetrics(x.name, [v for v in x.values],
list(x.keys)) for x in metrics)
-
def to_table_as_iterator(
self, plan: pb2.Plan, observations: Dict[str, Observation]
) -> Iterator[Union[StructType, "pa.Table"]]:
@@ -1553,23 +1548,32 @@ class SparkConnectClient(object):
yield from self._build_metrics(b.metrics)
if b.observed_metrics:
logger.debug("Received observed metric batch.")
- for observed_metrics in
self._build_observed_metrics(b.observed_metrics):
- if observed_metrics.name == "__python_accumulator__":
- for metric in observed_metrics.metrics:
- (aid, update) =
pickleSer.loads(LiteralExpression._to_value(metric))
- if aid == SpecialAccumulatorIds.SQL_UDF_PROFIER:
- self._profiler_collector._update(update)
- elif observed_metrics.name in observations:
- observation_result =
observations[observed_metrics.name]._result
- assert observation_result is not None
- observation_result.update(
- {
- key: LiteralExpression._to_value(metric)
- for key, metric in zip(
- observed_metrics.keys,
observed_metrics.metrics
- )
- }
- )
+ for x in b.observed_metrics:
+ observed_metrics = PlanObservedMetrics(
+ x.name, [v for v in x.values], list(x.keys)
+ )
+ if x.HasField("root_error_idx"):
+ if x.name in observations:
+ converted =
convert_observation_errors(x.root_error_idx, list(x.errors))
+ observations[x.name]._set_error(converted)
+ else:
+ if observed_metrics.name == "__python_accumulator__":
+ for metric in observed_metrics.metrics:
+ (aid, update) =
pickleSer.loads(LiteralExpression._to_value(metric))
+ if aid ==
SpecialAccumulatorIds.SQL_UDF_PROFIER:
+ self._profiler_collector._update(update)
+ elif observed_metrics.name in observations:
+ observation_result =
observations[observed_metrics.name]._result
+ assert observation_result is not None
+ observation_result.update(
+ {
+ key: LiteralExpression._to_value(metric)
+ for key, metric in zip(
+ observed_metrics.keys,
+ observed_metrics.metrics,
+ )
+ }
+ )
yield observed_metrics
if b.HasField("schema"):
logger.debug("Received the schema.")
diff --git a/python/pyspark/sql/connect/observation.py
b/python/pyspark/sql/connect/observation.py
index bfb8a0a9355f..1ce4235a67ce 100644
--- a/python/pyspark/sql/connect/observation.py
+++ b/python/pyspark/sql/connect/observation.py
@@ -51,9 +51,14 @@ class Observation:
)
self._name = name
self._result: Optional[Dict[str, Any]] = None
+ self._error: Optional[BaseException] = None
__init__.__doc__ = PySparkObservation.__init__.__doc__
+ def _set_error(self, exc: BaseException) -> None:
+ """Set the error that occurred while collecting observed metrics (used
by the client)."""
+ self._error = exc
+
def _on(self, df: DataFrame, *exprs: Column) -> DataFrame:
if self._result is not None:
raise PySparkAssertionError(errorClass="REUSE_OBSERVATION",
messageParameters={})
@@ -74,6 +79,8 @@ class Observation:
@property
def get(self) -> Dict[str, Any]:
+ if self._error is not None:
+ raise self._error
if self._result is None:
raise PySparkAssertionError(errorClass="NO_OBSERVE_BEFORE_GET",
messageParameters={})
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py
b/python/pyspark/sql/connect/proto/base_pb2.py
index 6119d8dc5539..24c3a5ab31b0 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -45,7 +45,7 @@ from pyspark.sql.connect.proto import pipelines_pb2 as
spark_dot_connect_dot_pip
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto\x1a\x16spark/connect/ml.proto\x1a\x1dspark/connect/pipelines.proto"\xe3\x03\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.conn [...]
+
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto\x1a\x16spark/connect/ml.proto\x1a\x1dspark/connect/pipelines.proto"\xe3\x03\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.conn [...]
)
_globals = globals()
@@ -70,8 +70,8 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals[
"_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY"
]._serialized_options = b"8\001"
- _globals["_COMPRESSIONCODEC"]._serialized_start = 19857
- _globals["_COMPRESSIONCODEC"]._serialized_end = 19938
+ _globals["_COMPRESSIONCODEC"]._serialized_start = 19991
+ _globals["_COMPRESSIONCODEC"]._serialized_end = 20072
_globals["_PLAN"]._serialized_start = 275
_globals["_PLAN"]._serialized_end = 758
_globals["_PLAN_COMPRESSEDOPERATION"]._serialized_start = 477
@@ -147,7 +147,7 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_EXECUTEPLANREQUEST_REQUESTOPTION"]._serialized_start = 5868
_globals["_EXECUTEPLANREQUEST_REQUESTOPTION"]._serialized_end = 6129
_globals["_EXECUTEPLANRESPONSE"]._serialized_start = 6208
- _globals["_EXECUTEPLANRESPONSE"]._serialized_end = 9665
+ _globals["_EXECUTEPLANRESPONSE"]._serialized_end = 9799
_globals["_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT"]._serialized_start = 8308
_globals["_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT"]._serialized_end = 8379
_globals["_EXECUTEPLANRESPONSE_ARROWBATCH"]._serialized_start = 8382
@@ -165,121 +165,121 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_EXECUTEPLANRESPONSE_METRICS_METRICVALUE"]._serialized_start =
9062
_globals["_EXECUTEPLANRESPONSE_METRICS_METRICVALUE"]._serialized_end = 9150
_globals["_EXECUTEPLANRESPONSE_OBSERVEDMETRICS"]._serialized_start = 9153
- _globals["_EXECUTEPLANRESPONSE_OBSERVEDMETRICS"]._serialized_end = 9294
- _globals["_EXECUTEPLANRESPONSE_RESULTCOMPLETE"]._serialized_start = 9296
- _globals["_EXECUTEPLANRESPONSE_RESULTCOMPLETE"]._serialized_end = 9312
- _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS"]._serialized_start = 9315
- _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS"]._serialized_end = 9648
-
_globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO"]._serialized_start
= 9471
-
_globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO"]._serialized_end =
9648
- _globals["_KEYVALUE"]._serialized_start = 9667
- _globals["_KEYVALUE"]._serialized_end = 9732
- _globals["_CONFIGREQUEST"]._serialized_start = 9735
- _globals["_CONFIGREQUEST"]._serialized_end = 10934
- _globals["_CONFIGREQUEST_OPERATION"]._serialized_start = 10043
- _globals["_CONFIGREQUEST_OPERATION"]._serialized_end = 10541
- _globals["_CONFIGREQUEST_SET"]._serialized_start = 10543
- _globals["_CONFIGREQUEST_SET"]._serialized_end = 10635
- _globals["_CONFIGREQUEST_GET"]._serialized_start = 10637
- _globals["_CONFIGREQUEST_GET"]._serialized_end = 10662
- _globals["_CONFIGREQUEST_GETWITHDEFAULT"]._serialized_start = 10664
- _globals["_CONFIGREQUEST_GETWITHDEFAULT"]._serialized_end = 10727
- _globals["_CONFIGREQUEST_GETOPTION"]._serialized_start = 10729
- _globals["_CONFIGREQUEST_GETOPTION"]._serialized_end = 10760
- _globals["_CONFIGREQUEST_GETALL"]._serialized_start = 10762
- _globals["_CONFIGREQUEST_GETALL"]._serialized_end = 10810
- _globals["_CONFIGREQUEST_UNSET"]._serialized_start = 10812
- _globals["_CONFIGREQUEST_UNSET"]._serialized_end = 10839
- _globals["_CONFIGREQUEST_ISMODIFIABLE"]._serialized_start = 10841
- _globals["_CONFIGREQUEST_ISMODIFIABLE"]._serialized_end = 10875
- _globals["_CONFIGRESPONSE"]._serialized_start = 10937
- _globals["_CONFIGRESPONSE"]._serialized_end = 11112
- _globals["_ADDARTIFACTSREQUEST"]._serialized_start = 11115
- _globals["_ADDARTIFACTSREQUEST"]._serialized_end = 12117
- _globals["_ADDARTIFACTSREQUEST_ARTIFACTCHUNK"]._serialized_start = 11590
- _globals["_ADDARTIFACTSREQUEST_ARTIFACTCHUNK"]._serialized_end = 11643
- _globals["_ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT"]._serialized_start =
11645
- _globals["_ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT"]._serialized_end =
11756
- _globals["_ADDARTIFACTSREQUEST_BATCH"]._serialized_start = 11758
- _globals["_ADDARTIFACTSREQUEST_BATCH"]._serialized_end = 11851
- _globals["_ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT"]._serialized_start =
11854
- _globals["_ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT"]._serialized_end =
12047
- _globals["_ADDARTIFACTSRESPONSE"]._serialized_start = 12120
- _globals["_ADDARTIFACTSRESPONSE"]._serialized_end = 12392
- _globals["_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY"]._serialized_start = 12311
- _globals["_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY"]._serialized_end = 12392
- _globals["_ARTIFACTSTATUSESREQUEST"]._serialized_start = 12395
- _globals["_ARTIFACTSTATUSESREQUEST"]._serialized_end = 12721
- _globals["_ARTIFACTSTATUSESRESPONSE"]._serialized_start = 12724
- _globals["_ARTIFACTSTATUSESRESPONSE"]._serialized_end = 13076
- _globals["_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY"]._serialized_start =
12919
- _globals["_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY"]._serialized_end = 13034
- _globals["_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS"]._serialized_start =
13036
- _globals["_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS"]._serialized_end =
13076
- _globals["_INTERRUPTREQUEST"]._serialized_start = 13079
- _globals["_INTERRUPTREQUEST"]._serialized_end = 13682
- _globals["_INTERRUPTREQUEST_INTERRUPTTYPE"]._serialized_start = 13482
- _globals["_INTERRUPTREQUEST_INTERRUPTTYPE"]._serialized_end = 13610
- _globals["_INTERRUPTRESPONSE"]._serialized_start = 13685
- _globals["_INTERRUPTRESPONSE"]._serialized_end = 13829
- _globals["_REATTACHOPTIONS"]._serialized_start = 13831
- _globals["_REATTACHOPTIONS"]._serialized_end = 13884
- _globals["_RESULTCHUNKINGOPTIONS"]._serialized_start = 13887
- _globals["_RESULTCHUNKINGOPTIONS"]._serialized_end = 14068
- _globals["_REATTACHEXECUTEREQUEST"]._serialized_start = 14071
- _globals["_REATTACHEXECUTEREQUEST"]._serialized_end = 14477
- _globals["_RELEASEEXECUTEREQUEST"]._serialized_start = 14480
- _globals["_RELEASEEXECUTEREQUEST"]._serialized_end = 15065
- _globals["_RELEASEEXECUTEREQUEST_RELEASEALL"]._serialized_start = 14934
- _globals["_RELEASEEXECUTEREQUEST_RELEASEALL"]._serialized_end = 14946
- _globals["_RELEASEEXECUTEREQUEST_RELEASEUNTIL"]._serialized_start = 14948
- _globals["_RELEASEEXECUTEREQUEST_RELEASEUNTIL"]._serialized_end = 14995
- _globals["_RELEASEEXECUTERESPONSE"]._serialized_start = 15068
- _globals["_RELEASEEXECUTERESPONSE"]._serialized_end = 15233
- _globals["_RELEASESESSIONREQUEST"]._serialized_start = 15236
- _globals["_RELEASESESSIONREQUEST"]._serialized_end = 15448
- _globals["_RELEASESESSIONRESPONSE"]._serialized_start = 15450
- _globals["_RELEASESESSIONRESPONSE"]._serialized_end = 15558
- _globals["_FETCHERRORDETAILSREQUEST"]._serialized_start = 15561
- _globals["_FETCHERRORDETAILSREQUEST"]._serialized_end = 15893
- _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_start = 15896
- _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_end = 17905
- _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_start
= 16125
- _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_end =
16299
- _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_start =
16302
- _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_end = 16670
-
_globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_start
= 16633
-
_globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_end
= 16670
- _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_start =
16673
- _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_end =
17223
+ _globals["_EXECUTEPLANRESPONSE_OBSERVEDMETRICS"]._serialized_end = 9428
+ _globals["_EXECUTEPLANRESPONSE_RESULTCOMPLETE"]._serialized_start = 9430
+ _globals["_EXECUTEPLANRESPONSE_RESULTCOMPLETE"]._serialized_end = 9446
+ _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS"]._serialized_start = 9449
+ _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS"]._serialized_end = 9782
+
_globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO"]._serialized_start
= 9605
+
_globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO"]._serialized_end =
9782
+ _globals["_KEYVALUE"]._serialized_start = 9801
+ _globals["_KEYVALUE"]._serialized_end = 9866
+ _globals["_CONFIGREQUEST"]._serialized_start = 9869
+ _globals["_CONFIGREQUEST"]._serialized_end = 11068
+ _globals["_CONFIGREQUEST_OPERATION"]._serialized_start = 10177
+ _globals["_CONFIGREQUEST_OPERATION"]._serialized_end = 10675
+ _globals["_CONFIGREQUEST_SET"]._serialized_start = 10677
+ _globals["_CONFIGREQUEST_SET"]._serialized_end = 10769
+ _globals["_CONFIGREQUEST_GET"]._serialized_start = 10771
+ _globals["_CONFIGREQUEST_GET"]._serialized_end = 10796
+ _globals["_CONFIGREQUEST_GETWITHDEFAULT"]._serialized_start = 10798
+ _globals["_CONFIGREQUEST_GETWITHDEFAULT"]._serialized_end = 10861
+ _globals["_CONFIGREQUEST_GETOPTION"]._serialized_start = 10863
+ _globals["_CONFIGREQUEST_GETOPTION"]._serialized_end = 10894
+ _globals["_CONFIGREQUEST_GETALL"]._serialized_start = 10896
+ _globals["_CONFIGREQUEST_GETALL"]._serialized_end = 10944
+ _globals["_CONFIGREQUEST_UNSET"]._serialized_start = 10946
+ _globals["_CONFIGREQUEST_UNSET"]._serialized_end = 10973
+ _globals["_CONFIGREQUEST_ISMODIFIABLE"]._serialized_start = 10975
+ _globals["_CONFIGREQUEST_ISMODIFIABLE"]._serialized_end = 11009
+ _globals["_CONFIGRESPONSE"]._serialized_start = 11071
+ _globals["_CONFIGRESPONSE"]._serialized_end = 11246
+ _globals["_ADDARTIFACTSREQUEST"]._serialized_start = 11249
+ _globals["_ADDARTIFACTSREQUEST"]._serialized_end = 12251
+ _globals["_ADDARTIFACTSREQUEST_ARTIFACTCHUNK"]._serialized_start = 11724
+ _globals["_ADDARTIFACTSREQUEST_ARTIFACTCHUNK"]._serialized_end = 11777
+ _globals["_ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT"]._serialized_start =
11779
+ _globals["_ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT"]._serialized_end =
11890
+ _globals["_ADDARTIFACTSREQUEST_BATCH"]._serialized_start = 11892
+ _globals["_ADDARTIFACTSREQUEST_BATCH"]._serialized_end = 11985
+ _globals["_ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT"]._serialized_start =
11988
+ _globals["_ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT"]._serialized_end =
12181
+ _globals["_ADDARTIFACTSRESPONSE"]._serialized_start = 12254
+ _globals["_ADDARTIFACTSRESPONSE"]._serialized_end = 12526
+ _globals["_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY"]._serialized_start = 12445
+ _globals["_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY"]._serialized_end = 12526
+ _globals["_ARTIFACTSTATUSESREQUEST"]._serialized_start = 12529
+ _globals["_ARTIFACTSTATUSESREQUEST"]._serialized_end = 12855
+ _globals["_ARTIFACTSTATUSESRESPONSE"]._serialized_start = 12858
+ _globals["_ARTIFACTSTATUSESRESPONSE"]._serialized_end = 13210
+ _globals["_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY"]._serialized_start =
13053
+ _globals["_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY"]._serialized_end = 13168
+ _globals["_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS"]._serialized_start =
13170
+ _globals["_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS"]._serialized_end =
13210
+ _globals["_INTERRUPTREQUEST"]._serialized_start = 13213
+ _globals["_INTERRUPTREQUEST"]._serialized_end = 13816
+ _globals["_INTERRUPTREQUEST_INTERRUPTTYPE"]._serialized_start = 13616
+ _globals["_INTERRUPTREQUEST_INTERRUPTTYPE"]._serialized_end = 13744
+ _globals["_INTERRUPTRESPONSE"]._serialized_start = 13819
+ _globals["_INTERRUPTRESPONSE"]._serialized_end = 13963
+ _globals["_REATTACHOPTIONS"]._serialized_start = 13965
+ _globals["_REATTACHOPTIONS"]._serialized_end = 14018
+ _globals["_RESULTCHUNKINGOPTIONS"]._serialized_start = 14021
+ _globals["_RESULTCHUNKINGOPTIONS"]._serialized_end = 14202
+ _globals["_REATTACHEXECUTEREQUEST"]._serialized_start = 14205
+ _globals["_REATTACHEXECUTEREQUEST"]._serialized_end = 14611
+ _globals["_RELEASEEXECUTEREQUEST"]._serialized_start = 14614
+ _globals["_RELEASEEXECUTEREQUEST"]._serialized_end = 15199
+ _globals["_RELEASEEXECUTEREQUEST_RELEASEALL"]._serialized_start = 15068
+ _globals["_RELEASEEXECUTEREQUEST_RELEASEALL"]._serialized_end = 15080
+ _globals["_RELEASEEXECUTEREQUEST_RELEASEUNTIL"]._serialized_start = 15082
+ _globals["_RELEASEEXECUTEREQUEST_RELEASEUNTIL"]._serialized_end = 15129
+ _globals["_RELEASEEXECUTERESPONSE"]._serialized_start = 15202
+ _globals["_RELEASEEXECUTERESPONSE"]._serialized_end = 15367
+ _globals["_RELEASESESSIONREQUEST"]._serialized_start = 15370
+ _globals["_RELEASESESSIONREQUEST"]._serialized_end = 15582
+ _globals["_RELEASESESSIONRESPONSE"]._serialized_start = 15584
+ _globals["_RELEASESESSIONRESPONSE"]._serialized_end = 15692
+ _globals["_FETCHERRORDETAILSREQUEST"]._serialized_start = 15695
+ _globals["_FETCHERRORDETAILSREQUEST"]._serialized_end = 16027
+ _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_start = 16030
+ _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_end = 18039
+ _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_start
= 16259
+ _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_end =
16433
+ _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_start =
16436
+ _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_end = 16804
+
_globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_start
= 16767
+
_globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_end
= 16804
+ _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_start =
16807
+ _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_end =
17357
_globals[
"_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY"
- ]._serialized_start = 17100
+ ]._serialized_start = 17234
_globals[
"_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY"
- ]._serialized_end = 17168
-
_globals["_FETCHERRORDETAILSRESPONSE_BREAKINGCHANGEINFO"]._serialized_start =
17226
- _globals["_FETCHERRORDETAILSRESPONSE_BREAKINGCHANGEINFO"]._serialized_end
= 17476
- _globals["_FETCHERRORDETAILSRESPONSE_MITIGATIONCONFIG"]._serialized_start
= 17478
- _globals["_FETCHERRORDETAILSRESPONSE_MITIGATIONCONFIG"]._serialized_end =
17536
- _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_start = 17539
- _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_end = 17886
- _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_start = 17907
- _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_end = 17997
- _globals["_CLONESESSIONREQUEST"]._serialized_start = 18000
- _globals["_CLONESESSIONREQUEST"]._serialized_end = 18362
- _globals["_CLONESESSIONRESPONSE"]._serialized_start = 18365
- _globals["_CLONESESSIONRESPONSE"]._serialized_end = 18569
- _globals["_GETSTATUSREQUEST"]._serialized_start = 18572
- _globals["_GETSTATUSREQUEST"]._serialized_end = 19167
- _globals["_GETSTATUSREQUEST_OPERATIONSTATUSREQUEST"]._serialized_start =
18971
- _globals["_GETSTATUSREQUEST_OPERATIONSTATUSREQUEST"]._serialized_end =
19087
- _globals["_GETSTATUSRESPONSE"]._serialized_start = 19170
- _globals["_GETSTATUSRESPONSE"]._serialized_end = 19855
- _globals["_GETSTATUSRESPONSE_OPERATIONSTATUS"]._serialized_start = 19428
- _globals["_GETSTATUSRESPONSE_OPERATIONSTATUS"]._serialized_end = 19855
-
_globals["_GETSTATUSRESPONSE_OPERATIONSTATUS_OPERATIONSTATE"]._serialized_start
= 19625
-
_globals["_GETSTATUSRESPONSE_OPERATIONSTATUS_OPERATIONSTATE"]._serialized_end =
19855
- _globals["_SPARKCONNECTSERVICE"]._serialized_start = 19941
- _globals["_SPARKCONNECTSERVICE"]._serialized_end = 21060
+ ]._serialized_end = 17302
+
_globals["_FETCHERRORDETAILSRESPONSE_BREAKINGCHANGEINFO"]._serialized_start =
17360
+ _globals["_FETCHERRORDETAILSRESPONSE_BREAKINGCHANGEINFO"]._serialized_end
= 17610
+ _globals["_FETCHERRORDETAILSRESPONSE_MITIGATIONCONFIG"]._serialized_start
= 17612
+ _globals["_FETCHERRORDETAILSRESPONSE_MITIGATIONCONFIG"]._serialized_end =
17670
+ _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_start = 17673
+ _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_end = 18020
+ _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_start = 18041
+ _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_end = 18131
+ _globals["_CLONESESSIONREQUEST"]._serialized_start = 18134
+ _globals["_CLONESESSIONREQUEST"]._serialized_end = 18496
+ _globals["_CLONESESSIONRESPONSE"]._serialized_start = 18499
+ _globals["_CLONESESSIONRESPONSE"]._serialized_end = 18703
+ _globals["_GETSTATUSREQUEST"]._serialized_start = 18706
+ _globals["_GETSTATUSREQUEST"]._serialized_end = 19301
+ _globals["_GETSTATUSREQUEST_OPERATIONSTATUSREQUEST"]._serialized_start =
19105
+ _globals["_GETSTATUSREQUEST_OPERATIONSTATUSREQUEST"]._serialized_end =
19221
+ _globals["_GETSTATUSRESPONSE"]._serialized_start = 19304
+ _globals["_GETSTATUSRESPONSE"]._serialized_end = 19989
+ _globals["_GETSTATUSRESPONSE_OPERATIONSTATUS"]._serialized_start = 19562
+ _globals["_GETSTATUSRESPONSE_OPERATIONSTATUS"]._serialized_end = 19989
+
_globals["_GETSTATUSRESPONSE_OPERATIONSTATUS_OPERATIONSTATE"]._serialized_start
= 19759
+
_globals["_GETSTATUSRESPONSE_OPERATIONSTATUS_OPERATIONSTATE"]._serialized_end =
19989
+ _globals["_SPARKCONNECTSERVICE"]._serialized_start = 20075
+ _globals["_SPARKCONNECTSERVICE"]._serialized_end = 21194
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 00e98535047c..6650704799e3 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -1588,6 +1588,8 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
VALUES_FIELD_NUMBER: builtins.int
KEYS_FIELD_NUMBER: builtins.int
PLAN_ID_FIELD_NUMBER: builtins.int
+ ROOT_ERROR_IDX_FIELD_NUMBER: builtins.int
+ ERRORS_FIELD_NUMBER: builtins.int
name: builtins.str
@property
def values(
@@ -1600,6 +1602,19 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
self,
) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
...
plan_id: builtins.int
+ root_error_idx: builtins.int
+ """(Optional) The index of the root error in errors.
+ The field will not be set if there are no errors.
+ """
+ @property
+ def errors(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ global___FetchErrorDetailsResponse.Error
+ ]:
+ """A list of errors that occurred while collecting the observed
metrics.
+ If the length is 0, it means no errors occurred.
+ """
def __init__(
self,
*,
@@ -1610,13 +1625,37 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
| None = ...,
keys: collections.abc.Iterable[builtins.str] | None = ...,
plan_id: builtins.int = ...,
+ root_error_idx: builtins.int | None = ...,
+ errors:
collections.abc.Iterable[global___FetchErrorDetailsResponse.Error] | None = ...,
) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_root_error_idx", b"_root_error_idx", "root_error_idx",
b"root_error_idx"
+ ],
+ ) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "keys", b"keys", "name", b"name", "plan_id", b"plan_id",
"values", b"values"
+ "_root_error_idx",
+ b"_root_error_idx",
+ "errors",
+ b"errors",
+ "keys",
+ b"keys",
+ "name",
+ b"name",
+ "plan_id",
+ b"plan_id",
+ "root_error_idx",
+ b"root_error_idx",
+ "values",
+ b"values",
],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_root_error_idx",
b"_root_error_idx"]
+ ) -> typing_extensions.Literal["root_error_idx"] | None: ...
class ResultComplete(google.protobuf.message.Message):
"""If present, in a reattachable execution this means that after
server sends onComplete,
diff --git a/python/pyspark/sql/tests/test_observation.py
b/python/pyspark/sql/tests/test_observation.py
index f7a8b20a66ce..66a7bdb79891 100644
--- a/python/pyspark/sql/tests/test_observation.py
+++ b/python/pyspark/sql/tests/test_observation.py
@@ -19,6 +19,7 @@ from pyspark.sql import Row, Observation, functions as F
from pyspark.sql.types import StructType, LongType
from pyspark.errors import (
PySparkAssertionError,
+ PySparkException,
PySparkTypeError,
PySparkValueError,
)
@@ -238,6 +239,24 @@ class DataFrameObservationTestsMixin:
self.assertEqual(observation.get, {"map": {"count": 10}})
+ def test_observation_errors_propagated_to_client(self):
+ observation = Observation("test_observation")
+ observed_df = self.spark.range(10).observe(
+ observation,
+ F.sum("id").alias("sum_id"),
+ F.raise_error(F.lit("test error")).alias("raise_error"),
+ )
+ actual = observed_df.collect()
+ self.assertEqual(
+ [row.asDict() for row in actual],
+ [{"id": i} for i in range(10)],
+ )
+
+ with self.assertRaises(PySparkException) as cm:
+ _ = observation.get
+
+ self.assertIn("test error", str(cm.exception))
+
class DataFrameObservationTests(
DataFrameObservationTestsMixin,
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
index 41b90b934a28..37e5d9c0f9f8 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -25,6 +25,7 @@ import scala.concurrent.duration.Duration
import scala.jdk.CollectionConverters.MapHasAsJava
import scala.util.Try
+import org.apache.spark.SparkException
import org.apache.spark.util.SparkThreadUtils
/**
@@ -123,16 +124,6 @@ class Observation(val name: String) {
promise.tryComplete(metrics)
}
- /**
- * Get the observed metrics as a Row.
- *
- * @return
- * the observed metrics as a `Row`, or None if the metrics are not
available.
- */
- private[sql] def getRowOrEmpty: Option[Row] = {
- future.value.flatMap(_.toOption)
- }
-
/**
* Get the observed metrics as a Row.
*
@@ -140,7 +131,13 @@ class Observation(val name: String) {
* the observed metrics as a `Row`.
*/
private[sql] def getRow: Row = {
- SparkThreadUtils.awaitResult(future, Duration.Inf)
+ try {
+ SparkThreadUtils.awaitResult(future, Duration.Inf)
+ } catch {
+ case e: SparkException =>
+ // Throw the root cause since awaitResult wraps it in a SparkException.
+ throw e.getCause
+ }
}
}
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
index 52d87087805f..45bb2bb386c6 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
@@ -31,7 +31,7 @@ import org.apache.commons.io.output.TeeOutputStream
import org.scalactic.TolerantNumerics
import org.scalatest.PrivateMethodTester
-import org.apache.spark.{SparkArithmeticException, SparkException,
SparkUpgradeException}
+import org.apache.spark.{SparkArithmeticException, SparkException,
SparkRuntimeException, SparkUpgradeException}
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
import org.apache.spark.connect.proto
import org.apache.spark.internal.config.ConfigBuilder
@@ -1652,7 +1652,7 @@ class ClientE2ETestSuite
assert(metrics2 === Map("min(extra)" -> -1, "avg(extra)" -> 48,
"max(extra)" -> 97))
}
- test("SPARK-55150: observation errors leads to empty result in connect
mode") {
+ test("SPARK-55150: observation errors are propagated to client in connect
mode") {
val observation = Observation("test_observation")
val observed_df = spark
.range(10)
@@ -1661,9 +1661,14 @@ class ClientE2ETestSuite
sum("id").as("sum_id"),
raise_error(lit("test error")).as("raise_error"))
- observed_df.collect()
+ val actual = observed_df.collect()
+ assert(actual.toSeq === (0 until 10).map(_.toLong))
- assert(observation.get.isEmpty)
+ val exception = intercept[SparkRuntimeException] {
+ observation.get
+ }
+
+ assert(exception.getMessage.contains("test error"))
}
test("SPARK-48852: trim function on a string column returns correct
results") {
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/base.proto
b/sql/connect/common/src/main/protobuf/spark/connect/base.proto
index c2e8c5af2c79..c7247129f190 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -491,6 +491,12 @@ message ExecutePlanResponse {
repeated Expression.Literal values = 2;
repeated string keys = 3;
int64 plan_id = 4;
+ // (Optional) The index of the root error in errors.
+ // The field will not be set if there are no errors.
+ optional int32 root_error_idx = 5;
+ // A list of errors that occurred while collecting the observed metrics.
+ // If the length is 0, it means no errors occurred.
+ repeated FetchErrorDetailsResponse.Error errors = 6;
}
message ResultComplete {
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index be49f96a3958..dbb2630a52e0 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -742,12 +742,12 @@ class SparkSession private[sql] (
}
private def processRegisteredObservedMetrics(metrics:
java.util.List[ObservedMetrics]): Unit = {
- metrics.asScala.map { metric =>
+ metrics.asScala.foreach { metric =>
// Here we only process metrics that belong to a registered Observation
object.
// All metrics, whether registered or not, will be collected by
`SparkResult`.
val observationOrNull = observationRegistry.remove(metric.getPlanId)
if (observationOrNull != null) {
- val metricsResult = Try(SparkResult.transformObservedMetrics(metric))
+ val metricsResult = SparkResult.transformObservedMetrics(metric)
observationOrNull.setMetricsAndNotify(metricsResult)
}
}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
index 7b57f75d55ce..6e5304b8cc77 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
@@ -395,7 +395,7 @@ private[client] object GrpcExceptionConverter {
* FetchErrorDetailsResponse.Error with un-truncated error messages and
server-side stacktrace
* (if set).
*/
- private def errorsToThrowable(
+ private[client] def errorsToThrowable(
errorIdx: Int,
errors: Seq[FetchErrorDetailsResponse.Error]): Throwable = {
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 4199801d8505..2ab250673b76 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -22,6 +22,7 @@ import java.util.Objects
import scala.collection.mutable
import scala.jdk.CollectionConverters._
+import scala.util.{Failure, Success, Try}
import com.google.protobuf.ByteString
import org.apache.arrow.memory.BufferAllocator
@@ -86,7 +87,7 @@ private[sql] class SparkResult[T](
private[this] var arrowSchema: pojo.Schema = _
private[this] var nextResultIndex: Int = 0
private val resultMap = mutable.Map.empty[Int, (Long, Seq[ArrowMessage])]
- private val observedMetrics = mutable.Map.empty[String, Row]
+ private val observedMetrics = mutable.Map.empty[String, Try[Row]]
private val cleanable =
SparkResult.cleaner.register(this, new SparkResultCloseable(resultMap,
responses))
@@ -253,7 +254,7 @@ private[sql] class SparkResult[T](
}
private def processObservedMetrics(
- metrics: java.util.List[ObservedMetrics]): Iterable[(String, Row)] = {
+ metrics: java.util.List[ObservedMetrics]): Iterable[(String, Try[Row])]
= {
metrics.asScala.map { metric =>
metric.getName -> SparkResult.transformObservedMetrics(metric)
}
@@ -315,7 +316,7 @@ private[sql] class SparkResult[T](
def getObservedMetrics: Map[String, Row] = {
// We need to process all responses to get all metrics.
processResponses()
- observedMetrics.toMap
+ observedMetrics.view.mapValues(_.get).toMap
}
/**
@@ -421,18 +422,25 @@ private[sql] object SparkResult {
private val cleaner: Cleaner = Cleaner.create()
/** Return value is a Seq of pairs, to preserve the order of values. */
- private[sql] def transformObservedMetrics(metric: ObservedMetrics): Row = {
- assert(metric.getKeysCount == metric.getValuesCount)
- var schema = new StructType()
- val values = mutable.ArrayBuilder.make[Any]
- values.sizeHint(metric.getKeysCount)
- (0 until metric.getKeysCount).foreach { i =>
- val key = metric.getKeys(i)
- val value = LiteralValueProtoConverter.toScalaValue(metric.getValues(i))
- schema = schema.add(key,
LiteralValueProtoConverter.getDataType(metric.getValues(i)))
- values += value
+ private[sql] def transformObservedMetrics(metric: ObservedMetrics): Try[Row]
= {
+ // Check if the metric contains errors
+ if (metric.hasRootErrorIdx) {
+ Failure(
+ GrpcExceptionConverter
+ .errorsToThrowable(metric.getRootErrorIdx,
metric.getErrorsList.asScala.toSeq))
+ } else {
+ assert(metric.getKeysCount == metric.getValuesCount)
+ var schema = new StructType()
+ val values = mutable.ArrayBuilder.make[Any]
+ values.sizeHint(metric.getKeysCount)
+ (0 until metric.getKeysCount).foreach { i =>
+ val key = metric.getKeys(i)
+ val value =
LiteralValueProtoConverter.toScalaValue(metric.getValues(i))
+ schema = schema.add(key,
LiteralValueProtoConverter.getDataType(metric.getValues(i)))
+ values += value
+ }
+ Success(new GenericRowWithSchema(values.result(), schema))
}
- new GenericRowWithSchema(values.result(), schema)
}
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index f206ee1555a7..b7c335c6cfcf 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.execution
import java.util.concurrent.atomic.AtomicReference
import scala.jdk.CollectionConverters._
+import scala.util.{Success, Try}
import scala.util.control.NonFatal
import com.google.protobuf.Message
@@ -229,22 +230,20 @@ private[connect] class ExecuteThreadRunner(executeHolder:
ExecuteHolder) extends
executeHolder.request.getPlan.getDescriptorForType)
}
- val observedMetrics: Map[String, Seq[(Option[String], Any,
Option[DataType])]] = {
+ val observedMetrics: Map[String, Try[Seq[(Option[String], Any,
Option[DataType])]]] = {
executeHolder.observations.map { case (name, observation) =>
- val values =
- observation.getRowOrEmpty
- .map(SparkConnectPlanExecution.toObservedMetricsValues(_))
- .getOrElse(Seq.empty)
- name -> values
+ name -> observation.future.value
+ .map(_.map(SparkConnectPlanExecution.toObservedMetricsValues))
+ .getOrElse(Success(Seq.empty))
}.toMap
}
- val accumulatedInPython: Map[String, Seq[(Option[String], Any,
Option[DataType])]] = {
+ val accumulatedInPython: Map[String, Try[Seq[(Option[String], Any,
Option[DataType])]]] = {
executeHolder.sessionHolder.pythonAccumulator.flatMap { accumulator =>
accumulator.synchronized {
val value = accumulator.value.asScala.toSeq
if (value.nonEmpty) {
accumulator.reset()
- Some("__python_accumulator__" -> value.map(value => (None,
value, None)))
+ Some("__python_accumulator__" -> Success(value.map(value =>
(None, value, None))))
} else {
None
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index 4332074228d9..0df21ebaceaf 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.connect.execution
import scala.concurrent.duration.Duration
import scala.jdk.CollectionConverters._
-import scala.util.{Failure, Success}
+import scala.util.{Failure, Success, Try}
import com.google.protobuf.ByteString
import io.grpc.stub.StreamObserver
@@ -36,7 +36,7 @@ import
org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralP
import
org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_ARROW_MAX_BATCH_SIZE,
CONNECT_SESSION_RESULT_CHUNKING_MAX_CHUNK_SIZE}
import org.apache.spark.sql.connect.planner.{InvalidInputErrors,
SparkConnectPlanner}
import org.apache.spark.sql.connect.service.ExecuteHolder
-import org.apache.spark.sql.connect.utils.{MetricGenerator,
PipelineAnalysisContextUtils}
+import org.apache.spark.sql.connect.utils.{ErrorUtils, MetricGenerator,
PipelineAnalysisContextUtils}
import org.apache.spark.sql.execution.{DoNotCleanup, LocalTableScanExec,
QueryExecution, RemoveShuffleFiles, SkipMigration, SQLExecution}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.internal.SQLConf
@@ -340,22 +340,30 @@ object SparkConnectPlanExecution {
sessionId: String,
serverSessionId: String,
observationAndPlanIds: Map[String, Long],
- metrics: Map[String, Seq[(Option[String], Any, Option[DataType])]]):
ExecutePlanResponse = {
- val observedMetrics = metrics.map { case (name, values) =>
- val metrics = ExecutePlanResponse.ObservedMetrics
+ metrics: Map[String, Try[Seq[(Option[String], Any, Option[DataType])]]])
+ : ExecutePlanResponse = {
+ val observedMetrics = metrics.map { case (name, result) =>
+ val metricsBuilder = ExecutePlanResponse.ObservedMetrics
.newBuilder()
.setName(name)
- values.foreach { case (keyOpt, value, dataTypeOpt) =>
- dataTypeOpt match {
- case Some(dataType) =>
- metrics.addValues(toLiteralProto(value, dataType))
- case None =>
- metrics.addValues(toLiteralProto(value))
- }
- keyOpt.foreach(metrics.addKeys)
+ result match {
+ case Success(values) =>
+ values.foreach { case (keyOpt, value, dataTypeOpt) =>
+ dataTypeOpt match {
+ case Some(dataType) =>
+ metricsBuilder.addValues(toLiteralProto(value, dataType))
+ case None =>
+ metricsBuilder.addValues(toLiteralProto(value))
+ }
+ keyOpt.foreach(metricsBuilder.addKeys)
+ }
+ case Failure(throwable) =>
+ val (rootErrorIdx, errors) =
ErrorUtils.throwableToProtoErrors(throwable)
+ metricsBuilder.setRootErrorIdx(rootErrorIdx)
+ metricsBuilder.addAllErrors(errors.asJava)
}
- observationAndPlanIds.get(name).foreach(metrics.setPlanId)
- metrics.build()
+ observationAndPlanIds.get(name).foreach(metricsBuilder.setPlanId)
+ metricsBuilder.build()
}
// Prepare a response with the observed metrics.
ExecutePlanResponse
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 0c2c2940e685..40a1f647f7c9 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -67,17 +67,17 @@ private[connect] object ErrorUtils extends Logging {
private[connect] val MAX_ERROR_CHAIN_LENGTH = 5
/**
- * Convert Throwable to a protobuf message FetchErrorDetailsResponse.
+ * Convert Throwable to a sequence of protobuf Error messages.
* @param st
* the Throwable to be converted
* @param serverStackTraceEnabled
* whether to return the server stack trace.
* @return
- * FetchErrorDetailsResponse
+ * A tuple of (rootErrorIdx, sequence of FetchErrorDetailsResponse.Error)
*/
- private[connect] def throwableToFetchErrorDetailsResponse(
+ private[connect] def throwableToProtoErrors(
st: Throwable,
- serverStackTraceEnabled: Boolean = false): FetchErrorDetailsResponse = {
+ serverStackTraceEnabled: Boolean = false): (Int,
Seq[FetchErrorDetailsResponse.Error]) = {
var currentError = st
val buffer = mutable.Buffer.empty[FetchErrorDetailsResponse.Error]
@@ -177,10 +177,28 @@ private[connect] object ErrorUtils extends Logging {
currentError = currentError.getCause
}
+ (0, buffer.toSeq)
+ }
+
+ /**
+ * Convert Throwable to a protobuf message FetchErrorDetailsResponse.
+ * @param st
+ * the Throwable to be converted
+ * @param serverStackTraceEnabled
+ * whether to return the server stack trace.
+ * @return
+ * FetchErrorDetailsResponse
+ */
+ private[connect] def throwableToFetchErrorDetailsResponse(
+ st: Throwable,
+ serverStackTraceEnabled: Boolean = false): FetchErrorDetailsResponse = {
+
+ val (rootErrorIdx, errors) = throwableToProtoErrors(st,
serverStackTraceEnabled)
+
FetchErrorDetailsResponse
.newBuilder()
- .setRootErrorIdx(0)
- .addAllErrors(buffer.asJava)
+ .setRootErrorIdx(rootErrorIdx)
+ .addAllErrors(errors.asJava)
.build()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index ef053d638701..1e99642f2e34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -31,7 +31,7 @@ import org.scalatest.Assertions._
import org.scalatest.exceptions.TestFailedException
import org.scalatest.prop.TableDrivenPropertyChecks._
-import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException,
SparkUnsupportedOperationException, TaskContext}
+import org.apache.spark.{SparkConf, SparkRuntimeException,
SparkUnsupportedOperationException, TaskContext}
import org.apache.spark.TestUtils.withListener
import org.apache.spark.internal.config.MAX_RESULT_SIZE
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
@@ -1159,7 +1159,7 @@ class DatasetSuite extends QueryTest
assert(namedObservation2.get === expected2)
}
- test("SPARK-55150: observation errors are threw in Obseravtion.get in
classic mode") {
+ test("SPARK-55150: observation errors are thrown in Observation.get in
classic mode") {
val observation = Observation("test_observation")
val observed_df = spark.range(10).observe(
observation,
@@ -1167,13 +1167,14 @@ class DatasetSuite extends QueryTest
raise_error(lit("test error")).as("raise_error")
)
- observed_df.collect()
+ val actual = observed_df.collect()
+ assert(actual.toSeq === (0 until 10).map(_.toLong))
- val exception = intercept[SparkException] {
+ val exception = intercept[SparkRuntimeException] {
observation.get
}
- assert(exception.getCause.getMessage.contains("test error"))
+ assert(exception.getMessage.contains("test error"))
}
test("sample with replacement") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]