This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a2bab5efc5b [SPARK-45235][CONNECT][PYTHON] Support map and array
parameters by `sql()`
a2bab5efc5b is described below
commit a2bab5efc5b5f0e841e9b34ccbfd2cb99af5923e
Author: Max Gekk <[email protected]>
AuthorDate: Thu Sep 21 09:05:30 2023 +0300
[SPARK-45235][CONNECT][PYTHON] Support map and array parameters by `sql()`
### What changes were proposed in this pull request?
In the PR, I propose to change the Python connect client to support
`Column` as a parameter of `sql()`.
### Why are the changes needed?
To achieve feature parity w/ regular PySpark which supports map and arrays
as parameters of `sql()`, see https://github.com/apache/spark/pull/42996.
### Does this PR introduce _any_ user-facing change?
No. It fixes a bug.
### How was this patch tested?
By running the modified tests:
```
$ python/run-tests --parallelism=1 --testnames
'pyspark.sql.tests.connect.test_connect_basic
SparkConnectBasicTests.test_sql_with_named_args'
$ python/run-tests --parallelism=1 --testnames
'pyspark.sql.tests.connect.test_connect_basic
SparkConnectBasicTests.test_sql_with_pos_args'
```
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43014 from MaxGekk/map-sql-parameterized-python-connect-2.
Authored-by: Max Gekk <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
python/pyspark/sql/connect/plan.py | 22 ++++++++++------------
python/pyspark/sql/connect/session.py | 2 +-
.../sql/tests/connect/test_connect_basic.py | 12 ++++++++----
3 files changed, 19 insertions(+), 17 deletions(-)
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 3e8db2aae09..d069081e1af 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1049,6 +1049,12 @@ class SQL(LogicalPlan):
self._query = query
self._args = args
+ def _to_expr(self, session: "SparkConnectClient", v: Any) ->
proto.Expression:
+ if isinstance(v, Column):
+ return v.to_plan(session)
+ else:
+ return LiteralExpression._from_value(v).to_plan(session)
+
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.sql.query = self._query
@@ -1056,14 +1062,10 @@ class SQL(LogicalPlan):
if self._args is not None and len(self._args) > 0:
if isinstance(self._args, Dict):
for k, v in self._args.items():
- plan.sql.args[k].CopyFrom(
-
LiteralExpression._from_value(v).to_plan(session).literal
- )
+
plan.sql.named_arguments[k].CopyFrom(self._to_expr(session, v))
else:
for v in self._args:
- plan.sql.pos_args.append(
-
LiteralExpression._from_value(v).to_plan(session).literal
- )
+ plan.sql.pos_arguments.append(self._to_expr(session, v))
return plan
@@ -1073,14 +1075,10 @@ class SQL(LogicalPlan):
if self._args is not None and len(self._args) > 0:
if isinstance(self._args, Dict):
for k, v in self._args.items():
- cmd.sql_command.args[k].CopyFrom(
-
LiteralExpression._from_value(v).to_plan(session).literal
- )
+
cmd.sql_command.named_arguments[k].CopyFrom(self._to_expr(session, v))
else:
for v in self._args:
- cmd.sql_command.pos_args.append(
-
LiteralExpression._from_value(v).to_plan(session).literal
- )
+
cmd.sql_command.pos_arguments.append(self._to_expr(session, v))
return cmd
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 7582fe86ff2..e5d1d95a699 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -557,7 +557,7 @@ class SparkSession:
if "sql_command_result" in properties:
return
DataFrame.withPlan(CachedRelation(properties["sql_command_result"]), self)
else:
- return DataFrame.withPlan(SQL(sqlQuery, args), self)
+ return DataFrame.withPlan(cmd, self)
sql.__doc__ = PySparkSession.sql.__doc__
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 2b979570618..c5a127136d6 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1237,13 +1237,17 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.assertEqual(1, len(pdf.index))
def test_sql_with_named_args(self):
- df = self.connect.sql("SELECT * FROM range(10) WHERE id > :minId",
args={"minId": 7})
- df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > :minId",
args={"minId": 7})
+ sqlText = "SELECT *, element_at(:m, 'a') FROM range(10) WHERE id >
:minId"
+ df = self.connect.sql(
+ sqlText, args={"minId": 7, "m": CF.create_map(CF.lit("a"),
CF.lit(1))}
+ )
+ df2 = self.spark.sql(sqlText, args={"minId": 7, "m":
SF.create_map(SF.lit("a"), SF.lit(1))})
self.assert_eq(df.toPandas(), df2.toPandas())
def test_sql_with_pos_args(self):
- df = self.connect.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
- df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > ?", args=[7])
+ sqlText = "SELECT *, element_at(?, 1) FROM range(10) WHERE id > ?"
+ df = self.connect.sql(sqlText, args=[CF.array(CF.lit(1)), 7])
+ df2 = self.spark.sql(sqlText, args=[SF.array(SF.lit(1)), 7])
self.assert_eq(df.toPandas(), df2.toPandas())
def test_head(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]