This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 54a08f3  Replace `type()` calls with `isinstance()` (#188)
54a08f3 is described below

commit 54a08f31892de8a811175e3f633153cc26e070fb
Author: Jayce Slesar <[email protected]>
AuthorDate: Thu Dec 7 10:03:57 2023 -0500

    Replace `type()` calls with `isinstance()` (#188)
    
    * WIP
    
    * fix bad ifs that snuck in
    
    * missing this one
    
    * couple more cases
    
    * almost all of them
    
    * undo change
    
    * add a comment and explain why we are doing the explicit type(call)
    
    * standardize naming
    
    * lint
---
 pyiceberg/expressions/__init__.py |  4 +-
 pyiceberg/expressions/visitors.py |  2 +-
 pyiceberg/table/__init__.py       | 10 +++--
 pyiceberg/table/sorting.py        |  2 +-
 pyiceberg/transforms.py           | 91 ++++++++++++++++-----------------------
 5 files changed, 47 insertions(+), 62 deletions(-)

diff --git a/pyiceberg/expressions/__init__.py 
b/pyiceberg/expressions/__init__.py
index 7fb42f7..cb46a70 100644
--- a/pyiceberg/expressions/__init__.py
+++ b/pyiceberg/expressions/__init__.py
@@ -459,7 +459,7 @@ class NotNull(UnaryPredicate):
 class BoundIsNaN(BoundUnaryPredicate[L]):
     def __new__(cls, term: BoundTerm[L]) -> BooleanExpression:  # type: ignore 
 # pylint: disable=W0221
         bound_type = term.ref().field.field_type
-        if type(bound_type) in {FloatType, DoubleType}:
+        if isinstance(bound_type, (FloatType, DoubleType)):
             return super().__new__(cls)
         return AlwaysFalse()
 
@@ -475,7 +475,7 @@ class BoundIsNaN(BoundUnaryPredicate[L]):
 class BoundNotNaN(BoundUnaryPredicate[L]):
     def __new__(cls, term: BoundTerm[L]) -> BooleanExpression:  # type: ignore 
 # pylint: disable=W0221
         bound_type = term.ref().field.field_type
-        if type(bound_type) in {FloatType, DoubleType}:
+        if isinstance(bound_type, (FloatType, DoubleType)):
             return super().__new__(cls)
         return AlwaysTrue()
 
diff --git a/pyiceberg/expressions/visitors.py 
b/pyiceberg/expressions/visitors.py
index a4f311f..a13c1c5 100644
--- a/pyiceberg/expressions/visitors.py
+++ b/pyiceberg/expressions/visitors.py
@@ -620,7 +620,7 @@ class 
_ManifestEvalVisitor(BoundBooleanExpressionVisitor[bool]):
         # lowerBound is null if all partition values are null
         all_null = self.partition_fields[pos].contains_null is True and 
self.partition_fields[pos].lower_bound is None
 
-        if all_null and type(term.ref().field.field_type) in {DoubleType, 
FloatType}:
+        if all_null and isinstance(term.ref().field.field_type, (DoubleType, 
FloatType)):
             # floating point types may include NaN values, which we check 
separately.
             # In case bounds don't include NaN value, contains_nan needs to be 
checked against.
             all_null = self.partition_fields[pos].contains_nan is False
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 436266f..4768706 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -150,8 +150,9 @@ class Transaction:
             Transaction object with the new updates appended.
         """
         for new_update in new_updates:
+            # explicitly get type of new_update as new_update is an 
instantiated class
             type_new_update = type(new_update)
-            if any(type(update) == type_new_update for update in 
self._updates):
+            if any(isinstance(update, type_new_update) for update in 
self._updates):
                 raise ValueError(f"Updates in a single commit need to be 
unique, duplicate: {type_new_update}")
         self._updates = self._updates + new_updates
         return self
@@ -168,9 +169,10 @@ class Transaction:
         Returns:
             Transaction object with the new requirements appended.
         """
-        for requirement in new_requirements:
-            type_new_requirement = type(requirement)
-            if any(type(requirement) == type_new_requirement for update in 
self._requirements):
+        for new_requirement in new_requirements:
+            # explicitly get type of new_update as requirement is an 
instantiated class
+            type_new_requirement = type(new_requirement)
+            if any(isinstance(requirement, type_new_requirement) for 
requirement in self._requirements):
                 raise ValueError(f"Requirements in a single commit need to be 
unique, duplicate: {type_new_requirement}")
         self._requirements = self._requirements + new_requirements
         return self
diff --git a/pyiceberg/table/sorting.py b/pyiceberg/table/sorting.py
index 3a97e39..f970d68 100644
--- a/pyiceberg/table/sorting.py
+++ b/pyiceberg/table/sorting.py
@@ -114,7 +114,7 @@ class SortField(IcebergBaseModel):
 
     def __str__(self) -> str:
         """Return the string representation of the SortField class."""
-        if type(self.transform) == IdentityTransform:
+        if isinstance(self.transform, IdentityTransform):
             # In the case of an identity transform, we can omit the transform
             return f"{self.source_id} {self.direction} {self.null_order}"
         else:
diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py
index 6b373b7..b9afae9 100644
--- a/pyiceberg/transforms.py
+++ b/pyiceberg/transforms.py
@@ -220,38 +220,40 @@ class BucketTransform(Transform[S, int]):
             return None
 
     def can_transform(self, source: IcebergType) -> bool:
-        return type(source) in {
-            IntegerType,
-            DateType,
-            LongType,
-            TimeType,
-            TimestampType,
-            TimestamptzType,
-            DecimalType,
-            StringType,
-            FixedType,
-            BinaryType,
-            UUIDType,
-        }
+        return isinstance(
+            source,
+            (
+                IntegerType,
+                DateType,
+                LongType,
+                TimeType,
+                TimestampType,
+                TimestamptzType,
+                DecimalType,
+                StringType,
+                FixedType,
+                BinaryType,
+                UUIDType,
+            ),
+        )
 
     def transform(self, source: IcebergType, bucket: bool = True) -> 
Callable[[Optional[Any]], Optional[int]]:
-        source_type = type(source)
-        if source_type in {IntegerType, LongType, DateType, TimeType, 
TimestampType, TimestamptzType}:
+        if isinstance(source, (IntegerType, LongType, DateType, TimeType, 
TimestampType, TimestamptzType)):
 
             def hash_func(v: Any) -> int:
                 return mmh3.hash(struct.pack("<q", v))
 
-        elif source_type == DecimalType:
+        elif isinstance(source, DecimalType):
 
             def hash_func(v: Any) -> int:
                 return mmh3.hash(decimal_to_bytes(v))
 
-        elif source_type in {StringType, FixedType, BinaryType}:
+        elif isinstance(source, (StringType, FixedType, BinaryType)):
 
             def hash_func(v: Any) -> int:
                 return mmh3.hash(v)
 
-        elif source_type == UUIDType:
+        elif isinstance(source, UUIDType):
 
             def hash_func(v: Any) -> int:
                 if isinstance(v, UUID):
@@ -330,13 +332,12 @@ class YearTransform(TimeTransform[S]):
     root: LiteralType["year"] = Field(default="year")  # noqa: F821
 
     def transform(self, source: IcebergType) -> Callable[[Optional[S]], 
Optional[int]]:
-        source_type = type(source)
-        if source_type == DateType:
+        if isinstance(source, DateType):
 
             def year_func(v: Any) -> int:
                 return datetime.days_to_years(v)
 
-        elif source_type in {TimestampType, TimestamptzType}:
+        elif isinstance(source, (TimestampType, TimestamptzType)):
 
             def year_func(v: Any) -> int:
                 return datetime.micros_to_years(v)
@@ -347,11 +348,7 @@ class YearTransform(TimeTransform[S]):
         return lambda v: year_func(v) if v is not None else None
 
     def can_transform(self, source: IcebergType) -> bool:
-        return type(source) in {
-            DateType,
-            TimestampType,
-            TimestamptzType,
-        }
+        return isinstance(source, (DateType, TimestampType, TimestamptzType))
 
     @property
     def granularity(self) -> TimeResolution:
@@ -377,13 +374,12 @@ class MonthTransform(TimeTransform[S]):
     root: LiteralType["month"] = Field(default="month")  # noqa: F821
 
     def transform(self, source: IcebergType) -> Callable[[Optional[S]], 
Optional[int]]:
-        source_type = type(source)
-        if source_type == DateType:
+        if isinstance(source, DateType):
 
             def month_func(v: Any) -> int:
                 return datetime.days_to_months(v)
 
-        elif source_type in {TimestampType, TimestamptzType}:
+        elif isinstance(source, (TimestampType, TimestamptzType)):
 
             def month_func(v: Any) -> int:
                 return datetime.micros_to_months(v)
@@ -394,11 +390,7 @@ class MonthTransform(TimeTransform[S]):
         return lambda v: month_func(v) if v else None
 
     def can_transform(self, source: IcebergType) -> bool:
-        return type(source) in {
-            DateType,
-            TimestampType,
-            TimestamptzType,
-        }
+        return isinstance(source, (DateType, TimestampType, TimestamptzType))
 
     @property
     def granularity(self) -> TimeResolution:
@@ -424,13 +416,12 @@ class DayTransform(TimeTransform[S]):
     root: LiteralType["day"] = Field(default="day")  # noqa: F821
 
     def transform(self, source: IcebergType) -> Callable[[Optional[S]], 
Optional[int]]:
-        source_type = type(source)
-        if source_type == DateType:
+        if isinstance(source, DateType):
 
             def day_func(v: Any) -> int:
                 return v
 
-        elif source_type in {TimestampType, TimestamptzType}:
+        elif isinstance(source, (TimestampType, TimestamptzType)):
 
             def day_func(v: Any) -> int:
                 return datetime.micros_to_days(v)
@@ -441,11 +432,7 @@ class DayTransform(TimeTransform[S]):
         return lambda v: day_func(v) if v else None
 
     def can_transform(self, source: IcebergType) -> bool:
-        return type(source) in {
-            DateType,
-            TimestampType,
-            TimestamptzType,
-        }
+        return isinstance(source, (DateType, TimestampType, TimestamptzType))
 
     def result_type(self, source: IcebergType) -> IcebergType:
         return DateType()
@@ -474,7 +461,7 @@ class HourTransform(TimeTransform[S]):
     root: LiteralType["hour"] = Field(default="hour")  # noqa: F821
 
     def transform(self, source: IcebergType) -> Callable[[Optional[S]], 
Optional[int]]:
-        if type(source) in {TimestampType, TimestamptzType}:
+        if isinstance(source, (TimestampType, TimestamptzType)):
 
             def hour_func(v: Any) -> int:
                 return datetime.micros_to_hours(v)
@@ -485,10 +472,7 @@ class HourTransform(TimeTransform[S]):
         return lambda v: hour_func(v) if v else None
 
     def can_transform(self, source: IcebergType) -> bool:
-        return type(source) in {
-            TimestampType,
-            TimestamptzType,
-        }
+        return isinstance(source, (TimestampType, TimestamptzType))
 
     @property
     def granularity(self) -> TimeResolution:
@@ -580,7 +564,7 @@ class TruncateTransform(Transform[S, S]):
         self._width = width
 
     def can_transform(self, source: IcebergType) -> bool:
-        return type(source) in {IntegerType, LongType, StringType, BinaryType, 
DecimalType}
+        return isinstance(source, (IntegerType, LongType, StringType, 
BinaryType, DecimalType))
 
     def result_type(self, source: IcebergType) -> IcebergType:
         return source
@@ -616,18 +600,17 @@ class TruncateTransform(Transform[S, S]):
         return self._width
 
     def transform(self, source: IcebergType) -> Callable[[Optional[S]], 
Optional[S]]:
-        source_type = type(source)
-        if source_type in {IntegerType, LongType}:
+        if isinstance(source, (IntegerType, LongType)):
 
             def truncate_func(v: Any) -> Any:
                 return v - v % self._width
 
-        elif source_type in {StringType, BinaryType}:
+        elif isinstance(source, (StringType, BinaryType)):
 
             def truncate_func(v: Any) -> Any:
                 return v[0 : min(self._width, len(v))]
 
-        elif source_type == DecimalType:
+        elif isinstance(source, DecimalType):
 
             def truncate_func(v: Any) -> Any:
                 return truncate_decimal(v, self._width)
@@ -788,9 +771,9 @@ def _truncate_array(
 ) -> Optional[UnboundPredicate[Any]]:
     boundary = pred.literal
 
-    if type(pred) in {BoundLessThan, BoundLessThanOrEqual}:
+    if isinstance(pred, (BoundLessThan, BoundLessThanOrEqual)):
         return LessThanOrEqual(Reference(name), _transform_literal(transform, 
boundary))
-    elif type(pred) in {BoundGreaterThan, BoundGreaterThanOrEqual}:
+    elif isinstance(pred, (BoundGreaterThan, BoundGreaterThanOrEqual)):
         return GreaterThanOrEqual(Reference(name), 
_transform_literal(transform, boundary))
     if isinstance(pred, BoundEqualTo):
         return EqualTo(Reference(name), _transform_literal(transform, 
boundary))

Reply via email to