This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 17a39d08ce6 [SPARK-39168][PYTHON] Use all values in a python list when 
inferring ArrayType schema
17a39d08ce6 is described below

commit 17a39d08ce6973866af7ed51ed18e0ea49e73e22
Author: Brian Schaefer <physi...@gmail.com>
AuthorDate: Thu May 19 10:06:10 2022 +0900

    [SPARK-39168][PYTHON] Use all values in a python list when inferring 
ArrayType schema
    
    ### What changes were proposed in this pull request?
    This PR modifies type inference for python lists to consider all values in 
the list, not just the first value.
    
    ### Why are the changes needed?
    This enables convenient type inference in the two following cases:
    |  | previous | current |
    | --- | --- | --- |
    | `[None, 1]` | `array<void>` (raises `ValueError`) | `array<bigint>` |
    | `[{"b": 1}, {"c": 2}]` | `array<struct<b:bigint>>`  | 
`array<struct<b:bigint,c:bigint>>`
    
    ### Does this PR introduce _any_ user-facing change?
    Possible user-facing changes:
    * attempting to infer the schema of an array with mixed types (e.g. `["a", 
1]`) may result in a TypeError
       * previously, this was inferred as an `array<string>` and produced a 
value `["a", "1"]`
    * fields of inferred struct types will differ if dictionaries in an array 
have different keys
    
    ### How was this patch tested?
    Added unit tests for various cases.
    
    Closes #36545 from physinet/infer_schema_from_full_array.
    
    Authored-by: Brian Schaefer <physi...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/docs/source/migration_guide/index.rst       |  1 +
 .../source/migration_guide/pyspark_3.3_to_3.4.rst  | 23 +++++++
 python/pyspark/sql/session.py                      | 15 ++++-
 python/pyspark/sql/tests/test_types.py             | 75 ++++++++++++++++++++++
 python/pyspark/sql/types.py                        | 55 ++++++++++++++--
 .../org/apache/spark/sql/internal/SQLConf.scala    | 12 ++++
 6 files changed, 173 insertions(+), 8 deletions(-)

diff --git a/python/docs/source/migration_guide/index.rst 
b/python/docs/source/migration_guide/index.rst
index 2e61653a9a5..43c836fecba 100644
--- a/python/docs/source/migration_guide/index.rst
+++ b/python/docs/source/migration_guide/index.rst
@@ -25,6 +25,7 @@ This page describes the migration guide specific to PySpark.
 .. toctree::
    :maxdepth: 2
 
+   pyspark_3.3_to_3.4
    pyspark_3.2_to_3.3
    pyspark_3.1_to_3.2
    pyspark_2.4_to_3.0
diff --git a/python/docs/source/migration_guide/pyspark_3.3_to_3.4.rst 
b/python/docs/source/migration_guide/pyspark_3.3_to_3.4.rst
new file mode 100644
index 00000000000..9f8cf545e28
--- /dev/null
+++ b/python/docs/source/migration_guide/pyspark_3.3_to_3.4.rst
@@ -0,0 +1,23 @@
+..  Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+..    http://www.apache.org/licenses/LICENSE-2.0
+
+..  Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+
+=================================
+Upgrading from PySpark 3.3 to 3.4
+=================================
+
+* In Spark 3.4, the schema of an array column is inferred by merging the 
schemas of all elements in the array. To restore the previous behavior where 
the schema is only inferred from the first element, you can set 
``spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled`` to ``true``.
\ No newline at end of file
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 688d67d10d3..86e99aec819 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -570,10 +570,20 @@ class SparkSession(SparkConversionMixin):
         if not data:
             raise ValueError("can not infer schema from empty dataset")
         infer_dict_as_struct = self._jconf.inferDictAsStruct()
+        infer_array_from_first_element = 
self._jconf.legacyInferArrayTypeFromFirstElement()
         prefer_timestamp_ntz = is_timestamp_ntz_preferred()
         schema = reduce(
             _merge_type,
-            (_infer_schema(row, names, infer_dict_as_struct, 
prefer_timestamp_ntz) for row in data),
+            (
+                _infer_schema(
+                    row,
+                    names,
+                    infer_dict_as_struct=infer_dict_as_struct,
+                    
infer_array_from_first_element=infer_array_from_first_element,
+                    prefer_timestamp_ntz=prefer_timestamp_ntz,
+                )
+                for row in data
+            ),
         )
         if _has_nulltype(schema):
             raise ValueError("Some of types cannot be determined after 
inferring")
@@ -605,6 +615,7 @@ class SparkSession(SparkConversionMixin):
             raise ValueError("The first row in RDD is empty, " "can not infer 
schema")
 
         infer_dict_as_struct = self._jconf.inferDictAsStruct()
+        infer_array_from_first_element = 
self._jconf.legacyInferArrayTypeFromFirstElement()
         prefer_timestamp_ntz = is_timestamp_ntz_preferred()
         if samplingRatio is None:
             schema = _infer_schema(
@@ -621,6 +632,7 @@ class SparkSession(SparkConversionMixin):
                             row,
                             names=names,
                             infer_dict_as_struct=infer_dict_as_struct,
+                            
infer_array_from_first_element=infer_array_from_first_element,
                             prefer_timestamp_ntz=prefer_timestamp_ntz,
                         ),
                     )
@@ -639,6 +651,7 @@ class SparkSession(SparkConversionMixin):
                     row,
                     names,
                     infer_dict_as_struct=infer_dict_as_struct,
+                    
infer_array_from_first_element=infer_array_from_first_element,
                     prefer_timestamp_ntz=prefer_timestamp_ntz,
                 )
             ).reduce(_merge_type)
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index d9ad2344ac5..ef0ad82dbb9 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -285,6 +285,81 @@ class TypesTests(ReusedSQLTestCase):
             df = self.spark.createDataFrame(data)
             self.assertEqual(Row(f1=[Row(payment=200.5, name="A")], f2=[1, 
2]), df.first())
 
