gemini-code-assist[bot] commented on code in PR #488:
URL: https://github.com/apache/tvm-ffi/pull/488#discussion_r2867769818


##########
python/tvm_ffi/registry.py:
##########
@@ -531,6 +530,108 @@ def _replace_unsupported(self: Any, **kwargs: Any) -> Any:
     )
 
 
+def _install_dataclass_dunders(
+    cls: type,
+    *,
+    init: bool,
+    repr: bool,
+    eq: bool,
+    order: bool,
+    unsafe_hash: bool,
+) -> None:
+    """Install structural dunder methods on *cls*.
+
+    Each dunder delegates to the corresponding C++ recursive structural
+    operation.  If the user already defined a dunder in the class body
+    (i.e. it exists in ``cls.__dict__``), it is left untouched.
+
+    Parameters
+    ----------
+    cls
+        The class to install dunders on.
+    init
+        If True, install ``__init__`` from C++ reflection metadata.
+    repr
+        If True, install :func:`~tvm_ffi.core.object_repr` as ``__repr__``.
+    eq
+        If True, install ``__eq__`` and ``__ne__``.
+    order
+        If True, install ``__lt__``, ``__le__``, ``__gt__``, ``__ge__``.
+    unsafe_hash
+        If True, install ``__hash__``.
+
+    """
+    _install_init(cls, enabled=init)
+
+    if repr and "__repr__" not in cls.__dict__:
+        from .core import object_repr  # noqa: PLC0415
+
+        cls.__repr__ = object_repr  # type: ignore[attr-defined]
+
+    from . import _ffi_api  # noqa: PLC0415
+
+    dunders: dict[str, Any] = {}
+
+    if eq:
+        recursive_eq = _ffi_api.RecursiveEq
+
+        def __eq__(self: Any, other: Any) -> bool:
+            if not isinstance(other, type(self)) and not isinstance(self, 
type(other)):
+                return NotImplemented
+            return recursive_eq(self, other)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This type compatibility check is repeated in `__ne__` and all the ordering 
methods (`__lt__`, `__le__`, `__gt__`, `__ge__`). To avoid duplication and 
improve maintainability, you could extract this logic into a helper function 
defined within the `_install_dataclass_dunders` scope. This would make the code 
more DRY (Don't Repeat Yourself).



##########
tests/python/test_dataclass_c_class.py:
##########
@@ -0,0 +1,365 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for the c_class decorator (register_object + structural dunders)."""
+
+from __future__ import annotations
+
+import inspect
+
+import pytest
+from tvm_ffi.core import Object
+from tvm_ffi.dataclasses import c_class
+from tvm_ffi.testing import (
+    _TestCxxClassBase,
+    _TestCxxClassDerived,
+    _TestCxxClassDerivedDerived,
+    _TestCxxInitSubset,
+    _TestCxxKwOnly,
+)
+
+# ---------------------------------------------------------------------------
+# 1. Custom __init__ preservation
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_custom_init() -> None:
+    """c_class preserves user-defined __init__."""
+    obj = _TestCxxClassBase(v_i64=10, v_i32=20)
+    assert obj.v_i64 == 11  # +1 from custom __init__
+    assert obj.v_i32 == 22  # +2 from custom __init__
+
+
+# ---------------------------------------------------------------------------
+# 2. Auto-generated __init__ with defaults
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_auto_init_defaults() -> None:
+    """Derived classes use auto-generated __init__ with C++ defaults."""
+    obj = _TestCxxClassDerived(v_i64=1, v_i32=2, v_f64=3.0)
+    assert obj.v_i64 == 1
+    assert obj.v_i32 == 2
+    assert obj.v_f64 == 3.0
+    assert obj.v_f32 == 8.0  # default from C++
+
+
+def test_c_class_auto_init_all_explicit() -> None:
+    """Auto-generated __init__ accepts all fields explicitly."""
+    obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.0, v_f32=9.0)
+    assert obj.v_i64 == 123
+    assert obj.v_i32 == 456
+    assert obj.v_f64 == 4.0
+    assert obj.v_f32 == 9.0
+
+
+# ---------------------------------------------------------------------------
+# 3. Structural equality (__eq__)
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_eq() -> None:
+    """c_class installs __eq__ using RecursiveEq."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    assert a == b
+    assert a is not b  # different objects
+    c = _TestCxxClassDerived(1, 2, 3.0, 5.0)
+    assert a != c
+
+
+def test_c_class_eq_reflexive() -> None:
+    """Equality is reflexive: an object equals itself."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    b = a  # alias, same object
+    assert a == b
+
+
+def test_c_class_eq_symmetric() -> None:
+    """Equality is symmetric: a == b implies b == a."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    assert a == b
+    assert b == a
+
+
+# ---------------------------------------------------------------------------
+# 4. Structural hash (__hash__)
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_hash() -> None:
+    """c_class installs __hash__ using RecursiveHash."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    assert hash(a) == hash(b)
+
+
+def test_c_class_hash_as_dict_key() -> None:
+    """Equal objects can be used interchangeably as dict keys."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    d = {a: "value"}
+    assert d[b] == "value"
+
+
+# ---------------------------------------------------------------------------
+# 5. Ordering (__lt__, __le__, __gt__, __ge__)
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_ordering() -> None:
+    """c_class installs ordering operators."""
+    small = _TestCxxClassDerived(0, 0, 0.0, 0.0)
+    big = _TestCxxClassDerived(100, 100, 100.0, 100.0)
+    assert small < big  # ty: ignore[unsupported-operator]
+    assert small <= big  # ty: ignore[unsupported-operator]
+    assert big > small  # ty: ignore[unsupported-operator]
+    assert big >= small  # ty: ignore[unsupported-operator]
+    assert not (big < small)  # ty: ignore[unsupported-operator]
+    assert not (small > big)  # ty: ignore[unsupported-operator]
+
+
+def test_c_class_ordering_reflexive() -> None:
+    """<= and >= are reflexive."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    b = a  # alias, same object
+    assert a <= b  # ty: ignore[unsupported-operator]
+    assert a >= b  # ty: ignore[unsupported-operator]
+
+
+def test_c_class_ordering_antisymmetric() -> None:
+    """If a < b then not b < a."""
+    a = _TestCxxClassDerived(0, 0, 0.0, 0.0)
+    b = _TestCxxClassDerived(100, 100, 100.0, 100.0)
+    if a < b:  # ty: ignore[unsupported-operator]
+        assert not (b < a)  # ty: ignore[unsupported-operator]
+    else:
+        assert not (a < b)  # ty: ignore[unsupported-operator]
+
+
+# ---------------------------------------------------------------------------
+# 6. Equality with different types returns NotImplemented
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_eq_different_type() -> None:
+    """__eq__ returns NotImplemented for unrelated types."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    assert a != "hello"
+    assert a != 42
+    assert a != 3.14
+    assert a != None

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Comparing with `None` using `!=` is not idiomatic Python; `is not` should be 
used instead. However, for this test, the goal is to check that `__eq__` 
returns `NotImplemented` for unrelated types. The other assertions (`!= 
"hello"`, `!= 42`) already cover this well, making the `None` check redundant 
and stylistically poor.



##########
python/tvm_ffi/registry.py:
##########
@@ -531,6 +530,108 @@ def _replace_unsupported(self: Any, **kwargs: Any) -> Any:
     )
 
 
