This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 26f4953  [SPARK-37516][PYTHON][SQL] Uses Python's standard string 
formatter for SQL API in PySpark
26f4953 is described below

commit 26f495370fb45071f52cde6fff199d7f4b674bc7
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Wed Dec 8 13:57:35 2021 +0900

    [SPARK-37516][PYTHON][SQL] Uses Python's standard string formatter for SQL 
API in PySpark
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to use [Python's standard string 
formatter](https://docs.python.org/3/library/string.html#custom-string-formatting)
 in `SparkSession.sql`, see also https://github.com/apache/spark/pull/34677.
    
    ### Why are the changes needed?
    
    To improve usability in PySpark. It works together with Python standard 
string formatter.
    
    ### Does this PR introduce _any_ user-facing change?
    
    By default, there is no user-facing change. If `kwargs` is specified, yes.
    
    1. Attribute supports from frame (standard Python support):
    
        ```python
        mydf = spark.range(10)
        spark.sql("SELECT {tbl.id}, {tbl[id]} FROM {tbl}", tbl=mydf)
        ```
    
    2. Understanding `DataFrame`:
    
        ```python
        mydf = spark.range(10)
        spark.sql("SELECT * FROM {tbl}", tbl=mydf)
        ```
    
    3. Understanding `Column`. (explicit column reference only):
    
        ```python
        mydf = spark.range(10)
        spark.sql("SELECT {c} FROM {tbl}", c=col("id"), tbl=mydf)
        ```
    
    4. Leveraging other Python string format:
    
        ```python
        mydf = spark.range(10)
        spark.sql(
            "SELECT {col} FROM {mydf} WHERE id IN {x}",
            col=mydf.id, mydf=mydf, x=tuple(range(4)))
        ```
    
    ### How was this patch tested?
    
    Doctests were added.
    
    Closes #34774 from HyukjinKwon/SPARK-37516.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/pandas/sql_formatter.py   | 10 ++--
 python/pyspark/pandas/tests/test_sql.py  |  4 --
 python/pyspark/sql/session.py            | 90 +++++++++++++++++++++++++++++---
 python/pyspark/sql/sql_formatter.py      | 84 +++++++++++++++++++++++++++++
 python/pyspark/sql/tests/test_session.py | 10 +++-
 5 files changed, 182 insertions(+), 16 deletions(-)

diff --git a/python/pyspark/pandas/sql_formatter.py 
b/python/pyspark/pandas/sql_formatter.py
index 685ee25..4ade2b9 100644
--- a/python/pyspark/pandas/sql_formatter.py
+++ b/python/pyspark/pandas/sql_formatter.py
@@ -163,7 +163,7 @@ def sql(
         return sql_processor.sql(query, index_col=index_col, **kwargs)
 
     session = default_session()
-    formatter = SQLStringFormatter(session)
+    formatter = PandasSQLStringFormatter(session)
     try:
         sdf = session.sql(formatter.format(query, **kwargs))
     finally:
@@ -178,7 +178,7 @@ def sql(
     )
 
 
-class SQLStringFormatter(string.Formatter):
+class PandasSQLStringFormatter(string.Formatter):
     """
     A standard ``string.Formatter`` in Python that can understand 
pandas-on-Spark instances
     with basic Python objects. This object has to be clear after the use for 
single SQL
@@ -191,7 +191,7 @@ class SQLStringFormatter(string.Formatter):
         self._ref_sers: List[Tuple[Series, str]] = []
 
     def vformat(self, format_string: str, args: Sequence[Any], kwargs: 
Mapping[str, Any]) -> str:
-        ret = super(SQLStringFormatter, self).vformat(format_string, args, 
kwargs)
+        ret = super(PandasSQLStringFormatter, self).vformat(format_string, 
args, kwargs)
 
         for ref, n in self._ref_sers:
             if not any((ref is v for v in df._pssers.values()) for df, _ in 
self._temp_views):
@@ -200,7 +200,7 @@ class SQLStringFormatter(string.Formatter):
         return ret
 
     def get_field(self, field_name: str, args: Sequence[Any], kwargs: 
Mapping[str, Any]) -> Any:
-        obj, first = super(SQLStringFormatter, self).get_field(field_name, 
args, kwargs)
+        obj, first = super(PandasSQLStringFormatter, 
self).get_field(field_name, args, kwargs)
         return self._convert_value(obj, field_name), first
 
     def _convert_value(self, val: Any, name: str) -> Optional[str]:
@@ -256,7 +256,7 @@ def _test() -> None:
     globs["ps"] = pyspark.pandas
     spark = (
         SparkSession.builder.master("local[4]")
-        .appName("pyspark.pandas.sql_processor tests")
+        .appName("pyspark.pandas.sql_formatter tests")
         .getOrCreate()
     )
     (failure_count, test_count) = doctest.testmod(
diff --git a/python/pyspark/pandas/tests/test_sql.py 
b/python/pyspark/pandas/tests/test_sql.py
index ca0dd99..5a5d6d4 100644
--- a/python/pyspark/pandas/tests/test_sql.py
+++ b/python/pyspark/pandas/tests/test_sql.py
@@ -26,10 +26,6 @@ class SQLTest(PandasOnSparkTestCase, SQLTestUtils):
         with self.assertRaisesRegex(KeyError, "variable_foo"):
             ps.sql("select * from {variable_foo}")
 
-    def test_error_unsupported_type(self):
-        with self.assertRaisesRegex(KeyError, "some_dict"):
-            ps.sql("select * from {some_dict}")
-
     def test_error_bad_sql(self):
         with self.assertRaises(ParseException):
             ps.sql("this is not valid sql")
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 586af62..6ff63bc 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -44,6 +44,7 @@ from pyspark.sql.conf import RuntimeConfig
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.pandas.conversion import SparkConversionMixin
 from pyspark.sql.readwriter import DataFrameReader
+from pyspark.sql.sql_formatter import SQLStringFormatter
 from pyspark.sql.streaming import DataStreamReader
 from pyspark.sql.types import (
     AtomicType,
@@ -924,23 +925,100 @@ class SparkSession(SparkConversionMixin):
         df._schema = struct
         return df
 
-    def sql(self, sqlQuery: str) -> DataFrame:
+    def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame:
         """Returns a :class:`DataFrame` representing the result of the given 
query.
+        When ``kwargs`` is specified, this method formats the given string by 
using the Python
+        standard formatter.
 
         .. versionadded:: 2.0.0
 
+        Parameters
+        ----------
+        sqlQuery : str
+            SQL query string.
+        kwargs : dict
+            Other variables that the user wants to set that can be referenced 
in the query
+
+            .. versionchanged:: 3.3.0
+               Added optional argument ``kwargs`` to specify the mapping of 
variables in the query.
+               This feature is experimental and unstable.
+
         Returns
         -------
         :class:`DataFrame`
 
         Examples
         --------
-        >>> df.createOrReplaceTempView("table1")
-        >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1")
-        >>> df2.collect()
-        [Row(f1=1, f2='row1'), Row(f1=2, f2='row2'), Row(f1=3, f2='row3')]
+        Executing a SQL query.
+
+        >>> spark.sql("SELECT * FROM range(10) where id > 7").show()
+        +---+
+        | id|
+        +---+
+        |  8|
+        |  9|
+        +---+
+
+        Executing a SQL query with variables as Python formatter standard.
+
+        >>> spark.sql(
+        ...     "SELECT * FROM range(10) WHERE id > {bound1} AND id < 
{bound2}", bound1=7, bound2=9
+        ... ).show()
+        +---+
+        | id|
+        +---+
+        |  8|
+        +---+
+
+        >>> mydf = spark.range(10)
+        >>> spark.sql(
+        ...     "SELECT {col} FROM {mydf} WHERE id IN {x}",
+        ...     col=mydf.id, mydf=mydf, x=tuple(range(4))).show()
+        +---+
+        | id|
+        +---+
+        |  0|
+        |  1|
+        |  2|
+        |  3|
+        +---+
+
+        >>> spark.sql('''
+        ...   SELECT m1.a, m2.b
+        ...   FROM {table1} m1 INNER JOIN {table2} m2
+        ...   ON m1.key = m2.key
+        ...   ORDER BY m1.a, m2.b''',
+        ...   table1=spark.createDataFrame([(1, "a"), (2, "b")], ["a", "key"]),
+        ...   table2=spark.createDataFrame([(3, "a"), (4, "b"), (5, "b")], 
["b", "key"])).show()
+        +---+---+
+        |  a|  b|
+        +---+---+
+        |  1|  3|
+        |  2|  4|
+        |  2|  5|
+        +---+---+
+
+        Also, it is possible to query using class:`Column` from 
:class:`DataFrame`.
+
+        >>> mydf = spark.createDataFrame([(1, 4), (2, 4), (3, 6)], ["A", "B"])
+        >>> spark.sql("SELECT {df.A}, {df[B]} FROM {df}", df=mydf).show()
+        +---+---+
+        |  A|  B|
+        +---+---+
+        |  1|  4|
+        |  2|  4|
+        |  3|  6|
+        +---+---+
         """
-        return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped)
+
+        formatter = SQLStringFormatter(self)
+        if len(kwargs) > 0:
+            sqlQuery = formatter.format(sqlQuery, **kwargs)
+        try:
+            return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped)
+        finally:
+            if len(kwargs) > 0:
+                formatter.clear()
 
     def table(self, tableName: str) -> DataFrame:
         """Returns the specified table as a :class:`DataFrame`.
diff --git a/python/pyspark/sql/sql_formatter.py 
b/python/pyspark/sql/sql_formatter.py
new file mode 100644
index 0000000..8528dd3
--- /dev/null
+++ b/python/pyspark/sql/sql_formatter.py
@@ -0,0 +1,84 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import string
+import typing
+from typing import Any, Optional, List, Tuple, Sequence, Mapping
+import uuid
+
+from py4j.java_gateway import is_instance_of
+
+if typing.TYPE_CHECKING:
+    from pyspark.sql import SparkSession, DataFrame
+from pyspark.sql.functions import lit
+
+
+class SQLStringFormatter(string.Formatter):
+    """
+    A standard ``string.Formatter`` in Python that can understand PySpark 
instances
+    with basic Python objects. This object has to be clear after the use for 
single SQL
+    query; cannot be reused across multiple SQL queries without cleaning.
+    """
+
+    def __init__(self, session: "SparkSession") -> None:
+        self._session: "SparkSession" = session
+        self._temp_views: List[Tuple[DataFrame, str]] = []
+
+    def get_field(self, field_name: str, args: Sequence[Any], kwargs: 
Mapping[str, Any]) -> Any:
+        obj, first = super(SQLStringFormatter, self).get_field(field_name, 
args, kwargs)
+        return self._convert_value(obj, field_name), first
+
+    def _convert_value(self, val: Any, field_name: str) -> Optional[str]:
+        """
+        Converts the given value into a SQL string.
+        """
+        from pyspark import SparkContext
+        from pyspark.sql import Column, DataFrame
+
+        if isinstance(val, Column):
+            assert SparkContext._gateway is not None  # type: 
ignore[attr-defined]
+
+            gw = SparkContext._gateway  # type: ignore[attr-defined]
+            jexpr = val._jc.expr()
+            if is_instance_of(
+                gw, jexpr, 
"org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute"
+            ) or is_instance_of(
+                gw, jexpr, 
"org.apache.spark.sql.catalyst.expressions.AttributeReference"
+            ):
+                return jexpr.sql()
+            else:
+                raise ValueError(
+                    "%s in %s should be a plain column reference such as 
`df.col` "
+                    "or `col('column')`" % (val, field_name)
+                )
+        elif isinstance(val, DataFrame):
+            for df, n in self._temp_views:
+                if df is val:
+                    return n
+            df_name = "_pyspark_%s" % str(uuid.uuid4()).replace("-", "")
+            self._temp_views.append((val, df_name))
+            val.createOrReplaceTempView(df_name)
+            return df_name
+        elif isinstance(val, str):
+            return lit(val)._jc.expr().sql()  # for escaped characters.
+        else:
+            return val
+
+    def clear(self) -> None:
+        for _, n in self._temp_views:
+            self._session.catalog.dropTempView(n)
+        self._temp_views = []
diff --git a/python/pyspark/sql/tests/test_session.py 
b/python/pyspark/sql/tests/test_session.py
index 84fa23d..1262e52 100644
--- a/python/pyspark/sql/tests/test_session.py
+++ b/python/pyspark/sql/tests/test_session.py
@@ -20,6 +20,7 @@ import unittest
 
 from pyspark import SparkConf, SparkContext
 from pyspark.sql import SparkSession, SQLContext, Row
+from pyspark.sql.functions import col
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 from pyspark.testing.utils import PySparkTestCase
 
@@ -93,7 +94,7 @@ class SparkSessionTests3(unittest.TestCase):
         active = SparkSession.getActiveSession()
         self.assertEqual(active, None)
 
-    def test_SparkSession(self):
+    def test_spark_session(self):
         spark = SparkSession.builder.master("local").config("some-config", 
"v2").getOrCreate()
         try:
             self.assertEqual(spark.conf.get("some-config"), "v2")
@@ -105,6 +106,13 @@ class SparkSessionTests3(unittest.TestCase):
             spark.sql("CREATE TABLE table1 (name STRING, age INT) USING 
parquet")
             self.assertEqual(spark.table("table1").columns, ["name", "age"])
             self.assertEqual(spark.range(3).count(), 3)
+
+            # SPARK-37516: Only plain column references work as variable in 
SQL.
+            self.assertEqual(
+                spark.sql("select {c} from range(1)", c=col("id")).first(), 
spark.range(1).first()
+            )
+            with self.assertRaisesRegex(ValueError, "Column"):
+                spark.sql("select {c} from range(10)", c=col("id") + 1)
         finally:
             spark.sql("DROP DATABASE test_db CASCADE")
             spark.stop()

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to