This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d37b6abd56 [REFACTOR][IR] Phase out src/ir/structural_{hash,equal}.cc
to tvm-ffi (#19613)
d37b6abd56 is described below
commit d37b6abd5632536de64ab439527a08af4f1b2283
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue May 26 15:29:50 2026 -0400
[REFACTOR][IR] Phase out src/ir/structural_{hash,equal}.cc to tvm-ffi
(#19613)
## Summary
The tvm-ffi layer now provides fully featured structural-hash and
structural-equal APIs (including `GetFirstStructuralMismatch` with
`AccessPath` pair output). The two TUs `src/ir/structural_hash.cc` and
`src/ir/structural_equal.cc` had become thin adapters with no logic of
their own — they forwarded to tvm-ffi and registered the results as
`node.Structural*` globals for Python to call. This PR removes the
indirection.
- **Commit A** (`[REFACTOR][IR]`): relocates the `ffi::ModuleObj` and
`ffi::TensorObj` `__data_to_json__`/`__data_from_json__` `TypeAttrDef`
registrations from `structural_hash.cc` into `src/ir/module.cc` and
`src/runtime/tensor.cc` respectively, both of which already have a
`TVM_FFI_STATIC_INIT_BLOCK` for those types.
- **Commit B** (`[REFACTOR][PYTHON]`): rewrites the four Python wrappers
in `tvm.ir.base` (`structural_equal`, `get_first_structural_mismatch`,
`assert_structural_equal`, `structural_hash`) to call `tvm_ffi._ffi_api`
directly, bypassing the now-redundant `node.Structural*` globals.
`assert_structural_equal` reconstructs the same diagnostic message in
Python using `TVMScriptPrinterScript` with `path_to_underline`.
- **Commit C** (`[REFACTOR][IR]`): deletes `src/ir/structural_hash.cc`
and `src/ir/structural_equal.cc` whose remaining content (the
`node.Structural*` FFI global registrations) is now unused.
---
python/tvm/ir/base.py | 31 ++++++++++------
src/ir/module.cc | 15 ++++++++
src/ir/structural_equal.cc | 83 ------------------------------------------
src/ir/structural_hash.cc | 89 ----------------------------------------------
src/runtime/tensor.cc | 21 +++++++++++
5 files changed, 56 insertions(+), 183 deletions(-)
diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py
index b65a241450..cff43bb8c1 100644
--- a/python/tvm/ir/base.py
+++ b/python/tvm/ir/base.py
@@ -16,9 +16,9 @@
# under the License.
"""Common base structures."""
+import tvm_ffi
from tvm_ffi import get_global_func, register_object
-import tvm.error
from tvm.runtime import Object, _ffi_node_api
from . import _ffi_api, json_compact
@@ -205,9 +205,7 @@ def structural_equal(lhs, rhs, map_free_vars=False):
structural_hash
assert_strucural_equal
"""
- lhs = tvm.runtime.convert(lhs)
- rhs = tvm.runtime.convert(rhs)
- return bool(_ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars))
# type: ignore # pylint: disable=no-member
+ return tvm_ffi.structural_equal(lhs, rhs, map_free_vars)
def get_first_structural_mismatch(lhs, rhs, map_free_vars=False,
skip_tensor_content=False):
@@ -234,9 +232,7 @@ def get_first_structural_mismatch(lhs, rhs,
map_free_vars=False, skip_tensor_con
`None` if `lhs` and `rhs` are structurally equal.
Otherwise, a tuple of two AccessPath objects that point to the first
detected mismtach.
"""
- lhs = tvm.runtime.convert(lhs)
- rhs = tvm.runtime.convert(rhs)
- return _ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars,
skip_tensor_content) # type: ignore # pylint: disable=no-member
+ return tvm_ffi.get_first_structural_mismatch(lhs, rhs, map_free_vars,
skip_tensor_content)
def assert_structural_equal(lhs, rhs, map_free_vars=False):
@@ -262,9 +258,22 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False):
--------
structural_equal
"""
- lhs = tvm.runtime.convert(lhs)
- rhs = tvm.runtime.convert(rhs)
- _ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars) # type:
ignore # pylint: disable=no-member
+ first_mismatch = tvm_ffi.get_first_structural_mismatch(lhs, rhs,
map_free_vars)
+ if first_mismatch is not None:
+ from tvm.runtime.script_printer import ( # pylint:
disable=import-outside-toplevel
+ PrinterConfig,
+ _script,
+ )
+
+ lhs_path, rhs_path = first_mismatch
+ lhs_script = _script(lhs, PrinterConfig(syntax_sugar=False,
path_to_underline=[lhs_path]))
+ rhs_script = _script(rhs, PrinterConfig(syntax_sugar=False,
path_to_underline=[rhs_path]))
+ raise ValueError(
+ f"StructuralEqual check failed, caused by lhs at {lhs_path}:\n"
+ f"{lhs_script}\n"
+ f"and rhs at {rhs_path}:\n"
+ f"{rhs_script}"
+ )
def structural_hash(node, map_free_vars=False):
@@ -306,7 +315,7 @@ def structural_hash(node, map_free_vars=False):
--------
structrual_equal
"""
- return _ffi_node_api.StructuralHash(node, map_free_vars) # type: ignore #
pylint: disable=no-member
+ return tvm_ffi.structural_hash(node, map_free_vars)
def deprecated(
diff --git a/src/ir/module.cc b/src/ir/module.cc
index be74c6ba8d..a09780d94d 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -22,6 +22,8 @@
*/
#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/variant.h>
+#include <tvm/ffi/extra/base64.h>
+#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
@@ -29,6 +31,7 @@
#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/module.h>
#include <tvm/ir/type_functor.h>
+#include <tvm/target/codegen.h>
#include <algorithm>
#include <fstream>
@@ -230,6 +233,18 @@ IRModule IRModule::FromExpr(const RelaxExpr& expr,
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
+ refl::TypeAttrDef<ffi::ModuleObj>()
+ .def("__data_to_json__",
+ [](const ffi::ModuleObj* node) {
+ std::string bytes =
codegen::SerializeModuleToBytes(ffi::GetRef<ffi::Module>(node),
+
/*export_dso*/ false);
+ return ffi::Base64Encode(ffi::Bytes(bytes));
+ })
+ .def("__data_from_json__", [](const ffi::String& base64_bytes) {
+ ffi::Bytes bytes = ffi::Base64Decode(base64_bytes);
+ ffi::Module rtmod = codegen::DeserializeModuleFromBytes(bytes.operator
std::string());
+ return rtmod;
+ });
refl::GlobalDef()
.def("ir.IRModule",
[](tvm::ffi::Map<GlobalVar, BaseFunc> funcs, tvm::ffi::ObjectRef
attrs,
diff --git a/src/ir/structural_equal.cc b/src/ir/structural_equal.cc
deleted file mode 100644
index 4dcf2a32a6..0000000000
--- a/src/ir/structural_equal.cc
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-/*!
- * \file src/ir/structural_equal.cc
- */
-#include <tvm/ffi/extra/structural_equal.h>
-#include <tvm/ffi/function.h>
-#include <tvm/ffi/reflection/access_path.h>
-#include <tvm/ffi/reflection/registry.h>
-#include <tvm/ir/module.h>
-#include <tvm/ir/node_functor.h>
-#include <tvm/ir/repr.h>
-#include <tvm/script/printer/config.h>
-
-#include <optional>
-#include <unordered_map>
-
-namespace tvm {
-
-bool NodeStructuralEqualAdapter(const Any& lhs, const Any& rhs, bool
assert_mode,
- bool map_free_vars) {
- if (assert_mode) {
- auto first_mismatch = ffi::StructuralEqual::GetFirstMismatch(lhs, rhs,
map_free_vars);
- if (first_mismatch.has_value()) {
- std::ostringstream oss;
- oss << "StructuralEqual check failed, caused by lhs";
- oss << " at " << (*first_mismatch).get<0>();
- {
- // print lhs
- PrinterConfig cfg;
- cfg->syntax_sugar = false;
- cfg->path_to_underline.push_back((*first_mismatch).get<0>());
- // The TVMScriptPrinter::Script will fallback to Repr printer,
- // if the root node to print is not supported yet,
- // e.g. Relax nodes, ArrayObj, MapObj, etc.
- oss << ":" << std::endl <<
TVMScriptPrinter::Script(lhs.cast<ffi::ObjectRef>(), cfg);
- }
- oss << std::endl << "and rhs";
- {
- // print rhs
- oss << " at " << (*first_mismatch).get<1>();
- {
- PrinterConfig cfg;
- cfg->syntax_sugar = false;
- cfg->path_to_underline.push_back((*first_mismatch).get<1>());
- // The TVMScriptPrinter::Script will fallback to Repr printer,
- // if the root node to print is not supported yet,
- // e.g. Relax nodes, ArrayObj, MapObj, etc.
- oss << ":" << std::endl <<
TVMScriptPrinter::Script(rhs.cast<ffi::ObjectRef>(), cfg);
- }
- }
- TVM_FFI_THROW(ValueError) << oss.str();
- }
- return true;
- } else {
- return ffi::StructuralEqual::Equal(lhs, rhs, map_free_vars);
- }
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
- namespace refl = tvm::ffi::reflection;
- refl::GlobalDef()
- .def("node.StructuralEqual", NodeStructuralEqualAdapter)
- .def("node.GetFirstStructuralMismatch",
ffi::StructuralEqual::GetFirstMismatch);
-}
-
-} // namespace tvm
diff --git a/src/ir/structural_hash.cc b/src/ir/structural_hash.cc
deleted file mode 100644
index 9f33c2f50a..0000000000
--- a/src/ir/structural_hash.cc
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-/*!
- * \file src/ir/structural_hash.cc
- */
-#include <tvm/ffi/cast.h>
-#include <tvm/ffi/extra/base64.h>
-#include <tvm/ffi/extra/module.h>
-#include <tvm/ffi/extra/structural_hash.h>
-#include <tvm/ffi/function.h>
-#include <tvm/ffi/reflection/access_path.h>
-#include <tvm/ffi/reflection/registry.h>
-#include <tvm/ir/node_functor.h>
-#include <tvm/runtime/tensor.h>
-#include <tvm/support/io.h>
-#include <tvm/target/codegen.h>
-
-#include <algorithm>
-#include <unordered_map>
-
-#include "../support/base64.h"
-#include "../support/bytes_io.h"
-#include "../support/str_escape.h"
-#include "../support/utils.h"
-
-namespace tvm {
-
-TVM_FFI_STATIC_INIT_BLOCK() {
- namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("node.StructuralHash",
- [](const Any& object, bool map_free_vars) -> int64_t {
- return ffi::StructuralHash::Hash(object,
map_free_vars);
- });
- refl::TypeAttrDef<ffi::ModuleObj>()
- .def("__data_to_json__",
- [](const ffi::ModuleObj* node) {
- std::string bytes =
codegen::SerializeModuleToBytes(ffi::GetRef<ffi::Module>(node),
-
/*export_dso*/ false);
- return ffi::Base64Encode(ffi::Bytes(bytes));
- })
- .def("__data_from_json__", [](const ffi::String& base64_bytes) {
- ffi::Bytes bytes = ffi::Base64Decode(base64_bytes);
- ffi::Module rtmod = codegen::DeserializeModuleFromBytes(bytes.operator
std::string());
- return rtmod;
- });
-
- refl::TypeAttrDef<ffi::TensorObj>()
- .def("__data_to_json__",
- [](const ffi::TensorObj* node) {
- std::string result;
- support::BytesOutStream mstrm(&result);
- support::Base64OutStream b64strm(&mstrm);
- runtime::SaveDLTensor(&b64strm, node);
- b64strm.Finish();
- return ffi::String(std::move(result));
- })
- .def("__data_from_json__", [](const std::string& blob) {
- support::BytesInStream mstrm(blob);
- support::Base64InStream b64strm(&mstrm);
- b64strm.InitPosition();
- runtime::Tensor temp;
- TVM_FFI_ICHECK(temp.Load(&b64strm));
- return temp;
- });
-}
-
-struct RefToObjectPtr : public ffi::ObjectRef {
- static ffi::ObjectPtr<ffi::Object> Get(const ffi::ObjectRef& ref) {
- return
ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<ffi::Object>(ref);
- }
-};
-
-} // namespace tvm
diff --git a/src/runtime/tensor.cc b/src/runtime/tensor.cc
index bb526ef843..2b694b1742 100644
--- a/src/runtime/tensor.cc
+++ b/src/runtime/tensor.cc
@@ -22,12 +22,15 @@
* \brief Tensor container infratructure.
*/
#include <tvm/ffi/error.h>
+#include <tvm/ffi/extra/base64.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/base.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/tensor.h>
+#include "../support/base64.h"
+#include "../support/bytes_io.h"
#include "tvm/runtime/data_type.h"
namespace tvm {
@@ -241,6 +244,24 @@ using namespace tvm::runtime;
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
+ refl::TypeAttrDef<tvm::ffi::TensorObj>()
+ .def("__data_to_json__",
+ [](const tvm::ffi::TensorObj* node) {
+ std::string result;
+ tvm::support::BytesOutStream mstrm(&result);
+ tvm::support::Base64OutStream b64strm(&mstrm);
+ tvm::runtime::SaveDLTensor(&b64strm, node);
+ b64strm.Finish();
+ return tvm::ffi::String(std::move(result));
+ })
+ .def("__data_from_json__", [](const std::string& blob) {
+ tvm::support::BytesInStream mstrm(blob);
+ tvm::support::Base64InStream b64strm(&mstrm);
+ b64strm.InitPosition();
+ tvm::runtime::Tensor temp;
+ TVM_FFI_ICHECK(temp.Load(&b64strm));
+ return temp;
+ });
refl::GlobalDef()
.def("runtime.TVMTensorAllocWithScope", Tensor::Empty)
.def_method("runtime.TVMTensorCreateView", &Tensor::CreateView)