This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a3ee59253e [FFI][REFACTOR] Phase out getattr based attribute handling 
(#18189)
a3ee59253e is described below

commit a3ee59253e714d8b1cf95d6c8e0be94391d3d22c
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Aug 6 15:40:44 2025 -0400

    [FFI][REFACTOR] Phase out getattr based attribute handling (#18189)
    
    [REFACTOR] Phase out getattr based attribute handling
    
    This PR phases out getattar based attribute handling as they are slower
    and introduces extra code path.
    
    This does mean that if an Object is not explicitly registered
    in python side, we will no longer be able to access the field by name.
    Likely this is also desirable as we would like to enable faster use that
    updates the python end and do not rely on these behavior.
---
 docs/reference/api/python/relax/op.rst             |   1 +
 docs/reference/api/python/tir/transform.rst        |   1 +
 ffi/src/ffi/extra/serialization.cc                 |  13 +-
 include/tvm/relax/attrs/op.h                       |  30 ++--
 include/tvm/script/printer/doc.h                   |   5 +-
 python/tvm/arith/iter_affine_map.py                |   6 +
 python/tvm/contrib/msc/core/ir/graph.py            |   4 +-
 python/tvm/ffi/__init__.py                         |   1 +
 python/tvm/ffi/cython/function.pxi                 |   2 +
 python/tvm/ffi/cython/object.pxi                   |  72 ++++------
 python/tvm/ffi/serialization.py                    |  67 +++++++++
 python/tvm/ir/attrs.py                             |  14 +-
 python/tvm/relax/dpl/pattern.py                    |   2 +-
 python/tvm/relax/expr.py                           |   5 +
 python/tvm/relax/op/_op_gradient.py                |   4 +-
 python/tvm/relax/op/manipulate.py                  |   1 +
 python/tvm/relax/op/op_attrs.py                    | 155 +++++++++++++++++++++
 python/tvm/runtime/_ffi_node_api.py                |   8 --
 python/tvm/runtime/object.py                       |  11 --
 python/tvm/script/printer/doc.py                   |  35 +----
 python/tvm/te/tensor.py                            |  20 ---
 python/tvm/testing/__init__.py                     |   1 +
 python/tvm/testing/{__init__.py => attrs.py}       |  39 ++----
 python/tvm/tir/transform/transform.py              |  42 ++++++
 src/node/reflection.cc                             |  74 +---------
 src/relax/ir/emit_te.h                             |   2 +-
 src/tir/transforms/hoist_expression.cc             |   4 +-
 tests/python/ffi/test_container.py                 |   7 +
 tests/python/ir/test_ir_attrs.py                   |   4 +-
 .../test_transform_legalize_ops_manipulate.py      |   7 +-
 tests/python/runtime/test_runtime_rpc.py           |   1 -
 31 files changed, 389 insertions(+), 249 deletions(-)

diff --git a/docs/reference/api/python/relax/op.rst 
b/docs/reference/api/python/relax/op.rst
index 21f638442a..922af768f5 100644
--- a/docs/reference/api/python/relax/op.rst
+++ b/docs/reference/api/python/relax/op.rst
@@ -70,3 +70,4 @@ tvm.relax.op.op_attrs
 *********************
 .. automodule:: tvm.relax.op.op_attrs
    :members:
+   :exclude-members: Attrs
diff --git a/docs/reference/api/python/tir/transform.rst 
b/docs/reference/api/python/tir/transform.rst
index 8ce641b6d3..29f1bcbbf0 100644
--- a/docs/reference/api/python/tir/transform.rst
+++ b/docs/reference/api/python/tir/transform.rst
@@ -20,4 +20,5 @@ tvm.tir.transform
 -----------------
 .. automodule:: tvm.tir.transform
    :members:
+   :exclude-members: Attrs
    :imported-members:
diff --git a/ffi/src/ffi/extra/serialization.cc 
b/ffi/src/ffi/extra/serialization.cc
index 8d9df03361..ea9a96b696 100644
--- a/ffi/src/ffi/extra/serialization.cc
+++ b/ffi/src/ffi/extra/serialization.cc
@@ -408,9 +408,20 @@ class ObjectGraphDeserializer {
 
 Any FromJSONGraph(const json::Value& value) { return 
ObjectGraphDeserializer::Deserialize(value); }
 
+// string version of the api
+Any FromJSONGraphString(const String& value) { return 
FromJSONGraph(json::Parse(value)); }
+
+String ToJSONGraphString(const Any& value, const Any& metadata) {
+  return json::Stringify(ToJSONGraph(value, metadata));
+}
+
 TVM_FFI_STATIC_INIT_BLOCK({
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("ffi.ToJSONGraph", 
ToJSONGraph).def("ffi.FromJSONGraph", FromJSONGraph);
+  refl::GlobalDef()
+      .def("ffi.ToJSONGraph", ToJSONGraph)
+      .def("ffi.ToJSONGraphString", ToJSONGraphString)
+      .def("ffi.FromJSONGraph", FromJSONGraph)
+      .def("ffi.FromJSONGraphString", FromJSONGraphString);
   refl::EnsureTypeAttrColumn("__data_to_json__");
   refl::EnsureTypeAttrColumn("__data_from_json__");
 });
diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h
index cce78e9fd6..337f8dc4cb 100644
--- a/include/tvm/relax/attrs/op.h
+++ b/include/tvm/relax/attrs/op.h
@@ -51,16 +51,19 @@ struct CallTIRWithGradAttrs : public 
AttrsNodeReflAdapter<CallTIRWithGradAttrs>
 
 /*! \brief Attributes used in call_tir_inplace */
 struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter<CallTIRInplaceAttrs> {
+  /*!
+   * \brief Indices that describe which input corresponds to which output.
+   *
+   * If the `i`th member has the value `k` >= 0, then that means that input 
`k` should be used to
+   * store the `i`th output. If an element has the value -1, that means a new 
tensor should be
+   * allocated for that output.
+   */
   Array<Integer> inplace_indices;
 
   static void RegisterReflection() {
     namespace refl = tvm::ffi::reflection;
-    refl::ObjectDef<CallTIRInplaceAttrs>().def_ro(
-        "inplace_indices", &CallTIRInplaceAttrs::inplace_indices,
-        "Indices that describe which input corresponds to which output. If the 
`i`th member "
-        "has the value `k` >= 0, then that means that input `k` should be used 
to store the "
-        "`i`th output. If an element has the value -1, that means a new tensor 
should be "
-        "allocated for that output.");
+    refl::ObjectDef<CallTIRInplaceAttrs>().def_ro("inplace_indices",
+                                                  
&CallTIRInplaceAttrs::inplace_indices);
   }
 
   static constexpr const char* _type_key = "relax.attrs.CallTIRInplaceAttrs";
@@ -69,16 +72,19 @@ struct CallTIRInplaceAttrs : public 
AttrsNodeReflAdapter<CallTIRInplaceAttrs> {
 
 /*! \brief Attributes used in call_inplace_packed */
 struct CallInplacePackedAttrs : public 
AttrsNodeReflAdapter<CallInplacePackedAttrs> {
+  /*!
+   * \brief Indices that describe which input corresponds to which output.
+   *
+   * If the `i`th member has the value `k` >= 0, then that means that input 
`k` should be used to
+   * store the `i`th output. If an element has the value -1, that means the 
output will be newly
+   * allocated.
+   */
   Array<Integer> inplace_indices;
 
   static void RegisterReflection() {
     namespace refl = tvm::ffi::reflection;
-    refl::ObjectDef<CallInplacePackedAttrs>().def_ro(
-        "inplace_indices", &CallInplacePackedAttrs::inplace_indices,
-        "Indices that describe which input corresponds to which output. If the 
`i`th member "
-        "has the value `k` >= 0, then that means that input `k` should be used 
to store the "
-        "`i`th output. If an element has the value -1, that means the output 
will be newly "
-        "allocated.");
+    refl::ObjectDef<CallInplacePackedAttrs>().def_ro("inplace_indices",
+                                                     
&CallInplacePackedAttrs::inplace_indices);
   }
 
   static constexpr const char* _type_key = 
"relax.attrs.CallInplacePackedAttrs";
diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index b19bcab4c3..de3fb0bbad 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -65,10 +65,11 @@ class DocNode : public Object {
 
   static void RegisterReflection() {
     namespace refl = tvm::ffi::reflection;
-    refl::ObjectDef<DocNode>().def_ro("source_paths", &DocNode::source_paths);
+    refl::ObjectDef<DocNode>().def_rw("source_paths", &DocNode::source_paths);
   }
 
   static constexpr const char* _type_key = "script.printer.Doc";
+  static constexpr bool _type_mutable = true;
 
   TVM_DECLARE_BASE_OBJECT_INFO(DocNode, Object);
 
@@ -174,7 +175,7 @@ class StmtDocNode : public DocNode {
 
   static void RegisterReflection() {
     namespace refl = tvm::ffi::reflection;
-    refl::ObjectDef<StmtDocNode>().def_ro("comment", &StmtDocNode::comment);
+    refl::ObjectDef<StmtDocNode>().def_rw("comment", &StmtDocNode::comment);
   }
 
   static constexpr const char* _type_key = "script.printer.StmtDoc";
diff --git a/python/tvm/arith/iter_affine_map.py 
b/python/tvm/arith/iter_affine_map.py
index dbb4087f32..328bb052b8 100644
--- a/python/tvm/arith/iter_affine_map.py
+++ b/python/tvm/arith/iter_affine_map.py
@@ -22,6 +22,7 @@ from tvm.ir import PrimExpr
 from . import _ffi_api
 
 
[email protected]_object("arith.IterMapExpr")
 class IterMapExpr(PrimExpr):
     """Base class of all IterMap expressions."""
 
@@ -89,6 +90,11 @@ class IterSumExpr(IterMapExpr):
         self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base)
 
 
[email protected]_object("arith.IterMapResult")
+class IterMapResult(Object):
+    """Result of iter map detection."""
+
+
 class IterMapLevel(IntEnum):
     """Possible kinds of iter mapping check level."""
 
diff --git a/python/tvm/contrib/msc/core/ir/graph.py 
b/python/tvm/contrib/msc/core/ir/graph.py
index 9aa5bde933..7bd88df5f6 100644
--- a/python/tvm/contrib/msc/core/ir/graph.py
+++ b/python/tvm/contrib/msc/core/ir/graph.py
@@ -194,6 +194,7 @@ class MSCTensor(Object):
         return len(self.shape)
 
 
[email protected]_object("msc.core.BaseJoint")
 class BaseJoint(Object):
     """Base class of all MSC Nodes."""
 
@@ -561,6 +562,7 @@ class WeightJoint(BaseJoint):
         return bool(_ffi_api.WeightJointHasAttr(self, key))
 
 
[email protected]_object("msc.core.BaseGraph")
 class BaseGraph(Object):
     """Base class of all MSC Graphs."""
 
@@ -955,7 +957,7 @@ class MSCGraph(BaseGraph):
 
 
 @tvm.ffi.register_object("msc.core.WeightGraph")
-class WeightGraph(Object):
+class WeightGraph(BaseGraph):
     """The WeightGraph
 
     Parameters
diff --git a/python/tvm/ffi/__init__.py b/python/tvm/ffi/__init__.py
index b507064e34..43a20e751c 100644
--- a/python/tvm/ffi/__init__.py
+++ b/python/tvm/ffi/__init__.py
@@ -30,6 +30,7 @@ from .ndarray import Device, device
 from .ndarray import cpu, cuda, rocm, opencl, metal, vpi, vulkan, ext_dev, 
hexagon, webgpu
 from .ndarray import from_dlpack, NDArray, Shape
 from .container import Array, Map
+from . import serialization
 from . import testing
 
 
diff --git a/python/tvm/ffi/cython/function.pxi 
b/python/tvm/ffi/cython/function.pxi
index cbff3fecf1..8c9df19642 100644
--- a/python/tvm/ffi/cython/function.pxi
+++ b/python/tvm/ffi/cython/function.pxi
@@ -426,3 +426,5 @@ def _convert_to_ffi_func(object pyfunc):
 
 _STR_CONSTRUCTOR = _get_global_func("ffi.String", False)
 _BYTES_CONSTRUCTOR = _get_global_func("ffi.Bytes", False)
+_OBJECT_FROM_JSON_GRAPH_STR = _get_global_func("ffi.FromJSONGraphString", True)
+_OBJECT_TO_JSON_GRAPH_STR = _get_global_func("ffi.ToJSONGraphString", True)
diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi
index 4efedf35d8..7df5f7a19a 100644
--- a/python/tvm/ffi/cython/object.pxi
+++ b/python/tvm/ffi/cython/object.pxi
@@ -14,10 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import warnings
 
 _CLASS_OBJECT = None
 _FUNC_CONVERT_TO_OBJECT = None
 
+
 def _set_class_object(cls):
     global _CLASS_OBJECT
     _CLASS_OBJECT = cls
@@ -32,31 +34,15 @@ def __object_repr__(obj):
     return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")"
 
 
-def __object_save_json__(obj):
-    """Object repr function that can be overridden by assigning to it"""
-    raise NotImplementedError("JSON serialization depends on downstream init")
-
-
-def __object_load_json__(json_str):
-    """Object repr function that can be overridden by assigning to it"""
-    raise NotImplementedError("JSON serialization depends on downstream init")
-
-
-def __object_dir__(obj):
-    """Object dir function that can be overridden by assigning to it"""
-    return []
-
-
-def __object_getattr__(obj, name):
-    """Object getattr function that can be overridden by assigning to it"""
-    raise AttributeError()
-
-
 def _new_object(cls):
     """Helper function for pickle"""
     return cls.__new__(cls)
 
 
+_OBJECT_FROM_JSON_GRAPH_STR = None
+_OBJECT_TO_JSON_GRAPH_STR = None
+
+
 class ObjectGeneric:
     """Base class for all classes that can be converted to object."""
 
@@ -107,34 +93,24 @@ cdef class Object:
         return (_new_object, (cls,), self.__getstate__())
 
     def __getstate__(self):
+        if _OBJECT_TO_JSON_GRAPH_STR is None:
+            raise RuntimeError("ffi.ToJSONGraphString is not registered, make 
sure build project with extra API")
         if not self.__chandle__() == 0:
             # need to explicit convert to str in case String
             # returned and triggered another infinite recursion in get state
-            return {"handle": str(__object_save_json__(self))}
+            return {"handle": str(_OBJECT_TO_JSON_GRAPH_STR(self, None))}
         return {"handle": None}
 
     def __setstate__(self, state):
         # pylint: disable=assigning-non-slot, assignment-from-no-return
+        if _OBJECT_FROM_JSON_GRAPH_STR is None:
+            raise RuntimeError("ffi.FromJSONGraphString is not registered, 
make sure build project with extra API")
         handle = state["handle"]
         if handle is not None:
-            self.__init_handle_by_constructor__(__object_load_json__, handle)
+            self.__init_handle_by_constructor__(_OBJECT_FROM_JSON_GRAPH_STR, 
handle)
         else:
             self.chandle = NULL
 
-    def __getattr__(self, name):
-        if self.chandle == NULL:
-            raise AttributeError(f"{type(self)} has no attribute {name}")
-        try:
-            return __object_getattr__(self, name)
-        except AttributeError:
-            raise AttributeError(f"{type(self)} has no attribute {name}")
-
-    def __dir__(self):
-        # exception safety handling for chandle=None
-        if self.chandle == NULL:
-            return []
-        return __object_dir__(self)
-
     def __repr__(self):
         # exception safety handling for chandle=None
         if self.chandle == NULL:
@@ -147,9 +123,6 @@ cdef class Object:
     def __ne__(self, other):
         return not self.__eq__(other)
 
-    def __init_handle_by_load_json__(self, json_str):
-        raise NotImplementedError("JSON serialization depends on downstream 
init")
-
     def __init_handle_by_constructor__(self, fconstructor, *args):
         """Initialize the handle by calling constructor function.
 
@@ -269,6 +242,15 @@ def _object_type_key_to_index(str type_key):
         return tidx
     return None
 
+cdef inline str _type_index_to_key(int32_t tindex):
+    """get the type key of object class"""
+    cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(tindex)
+    cdef const TVMFFIByteArray* type_key
+    if info == NULL:
+        return "<unknown>"
+    type_key = &(info.type_key)
+    return py_str(PyBytes_FromStringAndSize(type_key.data, type_key.size))
+
 
 cdef inline object make_ret_object(TVMFFIAny result):
     global OBJECT_TYPE
@@ -284,10 +266,14 @@ cdef inline object make_ret_object(TVMFFIAny result):
                 (<Object>obj).chandle = result.v_obj
                 return cls.__from_tvm_ffi_object__(cls, obj)
             obj = cls.__new__(cls)
-        else:
-            obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
-    else:
-        obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
+            (<Object>obj).chandle = result.v_obj
+            return obj
+
+    # object is not found in registered entry
+    # in this case we need to report an warning
+    type_key = _type_index_to_key(tindex)
+    warnings.warn(f"Returning type `{type_key}` which is not registered via 
register_object, fallback to Object")
+    obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
     (<Object>obj).chandle = result.v_obj
     return obj
 
diff --git a/python/tvm/ffi/serialization.py b/python/tvm/ffi/serialization.py
new file mode 100644
index 0000000000..25d9bcefb8
--- /dev/null
+++ b/python/tvm/ffi/serialization.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Serialization related utilities to enable some object can be pickled"""
+
+from typing import Optional, Any
+from . import _ffi_api
+
+
+def to_json_graph_str(obj: Any, metadata: Optional[dict] = None):
+    """
+    Dump an object to a JSON graph string.
+
+    The JSON graph string is a string representation of of the object
+    graph includes the reference information of same objects, which can
+    be used for serialization and debugging.
+
+    Parameters
+    ----------
+    obj : Any
+        The object to save.
+
+    metadata : Optional[dict], optional
+        Extra metadata to save into the json graph string.
+
+    Returns
+    -------
+    json_str : str
+        The JSON graph string.
+    """
+    return _ffi_api.ToJSONGraphString(obj, metadata)
+
+
+def from_json_graph_str(json_str: str):
+    """
+    Load an object from a JSON graph string.
+
+    The JSON graph string is a string representation of of the object
+    graph that also includes the reference information.
+
+    Parameters
+    ----------
+    json_str : str
+        The JSON graph string to load.
+
+    Returns
+    -------
+    obj : Any
+        The loaded object.
+    """
+    return _ffi_api.FromJSONGraphString(json_str)
+
+
+__all__ = ["from_json_graph_str", "to_json_graph_str"]
diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py
index e7de1a9f90..cab982f4e7 100644
--- a/python/tvm/ir/attrs.py
+++ b/python/tvm/ir/attrs.py
@@ -41,7 +41,7 @@ class Attrs(Object):
         -------
         value: Tuple of int
         """
-        return tuple(x if isinstance(x, int) else x.value for x in 
self.__getattr__(key))
+        return tuple(x if isinstance(x, int) else x.value for x in 
getattr(self, key))
 
     def get_int(self, key):
         """Get a python int value of a key
@@ -54,7 +54,7 @@ class Attrs(Object):
         -------
         value: int
         """
-        return self.__getattr__(key)
+        return getattr(self, key)
 
     def get_str(self, key):
         """Get a python int value of a key
@@ -67,10 +67,10 @@ class Attrs(Object):
         -------
         value: int
         """
-        return self.__getattr__(key)
+        return getattr(self, key)
 
     def __getitem__(self, item):
-        return self.__getattr__(item)
+        return getattr(self, item)
 
 
 @tvm.ffi.register_object("ir.DictAttrs")
@@ -101,6 +101,12 @@ class DictAttrs(Attrs):
     def __contains__(self, k):
         return self._dict().__contains__(k)
 
+    def __getattr__(self, name):
+        try:
+            return self._dict().__getitem__(name)
+        except KeyError:
+            raise AttributeError(f"DictAttrs has no attribute {name}")
+
     def items(self):
         """Get items from the map."""
         return self._dict().items()
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 633c2c6790..eca885e03a 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -326,7 +326,7 @@ class VarPattern(DFPattern):
 
 
 @register_df_node
-class DataflowVarPattern(DFPattern):
+class DataflowVarPattern(VarPattern):
     """A pattern for DataflowVar.
 
     Parameters
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 9ddaf52e72..ee9caf3a83 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -1177,6 +1177,11 @@ def const(
     return Constant(value)
 
 
[email protected]_object("relax.TEPlaceholderOp")
+class TEPlaceholderOp(tvm.te.tensor.Operation):
+    """The placeholder op that represents a relax expression."""
+
+
 def te_tensor(
     value: Expr, tir_var_map: Dict[tvm.tir.Var, tvm.tir.PrimExpr], name: str = 
"rxplaceholder"
 ):
diff --git a/python/tvm/relax/op/_op_gradient.py 
b/python/tvm/relax/op/_op_gradient.py
index 41eaa5de50..fd80f1e313 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -829,8 +829,8 @@ def cumsum_grad(
         The "reversed" cumsum along the same axis. Implemented by some tricks 
now.
     """
 
-    axis = orig_call.attrs["axis"]
-    dtype = orig_call.attrs["dtype"]
+    axis = orig_call.attrs.axis
+    dtype = orig_call.attrs.dtype
     x_shape = _get_shape(orig_call.args[0])
 
     if axis is not None:
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index 864eb3fec7..bb134f1148 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -624,6 +624,7 @@ def index_put(
     Examples
     --------
     .. code-block:: python
+
         # inputs
         data = torch.zeros(3, 3)
         indices = (torch.tensor([0, 2]), torch.tensor([1, 1]))
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index 3e0f87c487..9c15cdd966 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -202,3 +202,158 @@ class FlipAttrs(Attrs):
 @tvm.ffi.register_object("relax.attrs.PadAttrs")
 class PadAttrs(Attrs):
     """Attributes used in pad operator"""
+
+
[email protected]_object("relax.attrs.MultinomialFromUniformAttrs")
+class MultinomialFromUniformAttrs(Attrs):
+    """Attributes for multinomial_from_uniform operator"""
+
+
[email protected]_object("relax.attrs.CallInplacePackedAttrs")
+class CallInplacePackedAttrs(Attrs):
+    """Attributes used in call_inplace_packed operator"""
+
+
[email protected]_object("relax.attrs.CallTIRInplaceAttrs")
+class CallTIRInplaceAttrs(Attrs):
+    """Attributes used in call_tir_inplace operator"""
+
+
[email protected]_object("relax.attrs.ToVDeviceAttrs")
+class ToVDeviceAttrs(Attrs):
+    """Attributes used in to_vdevice operator"""
+
+
[email protected]_object("relax.attrs.HintOnDeviceAttrs")
+class HintOnDeviceAttrs(Attrs):
+    """Attributes used in hint_on_device operator"""
+
+
[email protected]_object("relax.attrs.ScatterCollectiveAttrs")
+class ScatterCollectiveAttrs(Attrs):
+    """Attributes used in scatter collective operators"""
+
+
[email protected]_object("relax.attrs.AttentionAttrs")
+class AttentionAttrs(Attrs):
+    """Attributes used in attention operator"""
+
+
[email protected]_object("relax.attrs.Conv1DAttrs")
+class Conv1DAttrs(Attrs):
+    """Attributes for nn.conv1d"""
+
+
[email protected]_object("relax.attrs.Conv1DTransposeAttrs")
+class Conv1DTransposeAttrs(Attrs):
+    """Attributes for nn.conv1d_transpose"""
+
+
[email protected]_object("relax.attrs.Pool1DAttrs")
+class Pool1DAttrs(Attrs):
+    """Attributes for nn.max_pool1d and nn.avg_pool1d"""
+
+
[email protected]_object("relax.attrs.Pool3DAttrs")
+class Pool3DAttrs(Attrs):
+    """Attributes for nn.max_pool3d and nn.avg_pool3d"""
+
+
[email protected]_object("relax.attrs.AdaptivePool1DAttrs")
+class AdaptivePool1DAttrs(Attrs):
+    """Attributes for 1d adaptive pool operator"""
+
+
[email protected]_object("relax.attrs.AdaptivePool3DAttrs")
+class AdaptivePool3DAttrs(Attrs):
+    """Attributes for 3d adaptive pool operator"""
+
+
[email protected]_object("relax.attrs.LeakyReluAttrs")
+class LeakyReluAttrs(Attrs):
+    """Attributes used in leaky_relu operator"""
+
+
[email protected]_object("relax.attrs.SoftplusAttrs")
+class SoftplusAttrs(Attrs):
+    """Attributes used in softplus operator"""
+
+
[email protected]_object("relax.attrs.PReluAttrs")
+class PReluAttrs(Attrs):
+    """Attributes used in prelu operator"""
+
+
[email protected]_object("relax.attrs.PixelShuffleAttrs")
+class PixelShuffleAttrs(Attrs):
+    """Attributes used in pixel_shuffle operator"""
+
+
[email protected]_object("relax.attrs.GroupNormAttrs")
+class GroupNormAttrs(Attrs):
+    """Attributes used in group_norm operator"""
+
+
[email protected]_object("relax.attrs.RMSNormAttrs")
+class RMSNormAttrs(Attrs):
+    """Attributes used in rms_norm operator"""
+
+
[email protected]_object("relax.attrs.NLLLossAttrs")
+class NLLLossAttrs(Attrs):
+    """Attributes used in nll_loss operator"""
+
+
[email protected]_object("relax.attrs.AllReduceAttrs")
+class AllReduceAttrs(Attrs):
+    """Attributes used in allreduce operator"""
+
+
[email protected]_object("relax.attrs.AllGatherAttrs")
+class AllGatherAttrs(Attrs):
+    """Attributes used in allgather operator"""
+
+
[email protected]_object("relax.attrs.WrapParamAttrs")
+class WrapParamAttrs(Attrs):
+    """Attributes used in wrap_param operator"""
+
+
[email protected]_object("relax.attrs.QuantizeAttrs")
+class QuantizeAttrs(Attrs):
+    """Attributes used in quantize/dequantize operators"""
+
+
[email protected]_object("relax.attrs.GatherElementsAttrs")
+class GatherElementsAttrs(Attrs):
+    """Attributes for gather_elements operator"""
+
+
[email protected]_object("relax.attrs.GatherNDAttrs")
+class GatherNDAttrs(Attrs):
+    """Attributes for gather_nd operator"""
+
+
[email protected]_object("relax.attrs.MeshgridAttrs")
+class MeshgridAttrs(Attrs):
+    """Attributes for meshgrid operator"""
+
+
[email protected]_object("relax.attrs.ScatterElementsAttrs")
+class ScatterElementsAttrs(Attrs):
+    """Attributes for scatter_elements operator"""
+
+
[email protected]_object("relax.attrs.ScatterNDAttrs")
+class ScatterNDAttrs(Attrs):
+    """Attributes for scatter_nd operator"""
+
+
[email protected]_object("relax.attrs.SliceScatterAttrs")
+class SliceScatterAttrs(Attrs):
+    """Attributes for slice_scatter operator"""
+
+
[email protected]_object("relax.attrs.OneHotAttrs")
+class OneHotAttrs(Attrs):
+    """Attributes for one_hot operator"""
diff --git a/python/tvm/runtime/_ffi_node_api.py 
b/python/tvm/runtime/_ffi_node_api.py
index aef9ded9cc..4a0edd449c 100644
--- a/python/tvm/runtime/_ffi_node_api.py
+++ b/python/tvm/runtime/_ffi_node_api.py
@@ -28,14 +28,6 @@ def AsRepr(obj):
     return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")"
 
 
-def NodeListAttrNames(obj):
-    return lambda x: 0
-
-
-def NodeGetAttr(obj, name):
-    raise AttributeError()
-
-
 def SaveJSON(obj):
     raise RuntimeError("Do not support object serialization in runtime only 
mode")
 
diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py
index 688682d197..b2fcddc40a 100644
--- a/python/tvm/runtime/object.py
+++ b/python/tvm/runtime/object.py
@@ -22,17 +22,6 @@ import tvm.ffi.core
 from . import _ffi_node_api
 
 
-def __object_dir__(obj):
-    class_names = dir(obj.__class__)
-    fnames = _ffi_node_api.NodeListAttrNames(obj)
-    size = fnames(-1)
-    return sorted([fnames(i) for i in range(size)] + class_names)
-
-
 tvm.ffi.core._set_class_object(Object)
 # override the default repr function for tvm.ffi.core.Object
 tvm.ffi.core.__object_repr__ = _ffi_node_api.AsRepr
-tvm.ffi.core.__object_save_json__ = _ffi_node_api.SaveJSON
-tvm.ffi.core.__object_load_json__ = _ffi_node_api.LoadJSON
-tvm.ffi.core.__object_getattr__ = _ffi_node_api.NodeGetAttr
-tvm.ffi.core.__object_dir__ = __object_dir__
diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py
index 02a67e916b..bf468b17ec 100644
--- a/python/tvm/script/printer/doc.py
+++ b/python/tvm/script/printer/doc.py
@@ -26,25 +26,12 @@ from tvm.tir import FloatImm, IntImm
 from . import _ffi_api
 
 
+@register_object("script.printer.Doc")
 class Doc(Object):
     """Base class of all Docs"""
 
-    @property
-    def source_paths(self) -> Sequence[ObjectPath]:
-        """
-        The list of object paths of the source IR node.
-
-        This is used to trace back to the IR node position where
-        this Doc is generated, in order to position the diagnostic
-        message.
-        """
-        return self.__getattr__("source_paths")  # pylint: 
disable=unnecessary-dunder-call
-
-    @source_paths.setter
-    def source_paths(self, value):
-        return _ffi_api.DocSetSourcePaths(self, value)  # type: ignore # 
pylint: disable=no-member
-
 
+@register_object("script.printer.ExprDoc")
 class ExprDoc(Doc):
     """Base class of all expression Docs"""
 
@@ -114,26 +101,10 @@ class ExprDoc(Doc):
         raise RuntimeError(f"{self.__class__} cannot be used as iterable.")
 
 
+@register_object("script.printer.StmtDoc")
 class StmtDoc(Doc):
     """Base class of statement doc"""
 
-    @property
-    def comment(self) -> Optional[str]:
-        """
-        The comment of this doc.
-
-        The actual position of the comment depends on the type of Doc
-        and also the DocPrinter implementation. It could be on the same
-        line as the statement, or the line above, or inside the statement
-        if it spans over multiple lines.
-        """
-        # It has to call the dunder method to avoid infinite recursion
-        return self.__getattr__("comment")  # pylint: 
disable=unnecessary-dunder-call
-
-    @comment.setter
-    def comment(self, value):
-        return _ffi_api.StmtDocSetComment(self, value)  # type: ignore # 
pylint: disable=no-member
-
 
 @register_object("script.printer.StmtBlockDoc")
 class StmtBlockDoc(Doc):
diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py
index 489ec38ba5..73b995a45e 100644
--- a/python/tvm/te/tensor.py
+++ b/python/tvm/te/tensor.py
@@ -84,26 +84,6 @@ class Tensor(DataProducer, _expr.ExprOp):
         """Dimension of the tensor."""
         return len(self.shape)
 
-    @property
-    def axis(self):
-        """Axis of the tensor."""
-        return self.__getattr__("axis")
-
-    @property
-    def op(self):
-        """The corressponding :py:class:`Operation`."""
-        return self.__getattr__("op")
-
-    @property
-    def value_index(self):
-        """The output value index the tensor corresponds to."""
-        return self.__getattr__("value_index")
-
-    @property
-    def shape(self):
-        """The output shape of the tensor."""
-        return self.__getattr__("shape")
-
     @property
     def name(self):
         op = self.op
diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py
index ea798242b4..620a66351d 100644
--- a/python/tvm/testing/__init__.py
+++ b/python/tvm/testing/__init__.py
@@ -43,3 +43,4 @@ from .popen_pool import (
 )
 from .runner import local_run, rpc_run
 from .utils import *
+from .attrs import *
diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/attrs.py
similarity index 56%
copy from python/tvm/testing/__init__.py
copy to python/tvm/testing/attrs.py
index ea798242b4..ea6f1b1af6 100644
--- a/python/tvm/testing/__init__.py
+++ b/python/tvm/testing/attrs.py
@@ -14,32 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=invalid-name, import-outside-toplevel, unused-variable
+"""Testing utilities for attrs"""
+from ..ir import Attrs
+from ..ffi import register_object
 
-# pylint: disable=redefined-builtin, wildcard-import
-"""Utility Python functions for TVM testing"""
-from ._ffi_api import (
-    ErrorTest,
-    FrontendTestModule,
-    device_test,
-    echo,
-    identity_cpp,
-    nop,
-    object_use_count,
-    run_check_signal,
-    test_check_eq_callback,
-    test_raise_error,
-    test_wrap_callback,
-)
-from .popen_pool import (
-    after_initializer,
-    call_cpp_ffi,
-    call_cpp_py_ffi,
-    call_py_ffi,
-    fast_summation,
-    initializer,
-    register_ffi,
-    slow_summation,
-    timeout_job,
-)
-from .runner import local_run, rpc_run
-from .utils import *
+
+@register_object("attrs.TestAttrs")
+class TestAttrs(Attrs):
+    """Attrs used for testing purposes"""
+
+
+__all__ = ["TestAttrs"]
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 81ce63b797..93a182ca3b 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -23,6 +23,8 @@ from typing import Callable, Optional
 
 from . import _ffi_api
 from . import function_pass as _fpass
+from ... import ir as _ir
+from ... import ffi as _ffi
 
 
 def Apply(ftransform):
@@ -48,6 +50,11 @@ def Apply(ftransform):
     return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply")  # 
type: ignore
 
 
+@_ffi.register_object("tir.transform.LoopPartitionConfig")
+class LoopPartitionConfig(_ir.Attrs):
+    """Config for loop partition pass"""
+
+
 def LoopPartition():
     """Inject virtual thread loops.
 
@@ -87,6 +94,11 @@ def InjectVirtualThread():
     return _ffi_api.InjectVirtualThread()  # type: ignore
 
 
+@_ffi.register_object("tir.transform.InjectDoubleBufferConfig")
+class InjectDoubleBufferConfig(_ir.Attrs):
+    """Config for inject double buffer pass"""
+
+
 def InjectDoubleBuffer():
     """Inject double buffer statements.
 
@@ -149,6 +161,11 @@ def PointerValueTypeRewrite():
     return _ffi_api.PointerValueTypeRewrite()  # type: ignore
 
 
+@_ffi.register_object("tir.transform.UnrollLoopConfig")
+class UnrollLoopConfig(_ir.Attrs):
+    """Config for unroll loop pass"""
+
+
 def UnrollLoop():
     """Unroll the constant loop marked by unroll.
 
@@ -162,6 +179,11 @@ def UnrollLoop():
     return _ffi_api.UnrollLoop()  # type: ignore
 
 
+@_ffi.register_object("tir.transform.ReduceBranchingThroughOvercomputeConfig")
+class ReduceBranchingThroughOvercomputeConfig(_ir.Attrs):
+    """Config for reduce branching through overcompute pass"""
+
+
 def ReduceBranchingThroughOvercompute():
     """Reduce branching by introducing overcompute
 
@@ -173,6 +195,11 @@ def ReduceBranchingThroughOvercompute():
     return _ffi_api.ReduceBranchingThroughOvercompute()  # type: ignore
 
 
+@_ffi.register_object("tir.transform.RemoveNoOpConfig")
+class RemoveNoOpConfig(_ir.Attrs):
+    """Config for remove no op pass"""
+
+
 def RemoveNoOp():
     """Remove No Op from the Stmt.
 
@@ -277,6 +304,11 @@ def RewriteUnsafeSelect():
     return _ffi_api.RewriteUnsafeSelect()  # type: ignore
 
 
+@_ffi.register_object("tir.transform.SimplifyConfig")
+class SimplifyConfig(_ir.Attrs):
+    """Config for simplify pass"""
+
+
 def Simplify():
     """Run arithmetic simplifications on the statements and expressions.
 
@@ -607,6 +639,11 @@ def VerifyVTCMLimit(limit=None):
     return _ffi_api.VerifyVTCMLimit(limit)  # type: ignore
 
 
+@_ffi.register_object("tir.transform.HoistIfThenElseConfig")
+class HoistIfThenElseConfig(_ir.Attrs):
+    """Config for hoist if then else pass"""
+
+
 # pylint: disable=no-else-return,inconsistent-return-statements
 def HoistIfThenElse(variant: Optional[str] = None):
     """Hoist loop-invariant IfThenElse nodes to outside the eligible loops.
@@ -686,6 +723,11 @@ class HoistedLetBindings(enum.Flag):
     """ Enable all hoisting of let bindings """
 
 
+@_ffi.register_object("tir.transform.HoistExpressionConfig")
+class HoistExpressionConfig(_ir.Attrs):
+    """Config for hoist expression pass"""
+
+
 def HoistExpression():
     """Generalized verison of HoistIfThenElse.
 
diff --git a/src/node/reflection.cc b/src/node/reflection.cc
index 6db751a80f..e666b434f8 100644
--- a/src/node/reflection.cc
+++ b/src/node/reflection.cc
@@ -33,75 +33,6 @@ using ffi::Any;
 using ffi::Function;
 using ffi::PackedArgs;
 
-// Expose to FFI APIs.
-void NodeGetAttr(ffi::PackedArgs args, ffi::Any* ret) {
-  Object* self = const_cast<Object*>(args[0].cast<const Object*>());
-  String field_name = args[1].cast<String>();
-
-  bool success;
-  if (field_name == "type_key") {
-    *ret = self->GetTypeKey();
-    success = true;
-  } else if (!self->IsInstance<DictAttrsNode>()) {
-    const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index());
-    success = false;
-    // use new reflection mechanism
-    if (type_info->metadata != nullptr) {
-      ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* 
field_info) {
-        if (field_name.compare(field_info->name) == 0) {
-          ffi::reflection::FieldGetter field_getter(field_info);
-          *ret = field_getter(self);
-          success = true;
-        }
-      });
-    }
-  } else {
-    // specially handle dict attr
-    DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
-    auto it = dnode->dict.find(field_name);
-    if (it != dnode->dict.end()) {
-      success = true;
-      *ret = (*it).second;
-    } else {
-      success = false;
-    }
-  }
-  if (!success) {
-    TVM_FFI_THROW(AttributeError) << self->GetTypeKey() << " object has no 
attribute `"
-                                  << field_name << "`";
-  }
-}
-
-void NodeListAttrNames(ffi::PackedArgs args, ffi::Any* ret) {
-  Object* self = const_cast<Object*>(args[0].cast<const Object*>());
-
-  std::vector<String> names;
-  if (!self->IsInstance<DictAttrsNode>()) {
-    const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(self->type_index());
-    if (type_info->metadata != nullptr) {
-      // use new reflection mechanism
-      ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* 
field_info) {
-        names.push_back(std::string(field_info->name.data, 
field_info->name.size));
-      });
-    }
-  } else {
-    // specially handle dict attr
-    DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
-    for (const auto& kv : dnode->dict) {
-      names.push_back(kv.first);
-    }
-  }
-
-  *ret = ffi::Function::FromPacked([names](ffi::PackedArgs args, ffi::Any* rv) 
{
-    int64_t i = args[0].cast<int64_t>();
-    if (i == -1) {
-      *rv = static_cast<int64_t>(names.size());
-    } else {
-      *rv = names[i];
-    }
-  });
-}
-
 // API function to make node.
 // args format:
 //   key1, value1, ..., key_n, value_n
@@ -123,10 +54,7 @@ void MakeNode(const ffi::PackedArgs& args, ffi::Any* rv) {
 
 TVM_FFI_STATIC_INIT_BLOCK({
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef()
-      .def_packed("node.NodeGetAttr", NodeGetAttr)
-      .def_packed("node.NodeListAttrNames", NodeListAttrNames)
-      .def_packed("node.MakeNode", MakeNode);
+  refl::GlobalDef().def_packed("node.MakeNode", MakeNode);
 });
 
 }  // namespace tvm
diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h
index bc4b90a373..aa7cb9db53 100644
--- a/src/relax/ir/emit_te.h
+++ b/src/relax/ir/emit_te.h
@@ -52,7 +52,7 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode {
         .def_ro("dtype", &RXPlaceholderOpNode::dtype);
   }
 
-  static constexpr const char* _type_key = "RXPlaceholderOp";
+  static constexpr const char* _type_key = "relax.TEPlaceholderOp";
   TVM_DECLARE_FINAL_OBJECT_INFO(RXPlaceholderOpNode, te::PlaceholderOpNode);
 };
 
diff --git a/src/tir/transforms/hoist_expression.cc 
b/src/tir/transforms/hoist_expression.cc
index d89114c68a..1548ea1da6 100644
--- a/src/tir/transforms/hoist_expression.cc
+++ b/src/tir/transforms/hoist_expression.cc
@@ -82,7 +82,7 @@ struct HoistExpressionConfigNode : public 
AttrsNodeReflAdapter<HoistExpressionCo
     return static_cast<int>(flag) & hoisted_let_bindings;
   }
 
-  static constexpr const char* _type_key = 
"tir.transforms.HoistExpressionConfig";
+  static constexpr const char* _type_key = 
"tir.transform.HoistExpressionConfig";
   TVM_DECLARE_FINAL_OBJECT_INFO(HoistExpressionConfigNode, Object);
 };
 
@@ -112,7 +112,7 @@ struct HoistIfThenElseConfigNode : public 
AttrsNodeReflAdapter<HoistIfThenElseCo
         "Hoist if cond with block scope variables", refl::DefaultValue(false));
   }
 
