This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 3a18864  [SPARK-35809][PYTHON] Add `index_col` argument for ps.sql
3a18864 is described below

commit 3a18864c5fa063354108a4ebc2edf2e466bd628e
Author: itholic <[email protected]>
AuthorDate: Thu Jul 22 17:08:34 2021 +0900

    [SPARK-35809][PYTHON] Add `index_col` argument for ps.sql
    
    ### What changes were proposed in this pull request?
    
    This PR proposes adding an argument `index_col` for `ps.sql` function, to 
preserve the index when users want.
    
    NOTE that the `reset_index()` have to be performed before using `ps.sql` 
with `index_col`.
    
    ```python
    >>> psdf
       A  B
    a  1  4
    b  2  5
    c  3  6
    >>> psdf_reset_index = psdf.reset_index()
    >>> ps.sql("SELECT * from {psdf_reset_index} WHERE A > 1", 
index_col="index")
           A  B
    index
    b      2  5
    c      3  6
    ```
    
    Otherwise, the index is always lost.
    
    ```python
    >>> ps.sql("SELECT * from {psdf} WHERE A > 1")
       A  B
    0  2  5
    1  3  6
    ```
    
    ### Why are the changes needed?
    
    Index is one of the key object for the existing pandas users, so we should 
provide the way to keep the index after computing the `ps.sql`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, the new argument is added.
    
    ### How was this patch tested?
    
    Add a unit test and manually check the build pass.
    
    Closes #33450 from itholic/SPARK-35809.
    
    Authored-by: itholic <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit 6578f0b135e1feee87eeb4b8b1bd1a3b6d9dcf0f)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/pandas/sql_processor.py  | 58 ++++++++++++++++++++++++++++++---
 python/pyspark/pandas/tests/test_sql.py | 26 +++++++++++++++
 2 files changed, 79 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/pandas/sql_processor.py 
b/python/pyspark/pandas/sql_processor.py
index 757e664..8b2b8fd 100644
--- a/python/pyspark/pandas/sql_processor.py
+++ b/python/pyspark/pandas/sql_processor.py
@@ -16,7 +16,7 @@
 #
 
 import _string  # type: ignore
-from typing import Any, Dict, Optional  # noqa: F401 (SPARK-34943)
+from typing import Any, Dict, Optional, Union, List  # noqa: F401 (SPARK-34943)
 import inspect
 import pandas as pd
 
@@ -26,6 +26,8 @@ from pyspark import pandas as ps  # For running doctests and 
reference resolutio
 from pyspark.pandas.utils import default_session
 from pyspark.pandas.frame import DataFrame
 from pyspark.pandas.series import Series
+from pyspark.pandas.internal import InternalFrame
+from pyspark.pandas.namespace import _get_index_map
 
 
 __all__ = ["sql"]
