This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0ccaacf54ce3 [SPARK-54194][CONNECT] Spark Connect Proto Plan
Compression
0ccaacf54ce3 is described below
commit 0ccaacf54ce33db16f0b675e81a61b9b6a595479
Author: Xi Lyu <[email protected]>
AuthorDate: Tue Nov 11 09:11:04 2025 -0400
[SPARK-54194][CONNECT] Spark Connect Proto Plan Compression
### What changes were proposed in this pull request?
Currently, Spark Connect enforce gRPC message limits on both the client and
the server. These limits are largely meant to protect the server from potential
OOMs by rejecting abnormally large messages. However, there are several cases
where genuine messages exceed the limit and cause execution failures.
To improve Spark Connect stability, this PR implements compressing
unresolved proto plans to mitigate the issue of oversized messages from the
client to the server. The compression applies to ExecutePlan and AnalyzePlan -
the only two methods that might hit the message limit. The other issue of
message limit from the server to the client is a different issue, and it’s out
of the scope (that one is already fixed in
https://github.com/apache/spark/pull/52271).
In the implementation,
* Zstandard is leveraged to compress proto plan as it has consistent high
performance in our benchmark and achieves a good balance between compression
ratio and performance.
* The config `spark.connect.maxPlanSize` is introduced to control the
maximum size of a (decompressed) proto plan that can be executed in Spark
Connect. It is mainly used to avoid decompression bomb attacks.
(Scala client changes are being implemented in a follow-up PR.)
To reproduce the existing issue we are solving here, run this code on Spark
Connect:
```
import random
import string
def random_letters(length: int) -> str:
return ''.join(random.choices(string.ascii_letters, k=length))
num_unique_small_relations = 5
size_per_small_relation = 512 * 1024
small_dfs =
[spark.createDataFrame([(random_letters(size_per_small_relation),)],) for _ in
range(num_unique_small_relations)]
result_df = small_dfs[0]
for _ in range(512):
result_df = result_df.unionByName(small_dfs[random.randint(0,
len(small_dfs) - 1)])
result_df.collect()
```
It fails with `StatusCode.RESOURCE_EXHAUSTED` error with message`Sent
message larger than max (269178955 vs. 134217728)`, because the client was
trying to send a too large message to the server.
Note: repeated small local relations is just one way causing a large plan,
the size of the plan can also be contributed by repeated subtrees of plan
transformations, serialized UDFs, captured external variables by UDFs, etc.
With the improvement introduced by the PR, the above code runs successfully
and prints the expected result.
### Why are the changes needed?
It improves Spark Connect stability when executing and analyzing large
plans.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New tests on both the server side and the client side.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52894 from xi-db/plan-compression.
Authored-by: Xi Lyu <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 22 ++
dev/requirements.txt | 1 +
dev/spark-test-image/lint/Dockerfile | 1 +
dev/spark-test-image/numpy-213/Dockerfile | 2 +-
dev/spark-test-image/python-310/Dockerfile | 2 +-
dev/spark-test-image/python-311/Dockerfile | 2 +-
dev/spark-test-image/python-312/Dockerfile | 2 +-
dev/spark-test-image/python-313-nogil/Dockerfile | 2 +-
dev/spark-test-image/python-313/Dockerfile | 2 +-
dev/spark-test-image/python-314/Dockerfile | 2 +-
dev/spark-test-image/python-minimum/Dockerfile | 2 +-
dev/spark-test-image/python-ps-minimum/Dockerfile | 2 +-
python/packaging/classic/setup.py | 3 +
python/packaging/client/setup.py | 2 +
python/packaging/connect/setup.py | 2 +
python/pyspark/sql/connect/client/core.py | 131 ++++++-
python/pyspark/sql/connect/plan.py | 3 +-
python/pyspark/sql/connect/proto/base_pb2.py | 384 +++++++++++----------
python/pyspark/sql/connect/proto/base_pb2.pyi | 93 ++++-
python/pyspark/sql/connect/utils.py | 17 +
.../sql/tests/connect/test_connect_basic.py | 27 ++
.../pyspark/sql/tests/connect/test_connect_plan.py | 2 +-
python/pyspark/testing/connectutils.py | 13 +
.../src/main/protobuf/spark/connect/base.proto | 26 +-
.../apache/spark/sql/connect/config/Connect.scala | 31 ++
.../config/ConnectPlanCompressionAlgorithm.scala | 21 ++
.../service/SparkConnectAnalyzeHandler.scala | 22 +-
.../service/SparkConnectExecutePlanHandler.scala | 13 +-
.../sql/connect/utils/PlanCompressionUtils.scala | 118 +++++++
.../service/SparkConnectServiceE2ESuite.scala | 162 +++++++++
30 files changed, 889 insertions(+), 223 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index a13a4694c668..57ed891087f2 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -906,6 +906,28 @@
},
"sqlState" : "56K00"
},
+ "CONNECT_INVALID_PLAN" : {
+ "message" : [
+ "The Spark Connect plan is invalid."
+ ],
+ "subClass" : {
+ "CANNOT_PARSE" : {
+ "message" : [
+ "Cannot decompress or parse the input plan (<errorMsg>)",
+ "This may be caused by a corrupted compressed plan.",
+ "To disable plan compression, set
'spark.connect.session.planCompression.threshold' to -1."
+ ]
+ },
+ "PLAN_SIZE_LARGER_THAN_MAX" : {
+ "message" : [
+ "The plan size is larger than max (<planSize> vs. <maxPlanSize>)",
+ "This typically occurs when building very complex queries with many
operations, large literals, or deeply nested expressions.",
+ "Consider splitting the query into smaller parts using temporary
views for intermediate results or reducing the number of operations."
+ ]
+ }
+ },
+ "sqlState" : "56K00"
+ },
"CONNECT_ML" : {
"message" : [
"Generic Spark Connect ML error."
diff --git a/dev/requirements.txt b/dev/requirements.txt
index ddaeb9b3dd9d..cde0957715bf 100644
--- a/dev/requirements.txt
+++ b/dev/requirements.txt
@@ -65,6 +65,7 @@ grpcio>=1.76.0
grpcio-status>=1.76.0
googleapis-common-protos>=1.71.0
protobuf==6.33.0
+zstandard>=0.25.0
# Spark Connect python proto generation plugin (optional)
mypy-protobuf==3.3.0
diff --git a/dev/spark-test-image/lint/Dockerfile
b/dev/spark-test-image/lint/Dockerfile
index 4dfceae63a17..6ab571bf35d6 100644
--- a/dev/spark-test-image/lint/Dockerfile
+++ b/dev/spark-test-image/lint/Dockerfile
@@ -84,6 +84,7 @@ RUN python3.11 -m pip install \
'grpc-stubs==1.24.11' \
'grpcio-status==1.76.0' \
'grpcio==1.76.0' \
+ 'zstandard==0.25.0' \
'ipython' \
'ipython_genutils' \
'jinja2' \
diff --git a/dev/spark-test-image/numpy-213/Dockerfile
b/dev/spark-test-image/numpy-213/Dockerfile
index bc9a507853c2..713e9e7d7ef4 100644
--- a/dev/spark-test-image/numpy-213/Dockerfile
+++ b/dev/spark-test-image/numpy-213/Dockerfile
@@ -71,7 +71,7 @@ RUN apt-get update && apt-get install -y \
# Pin numpy==2.1.3
ARG BASIC_PIP_PKGS="numpy==2.1.3 pyarrow>=22.0.0 six==1.16.0 pandas==2.2.3
scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl
memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.11 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
diff --git a/dev/spark-test-image/python-310/Dockerfile
b/dev/spark-test-image/python-310/Dockerfile
index c318a615b7e0..9b5b18d061c2 100644
--- a/dev/spark-test-image/python-310/Dockerfile
+++ b/dev/spark-test-image/python-310/Dockerfile
@@ -66,7 +66,7 @@ RUN apt-get update && apt-get install -y \
ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy
plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0
scikit-learn>=1.3.2"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.10 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
diff --git a/dev/spark-test-image/python-311/Dockerfile
b/dev/spark-test-image/python-311/Dockerfile
index 69d47e62774a..f8a9df5842ce 100644
--- a/dev/spark-test-image/python-311/Dockerfile
+++ b/dev/spark-test-image/python-311/Dockerfile
@@ -70,7 +70,7 @@ RUN apt-get update && apt-get install -y \
ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy
plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0
scikit-learn>=1.3.2"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.11 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
diff --git a/dev/spark-test-image/python-312/Dockerfile
b/dev/spark-test-image/python-312/Dockerfile
index 0c8b816f8629..ca62bc5ebc61 100644
--- a/dev/spark-test-image/python-312/Dockerfile
+++ b/dev/spark-test-image/python-312/Dockerfile
@@ -70,7 +70,7 @@ RUN apt-get update && apt-get install -y \
ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy
plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0
scikit-learn>=1.3.2"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.12 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12
diff --git a/dev/spark-test-image/python-313-nogil/Dockerfile
b/dev/spark-test-image/python-313-nogil/Dockerfile
index 1262089f43e1..b6e2dd7c80a9 100644
--- a/dev/spark-test-image/python-313-nogil/Dockerfile
+++ b/dev/spark-test-image/python-313-nogil/Dockerfile
@@ -69,7 +69,7 @@ RUN apt-get update && apt-get install -y \
ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy
plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0
scikit-learn>=1.3.2"
-ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.13 packages
diff --git a/dev/spark-test-image/python-313/Dockerfile
b/dev/spark-test-image/python-313/Dockerfile
index 2e4dde33077d..bd64ecb31087 100644
--- a/dev/spark-test-image/python-313/Dockerfile
+++ b/dev/spark-test-image/python-313/Dockerfile
@@ -70,7 +70,7 @@ RUN apt-get update && apt-get install -y \
ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy
plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0
scikit-learn>=1.3.2"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.13 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13
diff --git a/dev/spark-test-image/python-314/Dockerfile
b/dev/spark-test-image/python-314/Dockerfile
index 07916fc35a0d..f3da21e005b3 100644
--- a/dev/spark-test-image/python-314/Dockerfile
+++ b/dev/spark-test-image/python-314/Dockerfile
@@ -70,7 +70,7 @@ RUN apt-get update && apt-get install -y \
ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy
plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0
scikit-learn>=1.3.2"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0
googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.14 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.14
diff --git a/dev/spark-test-image/python-minimum/Dockerfile
b/dev/spark-test-image/python-minimum/Dockerfile
index 627dccdf34b1..575b4afdd02c 100644
--- a/dev/spark-test-image/python-minimum/Dockerfile
+++ b/dev/spark-test-image/python-minimum/Dockerfile
@@ -64,7 +64,7 @@ RUN apt-get update && apt-get install -y \
ARG BASIC_PIP_PKGS="numpy==1.22.4 pyarrow==15.0.0 pandas==2.2.0 six==1.16.0
scipy scikit-learn coverage unittest-xml-reporting"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0
googleapis-common-protos==1.71.0 graphviz==0.20 protobuf"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0
googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20 protobuf"
# Install Python 3.10 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
diff --git a/dev/spark-test-image/python-ps-minimum/Dockerfile
b/dev/spark-test-image/python-ps-minimum/Dockerfile
index 13a5f2db386c..5142d46cc3eb 100644
--- a/dev/spark-test-image/python-ps-minimum/Dockerfile
+++ b/dev/spark-test-image/python-ps-minimum/Dockerfile
@@ -65,7 +65,7 @@ RUN apt-get update && apt-get install -y \
ARG BASIC_PIP_PKGS="pyarrow==15.0.0 pandas==2.2.0 six==1.16.0 numpy scipy
coverage unittest-xml-reporting"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0
googleapis-common-protos==1.71.0 graphviz==0.20 protobuf"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0
googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20 protobuf"
# Install Python 3.10 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
diff --git a/python/packaging/classic/setup.py
b/python/packaging/classic/setup.py
index e6ac729f20d6..54ec4abe3be9 100755
--- a/python/packaging/classic/setup.py
+++ b/python/packaging/classic/setup.py
@@ -156,6 +156,7 @@ _minimum_pyarrow_version = "15.0.0"
_minimum_grpc_version = "1.76.0"
_minimum_googleapis_common_protos_version = "1.71.0"
_minimum_pyyaml_version = "3.11"
+_minimum_zstandard_version = "0.25.0"
class InstallCommand(install):
@@ -366,6 +367,7 @@ try:
"grpcio>=%s" % _minimum_grpc_version,
"grpcio-status>=%s" % _minimum_grpc_version,
"googleapis-common-protos>=%s" %
_minimum_googleapis_common_protos_version,
+ "zstandard>=%s" % _minimum_zstandard_version,
"numpy>=%s" % _minimum_numpy_version,
],
"pipelines": [
@@ -375,6 +377,7 @@ try:
"grpcio>=%s" % _minimum_grpc_version,
"grpcio-status>=%s" % _minimum_grpc_version,
"googleapis-common-protos>=%s" %
_minimum_googleapis_common_protos_version,
+ "zstandard>=%s" % _minimum_zstandard_version,
"pyyaml>=%s" % _minimum_pyyaml_version,
],
},
diff --git a/python/packaging/client/setup.py b/python/packaging/client/setup.py
index b9dfce056a06..ee404210f293 100755
--- a/python/packaging/client/setup.py
+++ b/python/packaging/client/setup.py
@@ -139,6 +139,7 @@ try:
_minimum_grpc_version = "1.76.0"
_minimum_googleapis_common_protos_version = "1.71.0"
_minimum_pyyaml_version = "3.11"
+ _minimum_zstandard_version = "0.25.0"
with open("README.md") as f:
long_description = f.read()
@@ -211,6 +212,7 @@ try:
"grpcio>=%s" % _minimum_grpc_version,
"grpcio-status>=%s" % _minimum_grpc_version,
"googleapis-common-protos>=%s" %
_minimum_googleapis_common_protos_version,
+ "zstandard>=%s" % _minimum_zstandard_version,
"numpy>=%s" % _minimum_numpy_version,
"pyyaml>=%s" % _minimum_pyyaml_version,
],
diff --git a/python/packaging/connect/setup.py
b/python/packaging/connect/setup.py
index 03915f286cee..9a1a4ea81255 100755
--- a/python/packaging/connect/setup.py
+++ b/python/packaging/connect/setup.py
@@ -92,6 +92,7 @@ try:
_minimum_grpc_version = "1.76.0"
_minimum_googleapis_common_protos_version = "1.71.0"
_minimum_pyyaml_version = "3.11"
+ _minimum_zstandard_version = "0.25.0"
with open("README.md") as f:
long_description = f.read()
@@ -121,6 +122,7 @@ try:
"grpcio>=%s" % _minimum_grpc_version,
"grpcio-status>=%s" % _minimum_grpc_version,
"googleapis-common-protos>=%s" %
_minimum_googleapis_common_protos_version,
+ "zstandard>=%s" % _minimum_zstandard_version,
"numpy>=%s" % _minimum_numpy_version,
"pyyaml>=%s" % _minimum_pyyaml_version,
],
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 6f92af394913..6bff531c23d4 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -34,6 +34,7 @@ import platform
import urllib.parse
import uuid
import sys
+import time
from typing import (
Iterable,
Iterator,
@@ -113,6 +114,19 @@ if TYPE_CHECKING:
from pyspark.sql.datasource import DataSource
+def _import_zstandard_if_available() -> Optional[Any]:
+ """
+ Import zstandard if available, otherwise return None.
+ This is used to handle the case when zstandard is not installed.
+ """
+ try:
+ import zstandard
+
+ return zstandard
+ except ImportError:
+ return None
+
+
class ChannelBuilder:
"""
This is a helper class that is used to create a GRPC channel based on the
given
@@ -706,6 +720,10 @@ class SparkConnectClient(object):
self._progress_handlers: List[ProgressHandler] = []
+ self._zstd_module = _import_zstandard_if_available()
+ self._plan_compression_threshold: Optional[int] = None # Will be
fetched lazily
+ self._plan_compression_algorithm: Optional[str] = None # Will be
fetched lazily
+
# cleanup ml cache if possible
atexit.register(self._cleanup_ml_cache)
@@ -1156,7 +1174,7 @@ class SparkConnectClient(object):
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
- req.plan.command.CopyFrom(command)
+ self._set_command_in_plan(req.plan, command)
data, _, metrics, observed_metrics, properties =
self._execute_and_fetch(
req, observations or {}
)
@@ -1182,7 +1200,7 @@ class SparkConnectClient(object):
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
- req.plan.command.CopyFrom(command)
+ self._set_command_in_plan(req.plan, command)
for response in self._execute_and_fetch_as_iterator(req, observations
or {}):
if isinstance(response, dict):
yield response
@@ -1963,6 +1981,17 @@ class SparkConnectClient(object):
if info.metadata.get("errorClass") ==
"INVALID_HANDLE.SESSION_CHANGED":
self._closed = True
+ if info.metadata.get("errorClass") ==
"CONNECT_INVALID_PLAN.CANNOT_PARSE":
+ # Disable plan compression if the server fails to
interpret the plan.
+ logger.info(
+ "Disabling plan compression for the session due to
"
+ "CONNECT_INVALID_PLAN.CANNOT_PARSE error."
+ )
+ self._plan_compression_threshold,
self._plan_compression_algorithm = (
+ -1,
+ "NONE",
+ )
+
raise convert_exception(
info,
status.message,
@@ -2112,6 +2141,104 @@ class SparkConnectClient(object):
ml_command_result = properties["ml_command_result"]
return ml_command_result.param.long
+ def _set_relation_in_plan(self, plan: pb2.Plan, relation: pb2.Relation) ->
None:
+ """Sets the relation in the plan, attempting compression if
configured."""
+ self._try_compress_and_set_plan(
+ plan=plan,
+ message=relation,
+ op_type=pb2.Plan.CompressedOperation.OpType.OP_TYPE_RELATION,
+ )
+
+ def _set_command_in_plan(self, plan: pb2.Plan, command: pb2.Command) ->
None:
+ """Sets the command in the plan, attempting compression if
configured."""
+ self._try_compress_and_set_plan(
+ plan=plan,
+ message=command,
+ op_type=pb2.Plan.CompressedOperation.OpType.OP_TYPE_COMMAND,
+ )
+
+ def _try_compress_and_set_plan(
+ self,
+ plan: pb2.Plan,
+ message: google.protobuf.message.Message,
+ op_type: pb2.Plan.CompressedOperation.OpType.ValueType,
+ ) -> None:
+ """
+ Tries to compress a protobuf message and sets it on the plan.
+ If compression is not enabled, not effective, or not available,
+ it falls back to the original message.
+ """
+ (
+ plan_compression_threshold,
+ plan_compression_algorithm,
+ ) = self._get_plan_compression_threshold_and_algorithm()
+ plan_compression_enabled = (
+ plan_compression_threshold is not None
+ and plan_compression_threshold >= 0
+ and plan_compression_algorithm is not None
+ and plan_compression_algorithm != "NONE"
+ )
+ if plan_compression_enabled:
+ serialized_msg = message.SerializeToString()
+ original_size = len(serialized_msg)
+ if (
+ original_size > plan_compression_threshold
+ and plan_compression_algorithm == "ZSTD"
+ and self._zstd_module
+ ):
+ start_time = time.time()
+ compressed_operation = pb2.Plan.CompressedOperation(
+ data=self._zstd_module.compress(serialized_msg),
+ op_type=op_type,
+
compression_codec=pb2.CompressionCodec.COMPRESSION_CODEC_ZSTD,
+ )
+ duration = time.time() - start_time
+ compressed_size = len(compressed_operation.data)
+ logger.debug(
+ f"Plan compression: original_size={original_size}, "
+ f"compressed_size={compressed_size}, "
+ f"saving_ratio={1 - compressed_size / original_size:.2f}, "
+ f"duration_s={duration:.1f}"
+ )
+ if compressed_size < original_size:
+ plan.compressed_operation.CopyFrom(compressed_operation)
+ return
+ else:
+ logger.debug("Plan compression not effective. Using
original plan.")
+
+ if op_type == pb2.Plan.CompressedOperation.OpType.OP_TYPE_RELATION:
+ plan.root.CopyFrom(message) # type: ignore[arg-type]
+ else:
+ plan.command.CopyFrom(message) # type: ignore[arg-type]
+
+ def _get_plan_compression_threshold_and_algorithm(self) -> Tuple[int, str]:
+ if self._plan_compression_threshold is None or
self._plan_compression_algorithm is None:
+ try:
+ (
+ plan_compression_threshold_str,
+ self._plan_compression_algorithm,
+ ) = self.get_configs(
+ "spark.connect.session.planCompression.threshold",
+ "spark.connect.session.planCompression.defaultAlgorithm",
+ )
+ self._plan_compression_threshold = (
+ int(plan_compression_threshold_str) if
plan_compression_threshold_str else -1
+ )
+ logger.debug(
+ f"Plan compression threshold:
{self._plan_compression_threshold}, "
+ f"algorithm: {self._plan_compression_algorithm}"
+ )
+ except Exception as e:
+ self._plan_compression_threshold = -1
+ self._plan_compression_algorithm = "NONE"
+ logger.debug(
+ "Plan compression is disabled because the server does not
support it.", e
+ )
+ return (
+ self._plan_compression_threshold,
+ self._plan_compression_algorithm,
+ ) # type: ignore[return-value]
+
def clone(self, new_session_id: Optional[str] = None) ->
"SparkConnectClient":
"""
Clone this client session on the server side. The server-side session
is cloned with
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 82a6326c7dc5..02fe7176b6fe 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -143,7 +143,8 @@ class LogicalPlan:
if enabled, the proto plan will be printed.
"""
plan = proto.Plan()
- plan.root.CopyFrom(self.plan(session))
+ relation = self.plan(session)
+ session._set_relation_in_plan(plan, relation)
if debug:
print(plan)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py
b/python/pyspark/sql/connect/proto/base_pb2.py
index 32bf6802df7b..32b2840dffad 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -45,7 +45,7 @@ from pyspark.sql.connect.proto import pipelines_pb2 as
spark_dot_connect_dot_pip
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto\x1a\x16spark/connect/ml.proto\x1a\x1dspark/connect/pipelines.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.Com [...]
+
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto\x1a\x16spark/connect/ml.proto\x1a\x1dspark/connect/pipelines.proto"\xe3\x03\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.conn [...]
)
_globals = globals()
@@ -70,200 +70,206 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals[
"_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY"
]._serialized_options = b"8\001"
- _globals["_PLAN"]._serialized_start = 274
- _globals["_PLAN"]._serialized_end = 390
- _globals["_USERCONTEXT"]._serialized_start = 392
- _globals["_USERCONTEXT"]._serialized_end = 514
- _globals["_ANALYZEPLANREQUEST"]._serialized_start = 517
- _globals["_ANALYZEPLANREQUEST"]._serialized_end = 3194
- _globals["_ANALYZEPLANREQUEST_SCHEMA"]._serialized_start = 1879
- _globals["_ANALYZEPLANREQUEST_SCHEMA"]._serialized_end = 1928
- _globals["_ANALYZEPLANREQUEST_EXPLAIN"]._serialized_start = 1931
- _globals["_ANALYZEPLANREQUEST_EXPLAIN"]._serialized_end = 2246
- _globals["_ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE"]._serialized_start =
2074
- _globals["_ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE"]._serialized_end = 2246
- _globals["_ANALYZEPLANREQUEST_TREESTRING"]._serialized_start = 2248
- _globals["_ANALYZEPLANREQUEST_TREESTRING"]._serialized_end = 2338
- _globals["_ANALYZEPLANREQUEST_ISLOCAL"]._serialized_start = 2340
- _globals["_ANALYZEPLANREQUEST_ISLOCAL"]._serialized_end = 2390
- _globals["_ANALYZEPLANREQUEST_ISSTREAMING"]._serialized_start = 2392
- _globals["_ANALYZEPLANREQUEST_ISSTREAMING"]._serialized_end = 2446
- _globals["_ANALYZEPLANREQUEST_INPUTFILES"]._serialized_start = 2448
- _globals["_ANALYZEPLANREQUEST_INPUTFILES"]._serialized_end = 2501
- _globals["_ANALYZEPLANREQUEST_SPARKVERSION"]._serialized_start = 2503
- _globals["_ANALYZEPLANREQUEST_SPARKVERSION"]._serialized_end = 2517
- _globals["_ANALYZEPLANREQUEST_DDLPARSE"]._serialized_start = 2519
- _globals["_ANALYZEPLANREQUEST_DDLPARSE"]._serialized_end = 2560
- _globals["_ANALYZEPLANREQUEST_SAMESEMANTICS"]._serialized_start = 2562
- _globals["_ANALYZEPLANREQUEST_SAMESEMANTICS"]._serialized_end = 2683
- _globals["_ANALYZEPLANREQUEST_SEMANTICHASH"]._serialized_start = 2685
- _globals["_ANALYZEPLANREQUEST_SEMANTICHASH"]._serialized_end = 2740
- _globals["_ANALYZEPLANREQUEST_PERSIST"]._serialized_start = 2743
- _globals["_ANALYZEPLANREQUEST_PERSIST"]._serialized_end = 2894
- _globals["_ANALYZEPLANREQUEST_UNPERSIST"]._serialized_start = 2896
- _globals["_ANALYZEPLANREQUEST_UNPERSIST"]._serialized_end = 3006
- _globals["_ANALYZEPLANREQUEST_GETSTORAGELEVEL"]._serialized_start = 3008
- _globals["_ANALYZEPLANREQUEST_GETSTORAGELEVEL"]._serialized_end = 3078
- _globals["_ANALYZEPLANREQUEST_JSONTODDL"]._serialized_start = 3080
- _globals["_ANALYZEPLANREQUEST_JSONTODDL"]._serialized_end = 3124
- _globals["_ANALYZEPLANRESPONSE"]._serialized_start = 3197
- _globals["_ANALYZEPLANRESPONSE"]._serialized_end = 5063
- _globals["_ANALYZEPLANRESPONSE_SCHEMA"]._serialized_start = 4438
- _globals["_ANALYZEPLANRESPONSE_SCHEMA"]._serialized_end = 4495
- _globals["_ANALYZEPLANRESPONSE_EXPLAIN"]._serialized_start = 4497
- _globals["_ANALYZEPLANRESPONSE_EXPLAIN"]._serialized_end = 4545
- _globals["_ANALYZEPLANRESPONSE_TREESTRING"]._serialized_start = 4547
- _globals["_ANALYZEPLANRESPONSE_TREESTRING"]._serialized_end = 4592
- _globals["_ANALYZEPLANRESPONSE_ISLOCAL"]._serialized_start = 4594
- _globals["_ANALYZEPLANRESPONSE_ISLOCAL"]._serialized_end = 4630
- _globals["_ANALYZEPLANRESPONSE_ISSTREAMING"]._serialized_start = 4632
- _globals["_ANALYZEPLANRESPONSE_ISSTREAMING"]._serialized_end = 4680
- _globals["_ANALYZEPLANRESPONSE_INPUTFILES"]._serialized_start = 4682
- _globals["_ANALYZEPLANRESPONSE_INPUTFILES"]._serialized_end = 4716
- _globals["_ANALYZEPLANRESPONSE_SPARKVERSION"]._serialized_start = 4718
- _globals["_ANALYZEPLANRESPONSE_SPARKVERSION"]._serialized_end = 4758
- _globals["_ANALYZEPLANRESPONSE_DDLPARSE"]._serialized_start = 4760
- _globals["_ANALYZEPLANRESPONSE_DDLPARSE"]._serialized_end = 4819
- _globals["_ANALYZEPLANRESPONSE_SAMESEMANTICS"]._serialized_start = 4821
- _globals["_ANALYZEPLANRESPONSE_SAMESEMANTICS"]._serialized_end = 4860
- _globals["_ANALYZEPLANRESPONSE_SEMANTICHASH"]._serialized_start = 4862
- _globals["_ANALYZEPLANRESPONSE_SEMANTICHASH"]._serialized_end = 4900
- _globals["_ANALYZEPLANRESPONSE_PERSIST"]._serialized_start = 2743
- _globals["_ANALYZEPLANRESPONSE_PERSIST"]._serialized_end = 2752
- _globals["_ANALYZEPLANRESPONSE_UNPERSIST"]._serialized_start = 2896
- _globals["_ANALYZEPLANRESPONSE_UNPERSIST"]._serialized_end = 2907
- _globals["_ANALYZEPLANRESPONSE_GETSTORAGELEVEL"]._serialized_start = 4926
- _globals["_ANALYZEPLANRESPONSE_GETSTORAGELEVEL"]._serialized_end = 5009
- _globals["_ANALYZEPLANRESPONSE_JSONTODDL"]._serialized_start = 5011
- _globals["_ANALYZEPLANRESPONSE_JSONTODDL"]._serialized_end = 5053
- _globals["_EXECUTEPLANREQUEST"]._serialized_start = 5066
- _globals["_EXECUTEPLANREQUEST"]._serialized_end = 5837
- _globals["_EXECUTEPLANREQUEST_REQUESTOPTION"]._serialized_start = 5500
- _globals["_EXECUTEPLANREQUEST_REQUESTOPTION"]._serialized_end = 5761
- _globals["_EXECUTEPLANRESPONSE"]._serialized_start = 5840
- _globals["_EXECUTEPLANRESPONSE"]._serialized_end = 9297
- _globals["_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT"]._serialized_start = 7940
- _globals["_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT"]._serialized_end = 8011
- _globals["_EXECUTEPLANRESPONSE_ARROWBATCH"]._serialized_start = 8014
- _globals["_EXECUTEPLANRESPONSE_ARROWBATCH"]._serialized_end = 8262
- _globals["_EXECUTEPLANRESPONSE_METRICS"]._serialized_start = 8265
- _globals["_EXECUTEPLANRESPONSE_METRICS"]._serialized_end = 8782
- _globals["_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT"]._serialized_start =
8360
- _globals["_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT"]._serialized_end =
8692
+ _globals["_COMPRESSIONCODEC"]._serialized_start = 18571
+ _globals["_COMPRESSIONCODEC"]._serialized_end = 18652
+ _globals["_PLAN"]._serialized_start = 275
+ _globals["_PLAN"]._serialized_end = 758
+ _globals["_PLAN_COMPRESSEDOPERATION"]._serialized_start = 477
+ _globals["_PLAN_COMPRESSEDOPERATION"]._serialized_end = 747
+ _globals["_PLAN_COMPRESSEDOPERATION_OPTYPE"]._serialized_start = 671
+ _globals["_PLAN_COMPRESSEDOPERATION_OPTYPE"]._serialized_end = 747
+ _globals["_USERCONTEXT"]._serialized_start = 760
+ _globals["_USERCONTEXT"]._serialized_end = 882
+ _globals["_ANALYZEPLANREQUEST"]._serialized_start = 885
+ _globals["_ANALYZEPLANREQUEST"]._serialized_end = 3562
+ _globals["_ANALYZEPLANREQUEST_SCHEMA"]._serialized_start = 2247
+ _globals["_ANALYZEPLANREQUEST_SCHEMA"]._serialized_end = 2296
+ _globals["_ANALYZEPLANREQUEST_EXPLAIN"]._serialized_start = 2299
+ _globals["_ANALYZEPLANREQUEST_EXPLAIN"]._serialized_end = 2614
+ _globals["_ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE"]._serialized_start =
2442
+ _globals["_ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE"]._serialized_end = 2614
+ _globals["_ANALYZEPLANREQUEST_TREESTRING"]._serialized_start = 2616
+ _globals["_ANALYZEPLANREQUEST_TREESTRING"]._serialized_end = 2706
+ _globals["_ANALYZEPLANREQUEST_ISLOCAL"]._serialized_start = 2708
+ _globals["_ANALYZEPLANREQUEST_ISLOCAL"]._serialized_end = 2758
+ _globals["_ANALYZEPLANREQUEST_ISSTREAMING"]._serialized_start = 2760
+ _globals["_ANALYZEPLANREQUEST_ISSTREAMING"]._serialized_end = 2814
+ _globals["_ANALYZEPLANREQUEST_INPUTFILES"]._serialized_start = 2816
+ _globals["_ANALYZEPLANREQUEST_INPUTFILES"]._serialized_end = 2869
+ _globals["_ANALYZEPLANREQUEST_SPARKVERSION"]._serialized_start = 2871
+ _globals["_ANALYZEPLANREQUEST_SPARKVERSION"]._serialized_end = 2885
+ _globals["_ANALYZEPLANREQUEST_DDLPARSE"]._serialized_start = 2887
+ _globals["_ANALYZEPLANREQUEST_DDLPARSE"]._serialized_end = 2928
+ _globals["_ANALYZEPLANREQUEST_SAMESEMANTICS"]._serialized_start = 2930
+ _globals["_ANALYZEPLANREQUEST_SAMESEMANTICS"]._serialized_end = 3051
+ _globals["_ANALYZEPLANREQUEST_SEMANTICHASH"]._serialized_start = 3053
+ _globals["_ANALYZEPLANREQUEST_SEMANTICHASH"]._serialized_end = 3108
+ _globals["_ANALYZEPLANREQUEST_PERSIST"]._serialized_start = 3111
+ _globals["_ANALYZEPLANREQUEST_PERSIST"]._serialized_end = 3262
+ _globals["_ANALYZEPLANREQUEST_UNPERSIST"]._serialized_start = 3264
+ _globals["_ANALYZEPLANREQUEST_UNPERSIST"]._serialized_end = 3374
+ _globals["_ANALYZEPLANREQUEST_GETSTORAGELEVEL"]._serialized_start = 3376
+ _globals["_ANALYZEPLANREQUEST_GETSTORAGELEVEL"]._serialized_end = 3446
+ _globals["_ANALYZEPLANREQUEST_JSONTODDL"]._serialized_start = 3448
+ _globals["_ANALYZEPLANREQUEST_JSONTODDL"]._serialized_end = 3492
+ _globals["_ANALYZEPLANRESPONSE"]._serialized_start = 3565
+ _globals["_ANALYZEPLANRESPONSE"]._serialized_end = 5431
+ _globals["_ANALYZEPLANRESPONSE_SCHEMA"]._serialized_start = 4806
+ _globals["_ANALYZEPLANRESPONSE_SCHEMA"]._serialized_end = 4863
+ _globals["_ANALYZEPLANRESPONSE_EXPLAIN"]._serialized_start = 4865
+ _globals["_ANALYZEPLANRESPONSE_EXPLAIN"]._serialized_end = 4913
+ _globals["_ANALYZEPLANRESPONSE_TREESTRING"]._serialized_start = 4915
+ _globals["_ANALYZEPLANRESPONSE_TREESTRING"]._serialized_end = 4960
+ _globals["_ANALYZEPLANRESPONSE_ISLOCAL"]._serialized_start = 4962
+ _globals["_ANALYZEPLANRESPONSE_ISLOCAL"]._serialized_end = 4998
+ _globals["_ANALYZEPLANRESPONSE_ISSTREAMING"]._serialized_start = 5000
+ _globals["_ANALYZEPLANRESPONSE_ISSTREAMING"]._serialized_end = 5048
+ _globals["_ANALYZEPLANRESPONSE_INPUTFILES"]._serialized_start = 5050
+ _globals["_ANALYZEPLANRESPONSE_INPUTFILES"]._serialized_end = 5084
+ _globals["_ANALYZEPLANRESPONSE_SPARKVERSION"]._serialized_start = 5086
+ _globals["_ANALYZEPLANRESPONSE_SPARKVERSION"]._serialized_end = 5126
+ _globals["_ANALYZEPLANRESPONSE_DDLPARSE"]._serialized_start = 5128
+ _globals["_ANALYZEPLANRESPONSE_DDLPARSE"]._serialized_end = 5187
+ _globals["_ANALYZEPLANRESPONSE_SAMESEMANTICS"]._serialized_start = 5189
+ _globals["_ANALYZEPLANRESPONSE_SAMESEMANTICS"]._serialized_end = 5228
+ _globals["_ANALYZEPLANRESPONSE_SEMANTICHASH"]._serialized_start = 5230
+ _globals["_ANALYZEPLANRESPONSE_SEMANTICHASH"]._serialized_end = 5268
+ _globals["_ANALYZEPLANRESPONSE_PERSIST"]._serialized_start = 3111
+ _globals["_ANALYZEPLANRESPONSE_PERSIST"]._serialized_end = 3120
+ _globals["_ANALYZEPLANRESPONSE_UNPERSIST"]._serialized_start = 3264
+ _globals["_ANALYZEPLANRESPONSE_UNPERSIST"]._serialized_end = 3275
+ _globals["_ANALYZEPLANRESPONSE_GETSTORAGELEVEL"]._serialized_start = 5294
+ _globals["_ANALYZEPLANRESPONSE_GETSTORAGELEVEL"]._serialized_end = 5377
+ _globals["_ANALYZEPLANRESPONSE_JSONTODDL"]._serialized_start = 5379
+ _globals["_ANALYZEPLANRESPONSE_JSONTODDL"]._serialized_end = 5421
+ _globals["_EXECUTEPLANREQUEST"]._serialized_start = 5434
+ _globals["_EXECUTEPLANREQUEST"]._serialized_end = 6205
+ _globals["_EXECUTEPLANREQUEST_REQUESTOPTION"]._serialized_start = 5868
+ _globals["_EXECUTEPLANREQUEST_REQUESTOPTION"]._serialized_end = 6129
+ _globals["_EXECUTEPLANRESPONSE"]._serialized_start = 6208
+ _globals["_EXECUTEPLANRESPONSE"]._serialized_end = 9665
+ _globals["_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT"]._serialized_start = 8308
+ _globals["_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT"]._serialized_end = 8379
+ _globals["_EXECUTEPLANRESPONSE_ARROWBATCH"]._serialized_start = 8382
+ _globals["_EXECUTEPLANRESPONSE_ARROWBATCH"]._serialized_end = 8630
+ _globals["_EXECUTEPLANRESPONSE_METRICS"]._serialized_start = 8633
+ _globals["_EXECUTEPLANRESPONSE_METRICS"]._serialized_end = 9150
+ _globals["_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT"]._serialized_start =
8728
+ _globals["_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT"]._serialized_end =
9060
_globals[
"_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY"
- ]._serialized_start = 8569
+ ]._serialized_start = 8937
_globals[
"_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY"
- ]._serialized_end = 8692
- _globals["_EXECUTEPLANRESPONSE_METRICS_METRICVALUE"]._serialized_start =
8694
- _globals["_EXECUTEPLANRESPONSE_METRICS_METRICVALUE"]._serialized_end = 8782
- _globals["_EXECUTEPLANRESPONSE_OBSERVEDMETRICS"]._serialized_start = 8785
- _globals["_EXECUTEPLANRESPONSE_OBSERVEDMETRICS"]._serialized_end = 8926
- _globals["_EXECUTEPLANRESPONSE_RESULTCOMPLETE"]._serialized_start = 8928
- _globals["_EXECUTEPLANRESPONSE_RESULTCOMPLETE"]._serialized_end = 8944
- _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS"]._serialized_start = 8947
- _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS"]._serialized_end = 9280
-
_globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO"]._serialized_start
= 9103
-
_globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO"]._serialized_end =
9280
- _globals["_KEYVALUE"]._serialized_start = 9299
- _globals["_KEYVALUE"]._serialized_end = 9364
- _globals["_CONFIGREQUEST"]._serialized_start = 9367
- _globals["_CONFIGREQUEST"]._serialized_end = 10566
- _globals["_CONFIGREQUEST_OPERATION"]._serialized_start = 9675
- _globals["_CONFIGREQUEST_OPERATION"]._serialized_end = 10173
- _globals["_CONFIGREQUEST_SET"]._serialized_start = 10175
- _globals["_CONFIGREQUEST_SET"]._serialized_end = 10267
- _globals["_CONFIGREQUEST_GET"]._serialized_start = 10269
- _globals["_CONFIGREQUEST_GET"]._serialized_end = 10294
- _globals["_CONFIGREQUEST_GETWITHDEFAULT"]._serialized_start = 10296
- _globals["_CONFIGREQUEST_GETWITHDEFAULT"]._serialized_end = 10359
- _globals["_CONFIGREQUEST_GETOPTION"]._serialized_start = 10361
- _globals["_CONFIGREQUEST_GETOPTION"]._serialized_end = 10392
- _globals["_CONFIGREQUEST_GETALL"]._serialized_start = 10394
- _globals["_CONFIGREQUEST_GETALL"]._serialized_end = 10442
- _globals["_CONFIGREQUEST_UNSET"]._serialized_start = 10444
- _globals["_CONFIGREQUEST_UNSET"]._serialized_end = 10471
- _globals["_CONFIGREQUEST_ISMODIFIABLE"]._serialized_start = 10473
- _globals["_CONFIGREQUEST_ISMODIFIABLE"]._serialized_end = 10507
- _globals["_CONFIGRESPONSE"]._serialized_start = 10569
- _globals["_CONFIGRESPONSE"]._serialized_end = 10744
- _globals["_ADDARTIFACTSREQUEST"]._serialized_start = 10747
- _globals["_ADDARTIFACTSREQUEST"]._serialized_end = 11749
- _globals["_ADDARTIFACTSREQUEST_ARTIFACTCHUNK"]._serialized_start = 11222
- _globals["_ADDARTIFACTSREQUEST_ARTIFACTCHUNK"]._serialized_end = 11275
- _globals["_ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT"]._serialized_start =
11277
- _globals["_ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT"]._serialized_end =
11388
- _globals["_ADDARTIFACTSREQUEST_BATCH"]._serialized_start = 11390
- _globals["_ADDARTIFACTSREQUEST_BATCH"]._serialized_end = 11483
- _globals["_ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT"]._serialized_start =
11486
- _globals["_ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT"]._serialized_end =
11679
- _globals["_ADDARTIFACTSRESPONSE"]._serialized_start = 11752
- _globals["_ADDARTIFACTSRESPONSE"]._serialized_end = 12024
- _globals["_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY"]._serialized_start = 11943
- _globals["_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY"]._serialized_end = 12024
- _globals["_ARTIFACTSTATUSESREQUEST"]._serialized_start = 12027
- _globals["_ARTIFACTSTATUSESREQUEST"]._serialized_end = 12353
- _globals["_ARTIFACTSTATUSESRESPONSE"]._serialized_start = 12356
- _globals["_ARTIFACTSTATUSESRESPONSE"]._serialized_end = 12708
- _globals["_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY"]._serialized_start =
12551
- _globals["_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY"]._serialized_end = 12666
- _globals["_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS"]._serialized_start =
12668
- _globals["_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS"]._serialized_end =
12708
- _globals["_INTERRUPTREQUEST"]._serialized_start = 12711
- _globals["_INTERRUPTREQUEST"]._serialized_end = 13314
- _globals["_INTERRUPTREQUEST_INTERRUPTTYPE"]._serialized_start = 13114
- _globals["_INTERRUPTREQUEST_INTERRUPTTYPE"]._serialized_end = 13242
- _globals["_INTERRUPTRESPONSE"]._serialized_start = 13317
- _globals["_INTERRUPTRESPONSE"]._serialized_end = 13461
- _globals["_REATTACHOPTIONS"]._serialized_start = 13463
- _globals["_REATTACHOPTIONS"]._serialized_end = 13516
- _globals["_RESULTCHUNKINGOPTIONS"]._serialized_start = 13519
- _globals["_RESULTCHUNKINGOPTIONS"]._serialized_end = 13700
- _globals["_REATTACHEXECUTEREQUEST"]._serialized_start = 13703
- _globals["_REATTACHEXECUTEREQUEST"]._serialized_end = 14109
- _globals["_RELEASEEXECUTEREQUEST"]._serialized_start = 14112
- _globals["_RELEASEEXECUTEREQUEST"]._serialized_end = 14697
- _globals["_RELEASEEXECUTEREQUEST_RELEASEALL"]._serialized_start = 14566
- _globals["_RELEASEEXECUTEREQUEST_RELEASEALL"]._serialized_end = 14578
- _globals["_RELEASEEXECUTEREQUEST_RELEASEUNTIL"]._serialized_start = 14580
- _globals["_RELEASEEXECUTEREQUEST_RELEASEUNTIL"]._serialized_end = 14627
- _globals["_RELEASEEXECUTERESPONSE"]._serialized_start = 14700
- _globals["_RELEASEEXECUTERESPONSE"]._serialized_end = 14865
- _globals["_RELEASESESSIONREQUEST"]._serialized_start = 14868
- _globals["_RELEASESESSIONREQUEST"]._serialized_end = 15080
- _globals["_RELEASESESSIONRESPONSE"]._serialized_start = 15082
- _globals["_RELEASESESSIONRESPONSE"]._serialized_end = 15190
- _globals["_FETCHERRORDETAILSREQUEST"]._serialized_start = 15193
- _globals["_FETCHERRORDETAILSREQUEST"]._serialized_end = 15525
- _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_start = 15528
- _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_end = 17537
- _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_start
= 15757
- _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_end =
15931
- _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_start =
15934
- _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_end = 16302
-
_globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_start
= 16265
-
_globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_end
= 16302
- _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_start =
16305
- _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_end =
16855
+ ]._serialized_end = 9060
+ _globals["_EXECUTEPLANRESPONSE_METRICS_METRICVALUE"]._serialized_start =
9062
+ _globals["_EXECUTEPLANRESPONSE_METRICS_METRICVALUE"]._serialized_end = 9150
+ _globals["_EXECUTEPLANRESPONSE_OBSERVEDMETRICS"]._serialized_start = 9153
+ _globals["_EXECUTEPLANRESPONSE_OBSERVEDMETRICS"]._serialized_end = 9294
+ _globals["_EXECUTEPLANRESPONSE_RESULTCOMPLETE"]._serialized_start = 9296
+ _globals["_EXECUTEPLANRESPONSE_RESULTCOMPLETE"]._serialized_end = 9312
+ _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS"]._serialized_start = 9315
+ _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS"]._serialized_end = 9648
+
_globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO"]._serialized_start
= 9471
+
_globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO"]._serialized_end =
9648
+ _globals["_KEYVALUE"]._serialized_start = 9667
+ _globals["_KEYVALUE"]._serialized_end = 9732
+ _globals["_CONFIGREQUEST"]._serialized_start = 9735
+ _globals["_CONFIGREQUEST"]._serialized_end = 10934
+ _globals["_CONFIGREQUEST_OPERATION"]._serialized_start = 10043
+ _globals["_CONFIGREQUEST_OPERATION"]._serialized_end = 10541
+ _globals["_CONFIGREQUEST_SET"]._serialized_start = 10543
+ _globals["_CONFIGREQUEST_SET"]._serialized_end = 10635
+ _globals["_CONFIGREQUEST_GET"]._serialized_start = 10637
+ _globals["_CONFIGREQUEST_GET"]._serialized_end = 10662
+ _globals["_CONFIGREQUEST_GETWITHDEFAULT"]._serialized_start = 10664
+ _globals["_CONFIGREQUEST_GETWITHDEFAULT"]._serialized_end = 10727
+ _globals["_CONFIGREQUEST_GETOPTION"]._serialized_start = 10729
+ _globals["_CONFIGREQUEST_GETOPTION"]._serialized_end = 10760
+ _globals["_CONFIGREQUEST_GETALL"]._serialized_start = 10762
+ _globals["_CONFIGREQUEST_GETALL"]._serialized_end = 10810
+ _globals["_CONFIGREQUEST_UNSET"]._serialized_start = 10812
+ _globals["_CONFIGREQUEST_UNSET"]._serialized_end = 10839
+ _globals["_CONFIGREQUEST_ISMODIFIABLE"]._serialized_start = 10841
+ _globals["_CONFIGREQUEST_ISMODIFIABLE"]._serialized_end = 10875
+ _globals["_CONFIGRESPONSE"]._serialized_start = 10937
+ _globals["_CONFIGRESPONSE"]._serialized_end = 11112
+ _globals["_ADDARTIFACTSREQUEST"]._serialized_start = 11115
+ _globals["_ADDARTIFACTSREQUEST"]._serialized_end = 12117
+ _globals["_ADDARTIFACTSREQUEST_ARTIFACTCHUNK"]._serialized_start = 11590
+ _globals["_ADDARTIFACTSREQUEST_ARTIFACTCHUNK"]._serialized_end = 11643
+ _globals["_ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT"]._serialized_start =
11645
+ _globals["_ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT"]._serialized_end =
11756
+ _globals["_ADDARTIFACTSREQUEST_BATCH"]._serialized_start = 11758
+ _globals["_ADDARTIFACTSREQUEST_BATCH"]._serialized_end = 11851
+ _globals["_ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT"]._serialized_start =
11854
+ _globals["_ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT"]._serialized_end =
12047
+ _globals["_ADDARTIFACTSRESPONSE"]._serialized_start = 12120
+ _globals["_ADDARTIFACTSRESPONSE"]._serialized_end = 12392
+ _globals["_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY"]._serialized_start = 12311
+ _globals["_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY"]._serialized_end = 12392
+ _globals["_ARTIFACTSTATUSESREQUEST"]._serialized_start = 12395
+ _globals["_ARTIFACTSTATUSESREQUEST"]._serialized_end = 12721
+ _globals["_ARTIFACTSTATUSESRESPONSE"]._serialized_start = 12724
+ _globals["_ARTIFACTSTATUSESRESPONSE"]._serialized_end = 13076
+ _globals["_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY"]._serialized_start =
12919
+ _globals["_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY"]._serialized_end = 13034
+ _globals["_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS"]._serialized_start =
13036
+ _globals["_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS"]._serialized_end =
13076
+ _globals["_INTERRUPTREQUEST"]._serialized_start = 13079
+ _globals["_INTERRUPTREQUEST"]._serialized_end = 13682
+ _globals["_INTERRUPTREQUEST_INTERRUPTTYPE"]._serialized_start = 13482
+ _globals["_INTERRUPTREQUEST_INTERRUPTTYPE"]._serialized_end = 13610
+ _globals["_INTERRUPTRESPONSE"]._serialized_start = 13685
+ _globals["_INTERRUPTRESPONSE"]._serialized_end = 13829
+ _globals["_REATTACHOPTIONS"]._serialized_start = 13831
+ _globals["_REATTACHOPTIONS"]._serialized_end = 13884
+ _globals["_RESULTCHUNKINGOPTIONS"]._serialized_start = 13887
+ _globals["_RESULTCHUNKINGOPTIONS"]._serialized_end = 14068
+ _globals["_REATTACHEXECUTEREQUEST"]._serialized_start = 14071
+ _globals["_REATTACHEXECUTEREQUEST"]._serialized_end = 14477
+ _globals["_RELEASEEXECUTEREQUEST"]._serialized_start = 14480
+ _globals["_RELEASEEXECUTEREQUEST"]._serialized_end = 15065
+ _globals["_RELEASEEXECUTEREQUEST_RELEASEALL"]._serialized_start = 14934
+ _globals["_RELEASEEXECUTEREQUEST_RELEASEALL"]._serialized_end = 14946
+ _globals["_RELEASEEXECUTEREQUEST_RELEASEUNTIL"]._serialized_start = 14948
+ _globals["_RELEASEEXECUTEREQUEST_RELEASEUNTIL"]._serialized_end = 14995
+ _globals["_RELEASEEXECUTERESPONSE"]._serialized_start = 15068
+ _globals["_RELEASEEXECUTERESPONSE"]._serialized_end = 15233
+ _globals["_RELEASESESSIONREQUEST"]._serialized_start = 15236
+ _globals["_RELEASESESSIONREQUEST"]._serialized_end = 15448
+ _globals["_RELEASESESSIONRESPONSE"]._serialized_start = 15450
+ _globals["_RELEASESESSIONRESPONSE"]._serialized_end = 15558
+ _globals["_FETCHERRORDETAILSREQUEST"]._serialized_start = 15561
+ _globals["_FETCHERRORDETAILSREQUEST"]._serialized_end = 15893
+ _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_start = 15896
+ _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_end = 17905
+ _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_start
= 16125
+ _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_end =
16299
+ _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_start =
16302
+ _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_end = 16670
+
_globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_start
= 16633
+
_globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_end
= 16670
+ _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_start =
16673
+ _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_end =
17223
_globals[
"_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY"
- ]._serialized_start = 16732
+ ]._serialized_start = 17100
_globals[
"_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY"
- ]._serialized_end = 16800
-
_globals["_FETCHERRORDETAILSRESPONSE_BREAKINGCHANGEINFO"]._serialized_start =
16858
- _globals["_FETCHERRORDETAILSRESPONSE_BREAKINGCHANGEINFO"]._serialized_end
= 17108
- _globals["_FETCHERRORDETAILSRESPONSE_MITIGATIONCONFIG"]._serialized_start
= 17110
- _globals["_FETCHERRORDETAILSRESPONSE_MITIGATIONCONFIG"]._serialized_end =
17168
- _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_start = 17171
- _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_end = 17518
- _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_start = 17539
- _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_end = 17629
- _globals["_CLONESESSIONREQUEST"]._serialized_start = 17632
- _globals["_CLONESESSIONREQUEST"]._serialized_end = 17994
- _globals["_CLONESESSIONRESPONSE"]._serialized_start = 17997
- _globals["_CLONESESSIONRESPONSE"]._serialized_end = 18201
- _globals["_SPARKCONNECTSERVICE"]._serialized_start = 18204
- _globals["_SPARKCONNECTSERVICE"]._serialized_end = 19241
+ ]._serialized_end = 17168
+
_globals["_FETCHERRORDETAILSRESPONSE_BREAKINGCHANGEINFO"]._serialized_start =
17226
+ _globals["_FETCHERRORDETAILSRESPONSE_BREAKINGCHANGEINFO"]._serialized_end
= 17476
+ _globals["_FETCHERRORDETAILSRESPONSE_MITIGATIONCONFIG"]._serialized_start
= 17478
+ _globals["_FETCHERRORDETAILSRESPONSE_MITIGATIONCONFIG"]._serialized_end =
17536
+ _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_start = 17539
+ _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_end = 17886
+ _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_start = 17907
+ _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_end = 17997
+ _globals["_CLONESESSIONREQUEST"]._serialized_start = 18000
+ _globals["_CLONESESSIONREQUEST"]._serialized_end = 18362
+ _globals["_CLONESESSIONRESPONSE"]._serialized_start = 18365
+ _globals["_CLONESESSIONRESPONSE"]._serialized_end = 18569
+ _globals["_SPARKCONNECTSERVICE"]._serialized_start = 18655
+ _globals["_SPARKCONNECTSERVICE"]._serialized_end = 19692
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index dc3099ecdffc..f12c21e5536d 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -57,42 +57,123 @@ else:
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
+class _CompressionCodec:
+ ValueType = typing.NewType("ValueType", builtins.int)
+ V: typing_extensions.TypeAlias = ValueType
+
+class _CompressionCodecEnumTypeWrapper(
+
google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_CompressionCodec.ValueType],
+ builtins.type,
+): # noqa: F821
+ DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
+ COMPRESSION_CODEC_UNSPECIFIED: _CompressionCodec.ValueType # 0
+ COMPRESSION_CODEC_ZSTD: _CompressionCodec.ValueType # 1
+
+class CompressionCodec(_CompressionCodec,
metaclass=_CompressionCodecEnumTypeWrapper):
+ """Compression codec for plan compression."""
+
+COMPRESSION_CODEC_UNSPECIFIED: CompressionCodec.ValueType # 0
+COMPRESSION_CODEC_ZSTD: CompressionCodec.ValueType # 1
+global___CompressionCodec = CompressionCodec
+
class Plan(google.protobuf.message.Message):
"""A [[Plan]] is the structure that carries the runtime information for
the execution from the
- client to the server. A [[Plan]] can either be of the type [[Relation]]
which is a reference
- to the underlying logical plan or it can be of the [[Command]] type that
is used to execute
- commands on the server.
+ client to the server. A [[Plan]] can be one of the following:
+ - [[Relation]]: a reference to the underlying logical plan.
+ - [[Command]]: used to execute commands on the server.
+ - [[CompressedOperation]]: a compressed representation of either a
Relation or a Command.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
+ class CompressedOperation(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ class _OpType:
+ ValueType = typing.NewType("ValueType", builtins.int)
+ V: typing_extensions.TypeAlias = ValueType
+
+ class _OpTypeEnumTypeWrapper(
+ google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
+ Plan.CompressedOperation._OpType.ValueType
+ ],
+ builtins.type,
+ ): # noqa: F821
+ DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
+ OP_TYPE_UNSPECIFIED: Plan.CompressedOperation._OpType.ValueType # 0
+ OP_TYPE_RELATION: Plan.CompressedOperation._OpType.ValueType # 1
+ OP_TYPE_COMMAND: Plan.CompressedOperation._OpType.ValueType # 2
+
+ class OpType(_OpType, metaclass=_OpTypeEnumTypeWrapper): ...
+ OP_TYPE_UNSPECIFIED: Plan.CompressedOperation.OpType.ValueType # 0
+ OP_TYPE_RELATION: Plan.CompressedOperation.OpType.ValueType # 1
+ OP_TYPE_COMMAND: Plan.CompressedOperation.OpType.ValueType # 2
+
+ DATA_FIELD_NUMBER: builtins.int
+ OP_TYPE_FIELD_NUMBER: builtins.int
+ COMPRESSION_CODEC_FIELD_NUMBER: builtins.int
+ data: builtins.bytes
+ op_type: global___Plan.CompressedOperation.OpType.ValueType
+ compression_codec: global___CompressionCodec.ValueType
+ def __init__(
+ self,
+ *,
+ data: builtins.bytes = ...,
+ op_type: global___Plan.CompressedOperation.OpType.ValueType = ...,
+ compression_codec: global___CompressionCodec.ValueType = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "compression_codec", b"compression_codec", "data", b"data",
"op_type", b"op_type"
+ ],
+ ) -> None: ...
+
ROOT_FIELD_NUMBER: builtins.int
COMMAND_FIELD_NUMBER: builtins.int
+ COMPRESSED_OPERATION_FIELD_NUMBER: builtins.int
@property
def root(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: ...
@property
def command(self) -> pyspark.sql.connect.proto.commands_pb2.Command: ...
+ @property
+ def compressed_operation(self) -> global___Plan.CompressedOperation: ...
def __init__(
self,
*,
root: pyspark.sql.connect.proto.relations_pb2.Relation | None = ...,
command: pyspark.sql.connect.proto.commands_pb2.Command | None = ...,
+ compressed_operation: global___Plan.CompressedOperation | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
- "command", b"command", "op_type", b"op_type", "root", b"root"
+ "command",
+ b"command",
+ "compressed_operation",
+ b"compressed_operation",
+ "op_type",
+ b"op_type",
+ "root",
+ b"root",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "command", b"command", "op_type", b"op_type", "root", b"root"
+ "command",
+ b"command",
+ "compressed_operation",
+ b"compressed_operation",
+ "op_type",
+ b"op_type",
+ "root",
+ b"root",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["op_type", b"op_type"]
- ) -> typing_extensions.Literal["root", "command"] | None: ...
+ ) -> typing_extensions.Literal["root", "command", "compressed_operation"]
| None: ...
global___Plan = Plan
diff --git a/python/pyspark/sql/connect/utils.py
b/python/pyspark/sql/connect/utils.py
index a2511836816c..0e0e04244653 100644
--- a/python/pyspark/sql/connect/utils.py
+++ b/python/pyspark/sql/connect/utils.py
@@ -37,6 +37,7 @@ def check_dependencies(mod_name: str) -> None:
require_minimum_grpc_version()
require_minimum_grpcio_status_version()
require_minimum_googleapis_common_protos_version()
+ require_minimum_zstandard_version()
def require_minimum_grpc_version() -> None:
@@ -96,5 +97,21 @@ def require_minimum_googleapis_common_protos_version() ->
None:
) from error
+def require_minimum_zstandard_version() -> None:
+ """Raise ImportError if zstandard is not installed"""
+ minimum_zstandard_version = "0.25.0"
+
+ try:
+ import zstandard # noqa
+ except ImportError as error:
+ raise PySparkImportError(
+ errorClass="PACKAGE_NOT_INSTALLED",
+ messageParameters={
+ "package_name": "zstandard",
+ "minimum_version": str(minimum_zstandard_version),
+ },
+ ) from error
+
+
def get_python_ver() -> str:
return "%d.%d" % sys.version_info[:2]
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index c1ba9a6fc2d4..b789d7919c94 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1447,6 +1447,33 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
proto_string_truncated_3 =
self.connect._client._proto_to_string(plan3, True)
self.assertTrue(len(proto_string_truncated_3) < 64000,
len(proto_string_truncated_3))
+ def test_plan_compression(self):
+ self.assertTrue(self.connect._client._zstd_module is not None)
+ self.connect.range(1).count()
+ default_plan_compression_threshold =
self.connect._client._plan_compression_threshold
+ self.assertTrue(default_plan_compression_threshold > 0)
+ self.assertTrue(self.connect._client._plan_compression_algorithm ==
"ZSTD")
+ try:
+ self.connect._client._plan_compression_threshold = 1000
+
+ # Small plan should not be compressed
+ cdf1 = self.connect.range(1).select(CF.lit("Apache Spark"))
+ plan1 = cdf1._plan.to_proto(self.connect._client)
+ self.assertTrue(plan1.root is not None)
+ self.assertTrue(cdf1.count() == 1)
+
+ # Large plan should be compressed
+ cdf2 = self.connect.range(1).select(CF.lit("Apache Spark" * 1000))
+ plan2 = cdf2._plan.to_proto(self.connect._client)
+ self.assertTrue(plan2.compressed_operation is not None)
+ # Test compressed relation
+ self.assertTrue(cdf2.count() == 1)
+ # Test compressed command
+ cdf2.createOrReplaceTempView("temp_view_cdf2")
+ self.assertTrue(self.connect.sql("SELECT * FROM
temp_view_cdf2").count() == 1)
+ finally:
+ self.connect._client._plan_compression_threshold =
default_plan_compression_threshold
+
class SparkConnectGCTests(SparkConnectSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py
b/python/pyspark/sql/tests/connect/test_connect_plan.py
index d25799f0c9f2..1d4d85e8426e 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -864,7 +864,7 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
def test_column_literals(self):
df = self.connect.with_plan(Read("table"))
lit_df = df.select(lit(10))
- self.assertIsNotNone(lit_df._plan.to_proto(None))
+ self.assertIsNotNone(lit_df._plan.to_proto(self.connect))
self.assertIsNotNone(lit(10).to_plan(None))
plan = lit(10).to_plan(None)
diff --git a/python/pyspark/testing/connectutils.py
b/python/pyspark/testing/connectutils.py
index bfcb886e1c91..d895c1d8a26b 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -23,6 +23,7 @@ import unittest
import uuid
import contextlib
+import pyspark.sql.connect.proto as pb2
from pyspark import Row, SparkConf
from pyspark.util import is_remote_only
from pyspark.testing.utils import PySparkErrorTestUtils
@@ -113,6 +114,16 @@ class PlanOnlyTestFixture(unittest.TestCase,
PySparkErrorTestUtils):
def _session_sql(cls, query):
return cls._df_mock(SQL(query))
+ @classmethod
+ def _set_relation_in_plan(self, plan: pb2.Plan, relation:
pb2.Relation) -> None:
+ # Skip plan compression in plan-only tests.
+ plan.root.CopyFrom(relation)
+
+ @classmethod
+ def _set_command_in_plan(self, plan: pb2.Plan, command: pb2.Command)
-> None:
+ # Skip plan compression in plan-only tests.
+ plan.command.CopyFrom(command)
+
if have_pandas:
@classmethod
@@ -129,6 +140,8 @@ class PlanOnlyTestFixture(unittest.TestCase,
PySparkErrorTestUtils):
cls.connect.set_hook("range", cls._session_range)
cls.connect.set_hook("sql", cls._session_sql)
cls.connect.set_hook("with_plan", cls._with_plan)
+ cls.connect.set_hook("_set_relation_in_plan",
cls._set_relation_in_plan)
+ cls.connect.set_hook("_set_command_in_plan",
cls._set_command_in_plan)
@classmethod
def tearDownClass(cls):
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/base.proto
b/sql/connect/common/src/main/protobuf/spark/connect/base.proto
index 6e1029bf0a6a..a97d2d25f490 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -33,17 +33,35 @@ option java_package = "org.apache.spark.connect.proto";
option go_package = "internal/generated";
// A [[Plan]] is the structure that carries the runtime information for the
execution from the
-// client to the server. A [[Plan]] can either be of the type [[Relation]]
which is a reference
-// to the underlying logical plan or it can be of the [[Command]] type that is
used to execute
-// commands on the server.
+// client to the server. A [[Plan]] can be one of the following:
+// - [[Relation]]: a reference to the underlying logical plan.
+// - [[Command]]: used to execute commands on the server.
+// - [[CompressedOperation]]: a compressed representation of either a Relation
or a Command.
message Plan {
oneof op_type {
Relation root = 1;
Command command = 2;
+ CompressedOperation compressed_operation = 3;
}
-}
+ message CompressedOperation {
+ bytes data = 1;
+ OpType op_type = 2;
+ CompressionCodec compression_codec = 3;
+ enum OpType {
+ OP_TYPE_UNSPECIFIED = 0;
+ OP_TYPE_RELATION = 1;
+ OP_TYPE_COMMAND = 2;
+ }
+ }
+}
+
+// Compression codec for plan compression.
+enum CompressionCodec {
+ COMPRESSION_CODEC_UNSPECIFIED = 0;
+ COMPRESSION_CODEC_ZSTD = 1;
+}
// User Context is used to refer to one particular user session that is
executing
// queries in the backend.
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index c6049187f6be..1ffed714b4ca 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.connect.config
+import java.util.Locale
import java.util.concurrent.TimeUnit
import org.apache.spark.SparkEnv
@@ -418,4 +419,34 @@ object Connect {
.bytesConf(ByteUnit.BYTE)
// 90% of the max message size by default to allow for some overhead.
.createWithDefault((ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE *
0.9).toInt)
+
+ private[spark] val CONNECT_MAX_PLAN_SIZE =
+ buildStaticConf("spark.connect.maxPlanSize")
+ .doc(
+ "The maximum size of a (decompressed) proto plan that can be executed
in Spark " +
+ "Connect. If the size of the plan exceeds this limit, an error will
be thrown. " +
+ "The size is in bytes.")
+ .version("4.1.0")
+ .internal()
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(512 * 1024 * 1024) // 512 MB
+
+ val CONNECT_SESSION_PLAN_COMPRESSION_THRESHOLD =
+ buildConf("spark.connect.session.planCompression.threshold")
+ .doc("The threshold in bytes for the size of proto plan to be
compressed. " +
+ "If the size of proto plan is smaller than this threshold, it will not
be compressed.")
+ .version("4.1.0")
+ .internal()
+ .intConf
+ .createWithDefault(10 * 1024 * 1024) // 10 MB
+
+ val CONNECT_PLAN_COMPRESSION_DEFAULT_ALGORITHM =
+ buildConf("spark.connect.session.planCompression.defaultAlgorithm")
+ .doc("The default algorithm of proto plan compression.")
+ .version("4.1.0")
+ .internal()
+ .stringConf
+ .transform(_.toUpperCase(Locale.ROOT))
+ .checkValues(ConnectPlanCompressionAlgorithm.values.map(_.toString))
+ .createWithDefault(ConnectPlanCompressionAlgorithm.ZSTD.toString)
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/ConnectPlanCompressionAlgorithm.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/ConnectPlanCompressionAlgorithm.scala
new file mode 100644
index 000000000000..0f9b959ee725
--- /dev/null
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/ConnectPlanCompressionAlgorithm.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.config
+
+object ConnectPlanCompressionAlgorithm extends Enumeration {
+ val ZSTD = Value
+}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
index 8fa003c11681..cdf7013211f7 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.classic.{DataFrame, Dataset}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
InvalidPlanInput, StorageLevelProtoConverter}
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.utils.PlanCompressionUtils
import org.apache.spark.sql.execution.{CodegenMode, CommandExecutionMode,
CostMode, ExtendedMode, FormattedMode, SimpleMode}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.ArrayImplicits._
@@ -63,6 +64,9 @@ private[connect] class SparkConnectAnalyzeHandler(
val builder = proto.AnalyzePlanResponse.newBuilder()
def transformRelation(rel: proto.Relation) =
planner.transformRelation(rel, cachePlan = true)
+ def transformRelationPlan(plan: proto.Plan) = {
+ transformRelation(PlanCompressionUtils.decompressPlan(plan).getRoot)
+ }
def getDataFrameWithoutExecuting(rel: LogicalPlan): DataFrame = {
val qe = session.sessionState.executePlan(rel, CommandExecutionMode.SKIP)
@@ -71,7 +75,7 @@ private[connect] class SparkConnectAnalyzeHandler(
request.getAnalyzeCase match {
case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
- val rel = transformRelation(request.getSchema.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getSchema.getPlan)
val schema = getDataFrameWithoutExecuting(rel).schema
builder.setSchema(
proto.AnalyzePlanResponse.Schema
@@ -79,7 +83,7 @@ private[connect] class SparkConnectAnalyzeHandler(
.setSchema(DataTypeProtoConverter.toConnectProtoType(schema))
.build())
case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
- val rel = transformRelation(request.getExplain.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getExplain.getPlan)
val queryExecution = getDataFrameWithoutExecuting(rel).queryExecution
val explainString = request.getExplain.getExplainMode match {
case
proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE =>
@@ -101,7 +105,7 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())
case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING =>
- val rel = transformRelation(request.getTreeString.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getTreeString.getPlan)
val schema = getDataFrameWithoutExecuting(rel).schema
val treeString = if (request.getTreeString.hasLevel) {
schema.treeString(request.getTreeString.getLevel)
@@ -115,7 +119,7 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())
case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
- val rel = transformRelation(request.getIsLocal.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getIsLocal.getPlan)
val isLocal = getDataFrameWithoutExecuting(rel).isLocal
builder.setIsLocal(
proto.AnalyzePlanResponse.IsLocal
@@ -124,7 +128,7 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())
case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
- val rel = transformRelation(request.getIsStreaming.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getIsStreaming.getPlan)
val isStreaming = getDataFrameWithoutExecuting(rel).isStreaming
builder.setIsStreaming(
proto.AnalyzePlanResponse.IsStreaming
@@ -133,7 +137,7 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())
case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
- val rel = transformRelation(request.getInputFiles.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getInputFiles.getPlan)
val inputFiles = getDataFrameWithoutExecuting(rel).inputFiles
builder.setInputFiles(
proto.AnalyzePlanResponse.InputFiles
@@ -157,8 +161,8 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())
case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS =>
- val targetRel =
transformRelation(request.getSameSemantics.getTargetPlan.getRoot)
- val otherRel =
transformRelation(request.getSameSemantics.getOtherPlan.getRoot)
+ val targetRel =
transformRelationPlan(request.getSameSemantics.getTargetPlan)
+ val otherRel =
transformRelationPlan(request.getSameSemantics.getOtherPlan)
val target = getDataFrameWithoutExecuting(targetRel)
val other = getDataFrameWithoutExecuting(otherRel)
builder.setSameSemantics(
@@ -167,7 +171,7 @@ private[connect] class SparkConnectAnalyzeHandler(
.setResult(target.sameSemantics(other)))
case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH =>
- val rel = transformRelation(request.getSemanticHash.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getSemanticHash.getPlan)
val semanticHash = getDataFrameWithoutExecuting(rel)
.semanticHash()
builder.setSemanticHash(
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
index 027f4517cf3b..6780ca37e96a 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
@@ -22,6 +22,7 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.SparkSQLException
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connect.utils.PlanCompressionUtils
class SparkConnectExecutePlanHandler(responseObserver:
StreamObserver[proto.ExecutePlanResponse])
extends Logging {
@@ -35,12 +36,20 @@ class SparkConnectExecutePlanHandler(responseObserver:
StreamObserver[proto.Exec
.getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId,
previousSessionId)
val executeKey = ExecuteKey(v, sessionHolder)
+ val decompressedRequest =
+
v.toBuilder.setPlan(PlanCompressionUtils.decompressPlan(v.getPlan)).build()
+
SparkConnectService.executionManager.getExecuteHolder(executeKey) match {
case None =>
// Create a new execute holder and attach to it.
SparkConnectService.executionManager
- .createExecuteHolderAndAttach(executeKey, v, sessionHolder,
responseObserver)
- case Some(executeHolder) if
executeHolder.request.getPlan.equals(v.getPlan) =>
+ .createExecuteHolderAndAttach(
+ executeKey,
+ decompressedRequest,
+ sessionHolder,
+ responseObserver)
+ case Some(executeHolder)
+ if executeHolder.request.getPlan.equals(decompressedRequest.getPlan)
=>
// If the execute holder already exists with the same plan, reattach
to it.
SparkConnectService.executionManager
.reattachExecuteHolder(executeHolder, responseObserver, None)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/PlanCompressionUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/PlanCompressionUtils.scala
new file mode 100644
index 000000000000..708ef1ee6558
--- /dev/null
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/PlanCompressionUtils.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.utils
+
+import java.io.IOException
+
+import scala.util.control.NonFatal
+
+import com.github.luben.zstd.{Zstd, ZstdInputStreamNoFinalizer}
+import com.google.protobuf.{ByteString, CodedInputStream}
+import org.apache.commons.io.input.BoundedInputStream
+
+import org.apache.spark.{SparkEnv, SparkSQLException}
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.config.Connect
+import org.apache.spark.sql.connect.planner.InvalidInputErrors
+
+object PlanCompressionUtils {
+ def decompressPlan(plan: proto.Plan): proto.Plan = {
+ plan.getOpTypeCase match {
+ case proto.Plan.OpTypeCase.COMPRESSED_OPERATION =>
+ val (cis, closeStream) = decompressBytes(
+ plan.getCompressedOperation.getData,
+ plan.getCompressedOperation.getCompressionCodec)
+ try {
+ plan.getCompressedOperation.getOpType match {
+ case proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION =>
+
proto.Plan.newBuilder().setRoot(proto.Relation.parser().parseFrom(cis)).build()
+ case proto.Plan.CompressedOperation.OpType.OP_TYPE_COMMAND =>
+
proto.Plan.newBuilder().setCommand(proto.Command.parser().parseFrom(cis)).build()
+ case other =>
+ throw InvalidInputErrors.invalidOneOfField(
+ other,
+ plan.getCompressedOperation.getDescriptorForType)
+ }
+ } catch {
+ case e: SparkSQLException =>
+ throw e
+ case NonFatal(e) =>
+ throw new SparkSQLException(
+ errorClass = "CONNECT_INVALID_PLAN.CANNOT_PARSE",
+ messageParameters = Map("errorMsg" -> e.getMessage))
+ } finally {
+ try {
+ closeStream()
+ } catch {
+ case NonFatal(_) =>
+ }
+ }
+ case _ => plan
+ }
+ }
+
+ private def getMaxPlanSize: Long = {
+ SparkEnv.get.conf.get(Connect.CONNECT_MAX_PLAN_SIZE)
+ }
+
+ /**
+ * Decompress the given bytes using the specified codec.
+ * @return
+ * A tuple of decompressed CodedInputStream and a function to close the
underlying stream.
+ */
+ private def decompressBytes(
+ data: ByteString,
+ compressionCodec: proto.CompressionCodec): (CodedInputStream, () =>
Unit) = {
+ compressionCodec match {
+ case proto.CompressionCodec.COMPRESSION_CODEC_ZSTD =>
+ decompressBytesWithZstd(data, getMaxPlanSize)
+ case other =>
+ throw InvalidInputErrors.invalidEnum(other)
+ }
+ }
+
+ private def decompressBytesWithZstd(
+ input: ByteString,
+ maxOutputSize: Long): (CodedInputStream, () => Unit) = {
+ // Check the declared size in the header against the limit.
+ val declaredSize = Zstd.getFrameContentSize(input.asReadOnlyByteBuffer())
+ if (declaredSize > maxOutputSize) {
+ throw new SparkSQLException(
+ errorClass = "CONNECT_INVALID_PLAN.PLAN_SIZE_LARGER_THAN_MAX",
+ messageParameters =
+ Map("planSize" -> declaredSize.toString, "maxPlanSize" ->
maxOutputSize.toString))
+ }
+
+ val zstdStream = new ZstdInputStreamNoFinalizer(input.newInput())
+
+ // Create a bounded input stream to limit the decompressed output size to
avoid decompression
+ // bomb attacks.
+ val boundedStream = new BoundedInputStream(zstdStream, maxOutputSize) {
+ @throws[IOException]
+ override protected def onMaxLength(maxBytes: Long, count: Long): Unit =
+ throw new SparkSQLException(
+ errorClass = "CONNECT_INVALID_PLAN.PLAN_SIZE_LARGER_THAN_MAX",
+ messageParameters =
+ Map("planSize" -> "unknown", "maxPlanSize" ->
maxOutputSize.toString))
+ }
+ val cis = CodedInputStream.newInstance(boundedStream)
+ cis.setSizeLimit(Integer.MAX_VALUE)
+
cis.setRecursionLimit(SparkEnv.get.conf.get(Connect.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT))
+ (cis, () => boundedStream.close())
+ }
+}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
index 6eee71db5709..91dbf419479f 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
@@ -16,12 +16,16 @@
*/
package org.apache.spark.sql.connect.service
+import java.io.ByteArrayOutputStream
import java.util.UUID
+import com.github.luben.zstd.{Zstd, ZstdOutputStreamNoFinalizer}
+import com.google.protobuf.ByteString
import org.scalatest.concurrent.Eventually
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkException
+import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.SparkConnectServerTest
import org.apache.spark.sql.connect.config.Connect
@@ -365,4 +369,162 @@ class SparkConnectServiceE2ESuite extends
SparkConnectServerTest {
assert(error.getMessage.contains("operation_id"))
}
}
+
+ test("Relation as compressed plan works") {
+ withClient { client =>
+ val relation = buildPlan("SELECT 1").getRoot
+ val compressedRelation = Zstd.compress(relation.toByteArray)
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(compressedRelation))
+ .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+ .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+ }
+
+ test("Command as compressed plan works") {
+ withClient { client =>
+ val command = buildSqlCommandPlan("SET
spark.sql.session.timeZone=Europe/Berlin").getCommand
+ val compressedCommand = Zstd.compress(command.toByteArray)
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(compressedCommand))
+ .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_COMMAND)
+ .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+ }
+
+ private def compressInZstdStreamingMode(input: Array[Byte]): Array[Byte] = {
+ val outputStream = new ByteArrayOutputStream()
+ val zstdStream = new ZstdOutputStreamNoFinalizer(outputStream)
+ zstdStream.write(input)
+ zstdStream.flush()
+ zstdStream.close()
+ outputStream.toByteArray
+ }
+
+ test("Compressed plans generated in streaming mode also work correctly") {
+ withClient { client =>
+ val relation = buildPlan("SELECT 1").getRoot
+ val compressedRelation =
compressInZstdStreamingMode(relation.toByteArray)
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(compressedRelation))
+ .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+ .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+ }
+
+ test("Invalid compressed bytes errors out") {
+ withClient { client =>
+ val invalidBytes = "invalidBytes".getBytes
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(invalidBytes))
+ .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+ .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val ex = intercept[SparkException] {
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+ assert(ex.getMessage.contains("CONNECT_INVALID_PLAN.CANNOT_PARSE"))
+ }
+ }
+
+ test("Invalid compressed proto message errors out") {
+ withClient { client =>
+ val data = Zstd.compress("Apache Spark".getBytes)
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(data))
+ .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+ .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val ex = intercept[SparkException] {
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+ assert(ex.getMessage.contains("CONNECT_INVALID_PLAN.CANNOT_PARSE"))
+ }
+ }
+
+ test("Large compressed plan errors out") {
+ withClient { client =>
+ withSparkEnvConfs(Connect.CONNECT_MAX_PLAN_SIZE.key -> "100") {
+ val relation = buildPlan("SELECT '" + "Apache Spark" * 100 +
"'").getRoot
+ val compressedRelation = Zstd.compress(relation.toByteArray)
+
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(compressedRelation))
+
.setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+
.setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val ex = intercept[SparkException] {
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+
assert(ex.getMessage.contains("CONNECT_INVALID_PLAN.PLAN_SIZE_LARGER_THAN_MAX"))
+ }
+ }
+ }
+
+ test("Large compressed plan generated in streaming mode also errors out") {
+ withClient { client =>
+ withSparkEnvConfs(Connect.CONNECT_MAX_PLAN_SIZE.key -> "100") {
+ val relation = buildPlan("SELECT '" + "Apache Spark" * 100 +
"'").getRoot
+ val compressedRelation =
compressInZstdStreamingMode(relation.toByteArray)
+
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(compressedRelation))
+
.setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+
.setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val ex = intercept[SparkException] {
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+
assert(ex.getMessage.contains("CONNECT_INVALID_PLAN.PLAN_SIZE_LARGER_THAN_MAX"))
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]