+    def test_infer_array_merge_element_types(self):
+        # SPARK-39168: Test inferring array element type from all values in 
array
+        ArrayRow = Row("f1", "f2")
+
+        data = [ArrayRow([1, None], [None, 2])]
+
+        rdd = self.sc.parallelize(data)
+        df = self.spark.createDataFrame(rdd)
+        self.assertEqual(Row(f1=[1, None], f2=[None, 2]), df.first())
+
+        df = self.spark.createDataFrame(data)
+        self.assertEqual(Row(f1=[1, None], f2=[None, 2]), df.first())
+
+        # Test legacy behavior inferring only from the first element
+        with self.sql_conf(
+            
{"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled": True}
+        ):
+            # Legacy: f2 schema inferred as an array of nulls, should raise 
error
+            self.assertRaises(ValueError, lambda: 
self.spark.createDataFrame(data))
+
+        # an array with only null values should raise an error
+        data2 = [ArrayRow([1], [None])]
+        self.assertRaises(ValueError, lambda: 
self.spark.createDataFrame(data2))
+
+        # an array with no values should raise an error
+        data3 = [ArrayRow([1], [])]
+        self.assertRaises(ValueError, lambda: 
self.spark.createDataFrame(data3))
+
+        # an array with conflicting types should raise an error
+        data4 = [ArrayRow([1, "1"], [None])]
+        self.assertRaises(TypeError, lambda: self.spark.createDataFrame(data4))
+
+    def test_infer_array_element_type_empty(self):
+        # SPARK-39168: Test inferring array element type from all rows
+        ArrayRow = Row("f1")
+
+        data = [ArrayRow([]), ArrayRow([None]), ArrayRow([1])]
+
+        rdd = self.sc.parallelize(data)
+        df = self.spark.createDataFrame(rdd)
+        rows = df.collect()
+        self.assertEqual(Row(f1=[]), rows[0])
+        self.assertEqual(Row(f1=[None]), rows[1])
+        self.assertEqual(Row(f1=[1]), rows[2])
+
+        df = self.spark.createDataFrame(data)
+        self.assertEqual(Row(f1=[]), rows[0])
+        self.assertEqual(Row(f1=[None]), rows[1])
+        self.assertEqual(Row(f1=[1]), rows[2])
+
+    def test_infer_array_element_type_with_struct(self):
+        # SPARK-39168: Test inferring array of struct type from all struct 
values
+        NestedRow = Row("f1")
+
+        with 
self.sql_conf({"spark.sql.pyspark.inferNestedDictAsStruct.enabled": True}):
+            data = [NestedRow([{"payment": 200.5}, {"name": "A"}])]
+
+            nestedRdd = self.sc.parallelize(data)
+            df = self.spark.createDataFrame(nestedRdd)
+            self.assertEqual(
+                Row(f1=[Row(payment=200.5, name=None), Row(payment=None, 
name="A")]), df.first()
+            )
+
+            df = self.spark.createDataFrame(data)
+            self.assertEqual(
+                Row(f1=[Row(payment=200.5, name=None), Row(payment=None, 
name="A")]), df.first()
+            )
+
+            # Test legacy behavior inferring only from the first element; 
excludes "name" field
+            with self.sql_conf(
+                
{"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled": True}
+            ):
+                df = self.spark.createDataFrame(data)
+                self.assertEqual(Row(f1=[Row(payment=200.5), 
Row(payment=None)]), df.first())
+
     def test_create_dataframe_from_dict_respects_schema(self):
         df = self.spark.createDataFrame([{"a": 1}], ["b"])
         self.assertEqual(df.columns, ["b"])
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 123fd628980..6602a94f54d 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -27,6 +27,7 @@ import base64
 from array import array
 import ctypes
 from collections.abc import Iterable
