This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 031864875674 [SPARK-47543][CONNECT][PYTHON] Inferring `dict` as
`MapType` from Pandas DataFrame to allow DataFrame creation
031864875674 is described below
commit 03186487567443d7457bfe5d32692501b96e9d90
Author: Haejoon Lee <[email protected]>
AuthorDate: Thu Mar 28 09:30:43 2024 +0900
[SPARK-47543][CONNECT][PYTHON] Inferring `dict` as `MapType` from Pandas
DataFrame to allow DataFrame creation
### What changes were proposed in this pull request?
This PR proposes to inferring `dict` as `MapType` from Pandas DataFrame to
allow DataFrame creation.
This PR also introduces new config `INFER_PANDAS_DICT_AS_MAP` to reach the
goal.
### Why are the changes needed?
Currently the PyArrow infers the Pandas dictionary field as `StructType`
instead of `MapType`:
```python
>>> pdf = pd.DataFrame({"str_col": ['second'], "dict_col": [{'first': 0.7,
'second': 0.3}]})
>>> pa.Schema.from_pandas(pdf)
str_col: string
dict_col: struct<first: double, second: double>
child 0, first: double
child 1, second: double
```
The problem is that this behavior make Spark cannot handle the schema
properly:
```python
>>> sdf = spark.createDataFrame(pdf)
>>> sdf.withColumn("test", F.col("dict_col")[F.col("str_col")]).show()
[INVALID_EXTRACT_FIELD_TYPE] Field name should be a non-null string
literal, but it's "str_col". SQLSTATE: 42000
```
### Does this PR introduce _any_ user-facing change?
No default behavior changes, but `createDataFrame` inferring `dict` as a
`MapType` from Pandas DataFrame when `INFER_PANDAS_DICT_AS_MAP` set to `True`:
**Before**
```python
>>> pdf = pd.DataFrame({"str_col": ['second'], "dict_col": [{'first': 0.7,
'second': 0.3}]})
>>> sdf = spark.createDataFrame(pdf)
>>> sdf.withColumn("test", F.col("dict_col")[F.col("str_col")]).show()
[INVALID_EXTRACT_FIELD_TYPE] Field name should be a non-null string
literal, but it's "str_col". SQLSTATE: 42000
```
**After**
```python
>>> spark.conf.set("spark.sql.execution.pandas.inferPandasDictAsMap", True)
>>> pdf = pd.DataFrame({"str_col": ['second'], "dict_col": [{'first': 0.7,
'second': 0.3}]})
>>> sdf = spark.createDataFrame(pdf)
>>> sdf.withColumn("test",
F.col("dict_col")[F.col("str_col")]).show(truncate=False)
+-------+-----------------------------+----+
|str_col|dict_col |test|
+-------+-----------------------------+----+
|second |{first -> 0.7, second -> 0.3}|0.3 |
+-------+-----------------------------+----+
```
### How was this patch tested?
Added UT.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #45699 from itholic/schema_issue.
Authored-by: Haejoon Lee <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../source/user_guide/sql/type_conversions.rst | 3 ++
python/pyspark/sql/connect/session.py | 31 ++++++++++++-
python/pyspark/sql/pandas/conversion.py | 46 ++++++++++++++++---
python/pyspark/sql/tests/test_creation.py | 52 ++++++++++++++++++++++
.../org/apache/spark/sql/internal/SQLConf.scala | 10 +++++
5 files changed, 134 insertions(+), 8 deletions(-)
diff --git a/python/docs/source/user_guide/sql/type_conversions.rst
b/python/docs/source/user_guide/sql/type_conversions.rst
index b63e7dfa8851..2f13701995ef 100644
--- a/python/docs/source/user_guide/sql/type_conversions.rst
+++ b/python/docs/source/user_guide/sql/type_conversions.rst
@@ -64,6 +64,9 @@ are listed below:
* - spark.sql.timestampType
- If set to `TIMESTAMP_NTZ`, the default timestamp type is
``TimestampNTZType``. Otherwise, the default timestamp type is TimestampType.
- ""
+ * - spark.sql.execution.pandas.inferPandasDictAsMap
+ - When enabled, Pandas dictionaries are inferred as MapType. Otherwise,
they are inferred as StructType.
+ - False
All Conversions
---------------
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 9de8579b12b4..8b7e403667cf 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -68,7 +68,12 @@ from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.streaming.readwriter import DataStreamReader
from pyspark.sql.connect.streaming.query import StreamingQueryManager
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
-from pyspark.sql.pandas.types import to_arrow_schema, to_arrow_type,
_deduplicate_field_names
+from pyspark.sql.pandas.types import (
+ to_arrow_schema,
+ to_arrow_type,
+ _deduplicate_field_names,
+ from_arrow_type,
+)
from pyspark.sql.profiler import Profile
from pyspark.sql.session import classproperty, SparkSession as PySparkSession
from pyspark.sql.types import (
@@ -81,6 +86,8 @@ from pyspark.sql.types import (
StructType,
AtomicType,
TimestampType,
+ MapType,
+ StringType,
)
from pyspark.sql.utils import to_str
from pyspark.errors import (
@@ -419,6 +426,28 @@ class SparkSession:
# If no schema supplied by user then get the names of columns only
if schema is None:
_cols = [str(x) if not isinstance(x, str) else x for x in
data.columns]
+ infer_pandas_dict_as_map = (
+
str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower()
+ == "true"
+ )
+ if infer_pandas_dict_as_map:
+ struct = StructType()
+ pa_schema = pa.Schema.from_pandas(data)
+ spark_type: Union[MapType, DataType]
+ for field in pa_schema:
+ field_type = field.type
+ if isinstance(field_type, pa.StructType):
+ if len(field_type) == 0:
+ raise PySparkValueError(
+ error_class="CANNOT_INFER_EMPTY_SCHEMA",
+ message_parameters={},
+ )
+ arrow_type = field_type.field(0).type
+ spark_type = MapType(StringType(),
from_arrow_type(arrow_type))
+ else:
+ spark_type = from_arrow_type(field_type)
+ struct.add(field.name, spark_type,
nullable=field.nullable)
+ schema = struct
elif isinstance(schema, (list, tuple)) and cast(int, _num_cols) <
len(data.columns):
assert isinstance(_cols, list)
_cols.extend([f"_{i + 1}" for i in range(cast(int, _num_cols),
len(data.columns))])
diff --git a/python/pyspark/sql/pandas/conversion.py
b/python/pyspark/sql/pandas/conversion.py
index d958b95795b7..891bab63b3da 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -32,10 +32,18 @@ from pyspark.loose_version import LooseVersion
from pyspark.rdd import _load_from_socket
from pyspark.sql.pandas.serializers import ArrowCollectSerializer
from pyspark.sql.pandas.types import _dedup_names
-from pyspark.sql.types import ArrayType, MapType, TimestampType, StructType,
DataType, _create_row
+from pyspark.sql.types import (
+ ArrayType,
+ MapType,
+ TimestampType,
+ StructType,
+ DataType,
+ _create_row,
+ StringType,
+)
from pyspark.sql.utils import is_timestamp_ntz_preferred
from pyspark.traceback_utils import SCCallSiteSync
-from pyspark.errors import PySparkTypeError
+from pyspark.errors import PySparkTypeError, PySparkValueError
if TYPE_CHECKING:
import numpy as np
@@ -600,15 +608,39 @@ class SparkConversionMixin:
)
import pyarrow as pa
+ infer_pandas_dict_as_map = (
+
str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower()
== "true"
+ )
+
# Create the Spark schema from list of names passed in with Arrow types
if isinstance(schema, (list, tuple)):
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
- struct = StructType()
prefer_timestamp_ntz = is_timestamp_ntz_preferred()
- for name, field in zip(schema, arrow_schema):
- struct.add(
- name, from_arrow_type(field.type, prefer_timestamp_ntz),
nullable=field.nullable
- )
+ struct = StructType()
+ if infer_pandas_dict_as_map:
+ spark_type: Union[MapType, DataType]
+ for name, field in zip(schema, arrow_schema):
+ field_type = field.type
+ if isinstance(field_type, pa.StructType):
+ if len(field_type) == 0:
+ raise PySparkValueError(
+ error_class="CANNOT_INFER_EMPTY_SCHEMA",
+ message_parameters={},
+ )
+ arrow_type = field_type.field(0).type
+ spark_type = MapType(
+ StringType(), from_arrow_type(arrow_type,
prefer_timestamp_ntz)
+ )
+ else:
+ spark_type = from_arrow_type(field_type)
+ struct.add(name, spark_type, nullable=field.nullable)
+ else:
+ for name, field in zip(schema, arrow_schema):
+ struct.add(
+ name,
+ from_arrow_type(field.type, prefer_timestamp_ntz),
+ nullable=field.nullable,
+ )
schema = struct
# Determine arrow types to coerce data when creating batches
diff --git a/python/pyspark/sql/tests/test_creation.py
b/python/pyspark/sql/tests/test_creation.py
index 272f7de2b4dc..cd1fd922a6a8 100644
--- a/python/pyspark/sql/tests/test_creation.py
+++ b/python/pyspark/sql/tests/test_creation.py
@@ -22,6 +22,7 @@ import time
import unittest
from pyspark.sql import Row
+import pyspark.sql.functions as F
from pyspark.sql.types import (
DateType,
TimestampType,
@@ -157,6 +158,57 @@ class DataFrameCreationTestsMixin:
message_parameters={},
)
+ def test_schema_inference_from_pandas_with_dict(self):
+ # SPARK-47543: test for verifying if inferring `dict` as `MapType`
work properly.
+ import pandas as pd
+
+ pdf = pd.DataFrame({"str_col": ["second"], "dict_col": [{"first": 0.7,
"second": 0.3}]})
+
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.pyspark.enabled": True,
+ "spark.sql.execution.arrow.pyspark.fallback.enabled": False,
+ "spark.sql.execution.pandas.inferPandasDictAsMap": True,
+ }
+ ):
+ sdf = self.spark.createDataFrame(pdf)
+ self.assertEqual(
+ sdf.withColumn("test",
F.col("dict_col")[F.col("str_col")]).collect(),
+ [Row(str_col="second", dict_col={"first": 0.7, "second": 0.3},
test=0.3)],
+ )
+
+ # Empty dict should fail
+ pdf_empty_struct = pd.DataFrame({"str_col": ["second"],
"dict_col": [{}]})
+
+ with self.assertRaises(PySparkValueError) as pe:
+ self.spark.createDataFrame(pdf_empty_struct)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="CANNOT_INFER_EMPTY_SCHEMA",
+ message_parameters={},
+ )
+
+ # Dict has different types of values should fail
+ pdf_different_type = pd.DataFrame(
+ {"str_col": ["second"], "dict_col": [{"first": 0.7, "second":
"0.3"}]}
+ )
+ self.assertRaises(
+ PySparkValueError, lambda:
self.spark.createDataFrame(pdf_different_type)
+ )
+
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.pyspark.enabled": False,
+ "spark.sql.execution.pandas.inferPandasDictAsMap": True,
+ }
+ ):
+ sdf = self.spark.createDataFrame(pdf)
+ self.assertEqual(
+ sdf.withColumn("test",
F.col("dict_col")[F.col("str_col")]).collect(),
+ [Row(str_col="second", dict_col={"first": 0.7, "second": 0.3},
test=0.3)],
+ )
+
class DataFrameCreationTests(
DataFrameCreationTestsMixin,
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 9a5e6b271a15..72831d8b32a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4409,6 +4409,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val INFER_PANDAS_DICT_AS_MAP =
buildConf("spark.sql.execution.pandas.inferPandasDictAsMap")
+ .doc("When true, spark.createDataFrame will infer dict from Pandas
DataFrame " +
+ "as a MapType. When false, spark.createDataFrame infers dict from Pandas
DataFrame " +
+ "as a StructType which is default inferring from PyArrow.")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
val LEGACY_INFER_ARRAY_TYPE_FROM_FIRST_ELEMENT =
buildConf("spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled")
.doc("PySpark's SparkSession.createDataFrame infers the element type of
an array from all " +
@@ -5601,6 +5609,8 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT)
+ def inferPandasDictAsMap: Boolean = getConf(SQLConf.INFER_PANDAS_DICT_AS_MAP)
+
def legacyInferArrayTypeFromFirstElement: Boolean = getConf(
SQLConf.LEGACY_INFER_ARRAY_TYPE_FROM_FIRST_ELEMENT)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]