-  static constexpr const char* _type_key = 
"tir.transforms.HoistIfThenElseConfig";
+  static constexpr const char* _type_key = 
"tir.transform.HoistIfThenElseConfig";
   TVM_DECLARE_FINAL_OBJECT_INFO(HoistIfThenElseConfigNode, Object);
 };
 
diff --git a/tests/python/ffi/test_container.py 
b/tests/python/ffi/test_container.py
index 5ac3af1799..25468f452a 100644
--- a/tests/python/ffi/test_container.py
+++ b/tests/python/ffi/test_container.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import pytest
+import pickle
 import tvm.ffi as tvm_ffi
 
 
@@ -93,3 +94,9 @@ def test_repr():
 
     smap = tvm_ffi.convert({"a": 1, "b": 2})
     assert str(smap) == "{'a': 1, 'b': 2}"
+
+
+def test_serialization():
+    a = tvm_ffi.convert([1, 2, 3])
+    b = pickle.loads(pickle.dumps(a))
+    assert str(b) == "[1, 2, 3]"
diff --git a/tests/python/ir/test_ir_attrs.py b/tests/python/ir/test_ir_attrs.py
index 48c38c1556..905069059f 100644
--- a/tests/python/ir/test_ir_attrs.py
+++ b/tests/python/ir/test_ir_attrs.py
@@ -15,8 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
+
+# needed for attrs
+import tvm.testing
 import pytest
