This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s3
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit ab7a2c3e618c7f5d046e47a319fc52d4d3be3727
Author: tqchen <[email protected]>
AuthorDate: Sun May 4 13:41:42 2025 -0400

    Bring custom hook to ffi layer
---
 python/tvm/_ffi/__init__.py                        |  2 +-
 python/tvm/_ffi/base.py                            |  9 ++--
 python/tvm/ffi/cython/dtype.pxi                    |  2 +-
 python/tvm/ffi/cython/object.pxi                   | 62 ++++++++++++++++++++++
 python/tvm/ffi/cython/string.pxi                   |  2 +
 python/tvm/ffi/dtype.py                            |  2 +-
 python/tvm/ir/container.py                         |  2 +
 python/tvm/meta_schedule/cost_model/mlp_model.py   |  4 +-
 python/tvm/rpc/client.py                           |  2 +-
 python/tvm/runtime/__init__.py                     |  2 +-
 python/tvm/runtime/_ffi_node_api.py                |  3 --
 python/tvm/runtime/container.py                    |  3 ++
 python/tvm/runtime/device.py                       |  5 +-
 python/tvm/runtime/module.py                       |  2 +-
 python/tvm/runtime/ndarray.py                      |  2 +-
 python/tvm/runtime/object.py                       | 62 +++++-----------------
 python/tvm/tir/expr.py                             |  4 +-
 src/node/structural_hash.cc                        | 30 +++++++++++
 src/support/ffi_testing.cc                         |  1 -
 tests/python/ffi/test_string.py                    |  9 ++++
 .../python/tvmscript/test_tvmscript_printer_doc.py |  2 -
 .../test_tvmscript_printer_structural_equal.py     |  5 +-
 22 files changed, 140 insertions(+), 77 deletions(-)

diff --git a/python/tvm/_ffi/__init__.py b/python/tvm/_ffi/__init__.py
index 4e8e59b3a8..559ca84635 100644
--- a/python/tvm/_ffi/__init__.py
+++ b/python/tvm/_ffi/__init__.py
@@ -28,4 +28,4 @@ from . import _pyversion
 from . import base
 from .registry import register_object, register_func
 from .registry import _init_api, get_global_func
-from tvm.ffi import register_error
+from ..ffi import register_error
diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py
index 76293cf125..18ed40fb4c 100644
--- a/python/tvm/_ffi/base.py
+++ b/python/tvm/_ffi/base.py
@@ -18,13 +18,9 @@
 # pylint: disable=invalid-name, import-outside-toplevel
 """Base library for TVM FFI."""
 import ctypes
-import functools
 import os
-import re
 import sys
-import types
 
-from typing import Callable, Sequence, Optional
 
 import numpy as np
 
@@ -64,10 +60,11 @@ _LIB, _LIB_NAME = _load_lib()
 # Whether we are runtime only
 _RUNTIME_ONLY = "runtime" in _LIB_NAME
 
-import tvm.ffi.registry
 
 if _RUNTIME_ONLY:
-    tvm.ffi.registry._SKIP_UNKNOWN_OBJECTS = True
+    from ..ffi import registry as _tvm_ffi_registry
+
+    _tvm_ffi_registry._SKIP_UNKNOWN_OBJECTS = True
 
 # The FFI mode of TVM
 _FFI_MODE = os.environ.get("TVM_FFI", "auto")
diff --git a/python/tvm/ffi/cython/dtype.pxi b/python/tvm/ffi/cython/dtype.pxi
index ec045fce65..bbf9e60053 100644
--- a/python/tvm/ffi/cython/dtype.pxi
+++ b/python/tvm/ffi/cython/dtype.pxi
@@ -29,7 +29,7 @@ def _create_dtype_from_tuple(cls, code, bits, lanes):
     cdtype.bits = bits
     cdtype.lanes = lanes
     ret = cls.__new__(cls)
-    (<DLDataType>ret).cdtype = cdtype
+    (<DataType>ret).cdtype = cdtype
     return ret
 
 
