This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 54a08f3 Replace `type()` calls with `isinstance()` (#188)
54a08f3 is described below
commit 54a08f31892de8a811175e3f633153cc26e070fb
Author: Jayce Slesar <[email protected]>
AuthorDate: Thu Dec 7 10:03:57 2023 -0500
Replace `type()` calls with `isinstance()` (#188)
* WIP
* fix bad ifs that snuck in
* missing this one
* couple more cases
* almost all of them
* undo change
* add a comment and explain why we are doing the explicit type(call)
* standardize naming
* lint
---
pyiceberg/expressions/__init__.py | 4 +-
pyiceberg/expressions/visitors.py | 2 +-
pyiceberg/table/__init__.py | 10 +++--
pyiceberg/table/sorting.py | 2 +-
pyiceberg/transforms.py | 91 ++++++++++++++++-----------------------
5 files changed, 47 insertions(+), 62 deletions(-)
diff --git a/pyiceberg/expressions/__init__.py
b/pyiceberg/expressions/__init__.py
index 7fb42f7..cb46a70 100644
--- a/pyiceberg/expressions/__init__.py
+++ b/pyiceberg/expressions/__init__.py
@@ -459,7 +459,7 @@ class NotNull(UnaryPredicate):
class BoundIsNaN(BoundUnaryPredicate[L]):
def __new__(cls, term: BoundTerm[L]) -> BooleanExpression: # type: ignore
# pylint: disable=W0221
bound_type = term.ref().field.field_type
- if type(bound_type) in {FloatType, DoubleType}:
+ if isinstance(bound_type, (FloatType, DoubleType)):
return super().__new__(cls)
return AlwaysFalse()
@@ -475,7 +475,7 @@ class BoundIsNaN(BoundUnaryPredicate[L]):
class BoundNotNaN(BoundUnaryPredicate[L]):
def __new__(cls, term: BoundTerm[L]) -> BooleanExpression: # type: ignore
# pylint: disable=W0221
bound_type = term.ref().field.field_type
- if type(bound_type) in {FloatType, DoubleType}:
+ if isinstance(bound_type, (FloatType, DoubleType)):
return super().__new__(cls)
return AlwaysTrue()
diff --git a/pyiceberg/expressions/visitors.py
b/pyiceberg/expressions/visitors.py
index a4f311f..a13c1c5 100644
--- a/pyiceberg/expressions/visitors.py
+++ b/pyiceberg/expressions/visitors.py
@@ -620,7 +620,7 @@ class
_ManifestEvalVisitor(BoundBooleanExpressionVisitor[bool]):
# lowerBound is null if all partition values are null
all_null = self.partition_fields[pos].contains_null is True and
self.partition_fields[pos].lower_bound is None
- if all_null and type(term.ref().field.field_type) in {DoubleType,
FloatType}:
+ if all_null and isinstance(term.ref().field.field_type, (DoubleType,
FloatType)):
# floating point types may include NaN values, which we check
separately.
# In case bounds don't include NaN value, contains_nan needs to be
checked against.
all_null = self.partition_fields[pos].contains_nan is False
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 436266f..4768706 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -150,8 +150,9 @@ class Transaction:
Transaction object with the new updates appended.
"""
for new_update in new_updates:
+ # explicitly get type of new_update as new_update is an
instantiated class
type_new_update = type(new_update)
- if any(type(update) == type_new_update for update in
self._updates):
+ if any(isinstance(update, type_new_update) for update in
self._updates):
raise ValueError(f"Updates in a single commit need to be
unique, duplicate: {type_new_update}")
self._updates = self._updates + new_updates
return self
@@ -168,9 +169,10 @@ class Transaction:
Returns:
Transaction object with the new requirements appended.
"""
- for requirement in new_requirements:
- type_new_requirement = type(requirement)
- if any(type(requirement) == type_new_requirement for update in
self._requirements):
+ for new_requirement in new_requirements:
+ # explicitly get type of new_update as requirement is an
instantiated class
+ type_new_requirement = type(new_requirement)
+ if any(isinstance(requirement, type_new_requirement) for
requirement in self._requirements):
raise ValueError(f"Requirements in a single commit need to be
unique, duplicate: {type_new_requirement}")
self._requirements = self._requirements + new_requirements
return self
diff --git a/pyiceberg/table/sorting.py b/pyiceberg/table/sorting.py
index 3a97e39..f970d68 100644
--- a/pyiceberg/table/sorting.py
+++ b/pyiceberg/table/sorting.py
@@ -114,7 +114,7 @@ class SortField(IcebergBaseModel):
def __str__(self) -> str:
"""Return the string representation of the SortField class."""
- if type(self.transform) == IdentityTransform:
+ if isinstance(self.transform, IdentityTransform):
# In the case of an identity transform, we can omit the transform
return f"{self.source_id} {self.direction} {self.null_order}"
else:
diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py
index 6b373b7..b9afae9 100644
--- a/pyiceberg/transforms.py
+++ b/pyiceberg/transforms.py
@@ -220,38 +220,40 @@ class BucketTransform(Transform[S, int]):
return None
def can_transform(self, source: IcebergType) -> bool:
- return type(source) in {
- IntegerType,
- DateType,
- LongType,
- TimeType,
- TimestampType,
- TimestamptzType,
- DecimalType,
- StringType,
- FixedType,
- BinaryType,
- UUIDType,
- }
+ return isinstance(
+ source,
+ (
+ IntegerType,
+ DateType,
+ LongType,
+ TimeType,
+ TimestampType,
+ TimestamptzType,
+ DecimalType,
+ StringType,
+ FixedType,
+ BinaryType,
+ UUIDType,
+ ),
+ )
def transform(self, source: IcebergType, bucket: bool = True) ->
Callable[[Optional[Any]], Optional[int]]:
- source_type = type(source)
- if source_type in {IntegerType, LongType, DateType, TimeType,
TimestampType, TimestamptzType}:
+ if isinstance(source, (IntegerType, LongType, DateType, TimeType,
TimestampType, TimestamptzType)):
def hash_func(v: Any) -> int:
return mmh3.hash(struct.pack("<q", v))
- elif source_type == DecimalType:
+ elif isinstance(source, DecimalType):
def hash_func(v: Any) -> int:
return mmh3.hash(decimal_to_bytes(v))
- elif source_type in {StringType, FixedType, BinaryType}:
+ elif isinstance(source, (StringType, FixedType, BinaryType)):
def hash_func(v: Any) -> int:
return mmh3.hash(v)
- elif source_type == UUIDType:
+ elif isinstance(source, UUIDType):
def hash_func(v: Any) -> int:
if isinstance(v, UUID):
@@ -330,13 +332,12 @@ class YearTransform(TimeTransform[S]):
root: LiteralType["year"] = Field(default="year") # noqa: F821
def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[int]]:
- source_type = type(source)
- if source_type == DateType:
+ if isinstance(source, DateType):
def year_func(v: Any) -> int:
return datetime.days_to_years(v)
- elif source_type in {TimestampType, TimestamptzType}:
+ elif isinstance(source, (TimestampType, TimestamptzType)):
def year_func(v: Any) -> int:
return datetime.micros_to_years(v)
@@ -347,11 +348,7 @@ class YearTransform(TimeTransform[S]):
return lambda v: year_func(v) if v is not None else None
def can_transform(self, source: IcebergType) -> bool:
- return type(source) in {
- DateType,
- TimestampType,
- TimestamptzType,
- }
+ return isinstance(source, (DateType, TimestampType, TimestamptzType))
@property
def granularity(self) -> TimeResolution:
@@ -377,13 +374,12 @@ class MonthTransform(TimeTransform[S]):
root: LiteralType["month"] = Field(default="month") # noqa: F821
def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[int]]:
- source_type = type(source)
- if source_type == DateType:
+ if isinstance(source, DateType):
def month_func(v: Any) -> int:
return datetime.days_to_months(v)
- elif source_type in {TimestampType, TimestamptzType}:
+ elif isinstance(source, (TimestampType, TimestamptzType)):
def month_func(v: Any) -> int:
return datetime.micros_to_months(v)
@@ -394,11 +390,7 @@ class MonthTransform(TimeTransform[S]):
return lambda v: month_func(v) if v else None
def can_transform(self, source: IcebergType) -> bool:
- return type(source) in {
- DateType,
- TimestampType,
- TimestamptzType,
- }
+ return isinstance(source, (DateType, TimestampType, TimestamptzType))
@property
def granularity(self) -> TimeResolution:
@@ -424,13 +416,12 @@ class DayTransform(TimeTransform[S]):
root: LiteralType["day"] = Field(default="day") # noqa: F821
def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[int]]:
- source_type = type(source)
- if source_type == DateType:
+ if isinstance(source, DateType):
def day_func(v: Any) -> int:
return v
- elif source_type in {TimestampType, TimestamptzType}:
+ elif isinstance(source, (TimestampType, TimestamptzType)):
def day_func(v: Any) -> int:
return datetime.micros_to_days(v)
@@ -441,11 +432,7 @@ class DayTransform(TimeTransform[S]):
return lambda v: day_func(v) if v else None
def can_transform(self, source: IcebergType) -> bool:
- return type(source) in {
- DateType,
- TimestampType,
- TimestamptzType,
- }
+ return isinstance(source, (DateType, TimestampType, TimestamptzType))
def result_type(self, source: IcebergType) -> IcebergType:
return DateType()
@@ -474,7 +461,7 @@ class HourTransform(TimeTransform[S]):
root: LiteralType["hour"] = Field(default="hour") # noqa: F821
def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[int]]:
- if type(source) in {TimestampType, TimestamptzType}:
+ if isinstance(source, (TimestampType, TimestamptzType)):
def hour_func(v: Any) -> int:
return datetime.micros_to_hours(v)
@@ -485,10 +472,7 @@ class HourTransform(TimeTransform[S]):
return lambda v: hour_func(v) if v else None
def can_transform(self, source: IcebergType) -> bool:
- return type(source) in {
- TimestampType,
- TimestamptzType,
- }
+ return isinstance(source, (TimestampType, TimestamptzType))
@property
def granularity(self) -> TimeResolution:
@@ -580,7 +564,7 @@ class TruncateTransform(Transform[S, S]):
self._width = width
def can_transform(self, source: IcebergType) -> bool:
- return type(source) in {IntegerType, LongType, StringType, BinaryType,
DecimalType}
+ return isinstance(source, (IntegerType, LongType, StringType,
BinaryType, DecimalType))
def result_type(self, source: IcebergType) -> IcebergType:
return source
@@ -616,18 +600,17 @@ class TruncateTransform(Transform[S, S]):
return self._width
def transform(self, source: IcebergType) -> Callable[[Optional[S]],
Optional[S]]:
- source_type = type(source)
- if source_type in {IntegerType, LongType}:
+ if isinstance(source, (IntegerType, LongType)):
def truncate_func(v: Any) -> Any:
return v - v % self._width
- elif source_type in {StringType, BinaryType}:
+ elif isinstance(source, (StringType, BinaryType)):
def truncate_func(v: Any) -> Any:
return v[0 : min(self._width, len(v))]
- elif source_type == DecimalType:
+ elif isinstance(source, DecimalType):
def truncate_func(v: Any) -> Any:
return truncate_decimal(v, self._width)
@@ -788,9 +771,9 @@ def _truncate_array(
) -> Optional[UnboundPredicate[Any]]:
boundary = pred.literal
- if type(pred) in {BoundLessThan, BoundLessThanOrEqual}:
+ if isinstance(pred, (BoundLessThan, BoundLessThanOrEqual)):
return LessThanOrEqual(Reference(name), _transform_literal(transform,
boundary))
- elif type(pred) in {BoundGreaterThan, BoundGreaterThanOrEqual}:
+ elif isinstance(pred, (BoundGreaterThan, BoundGreaterThanOrEqual)):
return GreaterThanOrEqual(Reference(name),
_transform_literal(transform, boundary))
if isinstance(pred, BoundEqualTo):
return EqualTo(Reference(name), _transform_literal(transform,
boundary))