This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 031864875674 [SPARK-47543][CONNECT][PYTHON] Inferring `dict` as 
`MapType` from Pandas DataFrame to allow DataFrame creation
031864875674 is described below

commit 03186487567443d7457bfe5d32692501b96e9d90
Author: Haejoon Lee <[email protected]>
AuthorDate: Thu Mar 28 09:30:43 2024 +0900

    [SPARK-47543][CONNECT][PYTHON] Inferring `dict` as `MapType` from Pandas 
DataFrame to allow DataFrame creation
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to inferring `dict` as `MapType` from Pandas DataFrame to 
allow DataFrame creation.
    
    This PR also introduces new config `INFER_PANDAS_DICT_AS_MAP` to reach the 
goal.
    
    ### Why are the changes needed?
    
    Currently the PyArrow infers the Pandas dictionary field as `StructType` 
instead of `MapType`:
    
    ```python
    >>> pdf = pd.DataFrame({"str_col": ['second'], "dict_col": [{'first': 0.7, 
'second': 0.3}]})
    >>> pa.Schema.from_pandas(pdf)
    str_col: string
    dict_col: struct<first: double, second: double>
      child 0, first: double
      child 1, second: double
    ```
    
    The problem is that this behavior make Spark cannot handle the schema 
properly:
    
    ```python
    >>> sdf = spark.createDataFrame(pdf)
    >>> sdf.withColumn("test", F.col("dict_col")[F.col("str_col")]).show()
    [INVALID_EXTRACT_FIELD_TYPE] Field name should be a non-null string 
literal, but it's "str_col". SQLSTATE: 42000
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No default behavior changes, but `createDataFrame` inferring `dict` as a 
`MapType` from Pandas DataFrame when `INFER_PANDAS_DICT_AS_MAP` set to `True`:
    
    **Before**
    ```python
    >>> pdf = pd.DataFrame({"str_col": ['second'], "dict_col": [{'first': 0.7, 
'second': 0.3}]})
    >>> sdf = spark.createDataFrame(pdf)
    >>> sdf.withColumn("test", F.col("dict_col")[F.col("str_col")]).show()
    [INVALID_EXTRACT_FIELD_TYPE] Field name should be a non-null string 