@@ -36,6 +38,7 @@ from builtins import locals as builtin_locals
 
 def sql(
     query: str,
+    index_col: Optional[Union[str, List[str]]] = None,
     globals: Optional[Dict[str, Any]] = None,
     locals: Optional[Dict[str, Any]] = None,
     **kwargs: Any
@@ -65,6 +68,44 @@ def sql(
     ----------
     query : str
         the SQL query
+    index_col : str or list of str, optional
+        Column names to be used in Spark to represent pandas-on-Spark's index. 
The index name
+        in pandas-on-Spark is ignored. By default, the index is always lost.
+
+        .. note:: If you want to preserve the index, explicitly use 
:func:`DataFrame.reset_index`,
+            and pass it to the sql statement with `index_col` parameter.
+
+            For example,
+
+            >>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, 
index=['a', 'b', 'c'])
+            >>> psdf_reset_index = psdf.reset_index()
+            >>> ps.sql("SELECT * FROM {psdf_reset_index}", index_col="index")
+            ... # doctest: +NORMALIZE_WHITESPACE
+                   A  B
+            index
+            a      1  4
+            b      2  5
+            c      3  6
+
+            For MultiIndex,
+
+            >>> psdf = ps.DataFrame(
+            ...     {"A": [1, 2, 3], "B": [4, 5, 6]},
+            ...     index=pd.MultiIndex.from_tuples(
+            ...         [("a", "b"), ("c", "d"), ("e", "f")], names=["index1", 
"index2"]
+            ...     ),
+            ... )
+            >>> psdf_reset_index = psdf.reset_index()
+            >>> ps.sql("SELECT * FROM {psdf_reset_index}", 
index_col=["index1", "index2"])
+            ... # doctest: +NORMALIZE_WHITESPACE
+                           A  B
+            index1 index2
+            a      b       1  4
+            c      d       2  5
+            e      f       3  6
+
+            Also note that the index name(s) should be matched to the existing 
name.
+
     globals : dict, optional
         the dictionary of global variables, if explicitly set by the user
     locals : dict, optional
@@ -151,7 +192,7 @@ def sql(
     _dict.update(_locals)
     # Highest order of precedence is the locals
     _dict.update(kwargs)
-    return SQLProcessor(_dict, query, default_session()).execute()
+    return SQLProcessor(_dict, query, default_session()).execute(index_col)
 
 
 _CAPTURE_SCOPES = 2
@@ -221,12 +262,12 @@ class SQLProcessor(object):
         # The normalized form is typically a string
         self._cached_vars = {}  # type: Dict[str, Any]
         # The SQL statement after:
-        # - all the dataframes have been have been registered as temporary 
views
+        # - all the dataframes have been registered as temporary views
         # - all the values have been converted normalized to equivalent SQL 
representations
         self._normalized_statement = None  # type: Optional[str]
         self._session = session
 
-    def execute(self) -> DataFrame:
+    def execute(self, index_col: Optional[Union[str, List[str]]]) -> DataFrame:
         """
         Returns a DataFrame for which the SQL statement has been executed by
         the underlying SQL engine.
@@ -260,7 +301,14 @@ class SQLProcessor(object):
         finally:
             for v in self._temp_views:
                 self._session.catalog.dropTempView(v)
-        return DataFrame(sdf)
+
+        index_spark_columns, index_names = _get_index_map(sdf, index_col)
+
+        return DataFrame(
+            InternalFrame(
+                spark_frame=sdf, index_spark_columns=index_spark_columns, 
index_names=index_names
+            )
+        )
 
     def _convert(self, key: str) -> Any:
         """
diff --git a/python/pyspark/pandas/tests/test_sql.py 
b/python/pyspark/pandas/tests/test_sql.py
index 4cf83b8..306ea16 100644
--- a/python/pyspark/pandas/tests/test_sql.py
+++ b/python/pyspark/pandas/tests/test_sql.py
@@ -37,6 +37,32 @@ class SQLTest(PandasOnSparkTestCase, SQLTestUtils):
         with self.assertRaises(ParseException):
             ps.sql("this is not valid sql")
 
+    def test_sql_with_index_col(self):
+        import pandas as pd
+
+        # Index
+        psdf = ps.DataFrame(
+            {"A": [1, 2, 3], "B": [4, 5, 6]}, index=pd.Index(["a", "b", "c"], 
name="index")
+        )
+        psdf_reset_index = psdf.reset_index()
+        actual = ps.sql("select * from {psdf_reset_index} where A > 1", 
index_col="index")
+        expected = psdf.iloc[[1, 2]]
+        self.assert_eq(actual, expected)
+
+        # MultiIndex
+        psdf = ps.DataFrame(
+            {"A": [1, 2, 3], "B": [4, 5, 6]},
+            index=pd.MultiIndex.from_tuples(
+                [("a", "b"), ("c", "d"), ("e", "f")], names=["index1", 
"index2"]
+            ),
+        )
+        psdf_reset_index = psdf.reset_index()
+        actual = ps.sql(
+            "select * from {psdf_reset_index} where A > 1", 
index_col=["index1", "index2"]
+        )
+        expected = psdf.iloc[[1, 2]]
+        self.assert_eq(actual, expected)
+
 
 if __name__ == "__main__":
     import unittest

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to