+def _install_dataclass_dunders(
+    cls: type,
+    *,
+    init: bool,
+    repr: bool,
+    eq: bool,
+    order: bool,
+    unsafe_hash: bool,
+) -> None:
+    """Install structural dunder methods on *cls*.
+
+    Each dunder delegates to the corresponding C++ recursive structural
+    operation.  If the user already defined a dunder in the class body
+    (i.e. it exists in ``cls.__dict__``), it is left untouched.
+
+    Parameters
+    ----------
+    cls
+        The class to install dunders on.
+    init
+        If True, install ``__init__`` from C++ reflection metadata.
+    repr
+        If True, install :func:`~tvm_ffi.core.object_repr` as ``__repr__``.
+    eq
+        If True, install ``__eq__`` and ``__ne__``.
+    order
+        If True, install ``__lt__``, ``__le__``, ``__gt__``, ``__ge__``.
+    unsafe_hash
+        If True, install ``__hash__``.
+
+    """
+    _install_init(cls, enabled=init)
+
+    if repr and "__repr__" not in cls.__dict__:
+        from .core import object_repr  # noqa: PLC0415
+
+        cls.__repr__ = object_repr  # type: ignore[attr-defined]
+
+    from . import _ffi_api  # noqa: PLC0415
+
+    dunders: dict[str, Any] = {}
+
+    if eq:
+        recursive_eq = _ffi_api.RecursiveEq
+
+        def __eq__(self: Any, other: Any) -> bool:
+            if not isinstance(other, type(self)) and not isinstance(self, 
type(other)):
+                return NotImplemented
+            return recursive_eq(self, other)
+
+        def __ne__(self: Any, other: Any) -> bool:
+            if not isinstance(other, type(self)) and not isinstance(self, 
type(other)):
+                return NotImplemented
+            return not recursive_eq(self, other)
+
+        dunders["__eq__"] = __eq__
+        dunders["__ne__"] = __ne__
+
+    if unsafe_hash:
+        recursive_hash = _ffi_api.RecursiveHash
+
+        def __hash__(self: Any) -> int:
+            return recursive_hash(self) & 0xFFFFFFFFFFFFFFFF

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The bitwise AND with `0xFFFFFFFFFFFFFFFF` is unnecessary and potentially 
confusing. The `recursive_hash` function returns a C++ `int64_t`, which is 
converted to a standard Python `int`. This is already a valid return value for 
`__hash__`. The masking alters the hash value for negative numbers (e.g., `-1` 
becomes a large positive integer). Simply returning the result of 
`recursive_hash(self)` would be clearer and sufficient.
   
   ```suggestion
               return recursive_hash(self)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to