This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b689c2ad756 [SPARK-42028][CONNECT][PYTHON][FOLLOW-UP] Uses the same 
logic with PySpark, and reeanbles skipped test
b689c2ad756 is described below

commit b689c2ad756cc00dfc0f71f142e771dd367bcf4a
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Fri Jan 13 15:20:22 2023 +0900

    [SPARK-42028][CONNECT][PYTHON][FOLLOW-UP] Uses the same logic with PySpark, 
and reeanbles skipped test
    
    ### What changes were proposed in this pull request?
    
    This PR is a followup of https://github.com/apache/spark/pull/39469 that 
uses the same logic with PySpark: 
https://github.com/apache/spark/blob/baa6fa9b148467bfc83e6c2d22ea9fd9fa5b4564/python/pyspark/sql/pandas/conversion.py#L546-L631
    
    and reeanbles skipped test 
`test_create_dataframe_from_pandas_with_timestamp`.
    
    This PR fixes a bug together by doing this. Nave datetime was inferred as 
`TimestampNTZType` before but now it is inferred as `TimestampType` that is 
matched with the regular PySpark.
    
    ### Why are the changes needed?
    
    To deduplicate the changes in the future, and maintainability.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No to end users.
    It matches the behaviour to the existing PySpark's `createDataFrame(pdf)`.
    
    ### How was this patch tested?
    
    Reenabled a skipped test, and existing test added in the previous PR should 
cover this.
    
    Closes #39544 from HyukjinKwon/SPARK-42028-followup.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/session.py              | 112 ++++++++++-----------
 .../sql/tests/connect/test_parity_dataframe.py     |   7 +-
 python/pyspark/sql/tests/test_dataframe.py         |  12 +--
 3 files changed, 60 insertions(+), 71 deletions(-)

diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 76073fc2717..6aec28d70a8 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -16,17 +16,39 @@
 #
 import os
 import warnings
-from distutils.version import LooseVersion
-from threading import RLock
 from collections.abc import Sized
+from distutils.version import LooseVersion
 from functools import reduce
+from threading import RLock
+from typing import (
+    Optional,
+    Any,
+    Union,
+    Dict,
+    List,
+    Tuple,
+    cast,
+    overload,
+    Iterable,
+    TYPE_CHECKING,
+)
 
 import numpy as np
 import pandas as pd
 import pyarrow as pa
+from pandas.api.types import (  # type: ignore[attr-defined]
+    is_datetime64_dtype,
+    is_datetime64tz_dtype,
+)
 
 from pyspark import SparkContext, SparkConf, __version__
 from pyspark.java_gateway import launch_gateway