literal, but it's "str_col". SQLSTATE: 42000
    ```
    
    **After**
    ```python
    >>> spark.conf.set("spark.sql.execution.pandas.inferPandasDictAsMap", True)
    >>> pdf = pd.DataFrame({"str_col": ['second'], "dict_col": [{'first': 0.7, 
'second': 0.3}]})
    >>> sdf = spark.createDataFrame(pdf)
    >>> sdf.withColumn("test", 
F.col("dict_col")[F.col("str_col")]).show(truncate=False)
    +-------+-----------------------------+----+
    |str_col|dict_col                     |test|
    +-------+-----------------------------+----+
    |second |{first -> 0.7, second -> 0.3}|0.3 |
    +-------+-----------------------------+----+
    ```
    
    ### How was this patch tested?
    
    Added UT.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #45699 from itholic/schema_issue.
    
    Authored-by: Haejoon Lee <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../source/user_guide/sql/type_conversions.rst     |  3 ++
 python/pyspark/sql/connect/session.py              | 31 ++++++++++++-
 python/pyspark/sql/pandas/conversion.py            | 46 ++++++++++++++++---
 python/pyspark/sql/tests/test_creation.py          | 52 ++++++++++++++++++++++
 .../org/apache/spark/sql/internal/SQLConf.scala    | 10 +++++
 5 files changed, 134 insertions(+), 8 deletions(-)

diff --git a/python/docs/source/user_guide/sql/type_conversions.rst 
b/python/docs/source/user_guide/sql/type_conversions.rst
index b63e7dfa8851..2f13701995ef 100644
--- a/python/docs/source/user_guide/sql/type_conversions.rst
+++ b/python/docs/source/user_guide/sql/type_conversions.rst
@@ -64,6 +64,9 @@ are listed below:
     * - spark.sql.timestampType
       - If set to `TIMESTAMP_NTZ`, the default timestamp type is 
``TimestampNTZType``. Otherwise, the default timestamp type is TimestampType.
       - ""
+    * - spark.sql.execution.pandas.inferPandasDictAsMap
+      - When enabled, Pandas dictionaries are inferred as MapType. Otherwise, 
they are inferred as StructType.
+      - False
 
 All Conversions
 ---------------
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 9de8579b12b4..8b7e403667cf 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -68,7 +68,12 @@ from pyspark.sql.connect.readwriter import DataFrameReader
 from pyspark.sql.connect.streaming.readwriter import DataStreamReader
 from pyspark.sql.connect.streaming.query import StreamingQueryManager
 from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
-from pyspark.sql.pandas.types import to_arrow_schema, to_arrow_type, 
_deduplicate_field_names
+from pyspark.sql.pandas.types import (
+    to_arrow_schema,
+    to_arrow_type,
+    _deduplicate_field_names,
+    from_arrow_type,
+)
 from pyspark.sql.profiler import Profile
 from pyspark.sql.session import classproperty, SparkSession as PySparkSession
 from pyspark.sql.types import (
@@ -81,6 +86,8 @@ from pyspark.sql.types import (
     StructType,
     AtomicType,
     TimestampType,
+    MapType,
+    StringType,
 )
 from pyspark.sql.utils import to_str
 from pyspark.errors import (
@@ -419,6 +426,28 @@ class SparkSession:
             # If no schema supplied by user then get the names of columns only
             if schema is None:
                 _cols = [str(x) if not isinstance(x, str) else x for x in 
data.columns]
+                infer_pandas_dict_as_map = (
+                    
str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower()
+                    == "true"
+                )
+                if infer_pandas_dict_as_map:
+                    struct = StructType()
+                    pa_schema = pa.Schema.from_pandas(data)
+                    spark_type: Union[MapType, DataType]
+                    for field in pa_schema:
+                        field_type = field.type
+                        if isinstance(field_type, pa.StructType):
+                            if len(field_type) == 0:
+                                raise PySparkValueError(
+                                    error_class="CANNOT_INFER_EMPTY_SCHEMA",
+                                    message_parameters={},
+                                )
+                            arrow_type = field_type.field(0).type
+                            spark_type = MapType(StringType(), 
from_arrow_type(arrow_type))
+                        else:
+                            spark_type = from_arrow_type(field_type)
+                        struct.add(field.name, spark_type, 
nullable=field.nullable)
+                    schema = struct
             elif isinstance(schema, (list, tuple)) and cast(int, _num_cols) < 
len(data.columns):
                 assert isinstance(_cols, list)
                 _cols.extend([f"_{i + 1}" for i in range(cast(int, _num_cols), 
len(data.columns))])
diff --git a/python/pyspark/sql/pandas/conversion.py 
b/python/pyspark/sql/pandas/conversion.py
index d958b95795b7..891bab63b3da 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -32,10 +32,18 @@ from pyspark.loose_version import LooseVersion
 from pyspark.rdd import _load_from_socket
 from pyspark.sql.pandas.serializers import ArrowCollectSerializer
 from pyspark.sql.pandas.types import _dedup_names
-from pyspark.sql.types import ArrayType, MapType, TimestampType, StructType, 
DataType, _create_row
+from pyspark.sql.types import (
+    ArrayType,
+    MapType,
+    TimestampType,
+    StructType,
+    DataType,
+    _create_row,
+    StringType,
+)
 from pyspark.sql.utils import is_timestamp_ntz_preferred
 from pyspark.traceback_utils import SCCallSiteSync
-from pyspark.errors import PySparkTypeError
+from pyspark.errors import PySparkTypeError, PySparkValueError
 
 if TYPE_CHECKING:
     import numpy as np
@@ -600,15 +608,39 @@ class SparkConversionMixin:
         )
         import pyarrow as pa
 
+        infer_pandas_dict_as_map = (
+            
str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower() 
== "true"
+        )
+
         # Create the Spark schema from list of names passed in with Arrow types
         if isinstance(schema, (list, tuple)):
             arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
-            struct = StructType()
             prefer_timestamp_ntz = is_timestamp_ntz_preferred()
-            for name, field in zip(schema, arrow_schema):
-                struct.add(
-                    name, from_arrow_type(field.type, prefer_timestamp_ntz), 
nullable=field.nullable
-                )
+            struct = StructType()
+            if infer_pandas_dict_as_map:
+                spark_type: Union[MapType, DataType]
+                for name, field in zip(schema, arrow_schema):
+                    field_type = field.type
+                    if isinstance(field_type, pa.StructType):
+                        if len(field_type) == 0:
+                            raise PySparkValueError(
+                                error_class="CANNOT_INFER_EMPTY_SCHEMA",
+                                message_parameters={},
+                            )
+                        arrow_type = field_type.field(0).type
+                        spark_type = MapType(
+                            StringType(), from_arrow_type(arrow_type, 
prefer_timestamp_ntz)
+                        )
+                    else:
+                        spark_type = from_arrow_type(field_type)
+                    struct.add(name, spark_type, nullable=field.nullable)
+            else:
+                for name, field in zip(schema, arrow_schema):
+                    struct.add(
+                        name,
+                        from_arrow_type(field.type, prefer_timestamp_ntz),
+                        nullable=field.nullable,
+                    )
             schema = struct
 
         # Determine arrow types to coerce data when creating batches
diff --git a/python/pyspark/sql/tests/test_creation.py 
b/python/pyspark/sql/tests/test_creation.py
index 272f7de2b4dc..cd1fd922a6a8 100644
--- a/python/pyspark/sql/tests/test_creation.py
+++ b/python/pyspark/sql/tests/test_creation.py
@@ -22,6 +22,7 @@ import time
 import unittest
 
 from pyspark.sql import Row
+import pyspark.sql.functions as F
 from pyspark.sql.types import (
     DateType,
     TimestampType,
@@ -157,6 +158,57 @@ class DataFrameCreationTestsMixin:
             message_parameters={},
         )
 
+    def test_schema_inference_from_pandas_with_dict(self):
+        # SPARK-47543: test for verifying if inferring `dict` as `MapType` 
work properly.
+        import pandas as pd
+
+        pdf = pd.DataFrame({"str_col": ["second"], "dict_col": [{"first": 0.7, 
"second": 0.3}]})
+
+        with self.sql_conf(
+            {
+                "spark.sql.execution.arrow.pyspark.enabled": True,
+                "spark.sql.execution.arrow.pyspark.fallback.enabled": False,
+                "spark.sql.execution.pandas.inferPandasDictAsMap": True,
+            }
+        ):
+            sdf = self.spark.createDataFrame(pdf)
+            self.assertEqual(
+                sdf.withColumn("test", 
F.col("dict_col")[F.col("str_col")]).collect(),
+                [Row(str_col="second", dict_col={"first": 0.7, "second": 0.3}, 
test=0.3)],
+            )
+
+            # Empty dict should fail
+            pdf_empty_struct = pd.DataFrame({"str_col": ["second"], 
"dict_col": [{}]})
+
+            with self.assertRaises(PySparkValueError) as pe:
+                self.spark.createDataFrame(pdf_empty_struct)
+
+            self.check_error(
+                exception=pe.exception,
+                error_class="CANNOT_INFER_EMPTY_SCHEMA",
+                message_parameters={},
+            )
+
+            # Dict has different types of values should fail
+            pdf_different_type = pd.DataFrame(
+                {"str_col": ["second"], "dict_col": [{"first": 0.7, "second": 
"0.3"}]}
+            )
+            self.assertRaises(
+                PySparkValueError, lambda: 
self.spark.createDataFrame(pdf_different_type)
+            )
+
+        with self.sql_conf(
+            {
+                "spark.sql.execution.arrow.pyspark.enabled": False,
+                "spark.sql.execution.pandas.inferPandasDictAsMap": True,
+            }
+        ):
+            sdf = self.spark.createDataFrame(pdf)
+            self.assertEqual(
+                sdf.withColumn("test", 
F.col("dict_col")[F.col("str_col")]).collect(),
+                [Row(str_col="second", dict_col={"first": 0.7, "second": 0.3}, 
test=0.3)],
+            )
+
 
 class DataFrameCreationTests(
     DataFrameCreationTestsMixin,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 9a5e6b271a15..72831d8b32a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4409,6 +4409,14 @@ object SQLConf {
     .booleanConf
     .createWithDefault(false)
 
+  val INFER_PANDAS_DICT_AS_MAP = 
buildConf("spark.sql.execution.pandas.inferPandasDictAsMap")
+    .doc("When true, spark.createDataFrame will infer dict from Pandas 
DataFrame " +
+      "as a MapType. When false, spark.createDataFrame infers dict from Pandas 
DataFrame " +
+      "as a StructType which is default inferring from PyArrow.")
+    .version("4.0.0")
+    .booleanConf
+    .createWithDefault(false)
+
   val LEGACY_INFER_ARRAY_TYPE_FROM_FIRST_ELEMENT =
     
buildConf("spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled")
       .doc("PySpark's SparkSession.createDataFrame infers the element type of 
an array from all " +
@@ -5601,6 +5609,8 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT)
 
+  def inferPandasDictAsMap: Boolean = getConf(SQLConf.INFER_PANDAS_DICT_AS_MAP)
+
   def legacyInferArrayTypeFromFirstElement: Boolean = getConf(
     SQLConf.LEGACY_INFER_ARRAY_TYPE_FROM_FIRST_ELEMENT)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to