This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new 3a18864 [SPARK-35809][PYTHON] Add `index_col` argument for ps.sql
3a18864 is described below
commit 3a18864c5fa063354108a4ebc2edf2e466bd628e
Author: itholic <[email protected]>
AuthorDate: Thu Jul 22 17:08:34 2021 +0900
[SPARK-35809][PYTHON] Add `index_col` argument for ps.sql
### What changes were proposed in this pull request?
This PR proposes adding an argument `index_col` for `ps.sql` function, to
preserve the index when users want.
NOTE that the `reset_index()` have to be performed before using `ps.sql`
with `index_col`.
```python
>>> psdf
A B
a 1 4
b 2 5
c 3 6
>>> psdf_reset_index = psdf.reset_index()
>>> ps.sql("SELECT * from {psdf_reset_index} WHERE A > 1",
index_col="index")
A B
index
b 2 5
c 3 6
```
Otherwise, the index is always lost.
```python
>>> ps.sql("SELECT * from {psdf} WHERE A > 1")
A B
0 2 5
1 3 6
```
### Why are the changes needed?
Index is one of the key object for the existing pandas users, so we should
provide the way to keep the index after computing the `ps.sql`.
### Does this PR introduce _any_ user-facing change?
Yes, the new argument is added.
### How was this patch tested?
Add a unit test and manually check the build pass.
Closes #33450 from itholic/SPARK-35809.
Authored-by: itholic <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 6578f0b135e1feee87eeb4b8b1bd1a3b6d9dcf0f)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/pandas/sql_processor.py | 58 ++++++++++++++++++++++++++++++---
python/pyspark/pandas/tests/test_sql.py | 26 +++++++++++++++
2 files changed, 79 insertions(+), 5 deletions(-)
diff --git a/python/pyspark/pandas/sql_processor.py
b/python/pyspark/pandas/sql_processor.py
index 757e664..8b2b8fd 100644
--- a/python/pyspark/pandas/sql_processor.py
+++ b/python/pyspark/pandas/sql_processor.py
@@ -16,7 +16,7 @@
#
import _string # type: ignore
-from typing import Any, Dict, Optional # noqa: F401 (SPARK-34943)
+from typing import Any, Dict, Optional, Union, List # noqa: F401 (SPARK-34943)
import inspect
import pandas as pd
@@ -26,6 +26,8 @@ from pyspark import pandas as ps # For running doctests and
reference resolutio
from pyspark.pandas.utils import default_session
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.series import Series
+from pyspark.pandas.internal import InternalFrame
+from pyspark.pandas.namespace import _get_index_map
__all__ = ["sql"]
@@ -36,6 +38,7 @@ from builtins import locals as builtin_locals
def sql(
query: str,
+ index_col: Optional[Union[str, List[str]]] = None,
globals: Optional[Dict[str, Any]] = None,
locals: Optional[Dict[str, Any]] = None,
**kwargs: Any
@@ -65,6 +68,44 @@ def sql(
----------
query : str
the SQL query
+ index_col : str or list of str, optional
+ Column names to be used in Spark to represent pandas-on-Spark's index.
The index name
+ in pandas-on-Spark is ignored. By default, the index is always lost.
+
+ .. note:: If you want to preserve the index, explicitly use
:func:`DataFrame.reset_index`,
+ and pass it to the sql statement with `index_col` parameter.
+
+ For example,
+
+ >>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]},
index=['a', 'b', 'c'])
+ >>> psdf_reset_index = psdf.reset_index()
+ >>> ps.sql("SELECT * FROM {psdf_reset_index}", index_col="index")
+ ... # doctest: +NORMALIZE_WHITESPACE
+ A B
+ index
+ a 1 4
+ b 2 5
+ c 3 6
+
+ For MultiIndex,
+
+ >>> psdf = ps.DataFrame(
+ ... {"A": [1, 2, 3], "B": [4, 5, 6]},
+ ... index=pd.MultiIndex.from_tuples(
+ ... [("a", "b"), ("c", "d"), ("e", "f")], names=["index1",
"index2"]
+ ... ),
+ ... )
+ >>> psdf_reset_index = psdf.reset_index()
+ >>> ps.sql("SELECT * FROM {psdf_reset_index}",
index_col=["index1", "index2"])
+ ... # doctest: +NORMALIZE_WHITESPACE
+ A B
+ index1 index2
+ a b 1 4
+ c d 2 5
+ e f 3 6
+
+ Also note that the index name(s) should be matched to the existing
name.
+
globals : dict, optional
the dictionary of global variables, if explicitly set by the user
locals : dict, optional
@@ -151,7 +192,7 @@ def sql(
_dict.update(_locals)
# Highest order of precedence is the locals
_dict.update(kwargs)
- return SQLProcessor(_dict, query, default_session()).execute()
+ return SQLProcessor(_dict, query, default_session()).execute(index_col)
_CAPTURE_SCOPES = 2
@@ -221,12 +262,12 @@ class SQLProcessor(object):
# The normalized form is typically a string
self._cached_vars = {} # type: Dict[str, Any]
# The SQL statement after:
- # - all the dataframes have been have been registered as temporary
views
+ # - all the dataframes have been registered as temporary views
# - all the values have been converted normalized to equivalent SQL
representations
self._normalized_statement = None # type: Optional[str]
self._session = session
- def execute(self) -> DataFrame:
+ def execute(self, index_col: Optional[Union[str, List[str]]]) -> DataFrame:
"""
Returns a DataFrame for which the SQL statement has been executed by
the underlying SQL engine.
@@ -260,7 +301,14 @@ class SQLProcessor(object):
finally:
for v in self._temp_views:
self._session.catalog.dropTempView(v)
- return DataFrame(sdf)
+
+ index_spark_columns, index_names = _get_index_map(sdf, index_col)
+
+ return DataFrame(
+ InternalFrame(
+ spark_frame=sdf, index_spark_columns=index_spark_columns,
index_names=index_names
+ )
+ )
def _convert(self, key: str) -> Any:
"""
diff --git a/python/pyspark/pandas/tests/test_sql.py
b/python/pyspark/pandas/tests/test_sql.py
index 4cf83b8..306ea16 100644
--- a/python/pyspark/pandas/tests/test_sql.py
+++ b/python/pyspark/pandas/tests/test_sql.py
@@ -37,6 +37,32 @@ class SQLTest(PandasOnSparkTestCase, SQLTestUtils):
with self.assertRaises(ParseException):
ps.sql("this is not valid sql")
+ def test_sql_with_index_col(self):
+ import pandas as pd
+
+ # Index
+ psdf = ps.DataFrame(
+ {"A": [1, 2, 3], "B": [4, 5, 6]}, index=pd.Index(["a", "b", "c"],
name="index")
+ )
+ psdf_reset_index = psdf.reset_index()
+ actual = ps.sql("select * from {psdf_reset_index} where A > 1",
index_col="index")
+ expected = psdf.iloc[[1, 2]]
+ self.assert_eq(actual, expected)
+
+ # MultiIndex
+ psdf = ps.DataFrame(
+ {"A": [1, 2, 3], "B": [4, 5, 6]},
+ index=pd.MultiIndex.from_tuples(
+ [("a", "b"), ("c", "d"), ("e", "f")], names=["index1",
"index2"]
+ ),
+ )
+ psdf_reset_index = psdf.reset_index()
+ actual = ps.sql(
+ "select * from {psdf_reset_index} where A > 1",
index_col=["index1", "index2"]
+ )
+ expected = psdf.iloc[[1, 2]]
+ self.assert_eq(actual, expected)
+
if __name__ == "__main__":
import unittest
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]