This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a2bab5efc5b [SPARK-45235][CONNECT][PYTHON] Support map and array 
parameters by `sql()`
a2bab5efc5b is described below

commit a2bab5efc5b5f0e841e9b34ccbfd2cb99af5923e
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Thu Sep 21 09:05:30 2023 +0300

    [SPARK-45235][CONNECT][PYTHON] Support map and array parameters by `sql()`
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to change the Python connect client to support 
`Column` as a parameter of `sql()`.
    
    ### Why are the changes needed?
    To achieve feature parity w/ regular PySpark which supports map and arrays 
as parameters of `sql()`, see https://github.com/apache/spark/pull/42996.
    
    ### Does this PR introduce _any_ user-facing change?
    No. It fixes a bug.
    
    ### How was this patch tested?
    By running the modified tests:
    ```
    $ python/run-tests --parallelism=1 --testnames 
'pyspark.sql.tests.connect.test_connect_basic 
SparkConnectBasicTests.test_sql_with_named_args'
    $ python/run-tests --parallelism=1 --testnames 
'pyspark.sql.tests.connect.test_connect_basic 
SparkConnectBasicTests.test_sql_with_pos_args'
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #43014 from MaxGekk/map-sql-parameterized-python-connect-2.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 python/pyspark/sql/connect/plan.py                 | 22 ++++++++++------------
 python/pyspark/sql/connect/session.py              |  2 +-
 .../sql/tests/connect/test_connect_basic.py        | 12 ++++++++----
 3 files changed, 19 insertions(+), 17 deletions(-)

diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 3e8db2aae09..d069081e1af 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1049,6 +1049,12 @@ class SQL(LogicalPlan):
         self._query = query
         self._args = args
 
+    def _to_expr(self, session: "SparkConnectClient", v: Any) -> 
proto.Expression:
+        if isinstance(v, Column):
+            return v.to_plan(session)
+        else:
+            return LiteralExpression._from_value(v).to_plan(session)
+
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         plan = self._create_proto_relation()
         plan.sql.query = self._query
@@ -1056,14 +1062,10 @@ class SQL(LogicalPlan):
         if self._args is not None and len(self._args) > 0:
             if isinstance(self._args, Dict):
                 for k, v in self._args.items():
-                    plan.sql.args[k].CopyFrom(
-                        
LiteralExpression._from_value(v).to_plan(session).literal
-                    )
+                    
plan.sql.named_arguments[k].CopyFrom(self._to_expr(session, v))
             else:
                 for v in self._args:
-                    plan.sql.pos_args.append(
-                        
LiteralExpression._from_value(v).to_plan(session).literal
-                    )
+                    plan.sql.pos_arguments.append(self._to_expr(session, v))
 
         return plan
 
@@ -1073,14 +1075,10 @@ class SQL(LogicalPlan):
         if self._args is not None and len(self._args) > 0:
             if isinstance(self._args, Dict):
                 for k, v in self._args.items():
-                    cmd.sql_command.args[k].CopyFrom(
-                        
LiteralExpression._from_value(v).to_plan(session).literal
-                    )
+                    
cmd.sql_command.named_arguments[k].CopyFrom(self._to_expr(session, v))
             else:
                 for v in self._args:
-                    cmd.sql_command.pos_args.append(
-                        
LiteralExpression._from_value(v).to_plan(session).literal
-                    )
+                    
cmd.sql_command.pos_arguments.append(self._to_expr(session, v))
 
         return cmd
 
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 7582fe86ff2..e5d1d95a699 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -557,7 +557,7 @@ class SparkSession:
         if "sql_command_result" in properties:
             return 
DataFrame.withPlan(CachedRelation(properties["sql_command_result"]), self)
         else:
-            return DataFrame.withPlan(SQL(sqlQuery, args), self)
+            return DataFrame.withPlan(cmd, self)
 
     sql.__doc__ = PySparkSession.sql.__doc__
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 2b979570618..c5a127136d6 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1237,13 +1237,17 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         self.assertEqual(1, len(pdf.index))
 
     def test_sql_with_named_args(self):
-        df = self.connect.sql("SELECT * FROM range(10) WHERE id > :minId", 
args={"minId": 7})
-        df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > :minId", 
args={"minId": 7})
+        sqlText = "SELECT *, element_at(:m, 'a') FROM range(10) WHERE id > 
:minId"
+        df = self.connect.sql(
+            sqlText, args={"minId": 7, "m": CF.create_map(CF.lit("a"), 
CF.lit(1))}
+        )
+        df2 = self.spark.sql(sqlText, args={"minId": 7, "m": 
SF.create_map(SF.lit("a"), SF.lit(1))})
         self.assert_eq(df.toPandas(), df2.toPandas())
 
     def test_sql_with_pos_args(self):
-        df = self.connect.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
-        df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
+        sqlText = "SELECT *, element_at(?, 1) FROM range(10) WHERE id > ?"
+        df = self.connect.sql(sqlText, args=[CF.array(CF.lit(1)), 7])
+        df2 = self.spark.sql(sqlText, args=[SF.array(SF.lit(1)), 7])
         self.assert_eq(df.toPandas(), df2.toPandas())
 
     def test_head(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to