This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new bac7050cf0a Revert "[SPARK-41811][PYTHON][CONNECT] Implement SparkSession.sql's string formatter" bac7050cf0a is described below commit bac7050cf0ad18608e921f46e40152d341d53fb8 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Fri Jul 14 09:31:17 2023 +0900 Revert "[SPARK-41811][PYTHON][CONNECT] Implement SparkSession.sql's string formatter" This reverts commit 9aa42a970c4bd8e54603b1795a0f449bd556b11b. --- python/pyspark/pandas/sql_formatter.py | 7 +-- python/pyspark/sql/connect/session.py | 35 ++++--------- python/pyspark/sql/connect/sql_formatter.py | 78 ----------------------------- python/pyspark/sql/sql_formatter.py | 5 +- python/pyspark/sql/utils.py | 8 --- 5 files changed, 16 insertions(+), 117 deletions(-) diff --git a/python/pyspark/pandas/sql_formatter.py b/python/pyspark/pandas/sql_formatter.py index 7501e19c038..8593703bd94 100644 --- a/python/pyspark/pandas/sql_formatter.py +++ b/python/pyspark/pandas/sql_formatter.py @@ -264,9 +264,10 @@ class PandasSQLStringFormatter(string.Formatter): val._to_spark().createOrReplaceTempView(df_name) return df_name elif isinstance(val, str): - from pyspark.sql.utils import get_lit_sql_str - - return get_lit_sql_str(val) + # This is matched to behavior from JVM implementation. + # See `sql` definition from `sql/catalyst/src/main/scala/org/apache/spark/ + # sql/catalyst/expressions/literals.scala` + return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'" else: return val diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 13868263174..ea88d60d760 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -489,31 +489,13 @@ class SparkSession: createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__ - def sql( - self, - sqlQuery: str, - args: Optional[Union[Dict[str, Any], List]] = None, - **kwargs: Any, - ) -> "DataFrame": - - if len(kwargs) > 0: - from pyspark.sql.connect.sql_formatter import SQLStringFormatter - - formatter = SQLStringFormatter(self) - sqlQuery = formatter.format(sqlQuery, **kwargs) - - try: - cmd = SQL(sqlQuery, args) - data, properties = self.client.execute_command(cmd.command(self._client)) - if "sql_command_result" in properties: - return DataFrame.withPlan(CachedRelation(properties["sql_command_result"]), self) - else: - return DataFrame.withPlan(SQL(sqlQuery, args), self) - finally: - if len(kwargs) > 0: - # TODO: should drop temp views after SPARK-44406 get resolved - # formatter.clear() - pass + def sql(self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = None) -> "DataFrame": + cmd = SQL(sqlQuery, args) + data, properties = self.client.execute_command(cmd.command(self._client)) + if "sql_command_result" in properties: + return DataFrame.withPlan(CachedRelation(properties["sql_command_result"]), self) + else: + return DataFrame.withPlan(SQL(sqlQuery, args), self) sql.__doc__ = PySparkSession.sql.__doc__ @@ -826,6 +808,9 @@ def _test() -> None: # RDD API is not supported in Spark Connect. del pyspark.sql.connect.session.SparkSession.createDataFrame.__doc__ + # TODO(SPARK-41811): Implement SparkSession.sql's string formatter + del pyspark.sql.connect.session.SparkSession.sql.__doc__ + (failure_count, test_count) = doctest.testmod( pyspark.sql.connect.session, globs=globs, diff --git a/python/pyspark/sql/connect/sql_formatter.py b/python/pyspark/sql/connect/sql_formatter.py deleted file mode 100644 index ab90a1bb847..00000000000 --- a/python/pyspark/sql/connect/sql_formatter.py +++ /dev/null @@ -1,78 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import string -import typing -from typing import Any, Optional, List, Tuple, Sequence, Mapping -import uuid - -if typing.TYPE_CHECKING: - from pyspark.sql.connect.session import SparkSession - from pyspark.sql.connect.dataframe import DataFrame - - -class SQLStringFormatter(string.Formatter): - """ - A standard ``string.Formatter`` in Python that can understand PySpark instances - with basic Python objects. This object has to be clear after the use for single SQL - query; cannot be reused across multiple SQL queries without cleaning. - """ - - def __init__(self, session: "SparkSession") -> None: - self._session: "SparkSession" = session - self._temp_views: List[Tuple[DataFrame, str]] = [] - - def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any: - obj, first = super(SQLStringFormatter, self).get_field(field_name, args, kwargs) - return self._convert_value(obj, field_name), first - - def _convert_value(self, val: Any, field_name: str) -> Optional[str]: - """ - Converts the given value into a SQL string. - """ - from pyspark.sql.connect.dataframe import DataFrame - from pyspark.sql.connect.column import Column - from pyspark.sql.connect.expressions import ColumnReference - - if isinstance(val, Column): - expr = val._expr - if isinstance(expr, ColumnReference): - return expr._unparsed_identifier - else: - raise ValueError( - "%s in %s should be a plain column reference such as `df.col` " - "or `col('column')`" % (val, field_name) - ) - elif isinstance(val, DataFrame): - for df, n in self._temp_views: - if df is val: - return n - df_name = "_pyspark_connect_%s" % str(uuid.uuid4()).replace("-", "") - self._temp_views.append((val, df_name)) - val.createOrReplaceTempView(df_name) - return df_name - elif isinstance(val, str): - from pyspark.sql.utils import get_lit_sql_str - - return get_lit_sql_str(val) - else: - return val - - def clear(self) -> None: - for _, n in self._temp_views: - self._session.catalog.dropTempView(n) - self._temp_views = [] diff --git a/python/pyspark/sql/sql_formatter.py b/python/pyspark/sql/sql_formatter.py index fbaa6c46a26..5e79b9ff5ea 100644 --- a/python/pyspark/sql/sql_formatter.py +++ b/python/pyspark/sql/sql_formatter.py @@ -24,6 +24,7 @@ from py4j.java_gateway import is_instance_of if typing.TYPE_CHECKING: from pyspark.sql import SparkSession, DataFrame +from pyspark.sql.functions import lit class SQLStringFormatter(string.Formatter): @@ -73,9 +74,7 @@ class SQLStringFormatter(string.Formatter): val.createOrReplaceTempView(df_name) return df_name elif isinstance(val, str): - from pyspark.sql.utils import get_lit_sql_str - - return get_lit_sql_str(val) + return lit(val)._jc.expr().sql() # for escaped characters. else: return val diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index f2874ccb10e..608ed7e9ac9 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -294,11 +294,3 @@ def get_window_class() -> Type["Window"]: return ConnectWindow # type: ignore[return-value] else: return PySparkWindow - - -def get_lit_sql_str(val: str) -> str: - # Equivalent to `lit(val)._jc.expr().sql()` for string typed val - # This is matched to behavior from JVM implementation. - # See `sql` definition from `sql/catalyst/src/main/scala/org/apache/spark/ - # sql/catalyst/expressions/literals.scala` - return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org