This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 68c3354267d [SPARK-41810][CONNECT] Infer names from a list of
dictionaries in SparkSession.createDataFrame
68c3354267d is described below
commit 68c3354267d30a96765a6592243205957d2cddf1
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Mon Jan 2 21:24:45 2023 +0900
[SPARK-41810][CONNECT] Infer names from a list of dictionaries in
SparkSession.createDataFrame
### What changes were proposed in this pull request?
This PR proposes to support to infer field names when the input data is the
list of dictionaries in `SparkSession.createDataFrame`.
For example,
```python
spark.createDataFrame([{"course": "dotNET", "earnings": 10000, "year":
2012}]).show()
```
**Before**:
```
+------+-----+----+
| _1| _2| _3|
+------+-----+----+
|dotNET|10000|2012|
+------+-----+----+
```
**After**:
```
+------+--------+----+
|course|earnings|year|
+------+--------+----+
|dotNET| 10000|2012|
+------+--------+----+
```
### Why are the changes needed?
To match the behaviour with the regular PySpark.
### Does this PR introduce _any_ user-facing change?
No to end users because Spark Connect has not been released.
### How was this patch tested?
Unittest was added.
Closes #39344 from HyukjinKwon/SPARK-41746.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/session.py | 16 +++++++++------
.../sql/tests/connect/test_connect_basic.py | 24 ++++++++++++----------
2 files changed, 23 insertions(+), 17 deletions(-)
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index a461372c08c..0233bde1c17 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -218,15 +218,21 @@ class SparkSession:
else:
_data = list(data)
- pdf = pd.DataFrame(_data)
- if _schema is None and isinstance(_data[0], Row):
+ if _schema is None and (isinstance(_data[0], Row) or
isinstance(_data[0], dict)):
+ if isinstance(_data[0], dict):
+ # Sort the data to respect inferred schema.
+ # For dictionaries, we sort the schema in alphabetical
order.
+ _data = [dict(sorted(d.items())) for d in _data]
+
_schema = self._inferSchemaFromList(_data, _cols)
if _cols is not None:
for i, name in enumerate(_cols):
_schema.fields[i].name = name
_schema.names[i] = name
+ pdf = pd.DataFrame(_data)
+
if _cols is None:
_cols = ["_%s" % i for i in range(1, pdf.shape[1] + 1)]
@@ -342,11 +348,9 @@ def _test() -> None:
# Spark Connect does not support to set master together.
pyspark.sql.connect.session.SparkSession.__doc__ = None
del pyspark.sql.connect.session.SparkSession.Builder.master.__doc__
-
- # TODO(SPARK-41746): SparkSession.createDataFrame does not respect the
column names in
- # dictionary
+ # RDD API is not supported in Spark Connect.
del pyspark.sql.connect.session.SparkSession.createDataFrame.__doc__
- del pyspark.sql.connect.session.SparkSession.read.__doc__
+
# TODO(SPARK-41811): Implement SparkSession.sql's string formatter
del pyspark.sql.connect.session.SparkSession.sql.__doc__
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 6a65e412dfd..7c17c5f6820 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -389,8 +389,8 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.connect.createDataFrame(data, "col1 int, col2 int, col3
int").show()
def test_with_local_rows(self):
- # SPARK-41789: Test creating a dataframe with list of Rows
- data = [
+ # SPARK-41789, SPARK-41810: Test creating a dataframe with list of
rows and dictionaries
+ rows = [
Row(course="dotNET", year=2012, earnings=10000),
Row(course="Java", year=2012, earnings=20000),
Row(course="dotNET", year=2012, earnings=5000),
@@ -398,19 +398,21 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
Row(course="Java", year=2013, earnings=30000),
Row(course="Scala", year=2022, earnings=None),
]
+ dicts = [row.asDict() for row in rows]
- sdf = self.spark.createDataFrame(data)
- cdf = self.connect.createDataFrame(data)
+ for data in [rows, dicts]:
+ sdf = self.spark.createDataFrame(data)
+ cdf = self.connect.createDataFrame(data)
- self.assertEqual(sdf.schema, cdf.schema)
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
+ self.assertEqual(sdf.schema, cdf.schema)
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
- # test with rename
- sdf = self.spark.createDataFrame(data, schema=["a", "b", "c"])
- cdf = self.connect.createDataFrame(data, schema=["a", "b", "c"])
+ # test with rename
+ sdf = self.spark.createDataFrame(data, schema=["a", "b", "c"])
+ cdf = self.connect.createDataFrame(data, schema=["a", "b", "c"])
- self.assertEqual(sdf.schema, cdf.schema)
- self.assert_eq(sdf.toPandas(), cdf.toPandas())
+ self.assertEqual(sdf.schema, cdf.schema)
+ self.assert_eq(sdf.toPandas(), cdf.toPandas())
def test_with_atom_type(self):
for data in [[(1), (2), (3)], [1, 2, 3]]:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]