diff --git a/python/tvm/ffi/cython/object.pxi b/python/tvm/ffi/cython/object.pxi
index 9abbb40195..1ac32c3bc6 100644
--- a/python/tvm/ffi/cython/object.pxi
+++ b/python/tvm/ffi/cython/object.pxi
@@ -32,6 +32,31 @@ def __object_repr__(obj):
     return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")"
 
 
+def __object_save_json__(obj):
+    """Object repr function that can be overridden by assigning to it"""
+    raise NotImplementedError("JSON serialization depends on downstream init")
+
+
+def __object_load_json__(json_str):
+    """Object repr function that can be overridden by assigning to it"""
+    raise NotImplementedError("JSON serialization depends on downstream init")
+
+
+def __object_dir__(obj):
+    """Object dir function that can be overridden by assigning to it"""
+    return []
+
+
+def __object_getattr__(obj, name):
+    """Object getattr function that can be overridden by assigning to it"""
+    raise AttributeError()
+
+
+def _new_object(cls):
+    """Helper function for pickle"""
+    return cls.__new__(cls)
+
+
 class ObjectGeneric:
     """Base class for all classes that can be converted to object."""
 
@@ -71,9 +96,46 @@ cdef class Object:
         cdef uint64_t chandle = <uint64_t>self.chandle
         return chandle
 
+    def __reduce__(self):
+        cls = type(self)
+        return (_new_object, (cls,), self.__getstate__())
+
+    def __getstate__(self):
+        if not self.__chandle__() == 0:
+            # need to explicit convert to str in case String
+            # returned and triggered another infinite recursion in get state
+            return {"handle": str(__object_save_json__(self))}
+        return {"handle": None}
+
+    def __setstate__(self, state):
+        # pylint: disable=assigning-non-slot, assignment-from-no-return
+        handle = state["handle"]
+        if handle is not None:
+            self.__init_handle_by_constructor__(__object_load_json__, handle)
+        else:
+            self.chandle = NULL
+
+    def __getattr__(self, name):
+        try:
+            return __object_getattr__(self, name)
+        except AttributeError:
+            raise AttributeError(f"{type(self)} has no attribute {name}")
+
+    def __dir__(self):
+        return __object_dir__(self)
+
     def __repr__(self):
         return __object_repr__(self)
 
