This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a2bab5efc5b [SPARK-45235][CONNECT][PYTHON] Support map and array parameters by `sql()` a2bab5efc5b is described below commit a2bab5efc5b5f0e841e9b34ccbfd2cb99af5923e Author: Max Gekk <max.g...@gmail.com> AuthorDate: Thu Sep 21 09:05:30 2023 +0300 [SPARK-45235][CONNECT][PYTHON] Support map and array parameters by `sql()` ### What changes were proposed in this pull request? In the PR, I propose to change the Python connect client to support `Column` as a parameter of `sql()`. ### Why are the changes needed? To achieve feature parity w/ regular PySpark which supports map and arrays as parameters of `sql()`, see https://github.com/apache/spark/pull/42996. ### Does this PR introduce _any_ user-facing change? No. It fixes a bug. ### How was this patch tested? By running the modified tests: ``` $ python/run-tests --parallelism=1 --testnames 'pyspark.sql.tests.connect.test_connect_basic SparkConnectBasicTests.test_sql_with_named_args' $ python/run-tests --parallelism=1 --testnames 'pyspark.sql.tests.connect.test_connect_basic SparkConnectBasicTests.test_sql_with_pos_args' ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43014 from MaxGekk/map-sql-parameterized-python-connect-2. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- python/pyspark/sql/connect/plan.py | 22 ++++++++++------------ python/pyspark/sql/connect/session.py | 2 +- .../sql/tests/connect/test_connect_basic.py | 12 ++++++++---- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 3e8db2aae09..d069081e1af 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1049,6 +1049,12 @@ class SQL(LogicalPlan): self._query = query self._args = args + def _to_expr(self, session: "SparkConnectClient", v: Any) -> proto.Expression: + if isinstance(v, Column): + return v.to_plan(session) + else: + return LiteralExpression._from_value(v).to_plan(session) + def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.sql.query = self._query @@ -1056,14 +1062,10 @@ class SQL(LogicalPlan): if self._args is not None and len(self._args) > 0: if isinstance(self._args, Dict): for k, v in self._args.items(): - plan.sql.args[k].CopyFrom( - LiteralExpression._from_value(v).to_plan(session).literal - ) + plan.sql.named_arguments[k].CopyFrom(self._to_expr(session, v)) else: for v in self._args: - plan.sql.pos_args.append( - LiteralExpression._from_value(v).to_plan(session).literal - ) + plan.sql.pos_arguments.append(self._to_expr(session, v)) return plan @@ -1073,14 +1075,10 @@ class SQL(LogicalPlan): if self._args is not None and len(self._args) > 0: if isinstance(self._args, Dict): for k, v in self._args.items(): - cmd.sql_command.args[k].CopyFrom( - LiteralExpression._from_value(v).to_plan(session).literal - ) + cmd.sql_command.named_arguments[k].CopyFrom(self._to_expr(session, v)) else: for v in self._args: - cmd.sql_command.pos_args.append( - LiteralExpression._from_value(v).to_plan(session).literal - ) + cmd.sql_command.pos_arguments.append(self._to_expr(session, v)) return cmd diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 7582fe86ff2..e5d1d95a699 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -557,7 +557,7 @@ class SparkSession: if "sql_command_result" in properties: return DataFrame.withPlan(CachedRelation(properties["sql_command_result"]), self) else: - return DataFrame.withPlan(SQL(sqlQuery, args), self) + return DataFrame.withPlan(cmd, self) sql.__doc__ = PySparkSession.sql.__doc__ diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 2b979570618..c5a127136d6 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -1237,13 +1237,17 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.assertEqual(1, len(pdf.index)) def test_sql_with_named_args(self): - df = self.connect.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7}) - df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > :minId", args={"minId": 7}) + sqlText = "SELECT *, element_at(:m, 'a') FROM range(10) WHERE id > :minId" + df = self.connect.sql( + sqlText, args={"minId": 7, "m": CF.create_map(CF.lit("a"), CF.lit(1))} + ) + df2 = self.spark.sql(sqlText, args={"minId": 7, "m": SF.create_map(SF.lit("a"), SF.lit(1))}) self.assert_eq(df.toPandas(), df2.toPandas()) def test_sql_with_pos_args(self): - df = self.connect.sql("SELECT * FROM range(10) WHERE id > ?", args=[7]) - df2 = self.spark.sql("SELECT * FROM range(10) WHERE id > ?", args=[7]) + sqlText = "SELECT *, element_at(?, 1) FROM range(10) WHERE id > ?" + df = self.connect.sql(sqlText, args=[CF.array(CF.lit(1)), 7]) + df2 = self.spark.sql(sqlText, args=[SF.array(SF.lit(1)), 7]) self.assert_eq(df.toPandas(), df2.toPandas()) def test_head(self): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org