This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 3b22267e3fe [SPARK-42402][CONNECT] Support parameterized SQL by `sql()`
3b22267e3fe is described below
commit 3b22267e3fe6adde8f1d50e99b3f06c86ec81ad8
Author: Takuya UESHIN <[email protected]>
AuthorDate: Sun Feb 12 09:05:02 2023 +0900
[SPARK-42402][CONNECT] Support parameterized SQL by `sql()`
### What changes were proposed in this pull request?
Supports parameterized SQL by `sql()`.
Note: `SparkSession.sql` in PySpark also supports string formatter, but it
will be handled separately.
### Why are the changes needed?
Currently `SparkSession.sql` in Spark Connect doesn't support parameterized
SQL.
### Does this PR introduce _any_ user-facing change?
The parameterized SQL will be available.
For example:
```py
>>> spark.sql("SELECT * FROM range(10) WHERE id > :minId", args = {"minId"
: "7"}).toPandas()
id
0 8
1 9
```
### How was this patch tested?
Added a test.
Closes #39971 from ueshin/issues/SPARK-42402/parameterized_sql.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 03227e18793c6902a816bbacdce78031ce37b14a)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/relations.proto | 3 +
.../sql/connect/planner/SparkConnectPlanner.scala | 5 +-
python/pyspark/sql/connect/plan.py | 14 +-
python/pyspark/sql/connect/proto/relations_pb2.py | 205 +++++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 26 ++-
python/pyspark/sql/connect/session.py | 4 +-
python/pyspark/sql/session.py | 4 +-
.../sql/tests/connect/test_connect_basic.py | 5 +
.../pyspark/sql/tests/connect/test_connect_plan.py | 3 +-
9 files changed, 166 insertions(+), 103 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 837727432f9..3d597fd2744 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -99,6 +99,9 @@ message RelationCommon {
message SQL {
// (Required) The SQL query.
string query = 1;
+
+ // (Optional) A map of parameter names to literal values.
+ map<string, string> args = 2;
}
// Relation that reads from a file / table or other data source. Does not have
additional
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 3bf5d2b1d30..75581851b5f 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -196,7 +196,10 @@ class SparkConnectPlanner(val session: SparkSession) {
}
private def transformSql(sql: proto.SQL): LogicalPlan = {
- session.sessionState.sqlParser.parsePlan(sql.getQuery)
+ val args = sql.getArgsMap.asScala.toMap
+ val parser = session.sessionState.sqlParser
+ val parsedArgs = args.mapValues(parser.parseExpression).toMap
+ Parameter.bind(parser.parsePlan(sql.getQuery), parsedArgs)
}
private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan
= {
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 39b32f065ea..4675631627a 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -919,13 +919,25 @@ class SubqueryAlias(LogicalPlan):
class SQL(LogicalPlan):
- def __init__(self, query: str) -> None:
+ def __init__(self, query: str, args: Optional[Dict[str, str]] = None) ->
None:
super().__init__(None)
+
+ if args is not None:
+ for k, v in args.items():
+ assert isinstance(k, str)
+ assert isinstance(v, str)
+
self._query = query
+ self._args = args
def plan(self, session: "SparkConnectClient") -> proto.Relation:
rel = proto.Relation()
rel.sql.query = self._query
+
+ if self._args is not None and len(self._args) > 0:
+ for k, v in self._args.items():
+ rel.sql.args[k] = v
+
return rel
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 4ef9a7278f9..1a24628ef30 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as
spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -44,6 +44,7 @@ _RELATION = DESCRIPTOR.message_types_by_name["Relation"]
_UNKNOWN = DESCRIPTOR.message_types_by_name["Unknown"]
_RELATIONCOMMON = DESCRIPTOR.message_types_by_name["RelationCommon"]
_SQL = DESCRIPTOR.message_types_by_name["SQL"]
+_SQL_ARGSENTRY = _SQL.nested_types_by_name["ArgsEntry"]
_READ = DESCRIPTOR.message_types_by_name["Read"]
_READ_NAMEDTABLE = _READ.nested_types_by_name["NamedTable"]
_READ_DATASOURCE = _READ.nested_types_by_name["DataSource"]
@@ -129,12 +130,22 @@ SQL = _reflection.GeneratedProtocolMessageType(
"SQL",
(_message.Message,),
{
+ "ArgsEntry": _reflection.GeneratedProtocolMessageType(
+ "ArgsEntry",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _SQL_ARGSENTRY,
+ "__module__": "spark.connect.relations_pb2"
+ #
@@protoc_insertion_point(class_scope:spark.connect.SQL.ArgsEntry)
+ },
+ ),
"DESCRIPTOR": _SQL,
"__module__": "spark.connect.relations_pb2"
# @@protoc_insertion_point(class_scope:spark.connect.SQL)
},
)
_sym_db.RegisterMessage(SQL)
+_sym_db.RegisterMessage(SQL.ArgsEntry)
Read = _reflection.GeneratedProtocolMessageType(
"Read",
@@ -606,6 +617,8 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options =
b"\n\036org.apache.spark.connect.protoP\001"
+ _SQL_ARGSENTRY._options = None
+ _SQL_ARGSENTRY._serialized_options = b"8\001"
_READ_DATASOURCE_OPTIONSENTRY._options = None
_READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001"
_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._options = None
@@ -616,98 +629,100 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_UNKNOWN._serialized_end = 2473
_RELATIONCOMMON._serialized_start = 2475
_RELATIONCOMMON._serialized_end = 2524
- _SQL._serialized_start = 2526
- _SQL._serialized_end = 2553
- _READ._serialized_start = 2556
- _READ._serialized_end = 3004
- _READ_NAMEDTABLE._serialized_start = 2698
- _READ_NAMEDTABLE._serialized_end = 2759
- _READ_DATASOURCE._serialized_start = 2762
- _READ_DATASOURCE._serialized_end = 2991
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2922
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2980
- _PROJECT._serialized_start = 3006
- _PROJECT._serialized_end = 3123
- _FILTER._serialized_start = 3125
- _FILTER._serialized_end = 3237
- _JOIN._serialized_start = 3240
- _JOIN._serialized_end = 3711
- _JOIN_JOINTYPE._serialized_start = 3503
- _JOIN_JOINTYPE._serialized_end = 3711
- _SETOPERATION._serialized_start = 3714
- _SETOPERATION._serialized_end = 4193
- _SETOPERATION_SETOPTYPE._serialized_start = 4030
- _SETOPERATION_SETOPTYPE._serialized_end = 4144
- _LIMIT._serialized_start = 4195
- _LIMIT._serialized_end = 4271
- _OFFSET._serialized_start = 4273
- _OFFSET._serialized_end = 4352
- _TAIL._serialized_start = 4354
- _TAIL._serialized_end = 4429
- _AGGREGATE._serialized_start = 4432
- _AGGREGATE._serialized_end = 5014
- _AGGREGATE_PIVOT._serialized_start = 4771
- _AGGREGATE_PIVOT._serialized_end = 4882
- _AGGREGATE_GROUPTYPE._serialized_start = 4885
- _AGGREGATE_GROUPTYPE._serialized_end = 5014
- _SORT._serialized_start = 5017
- _SORT._serialized_end = 5177
- _DROP._serialized_start = 5179
- _DROP._serialized_end = 5279
- _DEDUPLICATE._serialized_start = 5282
- _DEDUPLICATE._serialized_end = 5453
- _LOCALRELATION._serialized_start = 5455
- _LOCALRELATION._serialized_end = 5544
- _SAMPLE._serialized_start = 5547
- _SAMPLE._serialized_end = 5820
- _RANGE._serialized_start = 5823
- _RANGE._serialized_end = 5968
- _SUBQUERYALIAS._serialized_start = 5970
- _SUBQUERYALIAS._serialized_end = 6084
- _REPARTITION._serialized_start = 6087
- _REPARTITION._serialized_end = 6229
- _SHOWSTRING._serialized_start = 6232
- _SHOWSTRING._serialized_end = 6374
- _STATSUMMARY._serialized_start = 6376
- _STATSUMMARY._serialized_end = 6468
- _STATDESCRIBE._serialized_start = 6470
- _STATDESCRIBE._serialized_end = 6551
- _STATCROSSTAB._serialized_start = 6553
- _STATCROSSTAB._serialized_end = 6654
- _STATCOV._serialized_start = 6656
- _STATCOV._serialized_end = 6752
- _STATCORR._serialized_start = 6755
- _STATCORR._serialized_end = 6892
- _STATAPPROXQUANTILE._serialized_start = 6895
- _STATAPPROXQUANTILE._serialized_end = 7059
- _STATFREQITEMS._serialized_start = 7061
- _STATFREQITEMS._serialized_end = 7186
- _STATSAMPLEBY._serialized_start = 7189
- _STATSAMPLEBY._serialized_end = 7498
- _STATSAMPLEBY_FRACTION._serialized_start = 7390
- _STATSAMPLEBY_FRACTION._serialized_end = 7489
- _NAFILL._serialized_start = 7501
- _NAFILL._serialized_end = 7635
- _NADROP._serialized_start = 7638
- _NADROP._serialized_end = 7772
- _NAREPLACE._serialized_start = 7775
- _NAREPLACE._serialized_end = 8071
- _NAREPLACE_REPLACEMENT._serialized_start = 7930
- _NAREPLACE_REPLACEMENT._serialized_end = 8071
- _TODF._serialized_start = 8073
- _TODF._serialized_end = 8161
- _WITHCOLUMNSRENAMED._serialized_start = 8164
- _WITHCOLUMNSRENAMED._serialized_end = 8403
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8336
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8403
- _WITHCOLUMNS._serialized_start = 8405
- _WITHCOLUMNS._serialized_end = 8524
- _HINT._serialized_start = 8527
- _HINT._serialized_end = 8659
- _UNPIVOT._serialized_start = 8662
- _UNPIVOT._serialized_end = 8908
- _TOSCHEMA._serialized_start = 8910
- _TOSCHEMA._serialized_end = 9016
- _REPARTITIONBYEXPRESSION._serialized_start = 9019
- _REPARTITIONBYEXPRESSION._serialized_end = 9222
+ _SQL._serialized_start = 2527
+ _SQL._serialized_end = 2661
+ _SQL_ARGSENTRY._serialized_start = 2606
+ _SQL_ARGSENTRY._serialized_end = 2661
+ _READ._serialized_start = 2664
+ _READ._serialized_end = 3112
+ _READ_NAMEDTABLE._serialized_start = 2806
+ _READ_NAMEDTABLE._serialized_end = 2867
+ _READ_DATASOURCE._serialized_start = 2870
+ _READ_DATASOURCE._serialized_end = 3099
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3030
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3088
+ _PROJECT._serialized_start = 3114
+ _PROJECT._serialized_end = 3231
+ _FILTER._serialized_start = 3233
+ _FILTER._serialized_end = 3345
+ _JOIN._serialized_start = 3348
+ _JOIN._serialized_end = 3819
+ _JOIN_JOINTYPE._serialized_start = 3611
+ _JOIN_JOINTYPE._serialized_end = 3819
+ _SETOPERATION._serialized_start = 3822
+ _SETOPERATION._serialized_end = 4301
+ _SETOPERATION_SETOPTYPE._serialized_start = 4138
+ _SETOPERATION_SETOPTYPE._serialized_end = 4252
+ _LIMIT._serialized_start = 4303
+ _LIMIT._serialized_end = 4379
+ _OFFSET._serialized_start = 4381
+ _OFFSET._serialized_end = 4460
+ _TAIL._serialized_start = 4462
+ _TAIL._serialized_end = 4537
+ _AGGREGATE._serialized_start = 4540
+ _AGGREGATE._serialized_end = 5122
+ _AGGREGATE_PIVOT._serialized_start = 4879
+ _AGGREGATE_PIVOT._serialized_end = 4990
+ _AGGREGATE_GROUPTYPE._serialized_start = 4993
+ _AGGREGATE_GROUPTYPE._serialized_end = 5122
+ _SORT._serialized_start = 5125
+ _SORT._serialized_end = 5285
+ _DROP._serialized_start = 5287
+ _DROP._serialized_end = 5387
+ _DEDUPLICATE._serialized_start = 5390
+ _DEDUPLICATE._serialized_end = 5561
+ _LOCALRELATION._serialized_start = 5563
+ _LOCALRELATION._serialized_end = 5652
+ _SAMPLE._serialized_start = 5655
+ _SAMPLE._serialized_end = 5928
+ _RANGE._serialized_start = 5931
+ _RANGE._serialized_end = 6076
+ _SUBQUERYALIAS._serialized_start = 6078
+ _SUBQUERYALIAS._serialized_end = 6192
+ _REPARTITION._serialized_start = 6195
+ _REPARTITION._serialized_end = 6337
+ _SHOWSTRING._serialized_start = 6340
+ _SHOWSTRING._serialized_end = 6482
+ _STATSUMMARY._serialized_start = 6484
+ _STATSUMMARY._serialized_end = 6576
+ _STATDESCRIBE._serialized_start = 6578
+ _STATDESCRIBE._serialized_end = 6659
+ _STATCROSSTAB._serialized_start = 6661
+ _STATCROSSTAB._serialized_end = 6762
+ _STATCOV._serialized_start = 6764
+ _STATCOV._serialized_end = 6860
+ _STATCORR._serialized_start = 6863
+ _STATCORR._serialized_end = 7000
+ _STATAPPROXQUANTILE._serialized_start = 7003
+ _STATAPPROXQUANTILE._serialized_end = 7167
+ _STATFREQITEMS._serialized_start = 7169
+ _STATFREQITEMS._serialized_end = 7294
+ _STATSAMPLEBY._serialized_start = 7297
+ _STATSAMPLEBY._serialized_end = 7606
+ _STATSAMPLEBY_FRACTION._serialized_start = 7498
+ _STATSAMPLEBY_FRACTION._serialized_end = 7597
+ _NAFILL._serialized_start = 7609
+ _NAFILL._serialized_end = 7743
+ _NADROP._serialized_start = 7746
+ _NADROP._serialized_end = 7880
+ _NAREPLACE._serialized_start = 7883
+ _NAREPLACE._serialized_end = 8179
+ _NAREPLACE_REPLACEMENT._serialized_start = 8038
+ _NAREPLACE_REPLACEMENT._serialized_end = 8179
+ _TODF._serialized_start = 8181
+ _TODF._serialized_end = 8269
+ _WITHCOLUMNSRENAMED._serialized_start = 8272
+ _WITHCOLUMNSRENAMED._serialized_end = 8511
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8444
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8511
+ _WITHCOLUMNS._serialized_start = 8513
+ _WITHCOLUMNS._serialized_end = 8632
+ _HINT._serialized_start = 8635
+ _HINT._serialized_end = 8767
+ _UNPIVOT._serialized_start = 8770
+ _UNPIVOT._serialized_end = 9016
+ _TOSCHEMA._serialized_start = 9018
+ _TOSCHEMA._serialized_end = 9124
+ _REPARTITIONBYEXPRESSION._serialized_start = 9127
+ _REPARTITIONBYEXPRESSION._serialized_end = 9330
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 59c567c962f..647b26b6d31 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -496,15 +496,39 @@ class SQL(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
+ class ArgsEntry(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ KEY_FIELD_NUMBER: builtins.int
+ VALUE_FIELD_NUMBER: builtins.int
+ key: builtins.str
+ value: builtins.str
+ def __init__(
+ self,
+ *,
+ key: builtins.str = ...,
+ value: builtins.str = ...,
+ ) -> None: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["key", b"key",
"value", b"value"]
+ ) -> None: ...
+
QUERY_FIELD_NUMBER: builtins.int
+ ARGS_FIELD_NUMBER: builtins.int
query: builtins.str
"""(Required) The SQL query."""
+ @property
+ def args(self) ->
google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]:
+ """(Optional) A map of parameter names to literal values."""
def __init__(
self,
*,
query: builtins.str = ...,
+ args: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
+ ) -> None: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["args", b"args", "query",
b"query"]
) -> None: ...
- def ClearField(self, field_name: typing_extensions.Literal["query",
b"query"]) -> None: ...
global___SQL = SQL
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 3c44d06bb1c..75c8e61752e 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -341,8 +341,8 @@ class SparkSession:
createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__
- def sql(self, sqlQuery: str) -> "DataFrame":
- return DataFrame.withPlan(SQL(sqlQuery), self)
+ def sql(self, sqlQuery: str, args: Optional[Dict[str, str]] = None) ->
"DataFrame":
+ return DataFrame.withPlan(SQL(sqlQuery, args), self)
sql.__doc__ = PySparkSession.sql.__doc__
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 38c93b2d0ac..7019210a4d8 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -1308,7 +1308,7 @@ class SparkSession(SparkConversionMixin):
df._schema = struct
return df
- def sql(self, sqlQuery: str, args: Dict[str, str] = {}, **kwargs: Any) ->
DataFrame:
+ def sql(self, sqlQuery: str, args: Optional[Dict[str, str]] = None,
**kwargs: Any) -> DataFrame:
"""Returns a :class:`DataFrame` representing the result of the given
query.
When ``kwargs`` is specified, this method formats the given string by
using the Python
standard formatter. The method binds named parameters to SQL literals
from `args`.
@@ -1416,7 +1416,7 @@ class SparkSession(SparkConversionMixin):
if len(kwargs) > 0:
sqlQuery = formatter.format(sqlQuery, **kwargs)
try:
- return DataFrame(self._jsparkSession.sql(sqlQuery, args), self)
+ return DataFrame(self._jsparkSession.sql(sqlQuery, args or {}),
self)
finally:
if len(kwargs) > 0:
formatter.clear()
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index b3b241b2d4e..a723163cbe8 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1105,6 +1105,11 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
pdf = self.connect.sql("SELECT 1").toPandas()
self.assertEqual(1, len(pdf.index))
+ def test_sql_with_args(self):
+ df = self.connect.sql("SELECT * FROM range(10) WHERE id > :minId",
args={"minId": "7"})
+ df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > :minId",
args={"minId": "7"})
+ self.assert_eq(df.toPandas(), df2.toPandas())
+
def test_head(self):
# SPARK-41002: test `head` API in Python Client
df = self.connect.read.table(self.tbl_name)
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py
b/python/pyspark/sql/tests/connect/test_connect_plan.py
index 026f286cc56..980f61eb4b1 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -647,7 +647,8 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
def test_print(self):
# SPARK-41717: test print
self.assertEqual(
- self.connect.sql("SELECT 1")._plan.print().strip(), "<SQL
query='SELECT 1'>"
+ self.connect.sql("SELECT 1")._plan.print().strip(),
+ "<SQL query='SELECT 1', args='None'>",
)
self.assertEqual(
self.connect.range(1, 10)._plan.print().strip(),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]