gemini-code-assist[bot] commented on code in PR #52:
URL: https://github.com/apache/tvm-ffi/pull/52#discussion_r2374391699
##########
python/tvm_ffi/dataclasses/field.py:
##########
@@ -72,6 +75,11 @@ def field(
A zero-argument callable that produces the default. This matches the
semantics of :func:`dataclasses.field` and is useful for mutable
defaults such as ``list`` or ``dict``.
+ init : bool, default True
+ If ``True`` the field is included in the generated ``__init__`` and its
+ value is forwarded to the C++ ``__ffi_init__``. When ``False`` the
+ field is omitted from the initializer and, if a default is provided, it
+ is assigned on the Python side after construction.
Review Comment:

The documentation for the `init` parameter is inconsistent with the
implementation. The docs state that for `init=False`, a field with a default is
"assigned on the Python side after construction". However, the implementation
in `_utils.py` passes this value to the C++ `__ffi_init__` constructor.
The current tests rely on this implementation behavior. To avoid confusion
for users, please update this docstring to accurately describe that the default
value for an `init=False` field is computed and passed to the underlying C++
constructor.
##########
python/tvm_ffi/dataclasses/_utils.py:
##########
@@ -111,96 +110,79 @@ def fill_dataclass_field(type_cls: type, type_field:
TypeField) -> None:
type_field.dataclass_field = rhs
-def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
# noqa: PLR0915
+def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
"""Generate an ``__init__`` that forwards to the FFI constructor.
The generated initializer has a proper Python signature built from the
reflected field list, supporting default values and ``__post_init__``.
"""
-
- class DefaultFactory(NamedTuple):
- """Wrapper that marks a parameter as having a default factory."""
-
- fn: Callable[[], Any]
-
+ # Step 0. Collect all fields from the type hierarchy
fields: list[TypeField] = []
cur_type_info: TypeInfo | None = type_info
while cur_type_info is not None:
fields.extend(reversed(cur_type_info.fields))
cur_type_info = cur_type_info.parent_type_info
fields.reverse()
-
- annotations: dict[str, Any] = {"return": None}
- # Step 1. Split the parameters into two groups to ensure that
- # those without defaults appear first in the signature.
- params_without_defaults: list[inspect.Parameter] = []
- params_with_defaults: list[inspect.Parameter] = []
- ordering = [0] * len(fields)
- for i, field in enumerate(fields):
- assert field.name is not None
- name: str = field.name
- annotations[name] = Any # NOTE: We might be able to handle
annotations better
- assert field.dataclass_field is not None
- default_factory = field.dataclass_field.default_factory
- if default_factory is MISSING:
- ordering[i] = len(params_without_defaults)
- params_without_defaults.append(
- inspect.Parameter(name=name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
- )
- else:
- ordering[i] = -len(params_with_defaults) - 1
- params_with_defaults.append(
- inspect.Parameter(
- name=name,
- kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
- default=DefaultFactory(fn=default_factory),
- )
- )
- for i, order in enumerate(ordering):
- if order < 0:
- ordering[i] = len(params_without_defaults) - order - 1
- # Step 2. Create the signature object
- sig = inspect.Signature(parameters=[*params_without_defaults,
*params_with_defaults])
- signature_str = (
- f"{type_cls.__module__}.{type_cls.__qualname__}.__init__("
- + ", ".join(p.name for p in sig.parameters.values())
- + ")"
- )
-
- # Step 3. Create the `binding` method that reorders parameters
- def touch_arg(x: Any) -> Any:
- return x.fn() if isinstance(x, DefaultFactory) else x
-
- def bind_args(*args: Any, **kwargs: Any) -> tuple[Any, ...]:
- bound = sig.bind(*args, **kwargs)
- bound.apply_defaults()
- args = bound.args
- args = tuple(touch_arg(args[i]) for i in ordering)
- return args
-
+ # sanity check
for type_method in type_info.methods:
if type_method.name == "__ffi_init__":
break
else:
raise ValueError(f"Cannot find constructor method:
`{type_info.type_key}.__ffi_init__`")
-
- def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
- e = None
- try:
- args = bind_args(*args, **kwargs)
- del kwargs
- self.__ffi_init__(*args)
- except Exception as _e:
- e = TypeError(f"Error in `{signature_str}`:
{_e}").with_traceback(_e.__traceback__)
- if e is not None:
- raise e
- try:
- fn_post_init = self.__post_init__ # type: ignore[attr-defined]
- except AttributeError:
- pass
+ # Step 1. Split args into sections and register default factories
+ args_no_defaults: list[str] = []
+ args_with_defaults: list[str] = []
+ fields_with_defaults: list[tuple[str, bool]] = []
+ ffi_arg_order: list[str] = []
+ module_globals = globals()
+ for field in fields:
+ assert field.name is not None
+ assert field.dataclass_field is not None
+ dataclass_field = field.dataclass_field
+ has_default_factory = (default_factory :=
dataclass_field.default_factory) is not MISSING
+ if dataclass_field.init:
+ ffi_arg_order.append(field.name)
+ if has_default_factory:
+ args_with_defaults.append(field.name)
+ fields_with_defaults.append((field.name, True))
+ module_globals[f"_default_factory_{field.name}"] =
default_factory
+ else:
+ args_no_defaults.append(field.name)
+ elif has_default_factory:
+ ffi_arg_order.append(field.name)
+ fields_with_defaults.append((field.name, False))
+ module_globals[f"_default_factory_{field.name}"] = default_factory
Review Comment:

Storing default factories in `module_globals` (which is `_utils.py`'s global
namespace) is a critical issue. It's not thread-safe and will cause name
collisions if two different `@c_class` decorated classes have a field with the
same name. For example, if `class A` has `field_x = field(default=1)` and
`class B` has `field_x = field(default=2)`, the factory for `field_x` will be
overwritten. The `__init__` for `A` might end up using the default for `B`.
A safer approach is to create a local dictionary for the factories within
`method_init` and pass it to `exec`'s `globals` argument. For example:
```python
factories = {}
# ... populate factories ...
exec_globals = {**globals(), **factories}
exec(source, exec_globals, namespace)
```
Also, to be completely safe from collisions within the generated source,
factory names should be unique, e.g., by including the class and field name, or
using `id(field)`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]