This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 8789fc2 Bug Fix: Handle null vals from later row_groups in
StatsAggregator (#379)
8789fc2 is described below
commit 8789fc287755c8e617513f7d136b8cccc4800508
Author: Sung Yun <[email protected]>
AuthorDate: Tue Feb 6 10:26:56 2024 -0500
Bug Fix: Handle null vals from later row_groups in StatsAggregator (#379)
* handle null vals correctly in StatsAggregator
* return
* Apply suggestions from code review
Co-authored-by: Fokko Driesprong <[email protected]>
---------
Co-authored-by: Fokko Driesprong <[email protected]>
---
pyiceberg/io/pyarrow.py | 14 ++++++++++----
tests/io/test_pyarrow.py | 44 ++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 54 insertions(+), 4 deletions(-)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 99c1af5..9726451 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -1331,11 +1331,17 @@ class StatsAggregator:
def serialize(self, value: Any) -> bytes:
return to_bytes(self.primitive_type, value)
- def update_min(self, val: Any) -> None:
- self.current_min = val if self.current_min is None else min(val,
self.current_min)
+ def update_min(self, val: Optional[Any]) -> None:
+ if self.current_min is None:
+ self.current_min = val
+ elif val is not None:
+ self.current_min = min(val, self.current_min)
- def update_max(self, val: Any) -> None:
- self.current_max = val if self.current_max is None else max(val,
self.current_max)
+ def update_max(self, val: Optional[Any]) -> None:
+ if self.current_max is None:
+ self.current_max = val
+ elif val is not None:
+ self.current_max = max(val, self.current_max)
def min_as_bytes(self) -> Optional[bytes]:
if self.current_min is None:
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index 0628ed4..745de1a 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -18,6 +18,7 @@
import os
import tempfile
+from datetime import date
from typing import Any, List, Optional
from unittest.mock import MagicMock, patch
from uuid import uuid4
@@ -59,7 +60,9 @@ from pyiceberg.io.pyarrow import (
ICEBERG_SCHEMA,
PyArrowFile,
PyArrowFileIO,
+ StatsAggregator,
_ConvertToArrowSchema,
+ _primitive_to_physical,
_read_deletes,
expression_to_pyarrow,
project_table,
@@ -84,6 +87,7 @@ from pyiceberg.types import (
LongType,
MapType,
NestedField,
+ PrimitiveType,
StringType,
StructType,
TimestampType,
@@ -1666,3 +1670,43 @@ def test_parse_location() -> None:
def test_make_compatible_name() -> None:
assert make_compatible_name("label/abc") == "label_x2Fabc"
assert make_compatible_name("label?abc") == "label_x3Fabc"
+
+
[email protected](
+ "vals, primitive_type, expected_result",
+ [
+ ([None, 2, 1], IntegerType(), 1),
+ ([1, None, 2], IntegerType(), 1),
+ ([None, None, None], IntegerType(), None),
+ ([None, date(2024, 2, 4), date(2024, 1, 2)], DateType(), date(2024, 1,
2)),
+ ([date(2024, 1, 2), None, date(2024, 2, 4)], DateType(), date(2024, 1,
2)),
+ ([None, None, None], DateType(), None),
+ ],
+)
+def test_stats_aggregator_update_min(vals: List[Any], primitive_type:
PrimitiveType, expected_result: Any) -> None:
+ stats = StatsAggregator(primitive_type,
_primitive_to_physical(primitive_type))
+
+ for val in vals:
+ stats.update_min(val)
+
+ assert stats.current_min == expected_result
+
+
[email protected](
+ "vals, primitive_type, expected_result",
+ [
+ ([None, 2, 1], IntegerType(), 2),
+ ([1, None, 2], IntegerType(), 2),
+ ([None, None, None], IntegerType(), None),
+ ([None, date(2024, 2, 4), date(2024, 1, 2)], DateType(), date(2024, 2,
4)),
+ ([date(2024, 1, 2), None, date(2024, 2, 4)], DateType(), date(2024, 2,
4)),
+ ([None, None, None], DateType(), None),
+ ],
+)
+def test_stats_aggregator_update_max(vals: List[Any], primitive_type:
PrimitiveType, expected_result: Any) -> None:
+ stats = StatsAggregator(primitive_type,
_primitive_to_physical(primitive_type))
+
+ for val in vals:
+ stats.update_max(val)
+
+ assert stats.current_max == expected_result