This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9aa42a970c4 [SPARK-41811][PYTHON][CONNECT] Implement SparkSession.sql's string formatter 9aa42a970c4 is described below commit 9aa42a970c4bd8e54603b1795a0f449bd556b11b Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Jul 13 17:58:00 2023 +0800 [SPARK-41811][PYTHON][CONNECT] Implement SparkSession.sql's string formatter ### What changes were proposed in this pull request? Implement SparkSession.sql's string formatter ### Why are the changes needed? for parity ### Does this PR introduce _any_ user-facing change? yes before: ``` In [1]: spark.createDataFrame([("Alice", 6), ("Bob", 7), ("John", 10)], ['name', 'age']).createOrReplaceTempView("person") In [2]: spark.sql("""SELECT * FROM person WHERE age < {age}""", age = 9).show() --------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[2], line 1 ----> 1 spark.sql("""SELECT * FROM person WHERE age < {age}""", age = 9).show() TypeError: sql() got an unexpected keyword argument 'age' ``` after: ``` In [1]: spark.createDataFrame([("Alice", 6), ("Bob", 7), ("John", 10)], ['name', 'age']).createOrReplaceTempView("person") In [2]: spark.sql("""SELECT * FROM person WHERE age < {age}""", age = 9).show() +-----+---+ | name|age| +-----+---+ |Alice| 6| | Bob| 7| +-----+---+ ``` ### How was this patch tested? enabled doc test Closes #41980 from zhengruifeng/py_connect_sql_formatter. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/pandas/sql_formatter.py | 7 ++--- python/pyspark/sql/connect/session.py | 35 ++++++++++++++++------- python/pyspark/sql/{ => connect}/sql_formatter.py | 30 ++++++++----------- python/pyspark/sql/sql_formatter.py | 5 ++-- python/pyspark/sql/utils.py | 8 ++++++ 5 files changed, 51 insertions(+), 34 deletions(-) diff --git a/python/pyspark/pandas/sql_formatter.py b/python/pyspark/pandas/sql_formatter.py index 8593703bd94..7501e19c038 100644 --- a/python/pyspark/pandas/sql_formatter.py +++ b/python/pyspark/pandas/sql_formatter.py @@ -264,10 +264,9 @@ class PandasSQLStringFormatter(string.Formatter): val._to_spark().createOrReplaceTempView(df_name) return df_name elif isinstance(val, str): - # This is matched to behavior from JVM implementation. - # See `sql` definition from `sql/catalyst/src/main/scala/org/apache/spark/ - # sql/catalyst/expressions/literals.scala` - return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'" + from pyspark.sql.utils import get_lit_sql_str + + return get_lit_sql_str(val) else: return val diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index ea88d60d760..13868263174 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -489,13 +489,31 @@ class SparkSession: createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__ - def sql(self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = None) -> "DataFrame": - cmd = SQL(sqlQuery, args) - data, properties = self.client.execute_command(cmd.command(self._client)) - if "sql_command_result" in properties: - return DataFrame.withPlan(CachedRelation(properties["sql_command_result"]), self) - else: - return DataFrame.withPlan(SQL(sqlQuery, args), self) + def sql( + self, + sqlQuery: str, + args: Optional[Union[Dict[str, Any], List]] = None, + **kwargs: Any, + ) -> "DataFrame": + + if len(kwargs) > 0: + from pyspark.sql.connect.sql_formatter import SQLStringFormatter + + formatter = SQLStringFormatter(self) + sqlQuery = formatter.format(sqlQuery, **kwargs) + + try: + cmd = SQL(sqlQuery, args) + data, properties = self.client.execute_command(cmd.command(self._client)) + if "sql_command_result" in properties: + return DataFrame.withPlan(CachedRelation(properties["sql_command_result"]), self) + else: + return DataFrame.withPlan(SQL(sqlQuery, args), self) + finally: + if len(kwargs) > 0: + # TODO: should drop temp views after SPARK-44406 get resolved + # formatter.clear() + pass sql.__doc__ = PySparkSession.sql.__doc__ @@ -808,9 +826,6 @@ def _test() -> None: # RDD API is not supported in Spark Connect. del pyspark.sql.connect.session.SparkSession.createDataFrame.__doc__ - # TODO(SPARK-41811): Implement SparkSession.sql's string formatter - del pyspark.sql.connect.session.SparkSession.sql.__doc__ - (failure_count, test_count) = doctest.testmod( pyspark.sql.connect.session, globs=globs, diff --git a/python/pyspark/sql/sql_formatter.py b/python/pyspark/sql/connect/sql_formatter.py similarity index 76% copy from python/pyspark/sql/sql_formatter.py copy to python/pyspark/sql/connect/sql_formatter.py index 5e79b9ff5ea..ab90a1bb847 100644 --- a/python/pyspark/sql/sql_formatter.py +++ b/python/pyspark/sql/connect/sql_formatter.py @@ -20,11 +20,9 @@ import typing from typing import Any, Optional, List, Tuple, Sequence, Mapping import uuid -from py4j.java_gateway import is_instance_of - if typing.TYPE_CHECKING: - from pyspark.sql import SparkSession, DataFrame -from pyspark.sql.functions import lit + from pyspark.sql.connect.session import SparkSession + from pyspark.sql.connect.dataframe import DataFrame class SQLStringFormatter(string.Formatter): @@ -46,20 +44,14 @@ class SQLStringFormatter(string.Formatter): """ Converts the given value into a SQL string. """ - from pyspark import SparkContext - from pyspark.sql import Column, DataFrame + from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.connect.column import Column + from pyspark.sql.connect.expressions import ColumnReference if isinstance(val, Column): - assert SparkContext._gateway is not None - - gw = SparkContext._gateway - jexpr = val._jc.expr() - if is_instance_of( - gw, jexpr, "org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute" - ) or is_instance_of( - gw, jexpr, "org.apache.spark.sql.catalyst.expressions.AttributeReference" - ): - return jexpr.sql() + expr = val._expr + if isinstance(expr, ColumnReference): + return expr._unparsed_identifier else: raise ValueError( "%s in %s should be a plain column reference such as `df.col` " @@ -69,12 +61,14 @@ class SQLStringFormatter(string.Formatter): for df, n in self._temp_views: if df is val: return n - df_name = "_pyspark_%s" % str(uuid.uuid4()).replace("-", "") + df_name = "_pyspark_connect_%s" % str(uuid.uuid4()).replace("-", "") self._temp_views.append((val, df_name)) val.createOrReplaceTempView(df_name) return df_name elif isinstance(val, str): - return lit(val)._jc.expr().sql() # for escaped characters. + from pyspark.sql.utils import get_lit_sql_str + + return get_lit_sql_str(val) else: return val diff --git a/python/pyspark/sql/sql_formatter.py b/python/pyspark/sql/sql_formatter.py index 5e79b9ff5ea..fbaa6c46a26 100644 --- a/python/pyspark/sql/sql_formatter.py +++ b/python/pyspark/sql/sql_formatter.py @@ -24,7 +24,6 @@ from py4j.java_gateway import is_instance_of if typing.TYPE_CHECKING: from pyspark.sql import SparkSession, DataFrame -from pyspark.sql.functions import lit class SQLStringFormatter(string.Formatter): @@ -74,7 +73,9 @@ class SQLStringFormatter(string.Formatter): val.createOrReplaceTempView(df_name) return df_name elif isinstance(val, str): - return lit(val)._jc.expr().sql() # for escaped characters. + from pyspark.sql.utils import get_lit_sql_str + + return get_lit_sql_str(val) else: return val diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 608ed7e9ac9..f2874ccb10e 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -294,3 +294,11 @@ def get_window_class() -> Type["Window"]: return ConnectWindow # type: ignore[return-value] else: return PySparkWindow + + +def get_lit_sql_str(val: str) -> str: + # Equivalent to `lit(val)._jc.expr().sql()` for string typed val + # This is matched to behavior from JVM implementation. + # See `sql` definition from `sql/catalyst/src/main/scala/org/apache/spark/ + # sql/catalyst/expressions/literals.scala` + return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org