This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bd17b10fff [FFI][REFACTOR] Isolate out extra API (#18177)
bd17b10fff is described below
commit bd17b10fff1046c03a942a129cb73a2bed53e90b
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Jul 30 15:19:56 2025 -0400
[FFI][REFACTOR] Isolate out extra API (#18177)
This PR formalizes the extra API in FFI. The extra APIs are minimal set of
APIs that are not required in core mechanism, but still helpful.
Move structural equal/hash to extra API.
---
ffi/CMakeLists.txt | 6 +-
ffi/include/tvm/ffi/c_api.h | 21 ----
ffi/include/tvm/ffi/extra/base.h | 48 +++++++++
.../ffi/{reflection => extra}/structural_equal.h | 13 ++-
.../ffi/{reflection => extra}/structural_hash.h | 11 +-
ffi/include/tvm/ffi/reflection/access_path.h | 18 ++--
.../ffi/{reflection => extra}/structural_equal.cc | 68 ++++++------
.../ffi/{reflection => extra}/structural_hash.cc | 10 +-
...equal_hash.cc => test_structural_equal_hash.cc} | 114 ++++++++++-----------
src/meta_schedule/module_equality.cc | 26 ++---
src/node/structural_equal.cc | 18 ++--
src/node/structural_hash.cc | 6 +-
src/relax/ir/block_builder.cc | 6 +-
src/relax/transform/lift_transform_params.cc | 6 +-
14 files changed, 198 insertions(+), 173 deletions(-)
diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt
index d6c56c8112..76b2901c7a 100644
--- a/ffi/CMakeLists.txt
+++ b/ffi/CMakeLists.txt
@@ -59,13 +59,13 @@ set(tvm_ffi_objs_sources
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
)
if (TVM_FFI_USE_EXTRA_CXX_API)
list(APPEND tvm_ffi_objs_sources
- "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
- "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_equal.cc"
- "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_hash.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_equal.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/structural_hash.cc"
)
endif()
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index 60743b82c6..d99832af01 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -56,27 +56,6 @@
#define TVM_FFI_DLL_EXPORT __attribute__((visibility("default")))
#endif
-/*!
- * \brief Marks the API as extra c++ api that is defined in cc files.
- *
- * These APIs are extra features that depend on, but are not required to
- * support essential core functionality, such as function calling and object
- * access.
- *
- * They are implemented in cc files to reduce compile-time overhead.
- * The input/output only uses POD/Any/ObjectRef for ABI stability.
- * However, these extra APIs may have an issue across MSVC/Itanium ABI,
- *
- * Related features are also available through reflection based function
- * that is fully based on C API
- *
- * The project aims to minimize the number of extra C++ APIs and only
- * restrict the use to non-core functionalities.
- */
-#ifndef TVM_FFI_EXTRA_CXX_API
-#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL
-#endif
-
#ifdef __cplusplus
extern "C" {
#endif
diff --git a/ffi/include/tvm/ffi/extra/base.h b/ffi/include/tvm/ffi/extra/base.h
new file mode 100644
index 0000000000..b09b3540a8
--- /dev/null
+++ b/ffi/include/tvm/ffi/extra/base.h
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/base.h
+ * \brief Base header for Extra API.
+ *
+ * The extra APIs contains a minmal set of extra APIs that are not
+ * required to support essential core functionality.
+ */
+#ifndef TVM_FFI_EXTRA_BASE_H_
+#define TVM_FFI_EXTRA_BASE_H_
+
+#include <tvm/ffi/c_api.h>
+
+/*!
+ * \brief Marks the API as extra c++ api that is defined in cc files.
+ *
+ * They are implemented in cc files to reduce compile-time overhead.
+ * The input/output only uses POD/Any/ObjectRef for ABI stability.
+ * However, these extra APIs may have an issue across MSVC/Itanium ABI,
+ *
+ * Related features are also available through reflection based function
+ * that is fully based on C API
+ *
+ * The project aims to minimize the number of extra C++ APIs to keep things
+ * lightweight and restrict the use to non-core functionalities.
+ */
+#ifndef TVM_FFI_EXTRA_CXX_API
+#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL
+#endif
+
+#endif // TVM_FFI_EXTRA_BASE_H_
diff --git a/ffi/include/tvm/ffi/reflection/structural_equal.h
b/ffi/include/tvm/ffi/extra/structural_equal.h
similarity index 90%
rename from ffi/include/tvm/ffi/reflection/structural_equal.h
rename to ffi/include/tvm/ffi/extra/structural_equal.h
index 860222644c..9727940297 100644
--- a/ffi/include/tvm/ffi/reflection/structural_equal.h
+++ b/ffi/include/tvm/ffi/extra/structural_equal.h
@@ -17,19 +17,19 @@
* under the License.
*/
/*!
- * \file tvm/ffi/reflection/structural_equal.h
+ * \file tvm/ffi/extra/structural_equal.h
* \brief Structural equal implementation
*/
-#ifndef TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_
-#define TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_
+#ifndef TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_
+#define TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_
#include <tvm/ffi/any.h>
+#include <tvm/ffi/extra/base.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/reflection/access_path.h>
namespace tvm {
namespace ffi {
-namespace reflection {
/*
* \brief Structural equality comparators
*/
@@ -59,7 +59,7 @@ class StructuralEqual {
* \return If comparison fails, return the first mismatch AccessPath pair,
* otherwise return std::nullopt.
*/
- TVM_FFI_EXTRA_CXX_API static Optional<AccessPathPair> GetFirstMismatch(
+ TVM_FFI_EXTRA_CXX_API static Optional<reflection::AccessPathPair>
GetFirstMismatch(
const Any& lhs, const Any& rhs, bool map_free_vars = false,
bool skip_ndarray_content = false);
@@ -74,7 +74,6 @@ class StructuralEqual {
}
};
-} // namespace reflection
} // namespace ffi
} // namespace tvm
-#endif // TVM_FFI_REFLECTION_STRUCTURAL_EQUAL_H_
+#endif // TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_
diff --git a/ffi/include/tvm/ffi/reflection/structural_hash.h
b/ffi/include/tvm/ffi/extra/structural_hash.h
similarity index 87%
rename from ffi/include/tvm/ffi/reflection/structural_hash.h
rename to ffi/include/tvm/ffi/extra/structural_hash.h
index b0d17cf8bf..9cb08a1c0f 100644
--- a/ffi/include/tvm/ffi/reflection/structural_hash.h
+++ b/ffi/include/tvm/ffi/extra/structural_hash.h
@@ -17,17 +17,17 @@
* under the License.
*/
/*!
- * \file tvm/ffi/reflection/structural_hash.h
+ * \file tvm/ffi/extra/structural_hash.h
* \brief Structural hash
*/
-#ifndef TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_
-#define TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_
+#ifndef TVM_FFI_EXTRA_STRUCTURAL_HASH_H_
+#define TVM_FFI_EXTRA_STRUCTURAL_HASH_H_
#include <tvm/ffi/any.h>
+#include <tvm/ffi/extra/base.h>
namespace tvm {
namespace ffi {
-namespace reflection {
/*
* \brief Structural hash
@@ -52,7 +52,6 @@ class StructuralHash {
TVM_FFI_INLINE uint64_t operator()(const Any& value) const { return
Hash(value); }
};
-} // namespace reflection
} // namespace ffi
} // namespace tvm
-#endif // TVM_FFI_REFLECTION_STRUCTURAL_HASH_H_
+#endif // TVM_FFI_EXTRA_STRUCTURAL_HASH_H_
diff --git a/ffi/include/tvm/ffi/reflection/access_path.h
b/ffi/include/tvm/ffi/reflection/access_path.h
index e37b3f410c..a4f40f485e 100644
--- a/ffi/include/tvm/ffi/reflection/access_path.h
+++ b/ffi/include/tvm/ffi/reflection/access_path.h
@@ -35,12 +35,12 @@ namespace reflection {
enum class AccessKind : int32_t {
kObjectField = 0,
- kArrayIndex = 1,
- kMapKey = 2,
+ kArrayItem = 1,
+ kMapItem = 2,
// the following two are used for error reporting when
// the supposed access field is not available
- kArrayIndexMissing = 3,
- kMapKeyMissing = 4,
+ kArrayItemMissing = 3,
+ kMapItemMissing = 4,
};
/*!
@@ -86,15 +86,15 @@ class AccessStep : public ObjectRef {
return AccessStep(AccessKind::kObjectField, field_name);
}
- static AccessStep ArrayIndex(int64_t index) { return
AccessStep(AccessKind::kArrayIndex, index); }
+ static AccessStep ArrayItem(int64_t index) { return
AccessStep(AccessKind::kArrayItem, index); }
- static AccessStep ArrayIndexMissing(int64_t index) {
- return AccessStep(AccessKind::kArrayIndexMissing, index);
+ static AccessStep ArrayItemMissing(int64_t index) {
+ return AccessStep(AccessKind::kArrayItemMissing, index);
}
- static AccessStep MapKey(Any key) { return AccessStep(AccessKind::kMapKey,
key); }
+ static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem,
key); }
- static AccessStep MapKeyMissing(Any key) { return
AccessStep(AccessKind::kMapKeyMissing, key); }
+ static AccessStep MapItemMissing(Any key) { return
AccessStep(AccessKind::kMapItemMissing, key); }
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessStep, ObjectRef,
AccessStepObj);
};
diff --git a/ffi/src/ffi/reflection/structural_equal.cc
b/ffi/src/ffi/extra/structural_equal.cc
similarity index 83%
rename from ffi/src/ffi/reflection/structural_equal.cc
rename to ffi/src/ffi/extra/structural_equal.cc
index e44a0c3256..a73c07713f 100644
--- a/ffi/src/ffi/reflection/structural_equal.cc
+++ b/ffi/src/ffi/extra/structural_equal.cc
@@ -25,8 +25,8 @@
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/ndarray.h>
#include <tvm/ffi/container/shape.h>
+#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/reflection/accessor.h>
-#include <tvm/ffi/reflection/structural_equal.h>
#include <tvm/ffi/string.h>
#include <cmath>
@@ -34,7 +34,6 @@
namespace tvm {
namespace ffi {
-namespace reflection {
/**
* \brief Internal Handler class for structural equal comparison.
@@ -135,11 +134,11 @@ class StructEqualHandler {
bool success = true;
if (custom_s_equal[type_info->type_index] == nullptr) {
// We recursively compare the fields the object
- ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo*
field_info) {
+ reflection::ForEachFieldInfoWithEarlyStop(type_info, [&](const
TVMFFIFieldInfo* field_info) {
// skip fields that are marked as structural eq hash ignore
if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore) return
false;
// get the field value from both side
- FieldGetter getter(field_info);
+ reflection::FieldGetter getter(field_info);
Any lhs_value = getter(lhs);
Any rhs_value = getter(rhs);
// field is in def region, enable free var mapping
@@ -155,9 +154,9 @@ class StructEqualHandler {
// record the first mismatching field if we sub-rountine compare
failed
if (mismatch_lhs_reverse_path_ != nullptr) {
mismatch_lhs_reverse_path_->emplace_back(
- AccessStep::ObjectField(String(field_info->name)));
+ reflection::AccessStep::ObjectField(String(field_info->name)));
mismatch_rhs_reverse_path_->emplace_back(
- AccessStep::ObjectField(String(field_info->name)));
+ reflection::AccessStep::ObjectField(String(field_info->name)));
}
// return true to indicate early stop
return true;
@@ -185,8 +184,10 @@ class StructEqualHandler {
if (!success) {
if (mismatch_lhs_reverse_path_ != nullptr) {
String field_name_str = field_name.cast<String>();
-
mismatch_lhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str));
-
mismatch_rhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str));
+ mismatch_lhs_reverse_path_->emplace_back(
+ reflection::AccessStep::ObjectField(field_name_str));
+ mismatch_rhs_reverse_path_->emplace_back(
+ reflection::AccessStep::ObjectField(field_name_str));
}
}
return success;
@@ -235,16 +236,16 @@ class StructEqualHandler {
auto it = rhs.find(rhs_key);
if (it == rhs.end()) {
if (mismatch_lhs_reverse_path_ != nullptr) {
-
mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first));
-
mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKeyMissing(rhs_key));
+
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first));
+
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItemMissing(rhs_key));
}
return false;
}
// now recursively compare value
if (!CompareAny(kv.second, (*it).second)) {
if (mismatch_lhs_reverse_path_ != nullptr) {
-
mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first));
-
mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKey(rhs_key));
+
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first));
+
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(rhs_key));
}
return false;
}
@@ -258,8 +259,8 @@ class StructEqualHandler {
auto it = lhs.find(lhs_key);
if (it == lhs.end()) {
if (mismatch_lhs_reverse_path_ != nullptr) {
-
mismatch_lhs_reverse_path_->emplace_back(AccessStep::MapKeyMissing(lhs_key));
-
mismatch_rhs_reverse_path_->emplace_back(AccessStep::MapKey(kv.first));
+
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::MapItemMissing(lhs_key));
+
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::MapItem(kv.first));
}
return false;
}
@@ -276,8 +277,8 @@ class StructEqualHandler {
for (size_t i = 0; i < std::min(lhs.size(), rhs.size()); ++i) {
if (!CompareAny(lhs[i], rhs[i])) {
if (mismatch_lhs_reverse_path_ != nullptr) {
- mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(i));
- mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(i));
+
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(i));
+
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(i));
}
return false;
}
@@ -285,11 +286,13 @@ class StructEqualHandler {
if (lhs.size() == rhs.size()) return true;
if (mismatch_lhs_reverse_path_ != nullptr) {
if (lhs.size() > rhs.size()) {
-
mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(rhs.size()));
-
mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndexMissing(rhs.size()));
+
mismatch_lhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(rhs.size()));
+ mismatch_rhs_reverse_path_->emplace_back(
+ reflection::AccessStep::ArrayItemMissing(rhs.size()));
} else {
-
mismatch_lhs_reverse_path_->emplace_back(AccessStep::ArrayIndexMissing(lhs.size()));
-
mismatch_rhs_reverse_path_->emplace_back(AccessStep::ArrayIndex(lhs.size()));
+ mismatch_lhs_reverse_path_->emplace_back(
+ reflection::AccessStep::ArrayItemMissing(lhs.size()));
+
mismatch_rhs_reverse_path_->emplace_back(reflection::AccessStep::ArrayItem(lhs.size()));
}
}
return false;
@@ -354,8 +357,8 @@ class StructEqualHandler {
// whether we compare ndarray data
bool skip_ndarray_content_{false};
// the root lhs for result printing
- std::vector<AccessStep>* mismatch_lhs_reverse_path_ = nullptr;
- std::vector<AccessStep>* mismatch_rhs_reverse_path_ = nullptr;
+ std::vector<reflection::AccessStep>* mismatch_lhs_reverse_path_ = nullptr;
+ std::vector<reflection::AccessStep>* mismatch_rhs_reverse_path_ = nullptr;
// lazily initialize custom equal function
ffi::Function s_equal_callback_ = nullptr;
// map from lhs to rhs
@@ -372,32 +375,31 @@ bool StructuralEqual::Equal(const Any& lhs, const Any&
rhs, bool map_free_vars,
return handler.CompareAny(lhs, rhs);
}
-Optional<AccessPathPair> StructuralEqual::GetFirstMismatch(const Any& lhs,
const Any& rhs,
- bool map_free_vars,
- bool
skip_ndarray_content) {
+Optional<reflection::AccessPathPair> StructuralEqual::GetFirstMismatch(const
Any& lhs,
+ const
Any& rhs,
+ bool
map_free_vars,
+ bool
skip_ndarray_content) {
StructEqualHandler handler;
handler.map_free_vars_ = map_free_vars;
handler.skip_ndarray_content_ = skip_ndarray_content;
- std::vector<AccessStep> lhs_reverse_path;
- std::vector<AccessStep> rhs_reverse_path;
+ std::vector<reflection::AccessStep> lhs_reverse_path;
+ std::vector<reflection::AccessStep> rhs_reverse_path;
handler.mismatch_lhs_reverse_path_ = &lhs_reverse_path;
handler.mismatch_rhs_reverse_path_ = &rhs_reverse_path;
if (handler.CompareAny(lhs, rhs)) {
return std::nullopt;
}
- AccessPath lhs_path(lhs_reverse_path.rbegin(), lhs_reverse_path.rend());
- AccessPath rhs_path(rhs_reverse_path.rbegin(), rhs_reverse_path.rend());
- return AccessPathPair(lhs_path, rhs_path);
+ reflection::AccessPath lhs_path(lhs_reverse_path.rbegin(),
lhs_reverse_path.rend());
+ reflection::AccessPath rhs_path(rhs_reverse_path.rbegin(),
rhs_reverse_path.rend());
+ return reflection::AccessPathPair(lhs_path, rhs_path);
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("ffi.reflection.GetFirstStructuralMismatch",
- StructuralEqual::GetFirstMismatch);
+ refl::GlobalDef().def("ffi.GetFirstStructuralMismatch",
StructuralEqual::GetFirstMismatch);
// ensure the type attribute column is presented in the system even if it is
empty.
refl::EnsureTypeAttrColumn("__s_equal__");
});
-} // namespace reflection
} // namespace ffi
} // namespace tvm
diff --git a/ffi/src/ffi/reflection/structural_hash.cc
b/ffi/src/ffi/extra/structural_hash.cc
similarity index 97%
rename from ffi/src/ffi/reflection/structural_hash.cc
rename to ffi/src/ffi/extra/structural_hash.cc
index e8ffcf6d2a..e47fbbacc8 100644
--- a/ffi/src/ffi/reflection/structural_hash.cc
+++ b/ffi/src/ffi/extra/structural_hash.cc
@@ -25,9 +25,9 @@
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/ndarray.h>
#include <tvm/ffi/container/shape.h>
+#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/ffi/reflection/structural_hash.h>
#include <tvm/ffi/string.h>
#include <cmath>
@@ -37,7 +37,6 @@
namespace tvm {
namespace ffi {
-namespace reflection {
/**
* \brief Internal Handler class for structural hash.
*/
@@ -119,11 +118,11 @@ class StructuralHashHandler {
uint64_t hash_value = obj->GetTypeKeyHash();
if (custom_s_hash[type_info->type_index] == nullptr) {
// go over the content and hash the fields
- ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) {
+ reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo*
field_info) {
// skip fields that are marked as structural eq hash ignore
if (!(field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashIgnore)) {
// get the field value from both side
- FieldGetter getter(field_info);
+ reflection::FieldGetter getter(field_info);
Any field_value = getter(obj);
// field is in def region, enable free var mapping
if (field_info->flags & kTVMFFIFieldFlagBitMaskSEqHashDef) {
@@ -297,10 +296,9 @@ uint64_t StructuralHash::Hash(const Any& value, bool
map_free_vars, bool skip_nd
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("ffi.reflection.StructuralHash", StructuralHash::Hash);
+ refl::GlobalDef().def("ffi.StructuralHash", StructuralHash::Hash);
refl::EnsureTypeAttrColumn("__s_hash__");
});
-} // namespace reflection
} // namespace ffi
} // namespace tvm
diff --git a/ffi/tests/cpp/extra/test_reflection_structural_equal_hash.cc
b/ffi/tests/cpp/extra/test_structural_equal_hash.cc
similarity index 70%
rename from ffi/tests/cpp/extra/test_reflection_structural_equal_hash.cc
rename to ffi/tests/cpp/extra/test_structural_equal_hash.cc
index d3353b782d..76c485d906 100644
--- a/ffi/tests/cpp/extra/test_reflection_structural_equal_hash.cc
+++ b/ffi/tests/cpp/extra/test_structural_equal_hash.cc
@@ -20,10 +20,10 @@
#include <gtest/gtest.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/extra/structural_equal.h>
+#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/ffi/reflection/structural_equal.h>
-#include <tvm/ffi/reflection/structural_hash.h>
#include <tvm/ffi/string.h>
#include "../testing_object.h"
@@ -37,18 +37,18 @@ namespace refl = tvm::ffi::reflection;
TEST(StructuralEqualHash, Array) {
Array<int> a = {1, 2, 3};
Array<int> b = {1, 2, 3};
- EXPECT_TRUE(refl::StructuralEqual()(a, b));
- EXPECT_EQ(refl::StructuralHash()(a), refl::StructuralHash()(b));
+ EXPECT_TRUE(StructuralEqual()(a, b));
+ EXPECT_EQ(StructuralHash()(a), StructuralHash()(b));
Array<int> c = {1, 3};
- EXPECT_FALSE(refl::StructuralEqual()(a, c));
- EXPECT_NE(refl::StructuralHash()(a), refl::StructuralHash()(c));
- auto diff_a_c = refl::StructuralEqual::GetFirstMismatch(a, c);
+ EXPECT_FALSE(StructuralEqual()(a, c));
+ EXPECT_NE(StructuralHash()(a), StructuralHash()(c));
+ auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c);
// first directly interepret diff,
EXPECT_TRUE(diff_a_c.has_value());
- EXPECT_EQ((*diff_a_c).get<0>()[0]->kind, refl::AccessKind::kArrayIndex);
- EXPECT_EQ((*diff_a_c).get<1>()[0]->kind, refl::AccessKind::kArrayIndex);
+ EXPECT_EQ((*diff_a_c).get<0>()[0]->kind, refl::AccessKind::kArrayItem);
+ EXPECT_EQ((*diff_a_c).get<1>()[0]->kind, refl::AccessKind::kArrayItem);
EXPECT_EQ((*diff_a_c).get<0>()[0]->key.cast<int64_t>(), 1);
EXPECT_EQ((*diff_a_c).get<1>()[0]->key.cast<int64_t>(), 1);
EXPECT_EQ((*diff_a_c).get<0>().size(), 1);
@@ -57,90 +57,90 @@ TEST(StructuralEqualHash, Array) {
// use structural equal for checking in future parts
// given we have done some basic checks above by directly interepret diff,
Array<int> d = {1, 2};
- auto diff_a_d = refl::StructuralEqual::GetFirstMismatch(a, d);
+ auto diff_a_d = StructuralEqual::GetFirstMismatch(a, d);
auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath({
-
refl::AccessStep::ArrayIndex(2),
+
refl::AccessStep::ArrayItem(2),
}),
refl::AccessPath({
-
refl::AccessStep::ArrayIndexMissing(2),
+
refl::AccessStep::ArrayItemMissing(2),
}));
// then use structural equal to check it
- EXPECT_TRUE(refl::StructuralEqual()(diff_a_d, expected_diff_a_d));
+ EXPECT_TRUE(StructuralEqual()(diff_a_d, expected_diff_a_d));
}
TEST(StructuralEqualHash, Map) {
// same map but different insertion order
Map<String, int> a = {{"a", 1}, {"b", 2}, {"c", 3}};
Map<String, int> b = {{"b", 2}, {"c", 3}, {"a", 1}};
- EXPECT_TRUE(refl::StructuralEqual()(a, b));
- EXPECT_EQ(refl::StructuralHash()(a), refl::StructuralHash()(b));
+ EXPECT_TRUE(StructuralEqual()(a, b));
+ EXPECT_EQ(StructuralHash()(a), StructuralHash()(b));
Map<String, int> c = {{"a", 1}, {"b", 2}, {"c", 4}};
- EXPECT_FALSE(refl::StructuralEqual()(a, c));
- EXPECT_NE(refl::StructuralHash()(a), refl::StructuralHash()(c));
+ EXPECT_FALSE(StructuralEqual()(a, c));
+ EXPECT_NE(StructuralHash()(a), StructuralHash()(c));
- auto diff_a_c = refl::StructuralEqual::GetFirstMismatch(a, c);
+ auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c);
auto expected_diff_a_c = refl::AccessPathPair(refl::AccessPath({
-
refl::AccessStep::MapKey("c"),
+
refl::AccessStep::MapItem("c"),
}),
refl::AccessPath({
-
refl::AccessStep::MapKey("c"),
+
refl::AccessStep::MapItem("c"),
}));
EXPECT_TRUE(diff_a_c.has_value());
- EXPECT_TRUE(refl::StructuralEqual()(diff_a_c, expected_diff_a_c));
+ EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c));
}
TEST(StructuralEqualHash, NestedMapArray) {
Map<String, Array<Any>> a = {{"a", {1, 2, 3}}, {"b", {4, "hello", 6}}};
Map<String, Array<Any>> b = {{"a", {1, 2, 3}}, {"b", {4, "hello", 6}}};
- EXPECT_TRUE(refl::StructuralEqual()(a, b));
- EXPECT_EQ(refl::StructuralHash()(a), refl::StructuralHash()(b));
+ EXPECT_TRUE(StructuralEqual()(a, b));
+ EXPECT_EQ(StructuralHash()(a), StructuralHash()(b));
Map<String, Array<Any>> c = {{"a", {1, 2, 3}}, {"b", {4, "world", 6}}};
- EXPECT_FALSE(refl::StructuralEqual()(a, c));
- EXPECT_NE(refl::StructuralHash()(a), refl::StructuralHash()(c));
+ EXPECT_FALSE(StructuralEqual()(a, c));
+ EXPECT_NE(StructuralHash()(a), StructuralHash()(c));
- auto diff_a_c = refl::StructuralEqual::GetFirstMismatch(a, c);
+ auto diff_a_c = StructuralEqual::GetFirstMismatch(a, c);
auto expected_diff_a_c = refl::AccessPathPair(refl::AccessPath({
-
refl::AccessStep::MapKey("b"),
-
refl::AccessStep::ArrayIndex(1),
+
refl::AccessStep::MapItem("b"),
+
refl::AccessStep::ArrayItem(1),
}),
refl::AccessPath({
-
refl::AccessStep::MapKey("b"),
-
refl::AccessStep::ArrayIndex(1),
+
refl::AccessStep::MapItem("b"),
+
refl::AccessStep::ArrayItem(1),
}));
EXPECT_TRUE(diff_a_c.has_value());
- EXPECT_TRUE(refl::StructuralEqual()(diff_a_c, expected_diff_a_c));
+ EXPECT_TRUE(StructuralEqual()(diff_a_c, expected_diff_a_c));
Map<String, Array<Any>> d = {{"a", {1, 2, 3}}};
- auto diff_a_d = refl::StructuralEqual::GetFirstMismatch(a, d);
+ auto diff_a_d = StructuralEqual::GetFirstMismatch(a, d);
auto expected_diff_a_d = refl::AccessPathPair(refl::AccessPath({
-
refl::AccessStep::MapKey("b"),
+
refl::AccessStep::MapItem("b"),
}),
refl::AccessPath({
-
refl::AccessStep::MapKeyMissing("b"),
+
refl::AccessStep::MapItemMissing("b"),
}));
EXPECT_TRUE(diff_a_d.has_value());
- EXPECT_TRUE(refl::StructuralEqual()(diff_a_d, expected_diff_a_d));
+ EXPECT_TRUE(StructuralEqual()(diff_a_d, expected_diff_a_d));
- auto diff_d_a = refl::StructuralEqual::GetFirstMismatch(d, a);
+ auto diff_d_a = StructuralEqual::GetFirstMismatch(d, a);
auto expected_diff_d_a = refl::AccessPathPair(refl::AccessPath({
-
refl::AccessStep::MapKeyMissing("b"),
+
refl::AccessStep::MapItemMissing("b"),
}),
refl::AccessPath({
-
refl::AccessStep::MapKey("b"),
+
refl::AccessStep::MapItem("b"),
}));
}
TEST(StructuralEqualHash, FreeVar) {
TVar a = TVar("a");
TVar b = TVar("b");
- EXPECT_TRUE(refl::StructuralEqual::Equal(a, b, /*map_free_vars=*/true));
- EXPECT_FALSE(refl::StructuralEqual::Equal(a, b));
+ EXPECT_TRUE(StructuralEqual::Equal(a, b, /*map_free_vars=*/true));
+ EXPECT_FALSE(StructuralEqual::Equal(a, b));
- EXPECT_NE(refl::StructuralHash()(a), refl::StructuralHash()(b));
- EXPECT_EQ(refl::StructuralHash::Hash(a, /*map_free_vars=*/true),
- refl::StructuralHash::Hash(b, /*map_free_vars=*/true));
+ EXPECT_NE(StructuralHash()(a), StructuralHash()(b));
+ EXPECT_EQ(StructuralHash::Hash(a, /*map_free_vars=*/true),
+ StructuralHash::Hash(b, /*map_free_vars=*/true));
}
TEST(StructuralEqualHash, FuncDefAndIgnoreField) {
@@ -152,21 +152,21 @@ TEST(StructuralEqualHash, FuncDefAndIgnoreField) {
TFunc fc = TFunc({x}, {TInt(1), TInt(2)}, "comment c");
- EXPECT_TRUE(refl::StructuralEqual()(fa, fb));
- EXPECT_EQ(refl::StructuralHash()(fa), refl::StructuralHash()(fb));
+ EXPECT_TRUE(StructuralEqual()(fa, fb));
+ EXPECT_EQ(StructuralHash()(fa), StructuralHash()(fb));
- EXPECT_FALSE(refl::StructuralEqual()(fa, fc));
- auto diff_fa_fc = refl::StructuralEqual::GetFirstMismatch(fa, fc);
+ EXPECT_FALSE(StructuralEqual()(fa, fc));
+ auto diff_fa_fc = StructuralEqual::GetFirstMismatch(fa, fc);
auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath({
refl::AccessStep::ObjectField("body"),
-
refl::AccessStep::ArrayIndex(1),
+
refl::AccessStep::ArrayItem(1),
}),
refl::AccessPath({
refl::AccessStep::ObjectField("body"),
-
refl::AccessStep::ArrayIndex(1),
+
refl::AccessStep::ArrayItem(1),
}));
EXPECT_TRUE(diff_fa_fc.has_value());
- EXPECT_TRUE(refl::StructuralEqual()(diff_fa_fc, expected_diff_fa_fc));
+ EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc));
}
TEST(StructuralEqualHash, CustomTreeNode) {
@@ -178,21 +178,21 @@ TEST(StructuralEqualHash, CustomTreeNode) {
TCustomFunc fc = TCustomFunc({x}, {TInt(1), TInt(2)}, "comment c");
- EXPECT_TRUE(refl::StructuralEqual()(fa, fb));
- EXPECT_EQ(refl::StructuralHash()(fa), refl::StructuralHash()(fb));
+ EXPECT_TRUE(StructuralEqual()(fa, fb));
+ EXPECT_EQ(StructuralHash()(fa), StructuralHash()(fb));
- EXPECT_FALSE(refl::StructuralEqual()(fa, fc));
- auto diff_fa_fc = refl::StructuralEqual::GetFirstMismatch(fa, fc);
+ EXPECT_FALSE(StructuralEqual()(fa, fc));
+ auto diff_fa_fc = StructuralEqual::GetFirstMismatch(fa, fc);
auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath({
refl::AccessStep::ObjectField("body"),
-
refl::AccessStep::ArrayIndex(1),
+
refl::AccessStep::ArrayItem(1),
}),
refl::AccessPath({
refl::AccessStep::ObjectField("body"),
-
refl::AccessStep::ArrayIndex(1),
+
refl::AccessStep::ArrayItem(1),
}));
EXPECT_TRUE(diff_fa_fc.has_value());
- EXPECT_TRUE(refl::StructuralEqual()(diff_fa_fc, expected_diff_fa_fc));
+ EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc));
}
} // namespace
diff --git a/src/meta_schedule/module_equality.cc
b/src/meta_schedule/module_equality.cc
index 501d55b8ef..df8c45b5e6 100644
--- a/src/meta_schedule/module_equality.cc
+++ b/src/meta_schedule/module_equality.cc
@@ -18,8 +18,8 @@
*/
#include "module_equality.h"
-#include <tvm/ffi/reflection/structural_equal.h>
-#include <tvm/ffi/reflection/structural_hash.h>
+#include <tvm/ffi/extra/structural_equal.h>
+#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ir/module.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
@@ -40,12 +40,12 @@ class ModuleEqualityStructural : public ModuleEquality {
class ModuleEqualityIgnoreNDArray : public ModuleEquality {
public:
size_t Hash(IRModule mod) const {
- return tvm::ffi::reflection::StructuralHash::Hash(mod,
/*map_free_vars=*/false,
-
/*skip_ndarray_content=*/true);
+ return tvm::ffi::StructuralHash::Hash(mod, /*map_free_vars=*/false,
+ /*skip_ndarray_content=*/true);
}
bool Equal(IRModule lhs, IRModule rhs) const {
- return tvm::ffi::reflection::StructuralEqual::Equal(lhs, rhs,
/*map_free_vars=*/false,
-
/*skip_ndarray_content=*/true);
+ return tvm::ffi::StructuralEqual::Equal(lhs, rhs, /*map_free_vars=*/false,
+ /*skip_ndarray_content=*/true);
}
String GetName() const { return "ignore-ndarray"; }
};
@@ -56,9 +56,9 @@ class ModuleEqualityAnchorBlock : public ModuleEquality {
size_t Hash(IRModule mod) const {
auto anchor_block = tir::FindAnchorBlock(mod);
if (anchor_block) {
- return
ffi::reflection::StructuralHash::Hash(GetRef<tir::Block>(anchor_block),
- /*map_free_vars=*/false,
-
/*skip_ndarray_content=*/true);
+ return ffi::StructuralHash::Hash(GetRef<tir::Block>(anchor_block),
+ /*map_free_vars=*/false,
+ /*skip_ndarray_content=*/true);
}
return ModuleEqualityIgnoreNDArray().Hash(mod);
}
@@ -66,10 +66,10 @@ class ModuleEqualityAnchorBlock : public ModuleEquality {
auto anchor_block_lhs = tir::FindAnchorBlock(lhs);
auto anchor_block_rhs = tir::FindAnchorBlock(rhs);
if (anchor_block_lhs && anchor_block_rhs) {
- return
tvm::ffi::reflection::StructuralEqual::Equal(GetRef<tir::Block>(anchor_block_lhs),
-
GetRef<tir::Block>(anchor_block_rhs),
-
/*map_free_vars=*/false,
-
/*skip_ndarray_content=*/true);
+ return
tvm::ffi::StructuralEqual::Equal(GetRef<tir::Block>(anchor_block_lhs),
+
GetRef<tir::Block>(anchor_block_rhs),
+ /*map_free_vars=*/false,
+ /*skip_ndarray_content=*/true);
}
return ModuleEqualityIgnoreNDArray().Equal(lhs, rhs);
}
diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc
index 186f509472..5474be6676 100644
--- a/src/node/structural_equal.cc
+++ b/src/node/structural_equal.cc
@@ -19,10 +19,10 @@
/*!
* \file src/node/structural_equal.cc
*/
+#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/access_path.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/ffi/reflection/structural_equal.h>
#include <tvm/ir/module.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
@@ -64,19 +64,19 @@ Optional<ObjectPathPair> ObjectPathPairFromAccessPathPair(
result = result->Attr(step->key.cast<String>());
break;
}
- case ffi::reflection::AccessKind::kArrayIndex: {
+ case ffi::reflection::AccessKind::kArrayItem: {
result = result->ArrayIndex(step->key.cast<int64_t>());
break;
}
- case ffi::reflection::AccessKind::kMapKey: {
+ case ffi::reflection::AccessKind::kMapItem: {
result = result->MapValue(step->key);
break;
}
- case ffi::reflection::AccessKind::kArrayIndexMissing: {
+ case ffi::reflection::AccessKind::kArrayItemMissing: {
result = result->MissingArrayElement(step->key.cast<int64_t>());
break;
}
- case ffi::reflection::AccessKind::kMapKeyMissing: {
+ case ffi::reflection::AccessKind::kMapItemMissing: {
result = result->MissingMapEntry();
break;
}
@@ -96,7 +96,7 @@ bool NodeStructuralEqualAdapter(const Any& lhs, const Any&
rhs, bool assert_mode
bool map_free_vars) {
if (assert_mode) {
auto first_mismatch = ObjectPathPairFromAccessPathPair(
- ffi::reflection::StructuralEqual::GetFirstMismatch(lhs, rhs,
map_free_vars));
+ ffi::StructuralEqual::GetFirstMismatch(lhs, rhs, map_free_vars));
if (first_mismatch.has_value()) {
std::ostringstream oss;
oss << "StructuralEqual check failed, caused by lhs";
@@ -129,7 +129,7 @@ bool NodeStructuralEqualAdapter(const Any& lhs, const Any&
rhs, bool assert_mode
}
return true;
} else {
- return ffi::reflection::StructuralEqual::Equal(lhs, rhs, map_free_vars);
+ return ffi::StructuralEqual::Equal(lhs, rhs, map_free_vars);
}
}
@@ -147,12 +147,12 @@ TVM_FFI_STATIC_INIT_BLOCK({
return first_mismatch;
*/
return ObjectPathPairFromAccessPathPair(
- ffi::reflection::StructuralEqual::GetFirstMismatch(lhs, rhs,
map_free_vars));
+ ffi::StructuralEqual::GetFirstMismatch(lhs, rhs,
map_free_vars));
});
});
bool StructuralEqual::operator()(const ffi::Any& lhs, const ffi::Any& rhs,
bool map_free_params) const {
- return ffi::reflection::StructuralEqual::Equal(lhs, rhs, map_free_params);
+ return ffi::StructuralEqual::Equal(lhs, rhs, map_free_params);
}
} // namespace tvm
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index 3a5f1de041..383f344fac 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -20,9 +20,9 @@
* \file src/node/structural_hash.cc
*/
#include <dmlc/memory_io.h>
+#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/ffi/reflection/structural_hash.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
#include <tvm/node/object_path.h>
@@ -44,12 +44,12 @@ TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("node.StructuralHash",
[](const Any& object, bool map_free_vars) -> int64_t {
- return ffi::reflection::StructuralHash::Hash(object,
map_free_vars);
+ return ffi::StructuralHash::Hash(object,
map_free_vars);
});
});
uint64_t StructuralHash::operator()(const ffi::Any& object) const {
- return ffi::reflection::StructuralHash::Hash(object, false);
+ return ffi::StructuralHash::Hash(object, false);
}
struct RefToObjectPtr : public ObjectRef {
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index 037a9f3021..f87f531a87 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -21,9 +21,9 @@
* \file src/relax/block_builder.cc
*/
#include <tvm/arith/analyzer.h>
+#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/ffi/reflection/structural_hash.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/block_builder.h>
#include <tvm/relax/expr_functor.h>
@@ -431,8 +431,8 @@ class BlockBuilderImpl : public BlockBuilderNode {
class StructuralHashIgnoreNDarray {
public:
uint64_t operator()(const ObjectRef& key) const {
- return ffi::reflection::StructuralHash::Hash(key,
/*map_free_vars=*/false,
-
/*skip_ndarray_content=*/true);
+ return ffi::StructuralHash::Hash(key, /*map_free_vars=*/false,
+ /*skip_ndarray_content=*/true);
}
};
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
index 83d978f27d..40a1c307ce 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -22,8 +22,8 @@
* \brief Lift local functions into global functions.
*/
+#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/ffi/reflection/structural_equal.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
@@ -541,8 +541,8 @@ class ParamRemapper : private ExprFunctor<void(const Expr&,
const Expr&)> {
} else {
var_remap_.Set(GetRef<Var>(lhs_var), rhs_var);
}
- CHECK(tvm::ffi::reflection::StructuralEqual::Equal(lhs_var->struct_info_,
rhs_var->struct_info_,
- /*map_free_vars=*/true))
+ CHECK(tvm::ffi::StructuralEqual::Equal(lhs_var->struct_info_,
rhs_var->struct_info_,
+ /*map_free_vars=*/true))
<< "The struct info of the parameters should be the same for all
target functions";
auto lhs_tir_vars =
DefinableTIRVarsInStructInfo(GetStructInfo(GetRef<Var>(lhs_var)));
auto rhs_tir_vars = DefinableTIRVarsInStructInfo(GetStructInfo(rhs_expr));