This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 3a5bf5e  feat: add kw_only support for dataclass init generation (#384)
3a5bf5e is described below

commit 3a5bf5e68ad1b4108045ef6b336a13efcd2037d9
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sun Jan 18 15:13:10 2026 +0800

    feat: add kw_only support for dataclass init generation (#384)
    
    ## Related Issue
    
    #356
    
    ## Why
    
    Python's standard dataclasses support kw_only parameter to make fields
    keyword-only in __init__. This feature was missing from @c_class
    decorator.
    
    ## How
    
    - Add `KW_ONLY` sentinel class for marking keyword-only fields
    - Add `kw_only` parameter to field() function and @c_class decorator
    - Update `method_init()` to generate proper signature
    - Add tests
---
 python/tvm_ffi/dataclasses/__init__.py   |  4 +--
 python/tvm_ffi/dataclasses/_utils.py     | 58 ++++++++++++++++++++++----------
 python/tvm_ffi/dataclasses/c_class.py    | 47 +++++++++++++++++++++-----
 python/tvm_ffi/dataclasses/field.py      | 35 +++++++++++++++++--
 python/tvm_ffi/testing/__init__.py       |  1 +
 python/tvm_ffi/testing/testing.py        |  8 +++++
 src/ffi/testing/testing.cc               | 20 +++++++++++
 tests/python/test_dataclasses_c_class.py | 55 ++++++++++++++++++++++++++++++
 8 files changed, 199 insertions(+), 29 deletions(-)

diff --git a/python/tvm_ffi/dataclasses/__init__.py 
b/python/tvm_ffi/dataclasses/__init__.py
index 3185413..bfb4404 100644
--- a/python/tvm_ffi/dataclasses/__init__.py
+++ b/python/tvm_ffi/dataclasses/__init__.py
@@ -19,6 +19,6 @@
 from dataclasses import MISSING
 
 from .c_class import c_class
-from .field import Field, field
+from .field import KW_ONLY, Field, field
 
-__all__ = ["MISSING", "Field", "c_class", "field"]
+__all__ = ["KW_ONLY", "MISSING", "Field", "c_class", "field"]
diff --git a/python/tvm_ffi/dataclasses/_utils.py 
b/python/tvm_ffi/dataclasses/_utils.py
index 5ed4e96..7c0afb4 100644
--- a/python/tvm_ffi/dataclasses/_utils.py
+++ b/python/tvm_ffi/dataclasses/_utils.py
@@ -79,7 +79,13 @@ def type_info_to_cls(
     return cast(Type[_InputClsType], new_cls)
 
 
-def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None:
+def fill_dataclass_field(
+    type_cls: type,
+    type_field: TypeField,
+    *,
+    class_kw_only: bool = False,
+    kw_only_from_sentinel: bool = False,
+) -> None:
     from .field import Field, field  # noqa: PLC0415
 
     field_name = type_field.name
@@ -94,6 +100,14 @@ def fill_dataclass_field(type_cls: type, type_field: 
TypeField) -> None:
         raise ValueError(f"Cannot recognize field: {type_field.name}: {rhs}")
     assert isinstance(rhs, Field)
     rhs.name = type_field.name
+
+    # Resolve kw_only: field-level > KW_ONLY sentinel > class-level
+    if rhs.kw_only is MISSING:
+        if kw_only_from_sentinel:
+            rhs.kw_only = True
+        else:
+            rhs.kw_only = class_kw_only
+
     type_field.dataclass_field = rhs
 
 
@@ -148,47 +162,56 @@ def method_repr(type_cls: type, type_info: TypeInfo) -> 
Callable[..., str]:
     return __repr__
 
 
-def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
+def method_init(_type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
     """Generate an ``__init__`` that forwards to the FFI constructor.
 
     The generated initializer has a proper Python signature built from the
-    reflected field list, supporting default values and ``__post_init__``.
+    reflected field list, supporting default values, keyword-only args, and 
``__post_init__``.
     """
     # Step 0. Collect all fields from the type hierarchy
     fields = _get_all_fields(type_info)
     # sanity check
-    for type_method in type_info.methods:
-        if type_method.name == "__ffi_init__":
-            break
-    else:
+    if not any(m.name == "__ffi_init__" for m in type_info.methods):
         raise ValueError(f"Cannot find constructor method: 
`{type_info.type_key}.__ffi_init__`")
     # Step 1. Split args into sections and register default factories
-    args_no_defaults: list[str] = []
-    args_with_defaults: list[str] = []
+    pos_no_defaults: list[str] = []
+    pos_with_defaults: list[str] = []
+    kw_no_defaults: list[str] = []
+    kw_with_defaults: list[str] = []
     fields_with_defaults: list[tuple[str, bool]] = []
     ffi_arg_order: list[str] = []
-    exec_globals = {"MISSING": MISSING}
+    exec_globals: dict[str, Any] = {"MISSING": MISSING}
+
     for field in fields:
         assert field.name is not None
         assert field.dataclass_field is not None
         dataclass_field = field.dataclass_field
-        has_default_factory = (default_factory := 
dataclass_field.default_factory) is not MISSING
+        has_default = (default_factory := dataclass_field.default_factory) is 
not MISSING
+        is_kw_only = dataclass_field.kw_only is True
+
         if dataclass_field.init:
             ffi_arg_order.append(field.name)
-            if has_default_factory:
-                args_with_defaults.append(field.name)
+            if has_default:
+                (kw_with_defaults if is_kw_only else 
pos_with_defaults).append(field.name)
                 fields_with_defaults.append((field.name, True))
                 exec_globals[f"_default_factory_{field.name}"] = 
default_factory
             else:
-                args_no_defaults.append(field.name)
-        elif has_default_factory:
+                (kw_no_defaults if is_kw_only else 
pos_no_defaults).append(field.name)
+        elif has_default:
             ffi_arg_order.append(field.name)
             fields_with_defaults.append((field.name, False))
             exec_globals[f"_default_factory_{field.name}"] = default_factory
 
+    # Step 2. Build signature
     args: list[str] = ["self"]
-    args.extend(args_no_defaults)
-    args.extend(f"{name}=MISSING" for name in args_with_defaults)
+    args.extend(pos_no_defaults)
+    args.extend(f"{name}=MISSING" for name in pos_with_defaults)
+    if kw_no_defaults or kw_with_defaults:
+        args.append("*")
+        args.extend(kw_no_defaults)
+        args.extend(f"{name}=MISSING" for name in kw_with_defaults)
+
+    # Step 3. Build body
     body_lines: list[str] = []
     for field_name, is_init in fields_with_defaults:
         if is_init:
@@ -208,6 +231,7 @@ def method_init(type_cls: type, type_info: TypeInfo) -> 
Callable[..., None]:
             "    fn_post_init()",
         ]
     )
+
     source_lines = [f"def __init__({', '.join(args)}):"]
     source_lines.extend(f"    {line}" for line in body_lines)
     source_lines.append("    ...")
diff --git a/python/tvm_ffi/dataclasses/c_class.py 
b/python/tvm_ffi/dataclasses/c_class.py
index 8171b1b..8dd5e5a 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -34,14 +34,14 @@ from typing_extensions import dataclass_transform
 
 from ..core import TypeField, TypeInfo, 
_lookup_or_register_type_info_from_type_key, _set_type_cls
 from . import _utils
-from .field import field
+from .field import KW_ONLY, field
 
 _InputClsType = TypeVar("_InputClsType")
 
 
-@dataclass_transform(field_specifiers=(field,))
+@dataclass_transform(field_specifiers=(field,), kw_only_default=False)
 def c_class(
-    type_key: str, init: bool = True, repr: bool = True
+    type_key: str, init: bool = True, kw_only: bool = False, repr: bool = True
 ) -> Callable[[Type[_InputClsType]], Type[_InputClsType]]:  # noqa: UP006
     """(Experimental) Create a dataclass-like proxy for a C++ class registered 
with TVM FFI.
 
@@ -71,6 +71,12 @@ def c_class(
         signature.  The generated initializer calls the C++ ``__init__``
         function registered with ``ObjectDef`` and invokes ``__post_init__`` if
         it exists on the Python class.
+
+    kw_only
+        If ``True``, all fields become keyword-only parameters in the generated
+        ``__init__``. Individual fields can override this by setting
+        ``kw_only=False`` in :func:`field`. Additionally, a ``KW_ONLY`` 
sentinel
+        annotation can be used to mark all subsequent fields as keyword-only.
     repr
         If ``True`` and the Python class does not define ``__repr__``, a
         representation method is auto-generated that includes all fields with
@@ -129,9 +135,15 @@ def c_class(
         type_info: TypeInfo = 
_lookup_or_register_type_info_from_type_key(type_key)
         assert type_info.parent_type_info is not None
         # Step 2. Reflect all the fields of the type
-        type_info.fields = _inspect_c_class_fields(super_type_cls, type_info)
-        for type_field in type_info.fields:
-            _utils.fill_dataclass_field(super_type_cls, type_field)
+        type_info.fields, kw_only_start_idx = 
_inspect_c_class_fields(super_type_cls, type_info)
+        for idx, type_field in enumerate(type_info.fields):
+            kw_only_from_sentinel = kw_only_start_idx is not None and idx >= 
kw_only_start_idx
+            _utils.fill_dataclass_field(
+                super_type_cls,
+                type_field,
+                class_kw_only=kw_only,
+                kw_only_from_sentinel=kw_only_from_sentinel,
+            )
         # Step 3. Create the proxy class with the fields as properties
         fn_init = _utils.method_init(super_type_cls, type_info) if init else 
None
         fn_repr = _utils.method_repr(super_type_cls, type_info) if repr else 
None
@@ -146,7 +158,9 @@ def c_class(
     return decorator
 
 
-def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> 
list[TypeField]:
+def _inspect_c_class_fields(
+    type_cls: type, type_info: TypeInfo
+) -> tuple[list[TypeField], int | None]:
     if sys.version_info >= (3, 9):
         type_hints_resolved = get_type_hints(type_cls, include_extras=True)
     else:
@@ -159,7 +173,24 @@ def _inspect_c_class_fields(type_cls: type, type_info: 
TypeInfo) -> list[TypeFie
             ClassVar,
             InitVar,
         ]
+        and type_hints_resolved[name] is not KW_ONLY
     }
+
+    # Detect KW_ONLY sentinel position
+    kw_only_start_idx: int | None = None
+    field_count = 0
+    for name in getattr(type_cls, "__annotations__", {}).keys():
+        resolved_type = type_hints_resolved.get(name)
+        if resolved_type is None:
+            continue
+        if get_origin(resolved_type) in [ClassVar, InitVar]:
+            continue
+        if resolved_type is KW_ONLY:
+            if kw_only_start_idx is not None:
+                raise ValueError(f"KW_ONLY may only be used once per class: 
{type_cls}")
+            kw_only_start_idx = field_count
+            continue
+        field_count += 1
     del type_hints_resolved
 
     type_fields_cxx: dict[str, TypeField] = {f.name: f for f in 
type_info.fields}
@@ -178,4 +209,4 @@ def _inspect_c_class_fields(type_cls: type, type_info: 
TypeInfo) -> list[TypeFie
         raise ValueError(
             f"Missing fields in `{type_cls}`: {extra_fields}. Defined in C++ 
but not in Python"
         )
-    return type_fields
+    return type_fields, kw_only_start_idx
diff --git a/python/tvm_ffi/dataclasses/field.py 
b/python/tvm_ffi/dataclasses/field.py
index d0e27b1..a395e50 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -21,7 +21,17 @@ from __future__ import annotations
 from dataclasses import _MISSING_TYPE, MISSING
 from typing import Any, Callable, TypeVar, cast
 
+try:
+    from dataclasses import KW_ONLY  # type: ignore[attr-defined]
+except ImportError:
+    # Python < 3.10: define our own KW_ONLY sentinel
+    class _KW_ONLY_Sentinel:
+        __slots__ = ()
+
+    KW_ONLY = _KW_ONLY_Sentinel()
+
 _FieldValue = TypeVar("_FieldValue")
+_KW_ONLY_TYPE = type(KW_ONLY)
 
 
 class Field:
@@ -37,7 +47,7 @@ class Field:
     way the decorator understands.
     """
 
-    __slots__ = ("default_factory", "init", "name", "repr")
+    __slots__ = ("default_factory", "init", "kw_only", "name", "repr")
 
     def __init__(
         self,
@@ -46,12 +56,14 @@ class Field:
         default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,
         init: bool = True,
         repr: bool = True,
+        kw_only: bool | _MISSING_TYPE = MISSING,
     ) -> None:
         """Do not call directly; use :func:`field` instead."""
         self.name = name
         self.default_factory = default_factory
         self.init = init
         self.repr = repr
+        self.kw_only = kw_only
 
 
 def field(
@@ -60,6 +72,7 @@ def field(
     default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,  # 
type: ignore[assignment]
     init: bool = True,
     repr: bool = True,
+    kw_only: bool | _MISSING_TYPE = MISSING,  # type: ignore[assignment]
 ) -> _FieldValue:
     """(Experimental) Declare a dataclass-style field on a :func:`c_class` 
proxy.
 
@@ -84,6 +97,10 @@ def field(
     repr
         If ``True`` the field is included in the generated ``__repr__``.
         If ``False`` the field is omitted from the ``__repr__`` output.
+    kw_only
+        If ``True``, the field is a keyword-only argument in ``__init__``.
+        If ``MISSING``, inherits from the class-level ``kw_only`` setting or
+        from a preceding ``KW_ONLY`` sentinel annotation.
 
     Note
     ----
@@ -124,6 +141,18 @@ def field(
         obj = PyBase(v_i64=4)
         obj.v_i32  # -> 16
 
+    Use ``kw_only=True`` to make a field keyword-only:
+
+    .. code-block:: python
+
+        @c_class("testing.TestCxxClassBase")
+        class PyBase:
+            v_i64: int
+            v_i32: int = field(kw_only=True)
+
+
+        obj = PyBase(4, v_i32=8)  # v_i32 must be keyword
+
     """
     if default is not MISSING and default_factory is not MISSING:
         raise ValueError("Cannot specify both `default` and `default_factory`")
@@ -131,9 +160,11 @@ def field(
         raise TypeError("`init` must be a bool")
     if not isinstance(repr, bool):
         raise TypeError("`repr` must be a bool")
+    if kw_only is not MISSING and not isinstance(kw_only, bool):
+        raise TypeError(f"`kw_only` must be a bool, got 
{type(kw_only).__name__!r}")
     if default is not MISSING:
         default_factory = _make_default_factory(default)
-    ret = Field(default_factory=default_factory, init=init, repr=repr)
+    ret = Field(default_factory=default_factory, init=init, repr=repr, 
kw_only=kw_only)
     return cast(_FieldValue, ret)
 
 
diff --git a/python/tvm_ffi/testing/__init__.py 
b/python/tvm_ffi/testing/__init__.py
index cd35736..af22210 100644
--- a/python/tvm_ffi/testing/__init__.py
+++ b/python/tvm_ffi/testing/__init__.py
@@ -25,6 +25,7 @@ from .testing import (
     _TestCxxClassDerived,
     _TestCxxClassDerivedDerived,
     _TestCxxInitSubset,
+    _TestCxxKwOnly,
     add_one,
     create_object,
     make_unregistered_object,
diff --git a/python/tvm_ffi/testing/testing.py 
b/python/tvm_ffi/testing/testing.py
index b905b5b..0ffeb49 100644
--- a/python/tvm_ffi/testing/testing.py
+++ b/python/tvm_ffi/testing/testing.py
@@ -178,3 +178,11 @@ class _TestCxxInitSubset:
     required_field: int
     optional_field: int = field(init=False)
     note: str = field(default_factory=lambda: "py-default", init=False)
+
+
+@c_class("testing.TestCxxKwOnly", kw_only=True)
+class _TestCxxKwOnly:
+    x: int
+    y: int
+    z: int
+    w: int = 100
diff --git a/src/ffi/testing/testing.cc b/src/ffi/testing/testing.cc
index 7ee6ffd..0df7f1e 100644
--- a/src/ffi/testing/testing.cc
+++ b/src/ffi/testing/testing.cc
@@ -147,6 +147,19 @@ class TestCxxInitSubsetObj : public Object {
   TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxInitSubset", 
TestCxxInitSubsetObj, Object);
 };
 
+class TestCxxKwOnly : public Object {
+ public:
+  int64_t x;
+  int64_t y;
+  int64_t z;
+  int64_t w;
+
+  TestCxxKwOnly(int64_t x, int64_t y, int64_t z, int64_t w) : x(x), y(y), 
z(z), w(w) {}
+
+  static constexpr bool _type_mutable = true;
+  TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxKwOnly", TestCxxKwOnly, Object);
+};
+
 class TestUnregisteredBaseObject : public Object {
  public:
   int64_t v1;
@@ -229,6 +242,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
       .def_rw("optional_field", &TestCxxInitSubsetObj::optional_field)
       .def_rw("note", &TestCxxInitSubsetObj::note);
 
+  refl::ObjectDef<TestCxxKwOnly>()
+      .def(refl::init<int64_t, int64_t, int64_t, int64_t>())
+      .def_rw("x", &TestCxxKwOnly::x)
+      .def_rw("y", &TestCxxKwOnly::y)
+      .def_rw("z", &TestCxxKwOnly::z)
+      .def_rw("w", &TestCxxKwOnly::w);
+
   refl::ObjectDef<TestUnregisteredBaseObject>()
       .def(refl::init<int64_t>(), "Constructor of TestUnregisteredBaseObject")
       .def_ro("v1", &TestUnregisteredBaseObject::v1)
diff --git a/tests/python/test_dataclasses_c_class.py 
b/tests/python/test_dataclasses_c_class.py
index 5361cb6..3a757d0 100644
--- a/tests/python/test_dataclasses_c_class.py
+++ b/tests/python/test_dataclasses_c_class.py
@@ -15,12 +15,17 @@
 # specific language governing permissions and limitations
 # under the License.
 import inspect
+from dataclasses import MISSING
 
+import pytest
+from tvm_ffi.dataclasses import KW_ONLY, field
+from tvm_ffi.dataclasses.field import _KW_ONLY_TYPE, Field
 from tvm_ffi.testing import (
     _TestCxxClassBase,
     _TestCxxClassDerived,
     _TestCxxClassDerivedDerived,
     _TestCxxInitSubset,
+    _TestCxxKwOnly,
 )
 
 
@@ -129,3 +134,53 @@ def test_cxx_class_repr_derived_derived() -> None:
         assert "v_i32=456" in repr_str
         assert "v_str='hello'" in repr_str or 'v_str="hello"' in repr_str
         assert "v_bool=True" in repr_str
+
+
+def test_kw_only_class_level_signature() -> None:
+    sig = inspect.signature(_TestCxxKwOnly.__init__)
+    params = sig.parameters
+    assert params["x"].kind == inspect.Parameter.KEYWORD_ONLY
+    assert params["y"].kind == inspect.Parameter.KEYWORD_ONLY
+    assert params["z"].kind == inspect.Parameter.KEYWORD_ONLY
+    assert params["w"].kind == inspect.Parameter.KEYWORD_ONLY
+
+
+def test_kw_only_class_level_call() -> None:
+    obj = _TestCxxKwOnly(x=1, y=2, z=3, w=4)
+    assert obj.x == 1
+    assert obj.y == 2
+    assert obj.z == 3
+    assert obj.w == 4
+
+
+def test_kw_only_class_level_with_default() -> None:
+    obj = _TestCxxKwOnly(x=1, y=2, z=3)
+    assert obj.w == 100
+
+
+def test_kw_only_class_level_rejects_positional() -> None:
+    with pytest.raises(TypeError, match="positional"):
+        _TestCxxKwOnly(1, 2, 3, 4)  # type: ignore[misc]
+
+
+def test_field_kw_only_parameter() -> None:
+    f1: Field = field(kw_only=True)
+    assert isinstance(f1, Field)
+    assert f1.kw_only is True
+
+    f2: Field = field(kw_only=False)
+    assert f2.kw_only is False
+
+    f3: Field = field()
+    assert f3.kw_only is MISSING
+
+
+def test_field_kw_only_with_default() -> None:
+    f = field(default=42, kw_only=True)
+    assert isinstance(f, Field)
+    assert f.kw_only is True
+    assert f.default_factory() == 42
+
+
+def test_kw_only_sentinel_exists() -> None:
+    assert isinstance(KW_ONLY, _KW_ONLY_TYPE)

Reply via email to