+    def __eq__(self, other):
+        return self.same_as(other)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def __init_handle_by_load_json__(self, json_str):
+        raise NotImplementedError("JSON serialization depends on downstream 
init")
+
     def __init_handle_by_constructor__(self, fconstructor, *args):
         """Initialize the handle by calling constructor function.
 
diff --git a/python/tvm/ffi/cython/string.pxi b/python/tvm/ffi/cython/string.pxi
index 00ec92b7ec..733ea90301 100644
--- a/python/tvm/ffi/cython/string.pxi
+++ b/python/tvm/ffi/cython/string.pxi
@@ -65,6 +65,7 @@ class String(str, PyNativeObject):
         val.__tvm_ffi_object__ = obj
         return val
 
+
 _register_object_by_index(kTVMFFIStr, String)
 
 
@@ -89,6 +90,7 @@ class Bytes(bytes, PyNativeObject):
         val.__tvm_ffi_object__ = obj
         return val
 
+
 _register_object_by_index(kTVMFFIBytes, Bytes)
 
 # We special handle str/bytes constructor in cython to avoid extra cyclic deps
diff --git a/python/tvm/ffi/dtype.py b/python/tvm/ffi/dtype.py
index ca98726560..56b888316d 100644
--- a/python/tvm/ffi/dtype.py
+++ b/python/tvm/ffi/dtype.py
@@ -17,9 +17,9 @@
 """dtype class."""
 # pylint: disable=invalid-name
 from enum import IntEnum
+import numpy as np
 
 from . import core
-import numpy as np
 
 
 class DataTypeCode(IntEnum):
diff --git a/python/tvm/ir/container.py b/python/tvm/ir/container.py
index 6a013bce8c..4bc6fcae21 100644
--- a/python/tvm/ir/container.py
+++ b/python/tvm/ir/container.py
@@ -16,3 +16,5 @@
 # under the License.
 """Additional container data structures used across IR variants."""
 from tvm.ffi import Array, Map
+
+__all__ = ["Array", "Map"]
diff --git a/python/tvm/meta_schedule/cost_model/mlp_model.py 
b/python/tvm/meta_schedule/cost_model/mlp_model.py
index 8bd050b689..4ee5ba838d 100644
--- a/python/tvm/meta_schedule/cost_model/mlp_model.py
+++ b/python/tvm/meta_schedule/cost_model/mlp_model.py
@@ -541,7 +541,9 @@ class State:
                                 "_workload.json", "_candidates.json"
                             ),
                         )
-                    except tvm._ffi.base.TVMError:  # pylint: 
disable=protected-access
+                    except (
+                        tvm._ffi.base.TVMError
+                    ):  # pylint: 
disable=protected-access,broad-exception-caught
                         continue
                     candidates, results = [], []
                     tuning_records = database.get_all_tuning_records()
diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py
index cf9706c348..f9e677e49e 100644
--- a/python/tvm/rpc/client.py
+++ b/python/tvm/rpc/client.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=used-before-assignment
+# pylint: disable=used-before-assignment,broad-exception-caught
 """RPC client tools"""
 import os
 import socket
diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
index a630ce1101..774c8dd635 100644
--- a/python/tvm/runtime/__init__.py
+++ b/python/tvm/runtime/__init__.py
@@ -41,7 +41,7 @@ from .params import (
     load_param_dict_from_file,
 )
 
-from tvm.ffi import convert, dtype as DataType, DataTypeCode
 from . import disco
 
 from .support import _regex_match
+from ..ffi import convert, dtype as DataType, DataTypeCode
diff --git a/python/tvm/runtime/_ffi_node_api.py 
b/python/tvm/runtime/_ffi_node_api.py
index 9623a87351..395496d16b 100644
--- a/python/tvm/runtime/_ffi_node_api.py
+++ b/python/tvm/runtime/_ffi_node_api.py
@@ -48,6 +48,3 @@ def LoadJSON(json_str):
 # Exports functions registered via TVM_REGISTER_GLOBAL with the "node" prefix.
 # e.g. TVM_REGISTER_GLOBAL("node.AsRepr")
 tvm._ffi._init_api("node", __name__)
-
-# override the default repr function for tvm.ffi.core.Object
-tvm.ffi.core.__object_repr__ = AsRepr
diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py
index 052701f0d3..3bf149d6b2 100644
--- a/python/tvm/runtime/container.py
+++ b/python/tvm/runtime/container.py
@@ -16,3 +16,6 @@
 # under the License.
 """Runtime container structures."""
 from tvm.ffi import String, Shape as ShapeTuple
+
+
+__all__ = ["ShapeTuple", "String"]
diff --git a/python/tvm/runtime/device.py b/python/tvm/runtime/device.py
index b83bb8cceb..d9d6abce50 100644
--- a/python/tvm/runtime/device.py
+++ b/python/tvm/runtime/device.py
@@ -16,13 +16,12 @@
 # under the License.
 """Common runtime ctypes."""
 # pylint: disable=invalid-name
+import json
+
 import tvm.ffi
-from tvm.ffi import dtype as DataType, DataTypeCode
 
 from . import _ffi_api
 
-import json
-
 
 RPC_SESS_MASK = 128
 
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 657120f907..bb1fbb5fe3 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -102,7 +102,7 @@ class Module(tvm.ffi.Object):
     """Runtime Module."""
 
     def __new__(cls):
-        instance = super().__new__(cls)
+        instance = super(Module, cls).__new__(cls)  # pylint: 
disable=no-value-for-parameter
         instance.entry_name = "__tvm_main__"
         instance._entry = None
         return instance
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index f1a7d1c2f7..5581adbbc1 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -33,7 +33,7 @@ import tvm.ffi
 from . import _ffi_api
 
 
-from tvm.ffi import (
+from ..ffi import (
     device,
     cpu,
     cuda,
diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py
index 2aa2c08632..688682d197 100644
--- a/python/tvm/runtime/object.py
+++ b/python/tvm/runtime/object.py
@@ -17,58 +17,22 @@
 # pylint: disable=invalid-name, unused-import
 """Runtime Object API"""
 
+from tvm.ffi.core import Object
 import tvm.ffi.core
+from . import _ffi_node_api
 
-from . import _ffi_api, _ffi_node_api
 
-
-def _new_object(cls):
-    """Helper function for pickle"""
-    return cls.__new__(cls)
-
-
-class Object(tvm.ffi.core.Object):
-    """Base class for all tvm's runtime objects."""
-
-    __slots__ = []
-
-    def __dir__(self):
-        class_names = dir(self.__class__)
-        fnames = _ffi_node_api.NodeListAttrNames(self)
-        size = fnames(-1)
-        return sorted([fnames(i) for i in range(size)] + class_names)
-
-    def __getattr__(self, name):
-        try:
-            return _ffi_node_api.NodeGetAttr(self, name)
-        except AttributeError:
-            raise AttributeError(f"{type(self)} has no attribute {name}") from 
None
-
-    def __hash__(self):
-        return _ffi_api.ObjectPtrHash(self)
-
-    def __eq__(self, other):
-        return self.same_as(other)
-
-    def __ne__(self, other):
-        return not self.__eq__(other)
-
-    def __reduce__(self):
-        cls = type(self)
-        return (_new_object, (cls,), self.__getstate__())
-
-    def __getstate__(self):
-        if not self.__chandle__() == 0:
-            # need to explicit convert to str in case String
-            # returned and triggered another infinite recursion in get state
-            return {"handle": str(_ffi_node_api.SaveJSON(self))}
-        return {"handle": None}
-
-    def __setstate__(self, state):
-        # pylint: disable=assigning-non-slot, assignment-from-no-return
-        handle = state["handle"]
-        if handle is not None:
-            self.__init_handle_by_constructor__(_ffi_node_api.LoadJSON, handle)
+def __object_dir__(obj):
+    class_names = dir(obj.__class__)
+    fnames = _ffi_node_api.NodeListAttrNames(obj)
+    size = fnames(-1)
+    return sorted([fnames(i) for i in range(size)] + class_names)
 
 
 tvm.ffi.core._set_class_object(Object)
+# override the default repr function for tvm.ffi.core.Object
+tvm.ffi.core.__object_repr__ = _ffi_node_api.AsRepr
+tvm.ffi.core.__object_save_json__ = _ffi_node_api.SaveJSON
+tvm.ffi.core.__object_load_json__ = _ffi_node_api.LoadJSON
+tvm.ffi.core.__object_getattr__ = _ffi_node_api.NodeGetAttr
+tvm.ffi.core.__object_dir__ = __object_dir__
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index 13e10ba3ac..b293343cae 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -69,7 +69,7 @@ def _dtype_is_float(value):
     )  # type: ignore
 
 
