This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 3b22267e3fe [SPARK-42402][CONNECT] Support parameterized SQL by `sql()`
3b22267e3fe is described below

commit 3b22267e3fe6adde8f1d50e99b3f06c86ec81ad8
Author: Takuya UESHIN <[email protected]>
AuthorDate: Sun Feb 12 09:05:02 2023 +0900

    [SPARK-42402][CONNECT] Support parameterized SQL by `sql()`
    
    ### What changes were proposed in this pull request?
    
    Supports parameterized SQL by `sql()`.
    
    Note: `SparkSession.sql` in PySpark also supports string formatter, but it 
will be handled separately.
    
    ### Why are the changes needed?
    
    Currently `SparkSession.sql` in Spark Connect doesn't support parameterized 
SQL.
    
    ### Does this PR introduce _any_ user-facing change?
    
    The parameterized SQL will be available.
    
    For example:
    
    ```py
    >>> spark.sql("SELECT * FROM range(10) WHERE id > :minId", args = {"minId" 
: "7"}).toPandas()
       id
    0   8
    1   9
    ```
    
    ### How was this patch tested?
    
    Added a test.
    
    Closes #39971 from ueshin/issues/SPARK-42402/parameterized_sql.
    
    Authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit 03227e18793c6902a816bbacdce78031ce37b14a)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../main/protobuf/spark/connect/relations.proto    |   3 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |   5 +-
 python/pyspark/sql/connect/plan.py                 |  14 +-
 python/pyspark/sql/connect/proto/relations_pb2.py  | 205 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  26 ++-
 python/pyspark/sql/connect/session.py              |   4 +-
 python/pyspark/sql/session.py                      |   4 +-
 .../sql/tests/connect/test_connect_basic.py        |   5 +
 .../pyspark/sql/tests/connect/test_connect_plan.py |   3 +-
 9 files changed, 166 insertions(+), 103 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 837727432f9..3d597fd2744 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -99,6 +99,9 @@ message RelationCommon {
 message SQL {
   // (Required) The SQL query.
   string query = 1;
+
+  // (Optional) A map of parameter names to literal values.
+  map<string, string> args = 2;
 }
 
 // Relation that reads from a file / table or other data source. Does not have 
additional
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 3bf5d2b1d30..75581851b5f 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -196,7 +196,10 @@ class SparkConnectPlanner(val session: SparkSession) {
   }
 
   private def transformSql(sql: proto.SQL): LogicalPlan = {
-    session.sessionState.sqlParser.parsePlan(sql.getQuery)
+    val args = sql.getArgsMap.asScala.toMap
+    val parser = session.sessionState.sqlParser
+    val parsedArgs = args.mapValues(parser.parseExpression).toMap
+    Parameter.bind(parser.parsePlan(sql.getQuery), parsedArgs)
   }
 
   private def transformSubqueryAlias(alias: proto.SubqueryAlias): LogicalPlan 
= {
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 39b32f065ea..4675631627a 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -919,13 +919,25 @@ class SubqueryAlias(LogicalPlan):
 
 
 class SQL(LogicalPlan):
-    def __init__(self, query: str) -> None:
+    def __init__(self, query: str, args: Optional[Dict[str, str]] = None) -> 
None:
         super().__init__(None)
+
+        if args is not None:
+            for k, v in args.items():
+                assert isinstance(k, str)
+                assert isinstance(v, str)
+
         self._query = query
+        self._args = args
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         rel = proto.Relation()
         rel.sql.query = self._query
+
+        if self._args is not None and len(self._args) > 0:
+            for k, v in self._args.items():
+                rel.sql.args[k] = v
+
         return rel
 
 
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 4ef9a7278f9..1a24628ef30 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as 
spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -44,6 +44,7 @@ _RELATION = DESCRIPTOR.message_types_by_name["Relation"]
 _UNKNOWN = DESCRIPTOR.message_types_by_name["Unknown"]
 _RELATIONCOMMON = DESCRIPTOR.message_types_by_name["RelationCommon"]
 _SQL = DESCRIPTOR.message_types_by_name["SQL"]
+_SQL_ARGSENTRY = _SQL.nested_types_by_name["ArgsEntry"]
 _READ = DESCRIPTOR.message_types_by_name["Read"]
 _READ_NAMEDTABLE = _READ.nested_types_by_name["NamedTable"]
 _READ_DATASOURCE = _READ.nested_types_by_name["DataSource"]
@@ -129,12 +130,22 @@ SQL = _reflection.GeneratedProtocolMessageType(
     "SQL",
     (_message.Message,),
     {
+        "ArgsEntry": _reflection.GeneratedProtocolMessageType(
+            "ArgsEntry",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _SQL_ARGSENTRY,
+                "__module__": "spark.connect.relations_pb2"
+                # 
@@protoc_insertion_point(class_scope:spark.connect.SQL.ArgsEntry)
+            },
+        ),
         "DESCRIPTOR": _SQL,
         "__module__": "spark.connect.relations_pb2"
         # @@protoc_insertion_point(class_scope:spark.connect.SQL)
     },
 )
 _sym_db.RegisterMessage(SQL)
+_sym_db.RegisterMessage(SQL.ArgsEntry)
 
 Read = _reflection.GeneratedProtocolMessageType(
     "Read",
@@ -606,6 +617,8 @@ if _descriptor._USE_C_DESCRIPTORS == False:
 
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001"
+    _SQL_ARGSENTRY._options = None
+    _SQL_ARGSENTRY._serialized_options = b"8\001"
     _READ_DATASOURCE_OPTIONSENTRY._options = None
     _READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001"
     _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._options = None
@@ -616,98 +629,100 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _UNKNOWN._serialized_end = 2473
     _RELATIONCOMMON._serialized_start = 2475
     _RELATIONCOMMON._serialized_end = 2524
-    _SQL._serialized_start = 2526
-    _SQL._serialized_end = 2553
-    _READ._serialized_start = 2556
-    _READ._serialized_end = 3004
-    _READ_NAMEDTABLE._serialized_start = 2698
-    _READ_NAMEDTABLE._serialized_end = 2759
-    _READ_DATASOURCE._serialized_start = 2762
-    _READ_DATASOURCE._serialized_end = 2991
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2922
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2980
-    _PROJECT._serialized_start = 3006
-    _PROJECT._serialized_end = 3123
-    _FILTER._serialized_start = 3125
-    _FILTER._serialized_end = 3237
-    _JOIN._serialized_start = 3240
-    _JOIN._serialized_end = 3711
-    _JOIN_JOINTYPE._serialized_start = 3503
-    _JOIN_JOINTYPE._serialized_end = 3711
-    _SETOPERATION._serialized_start = 3714
-    _SETOPERATION._serialized_end = 4193
-    _SETOPERATION_SETOPTYPE._serialized_start = 4030
-    _SETOPERATION_SETOPTYPE._serialized_end = 4144
-    _LIMIT._serialized_start = 4195
-    _LIMIT._serialized_end = 4271
-    _OFFSET._serialized_start = 4273
-    _OFFSET._serialized_end = 4352
-    _TAIL._serialized_start = 4354
-    _TAIL._serialized_end = 4429
-    _AGGREGATE._serialized_start = 4432
-    _AGGREGATE._serialized_end = 5014
-    _AGGREGATE_PIVOT._serialized_start = 4771
-    _AGGREGATE_PIVOT._serialized_end = 4882
-    _AGGREGATE_GROUPTYPE._serialized_start = 4885
-    _AGGREGATE_GROUPTYPE._serialized_end = 5014
-    _SORT._serialized_start = 5017
-    _SORT._serialized_end = 5177
-    _DROP._serialized_start = 5179
-    _DROP._serialized_end = 5279
-    _DEDUPLICATE._serialized_start = 5282
-    _DEDUPLICATE._serialized_end = 5453
-    _LOCALRELATION._serialized_start = 5455
-    _LOCALRELATION._serialized_end = 5544
-    _SAMPLE._serialized_start = 5547
-    _SAMPLE._serialized_end = 5820
-    _RANGE._serialized_start = 5823
-    _RANGE._serialized_end = 5968
-    _SUBQUERYALIAS._serialized_start = 5970
-    _SUBQUERYALIAS._serialized_end = 6084
-    _REPARTITION._serialized_start = 6087
-    _REPARTITION._serialized_end = 6229
-    _SHOWSTRING._serialized_start = 6232
-    _SHOWSTRING._serialized_end = 6374
-    _STATSUMMARY._serialized_start = 6376
-    _STATSUMMARY._serialized_end = 6468
-    _STATDESCRIBE._serialized_start = 6470
-    _STATDESCRIBE._serialized_end = 6551
-    _STATCROSSTAB._serialized_start = 6553
-    _STATCROSSTAB._serialized_end = 6654
-    _STATCOV._serialized_start = 6656
-    _STATCOV._serialized_end = 6752
-    _STATCORR._serialized_start = 6755
-    _STATCORR._serialized_end = 6892
-    _STATAPPROXQUANTILE._serialized_start = 6895
-    _STATAPPROXQUANTILE._serialized_end = 7059
-    _STATFREQITEMS._serialized_start = 7061
-    _STATFREQITEMS._serialized_end = 7186
-    _STATSAMPLEBY._serialized_start = 7189
-    _STATSAMPLEBY._serialized_end = 7498
-    _STATSAMPLEBY_FRACTION._serialized_start = 7390
-    _STATSAMPLEBY_FRACTION._serialized_end = 7489
-    _NAFILL._serialized_start = 7501
-    _NAFILL._serialized_end = 7635
-    _NADROP._serialized_start = 7638
-    _NADROP._serialized_end = 7772
-    _NAREPLACE._serialized_start = 7775
-    _NAREPLACE._serialized_end = 8071
-    _NAREPLACE_REPLACEMENT._serialized_start = 7930
-    _NAREPLACE_REPLACEMENT._serialized_end = 8071
-    _TODF._serialized_start = 8073
-    _TODF._serialized_end = 8161
-    _WITHCOLUMNSRENAMED._serialized_start = 8164
-    _WITHCOLUMNSRENAMED._serialized_end = 8403
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8336
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8403
-    _WITHCOLUMNS._serialized_start = 8405
-    _WITHCOLUMNS._serialized_end = 8524
-    _HINT._serialized_start = 8527
-    _HINT._serialized_end = 8659
-    _UNPIVOT._serialized_start = 8662
-    _UNPIVOT._serialized_end = 8908
-    _TOSCHEMA._serialized_start = 8910
-    _TOSCHEMA._serialized_end = 9016
-    _REPARTITIONBYEXPRESSION._serialized_start = 9019
-    _REPARTITIONBYEXPRESSION._serialized_end = 9222
+    _SQL._serialized_start = 2527
+    _SQL._serialized_end = 2661
+    _SQL_ARGSENTRY._serialized_start = 2606
+    _SQL_ARGSENTRY._serialized_end = 2661
+    _READ._serialized_start = 2664
+    _READ._serialized_end = 3112
+    _READ_NAMEDTABLE._serialized_start = 2806
+    _READ_NAMEDTABLE._serialized_end = 2867
+    _READ_DATASOURCE._serialized_start = 2870
+    _READ_DATASOURCE._serialized_end = 3099
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3030
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3088
+    _PROJECT._serialized_start = 3114
+    _PROJECT._serialized_end = 3231
+    _FILTER._serialized_start = 3233
+    _FILTER._serialized_end = 3345
+    _JOIN._serialized_start = 3348
+    _JOIN._serialized_end = 3819
+    _JOIN_JOINTYPE._serialized_start = 3611
+    _JOIN_JOINTYPE._serialized_end = 3819
+    _SETOPERATION._serialized_start = 3822
+    _SETOPERATION._serialized_end = 4301
+    _SETOPERATION_SETOPTYPE._serialized_start = 4138
+    _SETOPERATION_SETOPTYPE._serialized_end = 4252
+    _LIMIT._serialized_start = 4303
+    _LIMIT._serialized_end = 4379
+    _OFFSET._serialized_start = 4381
+    _OFFSET._serialized_end = 4460
+    _TAIL._serialized_start = 4462
+    _TAIL._serialized_end = 4537
+    _AGGREGATE._serialized_start = 4540
+    _AGGREGATE._serialized_end = 5122
+    _AGGREGATE_PIVOT._serialized_start = 4879
+    _AGGREGATE_PIVOT._serialized_end = 4990
+    _AGGREGATE_GROUPTYPE._serialized_start = 4993
+    _AGGREGATE_GROUPTYPE._serialized_end = 5122
+    _SORT._serialized_start = 5125
+    _SORT._serialized_end = 5285
+    _DROP._serialized_start = 5287
+    _DROP._serialized_end = 5387
+    _DEDUPLICATE._serialized_start = 5390
+    _DEDUPLICATE._serialized_end = 5561
+    _LOCALRELATION._serialized_start = 5563
+    _LOCALRELATION._serialized_end = 5652
+    _SAMPLE._serialized_start = 5655
+    _SAMPLE._serialized_end = 5928
+    _RANGE._serialized_start = 5931
+    _RANGE._serialized_end = 6076
+    _SUBQUERYALIAS._serialized_start = 6078
+    _SUBQUERYALIAS._serialized_end = 6192
+    _REPARTITION._serialized_start = 6195
+    _REPARTITION._serialized_end = 6337
+    _SHOWSTRING._serialized_start = 6340
+    _SHOWSTRING._serialized_end = 6482
+    _STATSUMMARY._serialized_start = 6484
+    _STATSUMMARY._serialized_end = 6576
+    _STATDESCRIBE._serialized_start = 6578
+    _STATDESCRIBE._serialized_end = 6659
+    _STATCROSSTAB._serialized_start = 6661
+    _STATCROSSTAB._serialized_end = 6762
+    _STATCOV._serialized_start = 6764
+    _STATCOV._serialized_end = 6860
+    _STATCORR._serialized_start = 6863
+    _STATCORR._serialized_end = 7000
+    _STATAPPROXQUANTILE._serialized_start = 7003
+    _STATAPPROXQUANTILE._serialized_end = 7167
+    _STATFREQITEMS._serialized_start = 7169
+    _STATFREQITEMS._serialized_end = 7294
+    _STATSAMPLEBY._serialized_start = 7297
+    _STATSAMPLEBY._serialized_end = 7606
+    _STATSAMPLEBY_FRACTION._serialized_start = 7498
+    _STATSAMPLEBY_FRACTION._serialized_end = 7597
+    _NAFILL._serialized_start = 7609
+    _NAFILL._serialized_end = 7743
+    _NADROP._serialized_start = 7746
+    _NADROP._serialized_end = 7880
+    _NAREPLACE._serialized_start = 7883
+    _NAREPLACE._serialized_end = 8179
+    _NAREPLACE_REPLACEMENT._serialized_start = 8038
+    _NAREPLACE_REPLACEMENT._serialized_end = 8179
+    _TODF._serialized_start = 8181
+    _TODF._serialized_end = 8269
+    _WITHCOLUMNSRENAMED._serialized_start = 8272
+    _WITHCOLUMNSRENAMED._serialized_end = 8511
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8444
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8511
+    _WITHCOLUMNS._serialized_start = 8513
+    _WITHCOLUMNS._serialized_end = 8632
+    _HINT._serialized_start = 8635
+    _HINT._serialized_end = 8767
+    _UNPIVOT._serialized_start = 8770
+    _UNPIVOT._serialized_end = 9016
+    _TOSCHEMA._serialized_start = 9018
+    _TOSCHEMA._serialized_end = 9124
+    _REPARTITIONBYEXPRESSION._serialized_start = 9127
+    _REPARTITIONBYEXPRESSION._serialized_end = 9330
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 59c567c962f..647b26b6d31 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -496,15 +496,39 @@ class SQL(google.protobuf.message.Message):
 
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
+    class ArgsEntry(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        KEY_FIELD_NUMBER: builtins.int
+        VALUE_FIELD_NUMBER: builtins.int
+        key: builtins.str
+        value: builtins.str
+        def __init__(
+            self,
+            *,
+            key: builtins.str = ...,
+            value: builtins.str = ...,
+        ) -> None: ...
+        def ClearField(
+            self, field_name: typing_extensions.Literal["key", b"key", 
"value", b"value"]
+        ) -> None: ...
+
     QUERY_FIELD_NUMBER: builtins.int
+    ARGS_FIELD_NUMBER: builtins.int
     query: builtins.str
     """(Required) The SQL query."""
+    @property
+    def args(self) -> 
google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]:
+        """(Optional) A map of parameter names to literal values."""
     def __init__(
         self,
         *,
         query: builtins.str = ...,
+        args: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
+    ) -> None: ...
+    def ClearField(
+        self, field_name: typing_extensions.Literal["args", b"args", "query", 
b"query"]
     ) -> None: ...
-    def ClearField(self, field_name: typing_extensions.Literal["query", 
b"query"]) -> None: ...
 
 global___SQL = SQL
 
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 3c44d06bb1c..75c8e61752e 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -341,8 +341,8 @@ class SparkSession:
 
     createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__
 
-    def sql(self, sqlQuery: str) -> "DataFrame":
-        return DataFrame.withPlan(SQL(sqlQuery), self)
+    def sql(self, sqlQuery: str, args: Optional[Dict[str, str]] = None) -> 
"DataFrame":
+        return DataFrame.withPlan(SQL(sqlQuery, args), self)
 
     sql.__doc__ = PySparkSession.sql.__doc__
 
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 38c93b2d0ac..7019210a4d8 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -1308,7 +1308,7 @@ class SparkSession(SparkConversionMixin):
         df._schema = struct
         return df
 
-    def sql(self, sqlQuery: str, args: Dict[str, str] = {}, **kwargs: Any) -> 
DataFrame:
+    def sql(self, sqlQuery: str, args: Optional[Dict[str, str]] = None, 
**kwargs: Any) -> DataFrame:
         """Returns a :class:`DataFrame` representing the result of the given 
query.
         When ``kwargs`` is specified, this method formats the given string by 
using the Python
         standard formatter. The method binds named parameters to SQL literals 
from `args`.
@@ -1416,7 +1416,7 @@ class SparkSession(SparkConversionMixin):
         if len(kwargs) > 0:
             sqlQuery = formatter.format(sqlQuery, **kwargs)
         try:
-            return DataFrame(self._jsparkSession.sql(sqlQuery, args), self)
+            return DataFrame(self._jsparkSession.sql(sqlQuery, args or {}), 
self)
         finally:
             if len(kwargs) > 0:
                 formatter.clear()
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index b3b241b2d4e..a723163cbe8 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1105,6 +1105,11 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         pdf = self.connect.sql("SELECT 1").toPandas()
         self.assertEqual(1, len(pdf.index))
 
+    def test_sql_with_args(self):
+        df = self.connect.sql("SELECT * FROM range(10) WHERE id > :minId", 
args={"minId": "7"})
+        df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > :minId", 
args={"minId": "7"})
+        self.assert_eq(df.toPandas(), df2.toPandas())
+
     def test_head(self):
         # SPARK-41002: test `head` API in Python Client
         df = self.connect.read.table(self.tbl_name)
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py 
b/python/pyspark/sql/tests/connect/test_connect_plan.py
index 026f286cc56..980f61eb4b1 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -647,7 +647,8 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
     def test_print(self):
         # SPARK-41717: test print
         self.assertEqual(
-            self.connect.sql("SELECT 1")._plan.print().strip(), "<SQL 
query='SELECT 1'>"
+            self.connect.sql("SELECT 1")._plan.print().strip(),
+            "<SQL query='SELECT 1', args='None'>",
         )
         self.assertEqual(
             self.connect.range(1, 10)._plan.print().strip(),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to