jorisvandenbossche commented on code in PR #46529:
URL: https://github.com/apache/spark/pull/46529#discussion_r1609953958
##########
python/pyspark/sql/connect/session.py:
##########
@@ -413,11 +416,14 @@ def _inferSchemaFromList(
def createDataFrame(
self,
- data: Union["pd.DataFrame", "np.ndarray", Iterable[Any]],
+ data: Union["pd.DataFrame", "np.ndarray", "pa.Table", Iterable[Any]],
schema: Optional[Union[AtomicType, StructType, str, List[str],
Tuple[str, ...]]] = None,
samplingRatio: Optional[float] = None,
verifySchema: Optional[bool] = None,
) -> "ParentDataFrame":
+ import pandas as pd
+ import pyarrow as pa
+
Review Comment:
```suggestion
```
Those packages are already imported top-level in this file as well?
##########
python/pyspark/sql/pandas/types.py:
##########
@@ -232,6 +312,124 @@ def _get_local_timezone() -> str:
return os.environ.get("TZ", "dateutil/:")
+def _check_arrow_table_timestamps_localize(
+ table: "pa.Table", schema: StructType, truncate: bool = True, timezone:
Optional[str] = None
+) -> "pa.Table":
+ """
+ Convert timestamps in a PyArrow Table to timezone-naive in the specified
timezone if the
+ corresponding Spark data type is TimestampType in the specified Spark
schema is TimestampType,
+ and optionally truncate nanosecond timestamps to microseconds.
+
+ Parameters
+ ----------
+ table : :class:`pyarrow.Table`
+ schema : :class:`StructType`
+ The Spark schema corresponding to the schema of the Arrow Table.
+ truncate : bool, default True
+ Whether to truncate nanosecond timestamps to microseconds. (default
``True``)
+ timezone : str, optional
+ The timezone to convert from. If there is a timestamp type, it's
required.
+
+ Returns
+ -------
+ :class:`pyarrow.Table`
+ """
+ import pyarrow as pa
+
+ assert len(table.schema) == len(schema.fields)
+
+ return pa.Table.from_arrays(
+ [
+ _check_arrow_array_timestamps_localize(a, f.dataType, truncate,
timezone)
+ for a, f in zip(table.columns, schema.fields)
+ ],
+ schema=table.schema,
+ )
+
+
+def _check_arrow_array_timestamps_localize(
Review Comment:
Generally looks good to me. But from seeing that code, it reminds me that we
should really have some kernels to apply a scalar function on all values in a
nested array like a list array. Although that would only clean up the list case
in the function below (not struct and dict), not sure if we want something more
generic like a visitor that visits each child array and applies some function
on it.
##########
python/pyspark/sql/pandas/types.py:
##########
@@ -232,6 +312,124 @@ def _get_local_timezone() -> str:
return os.environ.get("TZ", "dateutil/:")
+def _check_arrow_array_timestamps_localize(
+ a: Union["pa.Array", "pa.ChunkedArray"],
+ dt: DataType,
+ truncate: bool = True,
+ timezone: Optional[str] = None,
+) -> Union["pa.Array", "pa.ChunkedArray"]:
+ """
+ Convert Arrow timestamps to timezone-naive in the specified timezone if
the specified Spark
+ data type is TimestampType, and optionally truncate nanosecond timestamps
to microseconds.
+
+ This function works on Arrow Arrays and ChunkedArrays, and it recurses to
convert nested
+ timestamps.
+
+ Parameters
+ ----------
+ a : :class:`pyarrow.Array` or :class:`pyarrow.ChunkedArray`
+ dt : :class:`DataType`
+ The Spark data type corresponding to the Arrow Array to be converted.
+ truncate : bool, default True
+ Whether to truncate nanosecond timestamps to microseconds. (default
``True``)
+ timezone : str, optional
+ The timezone to convert from. If there is a timestamp type, it's
required.
+
+ Returns
+ -------
+ :class:`pyarrow.Array` or :class:`pyarrow.ChunkedArray`
+ """
+ import pyarrow.types as types
+ import pyarrow as pa
+ import pyarrow.compute as pc
+
+ if isinstance(a, pa.ChunkedArray) and (types.is_nested(a.type) or
types.is_dictionary(a.type)):
+ return pa.chunked_array(
+ [
+ _check_arrow_array_timestamps_localize(chunk, dt, truncate,
timezone)
+ for chunk in a.iterchunks()
+ ]
+ )
+
+ if types.is_timestamp(a.type) and truncate and a.type.unit == "ns":
+ a = pc.floor_temporal(a, unit="microsecond")
+
+ if types.is_timestamp(a.type) and a.type.tz is None and type(dt) ==
TimestampType:
+ assert timezone is not None
+
+ # Only localize timestamps that will become Spark TimestampType
columns.
+ # Do not localize timestamps that will become Spark TimestampNTZType
columns.
+ return pc.assume_timezone(a, timezone)
+ if types.is_list(a.type):
+ at: ArrayType = cast(ArrayType, dt)
+ return pa.ListArray.from_arrays(
+ a.offsets,
+ _check_arrow_array_timestamps_localize(a.values, at.elementType,
truncate, timezone),
+ )
Review Comment:
Also this will not preserve nulls? (similarly like the issue you raised for
map type, although ListArray already has a `mask` keyword. And we should also
add something like https://github.com/apache/arrow/issues/23380 to simply apply
an existing validity bitmap buffer)
##########
python/pyspark/sql/pandas/types.py:
##########
@@ -232,6 +312,124 @@ def _get_local_timezone() -> str:
return os.environ.get("TZ", "dateutil/:")
+def _check_arrow_array_timestamps_localize(
+ a: Union["pa.Array", "pa.ChunkedArray"],
+ dt: DataType,
+ truncate: bool = True,
+ timezone: Optional[str] = None,
+) -> Union["pa.Array", "pa.ChunkedArray"]:
+ """
+ Convert Arrow timestamps to timezone-naive in the specified timezone if
the specified Spark
+ data type is TimestampType, and optionally truncate nanosecond timestamps
to microseconds.
+
+ This function works on Arrow Arrays and ChunkedArrays, and it recurses to
convert nested
+ timestamps.
+
+ Parameters
+ ----------
+ a : :class:`pyarrow.Array` or :class:`pyarrow.ChunkedArray`
+ dt : :class:`DataType`
+ The Spark data type corresponding to the Arrow Array to be converted.
+ truncate : bool, default True
+ Whether to truncate nanosecond timestamps to microseconds. (default
``True``)
+ timezone : str, optional
+ The timezone to convert from. If there is a timestamp type, it's
required.
+
+ Returns
+ -------
+ :class:`pyarrow.Array` or :class:`pyarrow.ChunkedArray`
+ """
+ import pyarrow.types as types
+ import pyarrow as pa
+ import pyarrow.compute as pc
+
+ if isinstance(a, pa.ChunkedArray) and (types.is_nested(a.type) or
types.is_dictionary(a.type)):
+ return pa.chunked_array(
+ [
+ _check_arrow_array_timestamps_localize(chunk, dt, truncate,
timezone)
+ for chunk in a.iterchunks()
+ ]
+ )
+
+ if types.is_timestamp(a.type) and truncate and a.type.unit == "ns":
+ a = pc.floor_temporal(a, unit="microsecond")
+
+ if types.is_timestamp(a.type) and a.type.tz is None and type(dt) ==
TimestampType:
+ assert timezone is not None
+
+ # Only localize timestamps that will become Spark TimestampType
columns.
+ # Do not localize timestamps that will become Spark TimestampNTZType
columns.
+ return pc.assume_timezone(a, timezone)
+ if types.is_list(a.type):
+ at: ArrayType = cast(ArrayType, dt)
+ return pa.ListArray.from_arrays(
+ a.offsets,
+ _check_arrow_array_timestamps_localize(a.values, at.elementType,
truncate, timezone),
+ )
+ if types.is_map(a.type):
+ mt: MapType = cast(MapType, dt)
+ # TODO(SPARK-48302): Do not replace nulls in MapArray with empty lists
+ return pa.MapArray.from_arrays(
+ a.offsets,
+ _check_arrow_array_timestamps_localize(a.keys, mt.keyType,
truncate, timezone),
+ _check_arrow_array_timestamps_localize(a.items, mt.valueType,
truncate, timezone),
+ )
+ if types.is_struct(a.type):
Review Comment:
Similar comment here, as StructArray can have a top-level nulls / validity
bitmap that will be lost here (I don't know if Spark has this ability, though)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]