junrushao commented on code in PR #8:
URL: https://github.com/apache/tvm-ffi/pull/8#discussion_r2364380545


##########
python/tvm_ffi/dataclasses/c_class.py:
##########
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A decorator that builds Python dataclasses from C++ via TVM FFI.
+
+The ``c_class`` decorator reflects fields and methods from the underlying
+FFI type (identified by ``type_key``) and attaches them to the decorated
+Python class.
+"""
+
+from collections.abc import Callable
+from dataclasses import InitVar
+from typing import ClassVar, TypeVar, get_origin, get_type_hints
+
+from ..core import TypeField, TypeInfo
+from . import _utils, field
+
+try:
+    from typing import dataclass_transform
+except ImportError:
+    from typing_extensions import dataclass_transform
+
+
+_InputClsType = TypeVar("_InputClsType")
+
+
+@dataclass_transform(field_specifiers=(field.field, field.Field))
+def c_class(
+    type_key: str, init: bool = True
+) -> Callable[[type[_InputClsType]], type[_InputClsType]]:
+    """Create a decorator that binds a Python dataclass to an FFI type from 
C++.
+
+    Parameters
+    ----------
+    type_key : str
+        Type key registered in the TVM FFI registry.
+
+    init : bool, default True
+        If True and the class does not implement ``__init__``, generate an
+        ``__init__`` that forwards to the FFI constructor.
+
+    Returns
+    -------
+    Callable[[type], type]
+        A class decorator that returns the finalized proxy class.
+
+    """
+
+    def decorator(super_type_cls: type[_InputClsType]) -> type[_InputClsType]:
+        nonlocal init
+        init = init and "__init__" not in super_type_cls.__dict__
+        # Step 1. Retrieve `type_info` from registry
+        type_info: TypeInfo = _utils._lookup_type_info_from_type_key(type_key)
+        assert type_info.parent_type_info is None, f"Already registered type: 
{type_key}"
+        type_info.parent_type_info = 
_utils.get_parent_type_info(super_type_cls)
+        # Step 2. Reflect all the fields of the type
+        type_info.fields = _inspect_c_class_fields(super_type_cls, type_info)
+        for type_field in type_info.fields:
+            _utils.fill_dataclass_field(super_type_cls, type_field)
+        # Step 3. Create the proxy class with the fields as properties
+        fn_init = _utils.method_init(super_type_cls, type_info) if init else 
None
+        type_cls: type[_InputClsType] = _utils.type_info_to_cls(
+            type_info=type_info,
+            cls=super_type_cls,
+            methods={"__init__": fn_init},
+        )
+        type_info.type_cls = type_cls
+        return type_cls
+
+    return decorator
+
+
+def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> 
list[TypeField]:
+    type_hints_resolved = get_type_hints(type_cls, include_extras=True)
+    type_hints_py = {
+        name: type_hints_resolved[name]
+        for name in getattr(type_cls, "__annotations__", {}).keys()
+        if get_origin(type_hints_resolved[name])
+        not in [  # ignore non-field annotations
+            ClassVar,
+            InitVar,
+        ]
+    }
+    del type_hints_resolved
+
+    type_fields_cxx: dict[str, TypeField] = {f.name: f for f in 
type_info.fields}
+    type_fields: list[TypeField] = []
+    for field_name, _field_ty_py in type_hints_py.items():
+        if field_name.startswith("__tvm_ffi"):  # TVM's private fields - skip
+            continue
+        type_field: TypeField = type_fields_cxx.pop(field_name, None)
+        if type_field is None:
+            raise ValueError(
+                f"Extraneous field `{type_cls}.{field_name}`. Defined in 
Python but not in C"

Review Comment:
   good catch!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to