This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0df4103675 [Bugfix] Restrict CopyOnWrite to _type_final (#17132)
0df4103675 is described below
commit 0df4103675a52cc5b9e6356cb003bb17c66bc1a4
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Jul 2 10:18:08 2024 -0500
[Bugfix] Restrict CopyOnWrite to _type_final (#17132)
Prior to this commit, the `TVM_DEFINE_OBJECT_REF_COW_METHOD` could be
used in any `ObjectRef` subclass to provide a `CopyOnWrite` method.
However, the implementation of this method method was invalid if the
object's `ContainerType` could itself be subclassed. In that case,
using `obj.CopyOnWrite()` when the object contains a subclass, and
when a copy is required, would silently convert `obj` to instead
contain a base class.
This commit adds a `static_assert`, to the
`TVM_DEFINE_OBJECT_REF_COW_METHOD` macro, preventing the macro from being
used in classes that would have incorrect usage.
Compilation with this change found two classes, `relax::Var` and
`relax::BindingBlock` that were susceptible to this error, and the macro
has been removed from these classes. For backwards-compatibility, the
`CopyOnWrite` function for these two classes is provided explicitly.
---
include/tvm/relax/expr.h | 7 ++++---
include/tvm/runtime/object.h | 20 ++++++++++++--------
src/relax/ir/expr.cc | 38 ++++++++++++++++++++++++++++++++++++++
3 files changed, 54 insertions(+), 11 deletions(-)
diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index 401aaa9248..60032c3462 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -427,7 +427,8 @@ class Var : public LeafExpr {
TVM_DLL explicit Var(Id vid, Optional<StructInfo> struct_info_annotation,
Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode);
- TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);
+
+ VarNode* CopyOnWrite();
};
/*! \brief A sub-type of the variable node used to mark dataflow variables from
@@ -784,10 +785,10 @@ class BindingBlock : public ObjectRef {
public:
TVM_DLL explicit BindingBlock(Array<Binding> bindings, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode);
- TVM_DEFINE_OBJECT_REF_COW_METHOD(BindingBlockNode);
+
+ BindingBlockNode* CopyOnWrite();
};
-class DataflowBlock;
class DataflowBlockNode : public BindingBlockNode {
public:
bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const
{
diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index 172316daae..4483867f3c 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -823,14 +823,18 @@ struct ObjectPtrEqual {
*
* \endcode
*/
-#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \
- ObjectName* CopyOnWrite() { \
- ICHECK(data_ != nullptr); \
- if (!data_.unique()) { \
- auto n = make_object<ObjectName>(*(operator->())); \
- ObjectPtr<Object>(std::move(n)).swap(data_); \
- } \
- return static_cast<ObjectName*>(data_.get()); \
+#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \
+ static_assert(ObjectName::_type_final, \
+ "TVM's CopyOnWrite may only be used for " \
+ "Object types that are declared as final, " \
+ "using the TVM_DECLARE_FINAL_OBJECT_INFO macro."); \
+ ObjectName* CopyOnWrite() { \
+ ICHECK(data_ != nullptr); \
+ if (!data_.unique()) { \
+ auto n = make_object<ObjectName>(*(operator->())); \
+ ObjectPtr<Object>(std::move(n)).swap(data_); \
+ } \
+ return static_cast<ObjectName*>(data_.get()); \
}
// Implementations details below
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index 59b6a0aeb7..a14ba1d9aa 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -265,6 +265,25 @@ Var::Var(Id vid, Optional<StructInfo>
struct_info_annotation, Span span) {
data_ = std::move(n);
}
+VarNode* Var::CopyOnWrite() {
+ // The `TVM_DEFINE_OBJECT_REF_COW_METHOD` cannot be used for
+ // Var, because it is the base class for `DataflowBlock`.
+ // If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the
+ // automatic implementation would erroneously convert from a
+ // `DataflowBlock` to a `Var`.
+ ICHECK(data_ != nullptr);
+ if (!data_.unique()) {
+ ObjectPtr<VarNode> node;
+ if (auto dataflow_var = as<DataflowVarNode>()) {
+ node = make_object<DataflowVarNode>(*dataflow_var);
+ } else {
+ node = make_object<VarNode>(*(operator->()));
+ }
+ ObjectPtr<Object>(std::move(node)).swap(data_);
+ }
+ return static_cast<VarNode*>(data_.get());
+}
+
TVM_REGISTER_GLOBAL("relax.Var")
.set_body_typed([](String name_hint, Optional<StructInfo>
struct_info_annotation, Span span) {
return Var(name_hint, struct_info_annotation, span);
@@ -473,6 +492,25 @@ BindingBlock::BindingBlock(Array<Binding> bindings, Span
span) {
data_ = std::move(n);
}
+BindingBlockNode* BindingBlock::CopyOnWrite() {
+ // The `TVM_DEFINE_OBJECT_REF_COW_METHOD` cannot be used for
+ // BindingBlock, because it is the base class for `DataflowBlock`.
+ // If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the
+ // automatic implementation would erroneously convert from a
+ // `DataflowBlock` to a `BindingBlock`.
+ ICHECK(data_ != nullptr);
+ if (!data_.unique()) {
+ ObjectPtr<BindingBlockNode> node;
+ if (auto dataflow_block = as<DataflowBlockNode>()) {
+ node = make_object<DataflowBlockNode>(*dataflow_block);
+ } else {
+ node = make_object<BindingBlockNode>(*(operator->()));
+ }
+ ObjectPtr<Object>(std::move(node)).swap(data_);
+ }
+ return static_cast<BindingBlockNode*>(data_.get());
+}
+
TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array<Binding>
bindings, Span span) {
return BindingBlock(bindings, span);
});