-class ExprOp(object):
+class ExprOp:
     """Operator overloading for Expr like expressions."""
 
     # TODO(tkonolige): use inspect to add source information to these objects
@@ -395,7 +395,7 @@ class SizeVar(Var):
 
 
 @tvm._ffi.register_object("tir.IterVar")
-class IterVar(Object, ExprOp, Scriptable):
+class IterVar(ExprOp, Object, Scriptable):
     """Represent iteration variable.
 
     IterVar represents axis iterations in the computation.
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index bba8a59647..94e6768203 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -328,6 +328,22 @@ struct StringObjTrait {
   }
 };
 
+struct BytesObjTrait {
+  static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+  static void SHashReduce(const ffi::BytesObj* key, SHashReducer hash_reduce) {
+    
hash_reduce->SHashReduceHashedValue(ffi::details::StableHashBytes(key->data, 
key->size));
+  }
+
+  static bool SEqualReduce(const ffi::BytesObj* lhs, const ffi::BytesObj* rhs,
+                           SEqualReducer equal) {
+    if (lhs == rhs) return true;
+    if (lhs->size != rhs->size) return false;
+    if (lhs->data == rhs->data) return true;
+    return std::memcmp(lhs->data, rhs->data, lhs->size) == 0;
+  }
+};
+
 struct RefToObjectPtr : public ObjectRef {
   static ObjectPtr<Object> Get(const ObjectRef& ref) {
     return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(ref);
@@ -350,6 +366,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << '"' << support::StrEscape(op->data, op->size) << '"';
     });
 
+TVM_REGISTER_REFLECTION_VTABLE(ffi::BytesObj, BytesObjTrait)
+    .set_creator([](const std::string& bytes) {
+      return RefToObjectPtr::Get(runtime::String(bytes));
+    })
+    .set_repr_bytes([](const Object* n) -> std::string {
+      return GetRef<ffi::Bytes>(static_cast<const ffi::BytesObj*>(n)).operator 
std::string();
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<ffi::BytesObj>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const ffi::BytesObj*>(node.get());
+      p->stream << "b\"" << support::StrEscape(op->data, op->size) << '"';
+    });
+
 struct ModuleNodeTrait {
   static constexpr const std::nullptr_t VisitAttrs = nullptr;
   static constexpr const std::nullptr_t SHashReduce = nullptr;
diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc
index cfd5c42f6e..9e727f06d4 100644
--- a/src/support/ffi_testing.cc
+++ b/src/support/ffi_testing.cc
@@ -71,7 +71,6 @@ TVM_REGISTER_GLOBAL("testing.test_wrap_callback_suppress_err")
       *ret = result;
     });
 
-
 TVM_REGISTER_GLOBAL("testing.test_check_eq_callback")
     .set_body_packed([](TVMArgs args, TVMRetValue* ret) {
       auto msg = args[0].cast<std::string>();
diff --git a/tests/python/ffi/test_string.py b/tests/python/ffi/test_string.py
index 67c82d52f7..98eab5bcb7 100644
--- a/tests/python/ffi/test_string.py
+++ b/tests/python/ffi/test_string.py
@@ -1,3 +1,4 @@
+import pickle
 from tvm import ffi as tvm_ffi
 
 
@@ -12,6 +13,10 @@ def test_string():
     assert isinstance(s3, tvm_ffi.String)
     assert isinstance(s3, str)
 
+    s4 = pickle.loads(pickle.dumps(s))
+    assert s4 == "hello"
+    assert isinstance(s4, tvm_ffi.String)
+
 
 def test_bytes():
     fecho = tvm_ffi.get_global_func("testing.echo")
@@ -27,3 +32,7 @@ def test_bytes():
     b4 = tvm_ffi.convert(bytearray(b"hello"))
     assert isinstance(b4, tvm_ffi.Bytes)
     assert isinstance(b4, bytes)
+
+    b5 = pickle.loads(pickle.dumps(b))
+    assert b5 == b"hello"
+    assert isinstance(b5, tvm_ffi.Bytes)
diff --git a/tests/python/tvmscript/test_tvmscript_printer_doc.py 
b/tests/python/tvmscript/test_tvmscript_printer_doc.py
index 6353627c58..e3d1280b32 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_doc.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_doc.py
@@ -307,7 +307,6 @@ def test_assign_doc(lhs, rhs, annotation):
 def test_invalid_assign_doc(lhs, rhs, annotation):
     with pytest.raises(ValueError) as e:
         AssignDoc(lhs, rhs, annotation)
-    assert "AssignDoc" in str(e.value)
 
 
 @pytest.mark.parametrize(
@@ -332,7 +331,6 @@ def test_if_doc(then_branch, else_branch):
     if not then_branch and not else_branch:
         with pytest.raises(ValueError) as e:
             IfDoc(predicate, then_branch, else_branch)
-        assert "IfDoc" in str(e.value)
         return
     else:
         doc = IfDoc(predicate, then_branch, else_branch)
diff --git a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py 
b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py
index 6f67733a28..58d9402e6f 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py
@@ -24,12 +24,11 @@ from tvm.script import ir as I, tir as T
 
 
 def _error_message(exception):
-    splitter = "ValueError: StructuralEqual"
-    return splitter + str(exception).split(splitter)[1]
+    return str(exception)
 
 
 def _expected_result(func1, func2, objpath1, objpath2):
-    return f"""ValueError: StructuralEqual check failed, caused by lhs at 
{objpath1}:
+    return f"""StructuralEqual check failed, caused by lhs at {objpath1}:
 {func1.script(path_to_underline=[objpath1], syntax_sugar=False)}
 and rhs at {objpath2}:
 {func2.script(path_to_underline=[objpath2], syntax_sugar=False)}"""

Reply via email to