+from functools import reduce
 from typing import (
     cast,
     overload,
@@ -1215,6 +1216,7 @@ if sys.version_info[0] < 4:
 def _infer_type(
     obj: Any,
     infer_dict_as_struct: bool = False,
+    infer_array_from_first_element: bool = False,
     prefer_timestamp_ntz: bool = False,
 ) -> DataType:
     """Infer the DataType from obj"""
@@ -1241,24 +1243,49 @@ def _infer_type(
             for key, value in obj.items():
                 if key is not None and value is not None:
                     struct.add(
-                        key, _infer_type(value, infer_dict_as_struct, 
prefer_timestamp_ntz), True
+                        key,
+                        _infer_type(
+                            value,
+                            infer_dict_as_struct,
+                            infer_array_from_first_element,
+                            prefer_timestamp_ntz,
+                        ),
+                        True,
                     )
             return struct
         else:
             for key, value in obj.items():
                 if key is not None and value is not None:
                     return MapType(
-                        _infer_type(key, infer_dict_as_struct, 
prefer_timestamp_ntz),
-                        _infer_type(value, infer_dict_as_struct, 
prefer_timestamp_ntz),
+                        _infer_type(
+                            key,
+                            infer_dict_as_struct,
+                            infer_array_from_first_element,
+                            prefer_timestamp_ntz,
+                        ),
+                        _infer_type(
+                            value,
+                            infer_dict_as_struct,
+                            infer_array_from_first_element,
+                            prefer_timestamp_ntz,
+                        ),
                         True,
                     )
             return MapType(NullType(), NullType(), True)
     elif isinstance(obj, list):
-        for v in obj:
-            if v is not None:
+        if len(obj) > 0:
+            if infer_array_from_first_element:
                 return ArrayType(
                     _infer_type(obj[0], infer_dict_as_struct, 
prefer_timestamp_ntz), True
                 )
+            else:
+                return ArrayType(
+                    reduce(
+                        _merge_type,
+                        (_infer_type(v, infer_dict_as_struct, 
prefer_timestamp_ntz) for v in obj),
+                    ),
+                    True,
+                )
         return ArrayType(NullType(), True)
     elif isinstance(obj, array):
         if obj.typecode in _array_type_mappings:
@@ -1267,7 +1294,11 @@ def _infer_type(
             raise TypeError("not supported type: array(%s)" % obj.typecode)
     else:
         try:
-            return _infer_schema(obj, 
infer_dict_as_struct=infer_dict_as_struct)
+            return _infer_schema(
+                obj,
+                infer_dict_as_struct=infer_dict_as_struct,
+                infer_array_from_first_element=infer_array_from_first_element,
+            )
         except TypeError:
             raise TypeError("not supported type: %s" % type(obj))
 
@@ -1276,6 +1307,7 @@ def _infer_schema(
     row: Any,
     names: Optional[List[str]] = None,
     infer_dict_as_struct: bool = False,
+    infer_array_from_first_element: bool = False,
     prefer_timestamp_ntz: bool = False,
 ) -> StructType:
     """Infer the schema from dict/namedtuple/object"""
@@ -1305,7 +1337,16 @@ def _infer_schema(
     for k, v in items:
         try:
             fields.append(
-                StructField(k, _infer_type(v, infer_dict_as_struct, 
prefer_timestamp_ntz), True)
+                StructField(
+                    k,
+                    _infer_type(
+                        v,
+                        infer_dict_as_struct,
+                        infer_array_from_first_element,
+                        prefer_timestamp_ntz,
+                    ),
+                    True,
+                )
             )
         except TypeError as e:
             raise TypeError("Unable to infer the type of the field 
{}.".format(k)) from e
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 950e3f8f6a2..6dfc46f6a3f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3758,6 +3758,15 @@ object SQLConf {
     .booleanConf
     .createWithDefault(false)
 
+  val LEGACY_INFER_ARRAY_TYPE_FROM_FIRST_ELEMENT =
+    
buildConf("spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled")
+      .doc("PySpark's SparkSession.createDataFrame infers the element type of 
an array from all " +
+        "values in the array by default. If this config is set to true, it 
restores the legacy " +
+        "behavior of only inferring the type from the first array element.")
+      .version("3.4.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val LEGACY_USE_V1_COMMAND =
     buildConf("spark.sql.legacy.useV1Command")
       .internal()
@@ -4554,6 +4563,9 @@ class SQLConf extends Serializable with Logging {
 
   def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT)
 
+  def legacyInferArrayTypeFromFirstElement: Boolean = getConf(
+    SQLConf.LEGACY_INFER_ARRAY_TYPE_FROM_FIRST_ELEMENT)
+
   def parquetFieldIdReadEnabled: Boolean = 
getConf(SQLConf.PARQUET_FIELD_ID_READ_ENABLED)
 
   def parquetFieldIdWriteEnabled: Boolean = 
getConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to