-import tvm.ir._ffi_api
 
 
 def test_make_attrs():
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py 
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 23d48894a2..ae865f1fb1 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -14,11 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
-import sys
-
-sys.path.append("/ssd1/htalendr/tvm/python")
-
 import tvm
 from tvm import relax
 from tvm.relax.transform import LegalizeOps
@@ -676,7 +671,7 @@ def test_reshape_symbolic():
 
         @R.function
         def main(
-            x: R.Tensor((10, "b"), dtype="float32")
+            x: R.Tensor((10, "b"), dtype="float32"),
         ) -> R.Tensor((5, "b * 2"), dtype="float32"):
             b = T.int64()
             lv: R.Shape([5, b * 2]) = R.shape([5, b * 2])
diff --git a/tests/python/runtime/test_runtime_rpc.py 
b/tests/python/runtime/test_runtime_rpc.py
index 6711ccf92f..e696cbcf08 100644
--- a/tests/python/runtime/test_runtime_rpc.py
+++ b/tests/python/runtime/test_runtime_rpc.py
@@ -413,7 +413,6 @@ def test_rpc_return_remote_object():
         get_elem = client.get_function("testing.GetShapeElem")
         get_size = client.get_function("testing.GetShapeSize")
         shape = make_shape(2, 3)
-        assert shape.type_key == "runtime.RPCObjectRef"
         assert get_elem(shape, 0) == 2
         assert get_elem(shape, 1) == 3
         assert get_size(shape) == 2


Reply via email to