This is an automated email from the ASF dual-hosted git repository.
honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 32e8f88e support PyArrow timestamptz with Etc/UTC (#910)
32e8f88e is described below
commit 32e8f88ebf8e45ae0a7f60a848ea44044a9564ef
Author: Sung Yun <[email protected]>
AuthorDate: Fri Jul 12 15:26:00 2024 -0400
support PyArrow timestamptz with Etc/UTC (#910)
Co-authored-by: Fokko Driesprong <[email protected]>
---
pyiceberg/io/pyarrow.py | 51 ++++++---
pyiceberg/table/__init__.py | 8 --
tests/conftest.py | 116 ++++++++++++++++++++-
tests/integration/test_add_files.py | 1 +
.../test_writes/test_partitioned_writes.py | 14 +--
tests/integration/test_writes/test_writes.py | 81 +++++---------
tests/io/test_pyarrow.py | 33 ++++++
7 files changed, 218 insertions(+), 86 deletions(-)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 91745a58..1ef9fc9b 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -174,6 +174,7 @@ LIST_ELEMENT_NAME = "element"
MAP_KEY_NAME = "key"
MAP_VALUE_NAME = "value"
DOC = "doc"
+UTC_ALIASES = {"UTC", "+00:00", "Etc/UTC", "Z"}
T = TypeVar("T")
@@ -937,7 +938,7 @@ class
_ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
else:
raise TypeError(f"Unsupported precision for timestamp type:
{primitive.unit}")
- if primitive.tz == "UTC" or primitive.tz == "+00:00":
+ if primitive.tz in UTC_ALIASES:
return TimestamptzType()
elif primitive.tz is None:
return TimestampType()
@@ -1073,7 +1074,7 @@ def _task_to_record_batches(
arrow_table = pa.Table.from_batches([batch])
arrow_table = arrow_table.filter(pyarrow_filter)
batch = arrow_table.to_batches()[0]
- yield to_requested_schema(projected_schema, file_project_schema,
batch, downcast_ns_timestamp_to_us=True)
+ yield _to_requested_schema(projected_schema, file_project_schema,
batch, downcast_ns_timestamp_to_us=True)
current_index += len(batch)
@@ -1278,7 +1279,7 @@ def project_batches(
total_row_count += len(batch)
-def to_requested_schema(
+def _to_requested_schema(
requested_schema: Schema,
file_schema: Schema,
batch: pa.RecordBatch,
@@ -1296,16 +1297,17 @@ def to_requested_schema(
class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array,
Optional[pa.Array]]):
- file_schema: Schema
+ _file_schema: Schema
_include_field_ids: bool
+ _downcast_ns_timestamp_to_us: bool
def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool
= False, include_field_ids: bool = False) -> None:
- self.file_schema = file_schema
+ self._file_schema = file_schema
self._include_field_ids = include_field_ids
- self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
+ self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
def _cast_if_needed(self, field: NestedField, values: pa.Array) ->
pa.Array:
- file_field = self.file_schema.find_field(field.field_id)
+ file_field = self._file_schema.find_field(field.field_id)
if field.field_type.is_primitive:
if field.field_type != file_field.field_type:
@@ -1313,14 +1315,31 @@ class
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
schema_to_pyarrow(promote(file_field.field_type,
field.field_type), include_field_ids=self._include_field_ids)
)
elif (target_type := schema_to_pyarrow(field.field_type,
include_field_ids=self._include_field_ids)) != values.type:
- # Downcasting of nanoseconds to microseconds
- if (
- pa.types.is_timestamp(target_type)
- and target_type.unit == "us"
- and pa.types.is_timestamp(values.type)
- and values.type.unit == "ns"
- ):
- return values.cast(target_type, safe=False)
+ if field.field_type == TimestampType():
+ # Downcasting of nanoseconds to microseconds
+ if (
+ pa.types.is_timestamp(target_type)
+ and not target_type.tz
+ and pa.types.is_timestamp(values.type)
+ and not values.type.tz
+ ):
+ if target_type.unit == "us" and values.type.unit ==
"ns" and self._downcast_ns_timestamp_to_us:
+ return values.cast(target_type, safe=False)
+ elif target_type.unit == "us" and values.type.unit in
{"s", "ms"}:
+ return values.cast(target_type)
+ raise ValueError(f"Unsupported schema projection from
{values.type} to {target_type}")
+ elif field.field_type == TimestamptzType():
+ if (
+ pa.types.is_timestamp(target_type)
+ and target_type.tz == "UTC"
+ and pa.types.is_timestamp(values.type)
+ and values.type.tz in UTC_ALIASES
+ ):
+ if target_type.unit == "us" and values.type.unit ==
"ns" and self._downcast_ns_timestamp_to_us:
+ return values.cast(target_type, safe=False)
+ elif target_type.unit == "us" and values.type.unit in
{"s", "ms", "us"}:
+ return values.cast(target_type)
+ raise ValueError(f"Unsupported schema projection from
{values.type} to {target_type}")
return values
def _construct_field(self, field: NestedField, arrow_type: pa.DataType) ->
pa.Field:
@@ -1970,7 +1989,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata,
tasks: Iterator[WriteT
downcast_ns_timestamp_to_us =
Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
batches = [
- to_requested_schema(
+ _to_requested_schema(
requested_schema=file_schema,
file_schema=table_schema,
batch=batch,
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 62440c47..b43dc320 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -484,10 +484,6 @@ class Transaction:
_check_schema_compatible(
self._table.schema(), other_schema=df.schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
- # cast if the two schemas are compatible but not equal
- table_arrow_schema = self._table.schema().as_arrow()
- if table_arrow_schema != df.schema:
- df = df.cast(table_arrow_schema)
manifest_merge_enabled = PropertyUtil.property_as_bool(
self.table_metadata.properties,
@@ -545,10 +541,6 @@ class Transaction:
_check_schema_compatible(
self._table.schema(), other_schema=df.schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
- # cast if the two schemas are compatible but not equal
- table_arrow_schema = self._table.schema().as_arrow()
- if table_arrow_schema != df.schema:
- df = df.cast(table_arrow_schema)
self.delete(delete_filter=overwrite_filter,
snapshot_properties=snapshot_properties)
diff --git a/tests/conftest.py b/tests/conftest.py
index 95e1128a..6b1a2b43 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2382,10 +2382,122 @@ def arrow_table_date_timestamps() -> "pa.Table":
@pytest.fixture(scope="session")
-def arrow_table_date_timestamps_schema() -> Schema:
- """Pyarrow table Schema with only date, timestamp and timestamptz
values."""
+def table_date_timestamps_schema() -> Schema:
+ """Iceberg table Schema with only date, timestamp and timestamptz
values."""
return Schema(
NestedField(field_id=1, name="date", field_type=DateType(),
required=False),
NestedField(field_id=2, name="timestamp", field_type=TimestampType(),
required=False),
NestedField(field_id=3, name="timestamptz",
field_type=TimestamptzType(), required=False),
)
+
+
[email protected](scope="session")
+def arrow_table_schema_with_all_timestamp_precisions() -> "pa.Schema":
+ """Pyarrow Schema with all supported timestamp types."""
+ import pyarrow as pa
+
+ return pa.schema([
+ ("timestamp_s", pa.timestamp(unit="s")),
+ ("timestamptz_s", pa.timestamp(unit="s", tz="UTC")),
+ ("timestamp_ms", pa.timestamp(unit="ms")),
+ ("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")),
+ ("timestamp_us", pa.timestamp(unit="us")),
+ ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
+ ("timestamp_ns", pa.timestamp(unit="ns")),
+ ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")),
+ ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")),
+ ("timestamptz_ns_z", pa.timestamp(unit="ns", tz="Z")),
+ ("timestamptz_s_0000", pa.timestamp(unit="s", tz="+00:00")),
+ ])
+
+
[email protected](scope="session")
+def
arrow_table_with_all_timestamp_precisions(arrow_table_schema_with_all_timestamp_precisions:
"pa.Schema") -> "pa.Table":
+ """Pyarrow table with all supported timestamp types."""
+ import pandas as pd
+ import pyarrow as pa
+
+ test_data = pd.DataFrame({
+ "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023,
3, 1, 19, 25, 00)],
+ "timestamptz_s": [
+ datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
+ None,
+ datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
+ ],
+ "timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None,
datetime(2023, 3, 1, 19, 25, 00)],
+ "timestamptz_ms": [
+ datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
+ None,
+ datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
+ ],
+ "timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None,
datetime(2023, 3, 1, 19, 25, 00)],
+ "timestamptz_us": [
+ datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
+ None,
+ datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
+ ],
+ "timestamp_ns": [
+ pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30,
second=0, microsecond=12, nanosecond=6),
+ None,
+ pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30,
second=0, microsecond=12, nanosecond=7),
+ ],
+ "timestamptz_ns": [
+ datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
+ None,
+ datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
+ ],
+ "timestamptz_us_etc_utc": [
+ datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
+ None,
+ datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
+ ],
+ "timestamptz_ns_z": [
+ pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30,
second=0, microsecond=12, nanosecond=6, tz="UTC"),
+ None,
+ pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30,
second=0, microsecond=12, nanosecond=7, tz="UTC"),
+ ],
+ "timestamptz_s_0000": [
+ datetime(2023, 1, 1, 19, 25, 1, tzinfo=timezone.utc),
+ None,
+ datetime(2023, 3, 1, 19, 25, 1, tzinfo=timezone.utc),
+ ],
+ })
+ return pa.Table.from_pandas(test_data,
schema=arrow_table_schema_with_all_timestamp_precisions)
+
+
[email protected](scope="session")
+def arrow_table_schema_with_all_microseconds_timestamp_precisions() ->
"pa.Schema":
+ """Pyarrow Schema with all microseconds timestamp."""
+ import pyarrow as pa
+
+ return pa.schema([
+ ("timestamp_s", pa.timestamp(unit="us")),
+ ("timestamptz_s", pa.timestamp(unit="us", tz="UTC")),
+ ("timestamp_ms", pa.timestamp(unit="us")),
+ ("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")),
+ ("timestamp_us", pa.timestamp(unit="us")),
+ ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
+ ("timestamp_ns", pa.timestamp(unit="us")),
+ ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")),
+ ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")),
+ ("timestamptz_ns_z", pa.timestamp(unit="us", tz="UTC")),
+ ("timestamptz_s_0000", pa.timestamp(unit="us", tz="UTC")),
+ ])
+
+
[email protected](scope="session")
+def table_schema_with_all_microseconds_timestamp_precision() -> Schema:
+ """Iceberg table Schema with only date, timestamp and timestamptz
values."""
+ return Schema(
+ NestedField(field_id=1, name="timestamp_s",
field_type=TimestampType(), required=False),
+ NestedField(field_id=2, name="timestamptz_s",
field_type=TimestamptzType(), required=False),
+ NestedField(field_id=3, name="timestamp_ms",
field_type=TimestampType(), required=False),
+ NestedField(field_id=4, name="timestamptz_ms",
field_type=TimestamptzType(), required=False),
+ NestedField(field_id=5, name="timestamp_us",
field_type=TimestampType(), required=False),
+ NestedField(field_id=6, name="timestamptz_us",
field_type=TimestamptzType(), required=False),
+ NestedField(field_id=7, name="timestamp_ns",
field_type=TimestampType(), required=False),
+ NestedField(field_id=8, name="timestamptz_ns",
field_type=TimestamptzType(), required=False),
+ NestedField(field_id=9, name="timestamptz_us_etc_utc",
field_type=TimestamptzType(), required=False),
+ NestedField(field_id=10, name="timestamptz_ns_z",
field_type=TimestamptzType(), required=False),
+ NestedField(field_id=11, name="timestamptz_s_0000",
field_type=TimestamptzType(), required=False),
+ )
diff --git a/tests/integration/test_add_files.py
b/tests/integration/test_add_files.py
index 984c7d11..b8fd6d09 100644
--- a/tests/integration/test_add_files.py
+++ b/tests/integration/test_add_files.py
@@ -570,6 +570,7 @@ def test_add_files_with_large_and_regular_schema(spark:
SparkSession, session_ca
assert table_schema == arrow_schema_large
[email protected]
def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog,
format_version: int, mocker: MockerFixture) -> None:
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux",
TimestamptzType()))
diff --git a/tests/integration/test_writes/test_partitioned_writes.py
b/tests/integration/test_writes/test_partitioned_writes.py
index 12da9c92..b199f002 100644
--- a/tests/integration/test_writes/test_partitioned_writes.py
+++ b/tests/integration/test_writes/test_partitioned_writes.py
@@ -461,7 +461,7 @@ def test_append_transform_partition_verify_partitions_count(
session_catalog: Catalog,
spark: SparkSession,
arrow_table_date_timestamps: pa.Table,
- arrow_table_date_timestamps_schema: Schema,
+ table_date_timestamps_schema: Schema,
transform: Transform[Any, Any],
expected_partitions: Set[Any],
format_version: int,
@@ -469,7 +469,7 @@ def test_append_transform_partition_verify_partitions_count(
# Given
part_col = "timestamptz"
identifier =
f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}"
- nested_field = arrow_table_date_timestamps_schema.find_field(part_col)
+ nested_field = table_date_timestamps_schema.find_field(part_col)
partition_spec = PartitionSpec(
PartitionField(source_id=nested_field.field_id, field_id=1001,
transform=transform, name=part_col),
)
@@ -481,7 +481,7 @@ def test_append_transform_partition_verify_partitions_count(
properties={"format-version": str(format_version)},
data=[arrow_table_date_timestamps],
partition_spec=partition_spec,
- schema=arrow_table_date_timestamps_schema,
+ schema=table_date_timestamps_schema,
)
# Then
@@ -510,20 +510,20 @@ def test_append_multiple_partitions(
session_catalog: Catalog,
spark: SparkSession,
arrow_table_date_timestamps: pa.Table,
- arrow_table_date_timestamps_schema: Schema,
+ table_date_timestamps_schema: Schema,
format_version: int,
) -> None:
# Given
identifier =
f"default.arrow_table_v{format_version}_with_multiple_partitions"
partition_spec = PartitionSpec(
PartitionField(
-
source_id=arrow_table_date_timestamps_schema.find_field("date").field_id,
+ source_id=table_date_timestamps_schema.find_field("date").field_id,
field_id=1001,
transform=YearTransform(),
name="date_year",
),
PartitionField(
-
source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id,
+
source_id=table_date_timestamps_schema.find_field("timestamptz").field_id,
field_id=1000,
transform=HourTransform(),
name="timestamptz_hour",
@@ -537,7 +537,7 @@ def test_append_multiple_partitions(
properties={"format-version": str(format_version)},
data=[arrow_table_date_timestamps],
partition_spec=partition_spec,
- schema=arrow_table_date_timestamps_schema,
+ schema=table_date_timestamps_schema,
)
# Then
diff --git a/tests/integration/test_writes/test_writes.py
b/tests/integration/test_writes/test_writes.py
index af626718..41bc6fb5 100644
--- a/tests/integration/test_writes/test_writes.py
+++ b/tests/integration/test_writes/test_writes.py
@@ -18,11 +18,12 @@
import math
import os
import time
-from datetime import date, datetime, timezone
+from datetime import date, datetime
from pathlib import Path
from typing import Any, Dict
from urllib.parse import urlparse
+import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
@@ -977,69 +978,43 @@ def table_write_subset_of_schema(session_catalog:
Catalog, arrow_table_with_null
@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
-def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog:
Catalog, format_version: int) -> None:
+def test_write_all_timestamp_precision(
+ mocker: MockerFixture,
+ spark: SparkSession,
+ session_catalog: Catalog,
+ format_version: int,
+ arrow_table_schema_with_all_timestamp_precisions: pa.Schema,
+ arrow_table_with_all_timestamp_precisions: pa.Table,
+ arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema,
+) -> None:
identifier = "default.table_all_timestamp_precision"
- arrow_table_schema_with_all_timestamp_precisions = pa.schema([
- ("timestamp_s", pa.timestamp(unit="s")),
- ("timestamptz_s", pa.timestamp(unit="s", tz="UTC")),
- ("timestamp_ms", pa.timestamp(unit="ms")),
- ("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")),
- ("timestamp_us", pa.timestamp(unit="us")),
- ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
- ("timestamp_ns", pa.timestamp(unit="ns")),
- ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")),
- ])
- TEST_DATA_WITH_NULL = {
- "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023,
3, 1, 19, 25, 00)],
- "timestamptz_s": [
- datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
- None,
- datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
- ],
- "timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None,
datetime(2023, 3, 1, 19, 25, 00)],
- "timestamptz_ms": [
- datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
- None,
- datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
- ],
- "timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None,
datetime(2023, 3, 1, 19, 25, 00)],
- "timestamptz_us": [
- datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
- None,
- datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
- ],
- "timestamp_ns": [datetime(2023, 1, 1, 19, 25, 00), None,
datetime(2023, 3, 1, 19, 25, 00)],
- "timestamptz_ns": [
- datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
- None,
- datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
- ],
- }
- input_arrow_table = pa.Table.from_pydict(TEST_DATA_WITH_NULL,
schema=arrow_table_schema_with_all_timestamp_precisions)
mocker.patch.dict(os.environ,
values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"})
tbl = _create_table(
session_catalog,
identifier,
{"format-version": format_version},
- data=[input_arrow_table],
+ data=[arrow_table_with_all_timestamp_precisions],
schema=arrow_table_schema_with_all_timestamp_precisions,
)
- tbl.overwrite(input_arrow_table)
+ tbl.overwrite(arrow_table_with_all_timestamp_precisions)
written_arrow_table = tbl.scan().to_arrow()
- expected_schema_in_all_us = pa.schema([
- ("timestamp_s", pa.timestamp(unit="us")),
- ("timestamptz_s", pa.timestamp(unit="us", tz="UTC")),
- ("timestamp_ms", pa.timestamp(unit="us")),
- ("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")),
- ("timestamp_us", pa.timestamp(unit="us")),
- ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
- ("timestamp_ns", pa.timestamp(unit="us")),
- ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")),
- ])
- assert written_arrow_table.schema == expected_schema_in_all_us
- assert written_arrow_table ==
input_arrow_table.cast(expected_schema_in_all_us)
+ assert written_arrow_table.schema ==
arrow_table_schema_with_all_microseconds_timestamp_precisions
+ assert written_arrow_table ==
arrow_table_with_all_timestamp_precisions.cast(
+ arrow_table_schema_with_all_microseconds_timestamp_precisions,
safe=False
+ )
+ lhs = spark.table(f"{identifier}").toPandas()
+ rhs = written_arrow_table.to_pandas()
+
+ for column in written_arrow_table.column_names:
+ for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
+ if pd.isnull(left):
+ assert pd.isnull(right)
+ else:
+ # Check only upto microsecond precision since Spark loaded
dtype is timezone unaware
+ # and supports upto microsecond precision
+ assert left.timestamp() == right.timestamp(), f"Difference in
column {column}: {left} != {right}"
@pytest.mark.integration
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index 326eeff1..37198b7e 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -65,6 +65,7 @@ from pyiceberg.io.pyarrow import (
_determine_partitions,
_primitive_to_physical,
_read_deletes,
+ _to_requested_schema,
bin_pack_arrow_table,
expression_to_pyarrow,
project_table,
@@ -1889,3 +1890,35 @@ def test_identity_partition_on_multi_columns() -> None:
("n_legs", "ascending"),
("animal", "ascending"),
]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs",
"ascending"), ("animal", "ascending")])
+
+
+def test__to_requested_schema_timestamps(
+ arrow_table_schema_with_all_timestamp_precisions: pa.Schema,
+ arrow_table_with_all_timestamp_precisions: pa.Table,
+ arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema,
+ table_schema_with_all_microseconds_timestamp_precision: Schema,
+) -> None:
+ requested_schema = table_schema_with_all_microseconds_timestamp_precision
+ file_schema = requested_schema
+ batch = arrow_table_with_all_timestamp_precisions.to_batches()[0]
+ result = _to_requested_schema(requested_schema, file_schema, batch,
downcast_ns_timestamp_to_us=True, include_field_ids=False)
+
+ expected = arrow_table_with_all_timestamp_precisions.cast(
+ arrow_table_schema_with_all_microseconds_timestamp_precisions,
safe=False
+ ).to_batches()[0]
+ assert result == expected
+
+
+def test__to_requested_schema_timestamps_without_downcast_raises_exception(
+ arrow_table_schema_with_all_timestamp_precisions: pa.Schema,
+ arrow_table_with_all_timestamp_precisions: pa.Table,
+ arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema,
+ table_schema_with_all_microseconds_timestamp_precision: Schema,
+) -> None:
+ requested_schema = table_schema_with_all_microseconds_timestamp_precision
+ file_schema = requested_schema
+ batch = arrow_table_with_all_timestamp_precisions.to_batches()[0]
+ with pytest.raises(ValueError) as exc_info:
+ _to_requested_schema(requested_schema, file_schema, batch,
downcast_ns_timestamp_to_us=False, include_field_ids=False)
+
+ assert "Unsupported schema projection from timestamp[ns] to timestamp[us]"
in str(exc_info.value)