+from pyspark.sql.connect.client import SparkConnectClient
+from pyspark.sql.connect.dataframe import DataFrame
+from pyspark.sql.connect.plan import SQL, Range, LocalRelation
+from pyspark.sql.connect.readwriter import DataFrameReader
+from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
+from pyspark.sql.pandas.types import to_arrow_type, _get_local_timezone
 from pyspark.sql.session import classproperty, SparkSession as PySparkSession
 from pyspark.sql.types import (
     _infer_schema,
@@ -36,28 +58,10 @@ from pyspark.sql.types import (
     DataType,
     StructType,
     AtomicType,
+    TimestampType,
 )
 from pyspark.sql.utils import to_str
 
-from pyspark.sql.connect.client import SparkConnectClient
-from pyspark.sql.connect.dataframe import DataFrame
-from pyspark.sql.connect.plan import SQL, Range, LocalRelation
-from pyspark.sql.connect.readwriter import DataFrameReader
-
-from typing import (
-    Optional,
-    Any,
-    Union,
-    Dict,
-    List,
-    Tuple,
-    cast,
-    overload,
-    Iterable,
-    TYPE_CHECKING,
-)
-
-
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import OptionalPrimitiveType
     from pyspark.sql.connect.catalog import Catalog
@@ -221,47 +225,37 @@ class SparkSession:
         _inferred_schema: Optional[StructType] = None
 
         if isinstance(data, pd.DataFrame):
-            from pandas.api.types import (  # type: ignore[attr-defined]
-                is_datetime64_dtype,
-                is_datetime64tz_dtype,
-            )
-            from pyspark.sql.pandas.types import (
-                _check_series_convert_timestamps_internal,
-                _get_local_timezone,
+            # Logic was borrowed from `_create_from_pandas_with_arrow` in
+            # `pyspark.sql.pandas.conversion.py`. Should ideally deduplicate 
the logics.
+
+            # If no schema supplied by user then get the names of columns only
+            if schema is None:
+                _cols = [str(x) if not isinstance(x, str) else x for x in 
data.columns]
+
+            # Determine arrow types to coerce data when creating batches
+            if isinstance(schema, StructType):
+                arrow_types = [to_arrow_type(f.dataType) for f in 
schema.fields]
+                _cols = [str(x) if not isinstance(x, str) else x for x in 
schema.fieldNames()]
+            elif isinstance(schema, DataType):
+                raise ValueError("Single data type %s is not supported with 
Arrow" % str(schema))
+            else:
+                # Any timestamps must be coerced to be compatible with Spark
+                arrow_types = [
+                    to_arrow_type(TimestampType())
+                    if is_datetime64_dtype(t) or is_datetime64tz_dtype(t)
+                    else None
+                    for t in data.dtypes
+                ]
+
+            ser = ArrowStreamPandasSerializer(
+                _get_local_timezone(),  # 'spark.session.timezone' should be 
respected
+                False,  # 
'spark.sql.execution.pandas.convertToArrowArraySafely' should be respected
+                True,
             )
 
-            # First, check if we need to create a copy of the input data to 
adjust
-            # the timestamps.
-            input_data = data
-            has_timestamp_data = any(
-                [is_datetime64_dtype(data[c]) or 
is_datetime64tz_dtype(data[c]) for c in data]
+            _table = pa.Table.from_batches(
+                [ser._create_batch([(c, t) for (_, c), t in zip(data.items(), 
arrow_types)])]
             )
-            if has_timestamp_data:
-                input_data = data.copy()
-                # We need double conversions for the truncation, first 
truncate to microseconds.
-                for col in input_data:
-                    if is_datetime64tz_dtype(input_data[col].dtype):
-                        input_data[col] = 
_check_series_convert_timestamps_internal(
-                            input_data[col], _get_local_timezone()
-                        ).astype("datetime64[us, UTC]")
-                    elif is_datetime64_dtype(input_data[col].dtype):
-                        input_data[col] = 
input_data[col].astype("datetime64[us]")
-
-                # Create a new schema and change the types to the truncated 
microseconds.
-                pd_schema = pa.Schema.from_pandas(input_data)
-                new_schema = pa.schema([])
-                for x in range(len(pd_schema.types)):
-                    f = pd_schema.field(x)
-                    # TODO(SPARK-42027) Add support for struct types.
-                    if isinstance(f.type, pa.TimestampType) and f.type.unit == 
"ns":
-                        tmp = f.with_type(pa.timestamp("us"))
-                        new_schema = new_schema.append(tmp)
-                    else:
-                        new_schema = new_schema.append(f)
-                new_schema = new_schema.with_metadata(pd_schema.metadata)
-                _table = pa.Table.from_pandas(input_data, schema=new_schema)
-            else:
-                _table = pa.Table.from_pandas(data)
 
         elif isinstance(data, np.ndarray):
             if data.ndim not in [1, 2]:
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py 
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index 61fd72b6bbf..c722f4693e4 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -37,16 +37,11 @@ class DataFrameParityTests(DataFrameTestsMixin, 
ReusedConnectTestCase):
     def test_create_dataframe_from_pandas_with_day_time_interval(self):
         super().test_create_dataframe_from_pandas_with_day_time_interval()
 
-    # TODO(SPARK-41842): Support data type Timestamp(NANOSECOND, null)
+    # TODO(SPARK-41834): Implement SparkSession.conf
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_create_dataframe_from_pandas_with_dst(self):
         super().test_create_dataframe_from_pandas_with_dst()
 
-    # TODO(SPARK-41842): Support data type Timestamp(NANOSECOND, null)
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_create_dataframe_from_pandas_with_timestamp(self):
-        super().test_create_dataframe_from_pandas_with_timestamp()
-
     # TODO(SPARK-41855): createDataFrame doesn't handle None/NaN properly
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_create_nan_decimal_dataframe(self):
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index c67f43ecb64..4ab16f231c0 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1246,15 +1246,15 @@ class DataFrameTestsMixin:
         )
         # test types are inferred correctly without specifying schema
         df = self.spark.createDataFrame(pdf)
-        self.assertTrue(isinstance(df.schema["ts"].dataType, TimestampType))
-        self.assertTrue(isinstance(df.schema["d"].dataType, DateType))
+        self.assertIsInstance(df.schema["ts"].dataType, TimestampType)
+        self.assertIsInstance(df.schema["d"].dataType, DateType)
         # test with schema will accept pdf as input
         df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp")
-        self.assertTrue(isinstance(df.schema["ts"].dataType, TimestampType))
-        self.assertTrue(isinstance(df.schema["d"].dataType, DateType))
+        self.assertIsInstance(df.schema["ts"].dataType, TimestampType)
+        self.assertIsInstance(df.schema["d"].dataType, DateType)
         df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp_ntz")
-        self.assertTrue(isinstance(df.schema["ts"].dataType, TimestampNTZType))
-        self.assertTrue(isinstance(df.schema["d"].dataType, DateType))
+        self.assertIsInstance(df.schema["ts"].dataType, TimestampNTZType)
+        self.assertIsInstance(df.schema["d"].dataType, DateType)
 
     @unittest.skipIf(have_pandas, "Required Pandas was found.")
     def test_create_dataframe_required_pandas_not_found(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to