This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 8789fc2  Bug Fix: Handle null vals from later row_groups in 
StatsAggregator (#379)
8789fc2 is described below

commit 8789fc287755c8e617513f7d136b8cccc4800508
Author: Sung Yun <[email protected]>
AuthorDate: Tue Feb 6 10:26:56 2024 -0500

    Bug Fix: Handle null vals from later row_groups in StatsAggregator (#379)
    
    * handle null vals correctly in StatsAggregator
    
    * return
    
    * Apply suggestions from code review
    
    Co-authored-by: Fokko Driesprong <[email protected]>
    
    ---------
    
    Co-authored-by: Fokko Driesprong <[email protected]>
---
 pyiceberg/io/pyarrow.py  | 14 ++++++++++----
 tests/io/test_pyarrow.py | 44 ++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 54 insertions(+), 4 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 99c1af5..9726451 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -1331,11 +1331,17 @@ class StatsAggregator:
     def serialize(self, value: Any) -> bytes:
         return to_bytes(self.primitive_type, value)
 
-    def update_min(self, val: Any) -> None:
-        self.current_min = val if self.current_min is None else min(val, 
self.current_min)
+    def update_min(self, val: Optional[Any]) -> None:
+        if self.current_min is None:
+            self.current_min = val
+        elif val is not None:
+            self.current_min = min(val, self.current_min)
 
-    def update_max(self, val: Any) -> None:
-        self.current_max = val if self.current_max is None else max(val, 
self.current_max)
+    def update_max(self, val: Optional[Any]) -> None:
+        if self.current_max is None:
+            self.current_max = val
+        elif val is not None:
+            self.current_max = max(val, self.current_max)
 
     def min_as_bytes(self) -> Optional[bytes]:
         if self.current_min is None:
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index 0628ed4..745de1a 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -18,6 +18,7 @@
 
 import os
 import tempfile
+from datetime import date
 from typing import Any, List, Optional
 from unittest.mock import MagicMock, patch
 from uuid import uuid4
@@ -59,7 +60,9 @@ from pyiceberg.io.pyarrow import (
     ICEBERG_SCHEMA,
     PyArrowFile,
     PyArrowFileIO,
+    StatsAggregator,
     _ConvertToArrowSchema,
+    _primitive_to_physical,
     _read_deletes,
     expression_to_pyarrow,
     project_table,
@@ -84,6 +87,7 @@ from pyiceberg.types import (
     LongType,
     MapType,
     NestedField,
+    PrimitiveType,
     StringType,
     StructType,
     TimestampType,
@@ -1666,3 +1670,43 @@ def test_parse_location() -> None:
 def test_make_compatible_name() -> None:
     assert make_compatible_name("label/abc") == "label_x2Fabc"
     assert make_compatible_name("label?abc") == "label_x3Fabc"
+
+
[email protected](
+    "vals, primitive_type, expected_result",
+    [
+        ([None, 2, 1], IntegerType(), 1),
+        ([1, None, 2], IntegerType(), 1),
+        ([None, None, None], IntegerType(), None),
+        ([None, date(2024, 2, 4), date(2024, 1, 2)], DateType(), date(2024, 1, 
2)),
+        ([date(2024, 1, 2), None, date(2024, 2, 4)], DateType(), date(2024, 1, 
2)),
+        ([None, None, None], DateType(), None),
+    ],
+)
+def test_stats_aggregator_update_min(vals: List[Any], primitive_type: 
PrimitiveType, expected_result: Any) -> None:
+    stats = StatsAggregator(primitive_type, 
_primitive_to_physical(primitive_type))
+
+    for val in vals:
+        stats.update_min(val)
+
+    assert stats.current_min == expected_result
+
+
[email protected](
+    "vals, primitive_type, expected_result",
+    [
+        ([None, 2, 1], IntegerType(), 2),
+        ([1, None, 2], IntegerType(), 2),
+        ([None, None, None], IntegerType(), None),
+        ([None, date(2024, 2, 4), date(2024, 1, 2)], DateType(), date(2024, 2, 
4)),
+        ([date(2024, 1, 2), None, date(2024, 2, 4)], DateType(), date(2024, 2, 
4)),
+        ([None, None, None], DateType(), None),
+    ],
+)
+def test_stats_aggregator_update_max(vals: List[Any], primitive_type: 
PrimitiveType, expected_result: Any) -> None:
+    stats = StatsAggregator(primitive_type, 
_primitive_to_physical(primitive_type))
+
+    for val in vals:
+        stats.update_max(val)
+
+    assert stats.current_max == expected_result

Reply via email to