This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7c5daaaa20b [SPARK-43032][SS][CONNECT] Python SQM bug fix
7c5daaaa20b is described below
commit 7c5daaaa20bb012110d5855e5908cc01658355ed
Author: Wei Liu <[email protected]>
AuthorDate: Mon May 8 10:08:27 2023 +0900
[SPARK-43032][SS][CONNECT] Python SQM bug fix
### What changes were proposed in this pull request?
Some bug fix for streaming ***connect*** python SQM
Note that I also changed ***non-connect***'s StreamingQueryManager `get()`
API to return an `Optional[StreamingQuery]`.
Before it looks like this when you get a non-exist query:
```
>>> a = spark.streams.get("00000000-0000-0001-0000-000000000001")
>>> a
<pyspark.sql.streaming.query.StreamingQuery object at 0x7f86465702b0>
>>> a.id
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/wei.liu/oss-spark/python/pyspark/sql/streaming/query.py",
line 78, in id
return self._jsq.id().toString()
AttributeError: 'NoneType' object has no attribute 'id'
```
But now it looks like:
```
>>> a = spark.streams.get("00000000-0000-0001-0000-000000000001")
>>> a.id
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'NoneType' object has no attribute 'id'
```
The only difference is the return type, which is not typically honored in
Python... But not very sure if that's a breaking change
### Why are the changes needed?
Bug fix
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Manually tested. Also verified that it won't throw even without this fix so
it's not that urgent
Closes #41037 from WweiL/SPARK-43032-python-sqm-fix.
Authored-by: Wei Liu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../src/main/protobuf/spark/connect/commands.proto | 2 +-
.../sql/connect/planner/SparkConnectPlanner.scala | 22 +++++++------
python/pyspark/sql/connect/proto/commands_pb2.py | 36 +++++++++++-----------
python/pyspark/sql/connect/proto/commands_pb2.pyi | 16 +++++-----
python/pyspark/sql/connect/streaming/query.py | 14 ++++++---
python/pyspark/sql/streaming/query.py | 8 +++--
6 files changed, 54 insertions(+), 44 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index b929ffa2564..72bc8b5b6ef 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -330,7 +330,7 @@ message StreamingQueryManagerCommand {
// active() API, returns a list of active queries.
bool active = 1;
// get() API, returns the StreamingQuery identified by id.
- string get = 2;
+ string get_query = 2;
// awaitAnyTermination() API, wait until any query terminates or timeout.
AwaitAnyTerminationCommand await_any_termination = 3;
// resetTerminated() API.
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 8c43f982ec1..01f1e890630 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2466,16 +2466,18 @@ class SparkConnectPlanner(val session: SparkSession) {
.toIterable
.asJava)
- case StreamingQueryManagerCommand.CommandCase.GET =>
- val query = session.streams.get(command.getGet)
- respBuilder.getQueryBuilder
- .setId(
- StreamingQueryInstanceId
- .newBuilder()
- .setId(query.id.toString)
- .setRunId(query.runId.toString)
- .build())
- .setName(SparkConnectService.convertNullString(query.name))
+ case StreamingQueryManagerCommand.CommandCase.GET_QUERY =>
+ val query = session.streams.get(command.getGetQuery)
+ if (query != null) {
+ respBuilder.getQueryBuilder
+ .setId(
+ StreamingQueryInstanceId
+ .newBuilder()
+ .setId(query.id.toString)
+ .setRunId(query.runId.toString)
+ .build())
+ .setName(SparkConnectService.convertNullString(query.name))
+ }
case StreamingQueryManagerCommand.CommandCase.AWAIT_ANY_TERMINATION =>
if (command.getAwaitAnyTermination.hasTimeoutMs) {
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py
b/python/pyspark/sql/connect/proto/commands_pb2.py
index 9848a40adab..bc764926213 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import relations_pb2 as
spark_dot_connect_dot_rel
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x86\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x
[...]
+
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x86\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x
[...]
)
@@ -525,21 +525,21 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start =
5817
_STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 5873
_STREAMINGQUERYMANAGERCOMMAND._serialized_start = 5891
- _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 6230
- _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start
= 6140
- _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end =
6219
- _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 6233
- _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 6942
- _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 6636
- _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 6763
-
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start =
6765
- _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end
= 6866
-
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start
= 6868
-
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end =
6927
- _GETRESOURCESCOMMAND._serialized_start = 6944
- _GETRESOURCESCOMMAND._serialized_end = 6965
- _GETRESOURCESCOMMANDRESULT._serialized_start = 6968
- _GETRESOURCESCOMMANDRESULT._serialized_end = 7180
- _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 7084
- _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 7180
+ _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 6241
+ _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start
= 6151
+ _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end =
6230
+ _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 6244
+ _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 6953
+ _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 6647
+ _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 6774
+
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start =
6776
+ _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end
= 6877
+
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start
= 6879
+
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end =
6938
+ _GETRESOURCESCOMMAND._serialized_start = 6955
+ _GETRESOURCESCOMMAND._serialized_end = 6976
+ _GETRESOURCESCOMMANDRESULT._serialized_start = 6979
+ _GETRESOURCESCOMMANDRESULT._serialized_end = 7191
+ _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 7095
+ _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 7191
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi
b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index 6fec61b02dd..2c80614c3fd 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -1283,12 +1283,12 @@ class
StreamingQueryManagerCommand(google.protobuf.message.Message):
) -> typing_extensions.Literal["timeout_ms"] | None: ...
ACTIVE_FIELD_NUMBER: builtins.int
- GET_FIELD_NUMBER: builtins.int
+ GET_QUERY_FIELD_NUMBER: builtins.int
AWAIT_ANY_TERMINATION_FIELD_NUMBER: builtins.int
RESET_TERMINATED_FIELD_NUMBER: builtins.int
active: builtins.bool
"""active() API, returns a list of active queries."""
- get: builtins.str
+ get_query: builtins.str
"""get() API, returns the StreamingQuery identified by id."""
@property
def await_any_termination(
@@ -1301,7 +1301,7 @@ class
StreamingQueryManagerCommand(google.protobuf.message.Message):
self,
*,
active: builtins.bool = ...,
- get: builtins.str = ...,
+ get_query: builtins.str = ...,
await_any_termination:
global___StreamingQueryManagerCommand.AwaitAnyTerminationCommand
| None = ...,
reset_terminated: builtins.bool = ...,
@@ -1315,8 +1315,8 @@ class
StreamingQueryManagerCommand(google.protobuf.message.Message):
b"await_any_termination",
"command",
b"command",
- "get",
- b"get",
+ "get_query",
+ b"get_query",
"reset_terminated",
b"reset_terminated",
],
@@ -1330,8 +1330,8 @@ class
StreamingQueryManagerCommand(google.protobuf.message.Message):
b"await_any_termination",
"command",
b"command",
- "get",
- b"get",
+ "get_query",
+ b"get_query",
"reset_terminated",
b"reset_terminated",
],
@@ -1339,7 +1339,7 @@ class
StreamingQueryManagerCommand(google.protobuf.message.Message):
def WhichOneof(
self, oneof_group: typing_extensions.Literal["command", b"command"]
) -> typing_extensions.Literal[
- "active", "get", "await_any_termination", "reset_terminated"
+ "active", "get_query", "await_any_termination", "reset_terminated"
] | None: ...
global___StreamingQueryManagerCommand = StreamingQueryManagerCommand
diff --git a/python/pyspark/sql/connect/streaming/query.py
b/python/pyspark/sql/connect/streaming/query.py
index 606c4d4febc..e5aa881c990 100644
--- a/python/pyspark/sql/connect/streaming/query.py
+++ b/python/pyspark/sql/connect/streaming/query.py
@@ -187,11 +187,15 @@ class StreamingQueryManager:
active.__doc__ = PySparkStreamingQueryManager.active.__doc__
- def get(self, id: str) -> StreamingQuery:
+ def get(self, id: str) -> Optional[StreamingQuery]:
cmd = pb2.StreamingQueryManagerCommand()
- cmd.get = id
- query = self._execute_streaming_query_manager_cmd(cmd).query
- return StreamingQuery(self._session, query.id.id, query.id.run_id,
query.name)
+ cmd.get_query = id
+ response = self._execute_streaming_query_manager_cmd(cmd)
+ if response.HasField("query"):
+ query = response.query
+ return StreamingQuery(self._session, query.id.id, query.id.run_id,
query.name)
+ else:
+ return None
get.__doc__ = PySparkStreamingQueryManager.get.__doc__
@@ -221,7 +225,7 @@ class StreamingQueryManager:
def resetTerminated(self) -> None:
cmd = pb2.StreamingQueryManagerCommand()
cmd.reset_terminated = True
- self._execute_streaming_query_manager_cmd(cmd).active.active_queries
+ self._execute_streaming_query_manager_cmd(cmd)
return None
resetTerminated.__doc__ =
PySparkStreamingQueryManager.resetTerminated.__doc__
diff --git a/python/pyspark/sql/streaming/query.py
b/python/pyspark/sql/streaming/query.py
index b6268dcdb18..ac7a1acfcaa 100644
--- a/python/pyspark/sql/streaming/query.py
+++ b/python/pyspark/sql/streaming/query.py
@@ -445,7 +445,7 @@ class StreamingQueryManager:
"""
return [StreamingQuery(jsq) for jsq in self._jsqm.active()]
- def get(self, id: str) -> StreamingQuery:
+ def get(self, id: str) -> Optional[StreamingQuery]:
"""
Returns an active query from this :class:`SparkSession`.
@@ -484,7 +484,11 @@ class StreamingQueryManager:
True
>>> sq.stop()
"""
- return StreamingQuery(self._jsqm.get(id))
+ query = self._jsqm.get(id)
+ if query is not None:
+ return StreamingQuery(query)
+ else:
+ return None
def awaitAnyTermination(self, timeout: Optional[int] = None) ->
Optional[bool]:
"""
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]