This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c3b168b8ea [FFI][REFACTOR] Introduce UnsafeInit and enhance ObjectRef
null safety (#18284)
c3b168b8ea is described below
commit c3b168b8eaea920a719677163a20e70729f91a70
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Sep 8 15:01:28 2025 -0400
[FFI][REFACTOR] Introduce UnsafeInit and enhance ObjectRef null safety
(#18284)
This PR enhances the nullptr and general type-safe of ObjectRef types.
Previously ObjectRef relies on constructor from ObjectPtr<Object> for
casting
and initialize from nullptr.
We introduce a tag ffi::UnsafeInit, which explicitly states the intent
that the initialization is unsafe and may initialize non-nullable Ref to
null.
Such tag should only be used in controlled scenarios.
Now the general RefType(ObjectPtr<Object>) is removed.
We still keep RefType(ObjectPtr<ContainerType>) for nullable objects,
but removes the default definition from non-nullable types, knowing that
user can always explicitly add it to class impl (ensuring null checking).
---
ffi/include/tvm/ffi/cast.h | 10 ++--
ffi/include/tvm/ffi/container/array.h | 4 ++
ffi/include/tvm/ffi/container/map.h | 4 ++
ffi/include/tvm/ffi/container/shape.h | 17 +++++-
ffi/include/tvm/ffi/container/tensor.h | 4 +-
ffi/include/tvm/ffi/container/tuple.h | 14 ++---
ffi/include/tvm/ffi/container/variant.h | 2 +-
ffi/include/tvm/ffi/extra/module.h | 17 +++++-
ffi/include/tvm/ffi/function.h | 2 +-
ffi/include/tvm/ffi/function_details.h | 2 +-
ffi/include/tvm/ffi/object.h | 62 +++++++++++++++++----
ffi/include/tvm/ffi/optional.h | 17 ++++--
ffi/include/tvm/ffi/reflection/access_path.h | 4 ++
ffi/include/tvm/ffi/reflection/registry.h | 10 ++++
ffi/include/tvm/ffi/rvalue_ref.h | 9 ++-
ffi/include/tvm/ffi/type_traits.h | 17 +++---
ffi/src/ffi/tensor.cc | 2 +-
ffi/tests/cpp/test_object.cc | 8 +++
ffi/tests/cpp/testing_object.h | 10 ++--
include/tvm/ir/attrs.h | 6 +-
include/tvm/ir/env_func.h | 8 +++
include/tvm/ir/expr.h | 6 +-
include/tvm/ir/module.h | 6 +-
include/tvm/ir/transform.h | 11 +++-
include/tvm/meta_schedule/builder.h | 7 +++
include/tvm/meta_schedule/database.h | 10 +++-
include/tvm/meta_schedule/runner.h | 6 +-
include/tvm/meta_schedule/space_generator.h | 7 +++
include/tvm/meta_schedule/task_scheduler.h | 5 +-
include/tvm/meta_schedule/tune_context.h | 7 +++
include/tvm/node/cast.h | 11 ++--
include/tvm/relax/dataflow_pattern.h | 5 ++
include/tvm/relax/expr.h | 3 +-
include/tvm/relax/struct_info.h | 6 ++
include/tvm/runtime/disco/session.h | 1 +
include/tvm/runtime/object.h | 2 +-
include/tvm/runtime/tensor.h | 3 +-
include/tvm/script/ir_builder/base.h | 1 +
include/tvm/script/ir_builder/ir/frame.h | 3 +
include/tvm/script/ir_builder/relax/frame.h | 24 ++++++++
include/tvm/script/ir_builder/tir/frame.h | 65 ++++++++++++++++++++++
include/tvm/script/printer/doc.h | 39 +++++++------
include/tvm/script/printer/ir_docsifier.h | 2 +-
include/tvm/target/target_kind.h | 3 +
include/tvm/te/tensor.h | 1 +
include/tvm/tir/block_scope.h | 7 +++
include/tvm/tir/schedule/state.h | 2 +-
include/tvm/tir/var.h | 6 +-
src/contrib/msc/core/printer/msc_doc.h | 8 +--
src/ir/source_map.cc | 4 +-
src/meta_schedule/database/database.cc | 4 +-
src/meta_schedule/database/json_database.cc | 4 +-
.../postproc/disallow_async_strided_mem_copy.cc | 2 +-
.../postproc/rewrite_parallel_vectorize_unroll.cc | 2 +-
src/meta_schedule/postproc/verify_gpu_code.cc | 6 +-
src/meta_schedule/schedule/cpu/winograd.cc | 2 +-
src/meta_schedule/schedule/cuda/thread_bind.cc | 4 +-
src/meta_schedule/schedule/cuda/winograd.cc | 6 +-
.../schedule_rule/cross_thread_reduction.cc | 18 +++---
.../multi_level_tiling_tensor_core.cc | 2 +-
.../search_strategy/evolutionary_search.cc | 8 +--
src/meta_schedule/utils.h | 2 +-
src/relax/ir/py_expr_functor.cc | 6 ++
src/relax/transform/few_shot_tuning.cc | 2 +-
src/relax/transform/meta_schedule.cc | 2 +-
src/runtime/rpc/rpc_session.h | 3 +
src/script/printer/relax/call.cc | 2 +-
src/script/printer/tir/block.cc | 2 +-
src/script/printer/tir/expr.cc | 12 ++--
src/script/printer/tir/for_loop.cc | 2 +-
src/script/printer/tir/ir.cc | 2 +-
src/script/printer/tir/stmt.cc | 4 +-
src/target/target.cc | 18 +++---
src/tir/ir/py_functor.cc | 6 ++
src/tir/schedule/analysis.h | 7 +++
src/tir/schedule/concrete_schedule.cc | 4 +-
src/tir/transforms/memhammer_tensorcore_rewrite.cc | 4 +-
77 files changed, 473 insertions(+), 153 deletions(-)
diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h
index f70df9fe7c..398953ad65 100644
--- a/ffi/include/tvm/ffi/cast.h
+++ b/ffi/include/tvm/ffi/cast.h
@@ -44,18 +44,20 @@ namespace ffi {
*/
template <typename RefType, typename ObjectType>
inline RefType GetRef(const ObjectType* ptr) {
- static_assert(std::is_base_of_v<typename RefType::ContainerType, ObjectType>,
+ using ContainerType = typename RefType::ContainerType;
+ static_assert(std::is_base_of_v<ContainerType, ObjectType>,
"Can only cast to the ref of same container type");
if constexpr (is_optional_type_v<RefType> || RefType::_type_is_nullable) {
if (ptr == nullptr) {
- return RefType(ObjectPtr<Object>(nullptr));
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<RefType>(nullptr);
}
} else {
TVM_FFI_ICHECK_NOTNULL(ptr);
}
- return RefType(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(
- const_cast<Object*>(static_cast<const Object*>(ptr))));
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<RefType>(
+ details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(
+ const_cast<Object*>(static_cast<const Object*>(ptr))));
}
/*!
diff --git a/ffi/include/tvm/ffi/container/array.h
b/ffi/include/tvm/ffi/container/array.h
index 7dbcc1f018..8fab30b8be 100644
--- a/ffi/include/tvm/ffi/container/array.h
+++ b/ffi/include/tvm/ffi/container/array.h
@@ -362,6 +362,10 @@ class Array : public ObjectRef {
/*! \brief The value type of the array */
using value_type = T;
// constructors
+ /*!
+ * \brief Construct an Array with UnsafeInit
+ */
+ explicit Array(UnsafeInit tag) : ObjectRef(tag) {}
/*!
* \brief default constructor
*/
diff --git a/ffi/include/tvm/ffi/container/map.h
b/ffi/include/tvm/ffi/container/map.h
index 27928d20c5..bea2688f7f 100644
--- a/ffi/include/tvm/ffi/container/map.h
+++ b/ffi/include/tvm/ffi/container/map.h
@@ -1381,6 +1381,10 @@ class Map : public ObjectRef {
using mapped_type = V;
/*! \brief The iterator type of the map */
class iterator;
+ /*!
+ * \brief Construct an Map with UnsafeInit
+ */
+ explicit Map(UnsafeInit tag) : ObjectRef(tag) {}
/*!
* \brief default constructor
*/
diff --git a/ffi/include/tvm/ffi/container/shape.h
b/ffi/include/tvm/ffi/container/shape.h
index 39c3ec2739..f5e88d6bb7 100644
--- a/ffi/include/tvm/ffi/container/shape.h
+++ b/ffi/include/tvm/ffi/container/shape.h
@@ -94,13 +94,13 @@ TVM_FFI_INLINE ObjectPtr<ShapeObj>
MakeInplaceShape(IterType begin, IterType end
return p;
}
-TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(int64_t ndim, int64_t*
shape) {
+TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(const int64_t* data,
int64_t ndim) {
int64_t* strides_data;
ObjectPtr<ShapeObj> strides = details::MakeEmptyShape(ndim, &strides_data);
int64_t stride = 1;
for (int i = ndim - 1; i >= 0; --i) {
strides_data[i] = stride;
- stride *= shape[i];
+ stride *= data[i];
}
return strides;
}
@@ -150,6 +150,16 @@ class Shape : public ObjectRef {
Shape(std::vector<int64_t> other) // NOLINT(*)
: ObjectRef(make_object<details::ShapeObjStdImpl>(std::move(other))) {}
+ /*!
+ * \brief Create a strides from a shape.
+ * \param data The shape data.
+ * \param ndim The number of dimensions.
+ * \return The strides.
+ */
+ static Shape StridesFromShape(const int64_t* data, int64_t ndim) {
+ return Shape(details::MakeStridesFromShape(data, ndim));
+ }
+
/*!
* \brief Return the data pointer
*
@@ -204,6 +214,9 @@ class Shape : public ObjectRef {
/// \cond Doxygen_Suppress
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Shape, ObjectRef, ShapeObj);
/// \endcond
+
+ private:
+ explicit Shape(ObjectPtr<ShapeObj> ptr) : ObjectRef(ptr) {}
};
inline std::ostream& operator<<(std::ostream& os, const Shape& shape) {
diff --git a/ffi/include/tvm/ffi/container/tensor.h
b/ffi/include/tvm/ffi/container/tensor.h
index 99fb29d108..21c67decfc 100644
--- a/ffi/include/tvm/ffi/container/tensor.h
+++ b/ffi/include/tvm/ffi/container/tensor.h
@@ -203,7 +203,7 @@ class TensorObjFromNDAlloc : public TensorObj {
this->ndim = static_cast<int>(shape.size());
this->dtype = dtype;
this->shape = const_cast<int64_t*>(shape.data());
- Shape strides = Shape(details::MakeStridesFromShape(this->ndim,
this->shape));
+ Shape strides = Shape::StridesFromShape(this->shape, this->ndim);
this->strides = const_cast<int64_t*>(strides.data());
this->byte_offset = 0;
this->shape_data_ = std::move(shape);
@@ -224,7 +224,7 @@ class TensorObjFromDLPack : public TensorObj {
explicit TensorObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor)
{
*static_cast<DLTensor*>(this) = tensor_->dl_tensor;
if (tensor_->dl_tensor.strides == nullptr) {
- Shape strides = Shape(details::MakeStridesFromShape(ndim, shape));
+ Shape strides = Shape::StridesFromShape(tensor_->dl_tensor.shape,
tensor_->dl_tensor.ndim);
this->strides = const_cast<int64_t*>(strides.data());
this->strides_data_ = std::move(strides);
}
diff --git a/ffi/include/tvm/ffi/container/tuple.h
b/ffi/include/tvm/ffi/container/tuple.h
index 0cb80b963e..75342409ea 100644
--- a/ffi/include/tvm/ffi/container/tuple.h
+++ b/ffi/include/tvm/ffi/container/tuple.h
@@ -47,6 +47,10 @@ class Tuple : public ObjectRef {
"All types used in Tuple<...> must be compatible with Any");
/*! \brief Default constructor */
Tuple() : ObjectRef(MakeDefaultTupleNode()) {}
+ /*!
+ * \brief Constructor with UnsafeInit
+ */
+ explicit Tuple(UnsafeInit tag) : ObjectRef(tag) {}
/*! \brief Copy constructor */
Tuple(const Tuple<Types...>& other) : ObjectRef(other) {}
/*! \brief Move constructor */
@@ -128,13 +132,6 @@ class Tuple : public ObjectRef {
return *this;
}
- /*!
- * \brief Constructor ObjectPtr
- * \param ptr The ObjectPtr
- * \tparam The enable_if_t type
- */
- explicit Tuple(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
-
/*!
* \brief Get I-th element of the tuple
*
@@ -283,7 +280,8 @@ struct TypeTraits<Tuple<Types...>> : public
ObjectRefTypeTraitsBase<Tuple<Types.
Array<Any> arr = TypeTraits<Array<Any>>::CopyFromAnyViewAfterCheck(src);
Any* ptr = arr.CopyOnWrite()->MutableBegin();
if (TryConvertElements<0, Types...>(ptr)) {
- return
Tuple<Types...>(details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(arr));
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<Tuple<Types...>>(
+ details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(arr));
}
return std::nullopt;
}
diff --git a/ffi/include/tvm/ffi/container/variant.h
b/ffi/include/tvm/ffi/container/variant.h
index 5f66d73a18..cae5a673b8 100644
--- a/ffi/include/tvm/ffi/container/variant.h
+++ b/ffi/include/tvm/ffi/container/variant.h
@@ -68,7 +68,7 @@ class VariantBase<true> : public ObjectRef {
explicit VariantBase(const T& other) : ObjectRef(other) {}
template <typename T>
explicit VariantBase(T&& other) : ObjectRef(std::move(other)) {}
- explicit VariantBase(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
+ explicit VariantBase(UnsafeInit tag) : ObjectRef(tag) {}
explicit VariantBase(Any other)
:
ObjectRef(details::AnyUnsafe::MoveFromAnyAfterCheck<ObjectRef>(std::move(other)))
{}
diff --git a/ffi/include/tvm/ffi/extra/module.h
b/ffi/include/tvm/ffi/extra/module.h
index 89e0c287a3..a1dc91eebc 100644
--- a/ffi/include/tvm/ffi/extra/module.h
+++ b/ffi/include/tvm/ffi/extra/module.h
@@ -36,6 +36,7 @@ class Module;
/*!
* \brief A module that can dynamically load ffi::Functions or exportable
source code.
+ * \sa Module
*/
class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object {
public:
@@ -168,6 +169,16 @@ class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object {
/*!
* \brief Reference to module object.
+ *
+ * When invoking a function on a ModuleObj, such as GetFunction,
+ * use operator-> to get the ModuleObj pointer and invoke the member functions.
+ *
+ * \code
+ * ffi::Module mod = ffi::Module::LoadFromFile("path/to/module.so");
+ * ffi::Function func = mod->GetFunction(name);
+ * \endcode
+ *
+ * \sa ModuleObj which contains most of the function implementations.
*/
class Module : public ObjectRef {
public:
@@ -202,7 +213,11 @@ class Module : public ObjectRef {
*/
kCompilationExportable = 0b100
};
-
+ /*!
+ * \brief Constructor from ObjectPtr<ModuleObj>.
+ * \param ptr The object pointer.
+ */
+ explicit Module(ObjectPtr<ModuleObj> ptr) : ObjectRef(ptr) {
TVM_FFI_ICHECK(ptr != nullptr); }
/*!
* \brief Load a module from file.
* \param file_name The name of the host function module.
diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h
index 884e46fa44..d27cfc0b61 100644
--- a/ffi/include/tvm/ffi/function.h
+++ b/ffi/include/tvm/ffi/function.h
@@ -403,7 +403,7 @@ class Function : public ObjectRef {
TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionGetGlobal(&name_arr, &handle));
if (handle != nullptr) {
return Function(
-
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<Object*>(handle)));
+
details::ObjectUnsafe::ObjectPtrFromOwned<FunctionObj>(static_cast<Object*>(handle)));
} else {
return std::nullopt;
}
diff --git a/ffi/include/tvm/ffi/function_details.h
b/ffi/include/tvm/ffi/function_details.h
index d029c19dd1..20ca44cbcb 100644
--- a/ffi/include/tvm/ffi/function_details.h
+++ b/ffi/include/tvm/ffi/function_details.h
@@ -193,7 +193,7 @@ TVM_FFI_INLINE static Error MoveFromSafeCallRaised() {
TVMFFIObjectHandle handle;
TVMFFIErrorMoveFromRaised(&handle);
// handle is owned by caller
- return Error(
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<Error>(
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle)));
}
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index c1ab9d16d9..478bb27a8f 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -44,6 +44,24 @@ using TypeIndex = TVMFFITypeIndex;
*/
using TypeInfo = TVMFFITypeInfo;
+/*!
+ * \brief Helper tag to explicitly request unsafe initialization.
+ *
+ * Constructing an ObjectRefType with UnsafeInit{} will set the data_ member
to nullptr.
+ *
+ * When initializing Object fields, ObjectRef fields can be set to UnsafeInit.
+ * This enables the "construct with UnsafeInit then set all fields" pattern
+ * when the object does not have a default constructor.
+ *
+ * Used for initialization in controlled scenarios where such unsafe
+ * initialization is known to be safe.
+ *
+ * Each ObjectRefType should have a constructor that takes an UnsafeInit tag.
+ *
+ * \note As the name suggests, do not use it in normal code paths.
+ */
+struct UnsafeInit {};
+
/*!
* \brief Known type keys for pre-defined types.
*/
@@ -702,6 +720,8 @@ class ObjectRef {
ObjectRef& operator=(ObjectRef&& other) = default;
/*! \brief Constructor from existing object ptr */
explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {}
+ /*! \brief Constructor from UnsafeInit */
+ explicit ObjectRef(UnsafeInit) : data_(nullptr) {}
/*!
* \brief Comparator
* \param other Another object ref.
@@ -774,7 +794,9 @@ class ObjectRef {
TVM_FFI_INLINE std::optional<ObjectRefType> as() const {
if (data_ != nullptr) {
if (data_->IsInstance<typename ObjectRefType::ContainerType>()) {
- return ObjectRefType(data_);
+ ObjectRefType ref(UnsafeInit{});
+ ref.data_ = data_;
+ return ref;
} else {
return std::nullopt;
}
@@ -782,6 +804,7 @@ class ObjectRef {
return std::nullopt;
}
}
+
/*!
* \brief Get the type index of the ObjectRef
* \return The type index of the ObjectRef
@@ -914,7 +937,8 @@ struct ObjectPtrEqual {
*/
#define TVM_FFI_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName)
\
TypeName() = default;
\
- explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) :
ParentType(n) {} \
+ explicit TypeName(::tvm::ffi::ObjectPtr<ObjectName> n) : ParentType(n) {}
\
+ explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {}
\
TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName)
\
const ObjectName* operator->() const { return static_cast<const
ObjectName*>(data_.get()); } \
const ObjectName* get() const { return operator->(); }
\
@@ -928,7 +952,7 @@ struct ObjectPtrEqual {
* \param ObjectName The type name of the object.
*/
#define TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType,
ObjectName) \
- explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) :
ParentType(n) {} \
+ explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {}
\
TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName)
\
const ObjectName* operator->() const { return static_cast<const
ObjectName*>(data_.get()); } \
const ObjectName* get() const { return operator->(); }
\
@@ -943,11 +967,12 @@ struct ObjectPtrEqual {
* \note We recommend making objects immutable when possible.
* This macro is only reserved for objects that stores runtime states.
*/
-#define TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType,
ObjectName) \
- TypeName() = default;
\
- TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName)
\
- explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) :
ParentType(n) {} \
- ObjectName* operator->() const { return
static_cast<ObjectName*>(data_.get()); } \
+#define TVM_FFI_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType,
ObjectName) \
+ TypeName() = default;
\
+ explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {}
\
+ TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName)
\
+ explicit TypeName(::tvm::ffi::ObjectPtr<ObjectName> n) : ParentType(n) {}
\
+ ObjectName* operator->() const { return
static_cast<ObjectName*>(data_.get()); } \
using ContainerType = ObjectName
/*!
@@ -958,7 +983,7 @@ struct ObjectPtrEqual {
* \param ObjectName The type name of the object.
*/
#define TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName,
ParentType, ObjectName) \
- explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) :
ParentType(n) {} \
+ explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {}
\
TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName)
\
ObjectName* operator->() const { return
static_cast<ObjectName*>(data_.get()); } \
ObjectName* get() const { return operator->(); }
\
@@ -1021,6 +1046,20 @@ struct ObjectUnsafe {
reinterpret_cast<int64_t>(&(static_cast<Object*>(nullptr)->header_)));
}
+ template <typename T>
+ TVM_FFI_INLINE static T ObjectRefFromObjectPtr(const ObjectPtr<Object>& ptr)
{
+ T ref(UnsafeInit{});
+ ref.data_ = ptr;
+ return ref;
+ }
+
+ template <typename T>
+ TVM_FFI_INLINE static T ObjectRefFromObjectPtr(ObjectPtr<Object>&& ptr) {
+ T ref(UnsafeInit{});
+ ref.data_ = std::move(ptr);
+ return ref;
+ }
+
template <typename T>
TVM_FFI_INLINE static ObjectPtr<T> ObjectPtrFromObjectRef(const ObjectRef&
ref) {
if constexpr (std::is_same_v<T, Object>) {
@@ -1035,7 +1074,10 @@ struct ObjectUnsafe {
if constexpr (std::is_same_v<T, Object>) {
return std::move(ref.data_);
} else {
- return tvm::ffi::ObjectPtr<T>(std::move(ref.data_.data_));
+ ObjectPtr<T> result;
+ result.data_ = std::move(ref.data_.data_);
+ ref.data_.data_ = nullptr;
+ return result;
}
}
diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h
index f93a0f0d55..f370a17850 100644
--- a/ffi/include/tvm/ffi/optional.h
+++ b/ffi/include/tvm/ffi/optional.h
@@ -262,7 +262,7 @@ class Optional<T,
std::enable_if_t<use_ptr_based_optional_v<T>>> : public Object
Optional() = default;
Optional(const Optional<T>& other) : ObjectRef(other.data_) {}
Optional(Optional<T>&& other) : ObjectRef(std::move(other.data_)) {}
- explicit Optional(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
+ explicit Optional(ffi::UnsafeInit tag) : ObjectRef(tag) {}
// nullopt hanlding
Optional(std::nullopt_t) {} // NOLINT(*)
@@ -300,19 +300,20 @@ class Optional<T,
std::enable_if_t<use_ptr_based_optional_v<T>>> : public Object
if (data_ == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Back optional access";
}
- return T(data_);
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<T>(data_);
}
TVM_FFI_INLINE T value() && {
if (data_ == nullptr) {
TVM_FFI_THROW(RuntimeError) << "Back optional access";
}
- return T(std::move(data_));
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<T>(std::move(data_));
}
template <typename U = std::remove_cv_t<T>>
TVM_FFI_INLINE T value_or(U&& default_value) const {
- return data_ != nullptr ? T(data_) : T(std::forward<U>(default_value));
+ return data_ != nullptr ?
details::ObjectUnsafe::ObjectRefFromObjectPtr<T>(data_)
+ : T(std::forward<U>(default_value));
}
TVM_FFI_INLINE explicit operator bool() const { return data_ != nullptr; }
@@ -324,14 +325,18 @@ class Optional<T,
std::enable_if_t<use_ptr_based_optional_v<T>>> : public Object
* \return the const reference to the stored value.
* \note only use this function after checking has_value()
*/
- TVM_FFI_INLINE T operator*() const& noexcept { return T(data_); }
+ TVM_FFI_INLINE T operator*() const& noexcept {
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<T>(data_);
+ }
/*!
* \brief Direct access to the value.
* \return the const reference to the stored value.
* \note only use this function after checking has_value()
*/
- TVM_FFI_INLINE T operator*() && noexcept { return T(std::move(data_)); }
+ TVM_FFI_INLINE T operator*() && noexcept {
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<T>(std::move(data_));
+ }
TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { return
!has_value(); }
TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { return
has_value(); }
diff --git a/ffi/include/tvm/ffi/reflection/access_path.h
b/ffi/include/tvm/ffi/reflection/access_path.h
index c614d4ca28..e7aed0a8fc 100644
--- a/ffi/include/tvm/ffi/reflection/access_path.h
+++ b/ffi/include/tvm/ffi/reflection/access_path.h
@@ -360,6 +360,10 @@ class AccessPath : public ObjectRef {
/// \cond Doxygen_Suppress
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AccessPath, ObjectRef,
AccessPathObj);
/// \endcond
+
+ private:
+ friend class AccessPathObj;
+ explicit AccessPath(ObjectPtr<AccessPathObj> ptr) : ObjectRef(ptr) {}
};
/*!
diff --git a/ffi/include/tvm/ffi/reflection/registry.h
b/ffi/include/tvm/ffi/reflection/registry.h
index ba723fa394..6a1a9b55d2 100644
--- a/ffi/include/tvm/ffi/reflection/registry.h
+++ b/ffi/include/tvm/ffi/reflection/registry.h
@@ -148,6 +148,14 @@ class ReflectionDefBase {
TVM_FFI_SAFE_CALL_END();
}
+ template <typename T>
+ static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle* result) {
+ TVM_FFI_SAFE_CALL_BEGIN();
+ ObjectPtr<T> obj = make_object<T>(UnsafeInit{});
+ *result =
details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
+ TVM_FFI_SAFE_CALL_END();
+ }
+
template <typename T>
TVM_FFI_INLINE static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const
T& value) {
if constexpr (std::is_base_of_v<FieldInfoTrait, std::decay_t<T>>) {
@@ -413,6 +421,8 @@ class ObjectDef : public ReflectionDefBase {
info.doc = TVMFFIByteArray{nullptr, 0};
if constexpr (std::is_default_constructible_v<Class>) {
info.creator = ObjectCreatorDefault<Class>;
+ } else if constexpr (std::is_constructible_v<Class, UnsafeInit>) {
+ info.creator = ObjectCreatorUnsafeInit<Class>;
}
// apply extra info traits
((ApplyExtraInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
diff --git a/ffi/include/tvm/ffi/rvalue_ref.h b/ffi/include/tvm/ffi/rvalue_ref.h
index 7c89038cc2..ebbec582e6 100644
--- a/ffi/include/tvm/ffi/rvalue_ref.h
+++ b/ffi/include/tvm/ffi/rvalue_ref.h
@@ -71,15 +71,17 @@ namespace ffi {
template <typename TObjRef, typename =
std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef>>>
class RValueRef {
public:
+ /*! \brief the container type of the rvalue ref */
+ using ContainerType = typename TObjRef::ContainerType;
/*! \brief only allow move constructor from rvalue of T */
explicit RValueRef(TObjRef&& data)
- :
data_(details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(std::move(data))) {}
+ :
data_(details::ObjectUnsafe::ObjectPtrFromObjectRef<ContainerType>(std::move(data)))
{}
/*! \brief return the data as rvalue */
TObjRef operator*() && { return TObjRef(std::move(data_)); }
private:
- mutable ObjectPtr<Object> data_;
+ mutable ObjectPtr<ContainerType> data_;
template <typename, typename>
friend struct TypeTraits;
@@ -125,7 +127,8 @@ struct TypeTraits<RValueRef<TObjRef>> : public
TypeTraitsBase {
tmp_any.v_obj = reinterpret_cast<TVMFFIObject*>(rvalue_ref->get());
// fast path, storage type matches, direct move the rvalue ref
if (TypeTraits<TObjRef>::CheckAnyStrict(&tmp_any)) {
- return RValueRef<TObjRef>(TObjRef(std::move(*rvalue_ref)));
+ return RValueRef<TObjRef>(
+
details::ObjectUnsafe::ObjectRefFromObjectPtr<TObjRef>(std::move(*rvalue_ref)));
}
if (std::optional<TObjRef> opt =
TypeTraits<TObjRef>::TryCastFromAnyView(&tmp_any)) {
// object type does not match up, we need to try to convert the object
diff --git a/ffi/include/tvm/ffi/type_traits.h
b/ffi/include/tvm/ffi/type_traits.h
index 1812448ecc..0f1971945a 100644
--- a/ffi/include/tvm/ffi/type_traits.h
+++ b/ffi/include/tvm/ffi/type_traits.h
@@ -551,34 +551,37 @@ struct ObjectRefTypeTraitsBase : public TypeTraitsBase {
TVM_FFI_INLINE static TObjRef CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
if constexpr (TObjRef::_type_is_nullable) {
if (src->type_index == TypeIndex::kTVMFFINone) {
- return TObjRef(ObjectPtr<Object>(nullptr));
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<TObjRef>(nullptr);
}
}
- return
TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(src->v_obj));
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<TObjRef>(
+ details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(src->v_obj));
}
TVM_FFI_INLINE static TObjRef MoveFromAnyAfterCheck(TVMFFIAny* src) {
if constexpr (TObjRef::_type_is_nullable) {
if (src->type_index == TypeIndex::kTVMFFINone) {
- return TObjRef(ObjectPtr<Object>(nullptr));
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<TObjRef>(nullptr);
}
}
// move out the object pointer
- ObjectPtr<Object> obj_ptr =
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(src->v_obj);
+ ObjectPtr<ContainerType> obj_ptr =
+ details::ObjectUnsafe::ObjectPtrFromOwned<ContainerType>(src->v_obj);
// reset the src to nullptr
TypeTraits<std::nullptr_t>::MoveToAny(nullptr, src);
- return TObjRef(std::move(obj_ptr));
+ return
details::ObjectUnsafe::ObjectRefFromObjectPtr<TObjRef>(std::move(obj_ptr));
}
TVM_FFI_INLINE static std::optional<TObjRef> TryCastFromAnyView(const
TVMFFIAny* src) {
if constexpr (TObjRef::_type_is_nullable) {
if (src->type_index == TypeIndex::kTVMFFINone) {
- return TObjRef(ObjectPtr<Object>(nullptr));
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<TObjRef>(nullptr);
}
}
if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
if (details::IsObjectInstance<ContainerType>(src->type_index)) {
- return
TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(src->v_obj));
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<TObjRef>(
+
details::ObjectUnsafe::ObjectPtrFromUnowned<ContainerType>(src->v_obj));
}
}
return std::nullopt;
diff --git a/ffi/src/ffi/tensor.cc b/ffi/src/ffi/tensor.cc
index 7b44e4586b..c166c296c8 100644
--- a/ffi/src/ffi/tensor.cc
+++ b/ffi/src/ffi/tensor.cc
@@ -40,7 +40,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_THROW(ValueError) << "Expect shape to take list of int
arguments";
}
}
- *ret = Shape(shape);
+ *ret = details::ObjectUnsafe::ObjectRefFromObjectPtr<Shape>(shape);
});
});
diff --git a/ffi/tests/cpp/test_object.cc b/ffi/tests/cpp/test_object.cc
index 1d7de990f0..ec5c54c4d7 100644
--- a/ffi/tests/cpp/test_object.cc
+++ b/ffi/tests/cpp/test_object.cc
@@ -97,6 +97,14 @@ TEST(ObjectRef, as) {
EXPECT_EQ(b.as<TFloatObj>()->value, 20);
}
+TEST(ObjectRef, UnsafeInit) {
+ ObjectRef a(UnsafeInit{});
+ EXPECT_TRUE(a.get() == nullptr);
+
+ TInt b(UnsafeInit{});
+ EXPECT_TRUE(b.get() == nullptr);
+}
+
TEST(Object, CAPIAccessor) {
ObjectRef a = TInt(10);
TVMFFIObjectHandle obj = details::ObjectUnsafe::RawObjectPtrFromObjectRef(a);
diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h
index fe3ba1b013..1f6e678226 100644
--- a/ffi/tests/cpp/testing_object.h
+++ b/ffi/tests/cpp/testing_object.h
@@ -59,8 +59,8 @@ class TIntObj : public TNumberObj {
public:
int64_t value;
- TIntObj() = default;
TIntObj(int64_t value) : value(value) {}
+ explicit TIntObj(UnsafeInit) {}
int64_t GetValue() const { return value; }
@@ -165,9 +165,9 @@ class TVarObj : public Object {
public:
std::string name;
- // need default constructor for json serialization
- TVarObj() = default;
TVarObj(std::string name) : name(name) {}
+ // need unsafe init constructor for json serialization
+ explicit TVarObj(UnsafeInit) {}
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
@@ -193,8 +193,8 @@ class TFuncObj : public Object {
Array<ObjectRef> body;
Optional<String> comment;
- // need default constructor for json serialization
- TFuncObj() = default;
+ // need unsafe init constructor or default constructor for json serialization
+ explicit TFuncObj(UnsafeInit) {}
TFuncObj(Array<TVar> params, Array<ObjectRef> body, Optional<String> comment)
: params(params), body(body), comment(comment) {}
diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index 5557654916..5c02db36f7 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -54,7 +54,7 @@ namespace tvm {
template <typename TObjectRef>
inline TObjectRef NullValue() {
static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for
nullable types");
- return TObjectRef(ObjectPtr<Object>(nullptr));
+ return TObjectRef(ObjectPtr<typename TObjectRef::ContainerType>(nullptr));
}
template <>
@@ -165,6 +165,10 @@ class DictAttrsNode : public BaseAttrsNode {
*/
class DictAttrs : public Attrs {
public:
+ /*!
+ * \brief constructor with UnsafeInit
+ */
+ explicit DictAttrs(ffi::UnsafeInit tag) : Attrs(tag) {}
/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h
index e43575d486..e42cce5279 100644
--- a/include/tvm/ir/env_func.h
+++ b/include/tvm/ir/env_func.h
@@ -71,6 +71,10 @@ class EnvFunc : public ObjectRef {
public:
EnvFunc() {}
explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
+ /*!
+ * \brief constructor with UnsafeInit
+ */
+ explicit EnvFunc(ffi::UnsafeInit tag) : ObjectRef(tag) {}
/*! \return The internal global function pointer */
const EnvFuncNode* operator->() const { return static_cast<const
EnvFuncNode*>(get()); }
/*!
@@ -117,6 +121,10 @@ class TypedEnvFunc<R(Args...)> : public ObjectRef {
using TSelf = TypedEnvFunc<R(Args...)>;
TypedEnvFunc() {}
explicit TypedEnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
+ /*!
+ * \brief constructor with UnsafeInit
+ */
+ explicit TypedEnvFunc(ffi::UnsafeInit tag) : ObjectRef(tag) {}
/*!
* \brief Assign global function to a TypedEnvFunc
* \param other Another global function.
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 65954b83ac..d7e4e0f0d2 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -613,7 +613,11 @@ class Integer : public IntImm {
/*!
* \brief constructor from node.
*/
- explicit Integer(ObjectPtr<Object> node) : IntImm(node) {}
+ explicit Integer(ObjectPtr<IntImmNode> node) : IntImm(node) {}
+ /*!
+ * \brief constructor with UnsafeInit
+ */
+ explicit Integer(ffi::UnsafeInit tag) : IntImm(tag) {}
/*!
* \brief Construct integer from int value.
*/
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index 5da00fb0b3..3deef6fed1 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -273,7 +273,11 @@ class IRModule : public ObjectRef {
* \brief constructor
* \param n The object pointer.
*/
- explicit IRModule(ObjectPtr<Object> n) : ObjectRef(n) {}
+ explicit IRModule(ObjectPtr<IRModuleNode> n) : ObjectRef(n) {}
+ /*!
+ * \brief constructor with UnsafeInit
+ */
+ explicit IRModule(ffi::UnsafeInit tag) : ObjectRef(tag) {}
/*! \return mutable pointers to the node. */
IRModuleNode* operator->() const {
auto* ptr = get_mutable();
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index e501ace159..e283234cb0 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -156,7 +156,14 @@ class PassContextNode : public Object {
class PassContext : public ObjectRef {
public:
PassContext() {}
- explicit PassContext(ObjectPtr<Object> n) : ObjectRef(n) {}
+ /*!
+ * \brief constructor with UnsafeInit
+ */
+ explicit PassContext(ffi::UnsafeInit tag) : ObjectRef(tag) {}
+ /*!
+ * \brief constructor with ObjectPtr
+ */
+ explicit PassContext(ObjectPtr<PassContextNode> n) : ObjectRef(n) {}
/*!
* \brief const accessor.
* \return const access pointer.
@@ -512,7 +519,7 @@ class Sequential : public Pass {
TVM_DLL Sequential(ffi::Array<Pass> passes, ffi::String name = "sequential");
Sequential() = default;
- explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
+ explicit Sequential(ObjectPtr<SequentialNode> n) : Pass(n) {}
const SequentialNode* operator->() const;
using ContainerType = SequentialNode;
diff --git a/include/tvm/meta_schedule/builder.h
b/include/tvm/meta_schedule/builder.h
index 6a6df29502..0a527ad425 100644
--- a/include/tvm/meta_schedule/builder.h
+++ b/include/tvm/meta_schedule/builder.h
@@ -136,6 +136,13 @@ class BuilderNode : public runtime::Object {
*/
class Builder : public runtime::ObjectRef {
public:
+ /*!
+ * \brief Constructor from ObjectPtr<BuilderNode>.
+ * \param data The object pointer.
+ */
+ explicit Builder(ObjectPtr<BuilderNode> data) : ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
/*!
* \brief Create a builder with customized build method on the python-side.
* \param f_build The packed function to the `Build` function..
diff --git a/include/tvm/meta_schedule/database.h
b/include/tvm/meta_schedule/database.h
index fbb09d7852..0768607731 100644
--- a/include/tvm/meta_schedule/database.h
+++ b/include/tvm/meta_schedule/database.h
@@ -71,6 +71,7 @@ class WorkloadNode : public runtime::Object {
class Workload : public runtime::ObjectRef {
public:
using THashCode = WorkloadNode::THashCode;
+ explicit Workload(ObjectPtr<WorkloadNode> data) : ObjectRef(data) {}
/*!
* \brief Constructor of Workload.
* \param mod The workload's IRModule.
@@ -117,7 +118,7 @@ class TuningRecordNode : public runtime::Object {
/*! \brief The trace tuned. */
tir::Trace trace;
/*! \brief The workload. */
- Workload workload{nullptr};
+ Workload workload{ffi::UnsafeInit()};
/*! \brief The profiling result in seconds. */
ffi::Optional<ffi::Array<FloatImm>> run_secs;
/*! \brief The target for tuning. */
@@ -466,6 +467,13 @@ class PyDatabaseNode : public DatabaseNode {
*/
class Database : public runtime::ObjectRef {
public:
+ /*!
+ * \brief Constructor from ObjectPtr<DatabaseNode>.
+ * \param data The object pointer.
+ */
+ explicit Database(ObjectPtr<DatabaseNode> data) : ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
/*!
* \brief An in-memory database.
* \param mod_eq_name A string to specify the module equality testing and
hashing method.
diff --git a/include/tvm/meta_schedule/runner.h
b/include/tvm/meta_schedule/runner.h
index 2d42b5e590..f2753964ec 100644
--- a/include/tvm/meta_schedule/runner.h
+++ b/include/tvm/meta_schedule/runner.h
@@ -207,7 +207,11 @@ class RunnerNode : public runtime::Object {
class Runner : public runtime::ObjectRef {
public:
using FRun = RunnerNode::FRun;
-
+ /*!
+ * \brief Constructor from ObjectPtr<RunnerNode>.
+ * \param data The object pointer.
+ */
+ explicit Runner(ObjectPtr<RunnerNode> data) : ObjectRef(data) {
TVM_FFI_ICHECK(data != nullptr); }
/*!
* \brief Create a runner with customized build method on the python-side.
* \param f_run The packed function to run the built artifacts and get
runner futures.
diff --git a/include/tvm/meta_schedule/space_generator.h
b/include/tvm/meta_schedule/space_generator.h
index f013934e23..a2bf7a3949 100644
--- a/include/tvm/meta_schedule/space_generator.h
+++ b/include/tvm/meta_schedule/space_generator.h
@@ -123,6 +123,13 @@ class SpaceGeneratorNode : public runtime::Object {
*/
class SpaceGenerator : public runtime::ObjectRef {
public:
+ /*!
+ * \brief Constructor from ObjectPtr<SpaceGeneratorNode>.
+ * \param data The object pointer.
+ */
+ explicit SpaceGenerator(ObjectPtr<SpaceGeneratorNode> data) :
ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
/*!
* \brief The function type of `InitializeWithTuneContext` method.
* \param context The tuning context for initialization.
diff --git a/include/tvm/meta_schedule/task_scheduler.h
b/include/tvm/meta_schedule/task_scheduler.h
index 0c88cb12c8..a6a53becad 100644
--- a/include/tvm/meta_schedule/task_scheduler.h
+++ b/include/tvm/meta_schedule/task_scheduler.h
@@ -40,7 +40,7 @@ namespace meta_schedule {
class TaskRecordNode : public runtime::Object {
public:
/*! \brief The tune context of the task. */
- TuneContext ctx{nullptr};
+ TuneContext ctx{ffi::UnsafeInit()};
/*! \brief The weight of the task */
double task_weight{1.0};
/*! \brief The FLOP count of the task */
@@ -261,6 +261,9 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {
*/
class TaskScheduler : public runtime::ObjectRef {
public:
+ explicit TaskScheduler(ObjectPtr<TaskSchedulerNode> data) :
runtime::ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
/*!
* \brief Create a task scheduler that fetches tasks in a round-robin
fashion.
* \param logger The tuning task's logging function.
diff --git a/include/tvm/meta_schedule/tune_context.h
b/include/tvm/meta_schedule/tune_context.h
index cd9b8f1b5a..50bdb2586f 100644
--- a/include/tvm/meta_schedule/tune_context.h
+++ b/include/tvm/meta_schedule/tune_context.h
@@ -98,6 +98,13 @@ class TuneContextNode : public runtime::Object {
class TuneContext : public runtime::ObjectRef {
public:
using TRandState = support::LinearCongruentialEngine::TRandState;
+ /*!
+ * \brief Constructor from ObjectPtr<TuneContextNode>.
+ * \param data The object pointer.
+ */
+ explicit TuneContext(ObjectPtr<TuneContextNode> data) : ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
/*!
* \brief Constructor.
* \param mod The workload to be tuned.
diff --git a/include/tvm/node/cast.h b/include/tvm/node/cast.h
index 4ed5f4178c..32d4be7216 100644
--- a/include/tvm/node/cast.h
+++ b/include/tvm/node/cast.h
@@ -45,18 +45,19 @@ namespace tvm {
template <typename SubRef, typename BaseRef,
typename = std::enable_if_t<std::is_base_of_v<ffi::ObjectRef,
BaseRef>>>
inline SubRef Downcast(BaseRef ref) {
+ using ContainerType = typename SubRef::ContainerType;
if (ref.defined()) {
- if (!ref->template IsInstance<typename SubRef::ContainerType>()) {
+ if (!ref->template IsInstance<ContainerType>()) {
TVM_FFI_THROW(TypeError) << "Downcast from " << ref->GetTypeKey() << "
to "
<< SubRef::ContainerType::_type_key << "
failed.";
}
- return
SubRef(ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<ffi::Object>(std::move(ref)));
+ return ffi::details::ObjectUnsafe::ObjectRefFromObjectPtr<SubRef>(
+
ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<ffi::Object>(std::move(ref)));
} else {
if constexpr (ffi::is_optional_type_v<SubRef> ||
SubRef::_type_is_nullable) {
- return SubRef(ffi::ObjectPtr<ffi::Object>(nullptr));
+ return
ffi::details::ObjectUnsafe::ObjectRefFromObjectPtr<SubRef>(nullptr);
}
- TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `"
- << SubRef::ContainerType::_type_key
+ TVM_FFI_THROW(TypeError) << "Downcast from undefined(nullptr) to `" <<
ContainerType::_type_key
<< "` is not allowed. Use
Downcast<ffi::Optional<T>> instead.";
TVM_FFI_UNREACHABLE();
}
diff --git a/include/tvm/relax/dataflow_pattern.h
b/include/tvm/relax/dataflow_pattern.h
index 4a7fd73c6a..7c4ee4e43e 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -280,6 +280,7 @@ class PatternContextNode : public Object {
*/
class PatternContext : public ObjectRef {
public:
+ explicit PatternContext(ffi::UnsafeInit tag) : ObjectRef(tag) {}
TVM_DLL explicit PatternContext(ObjectPtr<Object> n) : ObjectRef(n) {}
TVM_DLL explicit PatternContext(bool incremental = false);
@@ -778,6 +779,10 @@ class WildcardPatternNode : public DFPatternNode {
class WildcardPattern : public DFPattern {
public:
WildcardPattern();
+ explicit WildcardPattern(ObjectPtr<WildcardPatternNode> data) :
DFPattern(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
// Declaring WildcardPattern declared as non-nullable avoids the
// default zero-parameter constructor for ObjectRef with `data_ =
diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index e0e2f4770f..80fe1e6710 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -607,7 +607,8 @@ class Binding : public ObjectRef {
Binding() = default;
public:
- explicit Binding(ObjectPtr<Object> n) : ObjectRef(n) {}
+ explicit Binding(ObjectPtr<BindingNode> n) : ObjectRef(n) {}
+ explicit Binding(ffi::UnsafeInit tag) : ObjectRef(tag) {}
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Binding);
const BindingNode* operator->() const { return static_cast<const
BindingNode*>(data_.get()); }
const BindingNode* get() const { return operator->(); }
diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h
index 8a97658330..059292806d 100644
--- a/include/tvm/relax/struct_info.h
+++ b/include/tvm/relax/struct_info.h
@@ -27,6 +27,8 @@
#include <tvm/relax/expr.h>
#include <tvm/relax/type.h>
+#include <utility>
+
namespace tvm {
namespace relax {
@@ -317,6 +319,10 @@ class FuncStructInfoNode : public StructInfoNode {
*/
class FuncStructInfo : public StructInfo {
public:
+ explicit FuncStructInfo(ObjectPtr<FuncStructInfoNode> data) :
StructInfo(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
/*!
* \brief Constructor from parameter struct info and return value struct
info.
* \param params The struct info of function parameters.
diff --git a/include/tvm/runtime/disco/session.h
b/include/tvm/runtime/disco/session.h
index 1506d2548f..671e4bbd67 100644
--- a/include/tvm/runtime/disco/session.h
+++ b/include/tvm/runtime/disco/session.h
@@ -170,6 +170,7 @@ class DRefObj : public Object {
*/
class DRef : public ObjectRef {
public:
+ explicit DRef(ObjectPtr<DRefObj> data) : ObjectRef(data) {
TVM_FFI_ICHECK(data != nullptr); }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DRef, ObjectRef, DRefObj);
};
diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index e04a800400..cf5d93eae6 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -128,7 +128,7 @@
static_assert(static_cast<int>(TypeIndex::kCustomStaticIndex) >=
*/
#define TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName,
ParentType, \
ObjectName)
\
- explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) :
ParentType(n) {} \
+ explicit TypeName(::tvm::ffi::ObjectPtr<ObjectName> n) : ParentType(n) {}
\
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName);
\
const ObjectName* operator->() const { return static_cast<const
ObjectName*>(data_.get()); } \
const ObjectName* get() const { return operator->(); }
\
diff --git a/include/tvm/runtime/tensor.h b/include/tvm/runtime/tensor.h
index 71f8d27be0..97af218a18 100644
--- a/include/tvm/runtime/tensor.h
+++ b/include/tvm/runtime/tensor.h
@@ -58,7 +58,8 @@ class Tensor : public tvm::ffi::Tensor {
* \brief constructor.
* \param data ObjectPtr to the data container.
*/
- explicit Tensor(ObjectPtr<Object> data) : tvm::ffi::Tensor(data) {}
+ explicit Tensor(ObjectPtr<ffi::TensorObj> data) : tvm::ffi::Tensor(data) {}
+ explicit Tensor(ffi::UnsafeInit tag) : tvm::ffi::Tensor(tag) {}
Tensor(ffi::Tensor&& other) : tvm::ffi::Tensor(std::move(other)) {} //
NOLINT(*)
Tensor(const ffi::Tensor& other) : tvm::ffi::Tensor(other) {} //
NOLINT(*)
diff --git a/include/tvm/script/ir_builder/base.h
b/include/tvm/script/ir_builder/base.h
index b2586e9387..75e6fd8061 100644
--- a/include/tvm/script/ir_builder/base.h
+++ b/include/tvm/script/ir_builder/base.h
@@ -107,6 +107,7 @@ class IRBuilderFrame : public runtime::ObjectRef {
protected:
/*! \brief Disallow direct construction of this object. */
IRBuilderFrame() = default;
+ explicit IRBuilderFrame(ObjectPtr<IRBuilderFrameNode> data) :
ObjectRef(data) {}
public:
/*!
diff --git a/include/tvm/script/ir_builder/ir/frame.h
b/include/tvm/script/ir_builder/ir/frame.h
index e9f98d4a8e..767986fdf7 100644
--- a/include/tvm/script/ir_builder/ir/frame.h
+++ b/include/tvm/script/ir_builder/ir/frame.h
@@ -75,6 +75,9 @@ class IRModuleFrameNode : public IRBuilderFrameNode {
*/
class IRModuleFrame : public IRBuilderFrame {
public:
+ explicit IRModuleFrame(ObjectPtr<IRModuleFrameNode> data) :
IRBuilderFrame(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRModuleFrame,
IRBuilderFrame,
IRModuleFrameNode);
};
diff --git a/include/tvm/script/ir_builder/relax/frame.h
b/include/tvm/script/ir_builder/relax/frame.h
index 053f84285f..7ea8c439bf 100644
--- a/include/tvm/script/ir_builder/relax/frame.h
+++ b/include/tvm/script/ir_builder/relax/frame.h
@@ -26,6 +26,8 @@
#include <tvm/script/ir_builder/ir/frame.h>
#include <tvm/script/ir_builder/ir/ir.h>
+#include <utility>
+
namespace tvm {
namespace script {
namespace ir_builder {
@@ -45,6 +47,10 @@ class RelaxFrameNode : public IRBuilderFrameNode {
class RelaxFrame : public IRBuilderFrame {
public:
+ explicit RelaxFrame(ObjectPtr<RelaxFrameNode> data) :
IRBuilderFrame(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame,
IRBuilderFrame, RelaxFrameNode);
protected:
@@ -78,6 +84,9 @@ class SeqExprFrameNode : public RelaxFrameNode {
class SeqExprFrame : public RelaxFrame {
public:
+ explicit SeqExprFrame(ObjectPtr<SeqExprFrameNode> data) : RelaxFrame(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SeqExprFrame, RelaxFrame,
SeqExprFrameNode);
};
@@ -134,6 +143,9 @@ class FunctionFrameNode : public SeqExprFrameNode {
class FunctionFrame : public SeqExprFrame {
public:
+ explicit FunctionFrame(ObjectPtr<FunctionFrameNode> data) :
SeqExprFrame(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionFrame,
SeqExprFrame, FunctionFrameNode);
};
@@ -175,6 +187,9 @@ class BlockFrameNode : public RelaxFrameNode {
class BlockFrame : public RelaxFrame {
public:
+ explicit BlockFrame(ObjectPtr<BlockFrameNode> data) : RelaxFrame(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, RelaxFrame,
BlockFrameNode);
};
@@ -229,6 +244,9 @@ class IfFrameNode : public RelaxFrameNode {
*/
class IfFrame : public RelaxFrame {
public:
+ explicit IfFrame(ObjectPtr<IfFrameNode> data) : RelaxFrame(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, RelaxFrame,
IfFrameNode);
};
@@ -267,6 +285,9 @@ class ThenFrameNode : public SeqExprFrameNode {
*/
class ThenFrame : public SeqExprFrame {
public:
+ explicit ThenFrame(ObjectPtr<ThenFrameNode> data) : SeqExprFrame(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, SeqExprFrame,
ThenFrameNode);
};
@@ -305,6 +326,9 @@ class ElseFrameNode : public SeqExprFrameNode {
*/
class ElseFrame : public SeqExprFrame {
public:
+ explicit ElseFrame(ObjectPtr<ElseFrameNode> data) : SeqExprFrame(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, SeqExprFrame,
ElseFrameNode);
};
diff --git a/include/tvm/script/ir_builder/tir/frame.h
b/include/tvm/script/ir_builder/tir/frame.h
index 1c3e199590..fa42ea9911 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -23,6 +23,8 @@
#include <tvm/script/ir_builder/ir/frame.h>
#include <tvm/tir/stmt.h>
+#include <utility>
+
namespace tvm {
namespace script {
namespace ir_builder {
@@ -58,6 +60,7 @@ class TIRFrame : public IRBuilderFrame {
protected:
TIRFrame() = default;
+ explicit TIRFrame(ObjectPtr<TIRFrameNode> data) : IRBuilderFrame(data) {}
};
/*!
@@ -115,6 +118,10 @@ class PrimFuncFrameNode : public TIRFrameNode {
*/
class PrimFuncFrame : public TIRFrame {
public:
+ explicit PrimFuncFrame(ObjectPtr<PrimFuncFrameNode> data) :
TIRFrame(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame,
PrimFuncFrameNode);
};
@@ -186,6 +193,10 @@ class BlockFrameNode : public TIRFrameNode {
class BlockFrame : public TIRFrame {
public:
+ explicit BlockFrame(ObjectPtr<BlockFrameNode> data) :
TIRFrame(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame,
BlockFrameNode);
};
@@ -224,6 +235,10 @@ class BlockInitFrameNode : public TIRFrameNode {
*/
class BlockInitFrame : public TIRFrame {
public:
+ explicit BlockInitFrame(ObjectPtr<BlockInitFrameNode> data) :
TIRFrame(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockInitFrame, TIRFrame,
BlockInitFrameNode);
};
@@ -277,6 +292,10 @@ class ForFrameNode : public TIRFrameNode {
*/
class ForFrame : public TIRFrame {
public:
+ explicit ForFrame(ObjectPtr<ForFrameNode> data) :
TIRFrame(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame,
ForFrameNode);
};
@@ -318,6 +337,10 @@ class AssertFrameNode : public TIRFrameNode {
*/
class AssertFrame : public TIRFrame {
public:
+ explicit AssertFrame(ObjectPtr<AssertFrameNode> data) :
TIRFrame(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AssertFrame, TIRFrame,
AssertFrameNode);
};
@@ -358,6 +381,10 @@ class LetFrameNode : public TIRFrameNode {
*/
class LetFrame : public TIRFrame {
public:
+ explicit LetFrame(ObjectPtr<LetFrameNode> data) :
TIRFrame(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LetFrame, TIRFrame,
LetFrameNode);
};
@@ -400,6 +427,10 @@ class LaunchThreadFrameNode : public TIRFrameNode {
*/
class LaunchThreadFrame : public TIRFrame {
public:
+ explicit LaunchThreadFrame(ObjectPtr<LaunchThreadFrameNode> data) :
TIRFrame(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(LaunchThreadFrame,
TIRFrame,
LaunchThreadFrameNode);
};
@@ -444,6 +475,10 @@ class RealizeFrameNode : public TIRFrameNode {
*/
class RealizeFrame : public TIRFrame {
public:
+ explicit RealizeFrame(ObjectPtr<RealizeFrameNode> data) :
TIRFrame(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RealizeFrame, TIRFrame,
RealizeFrameNode);
};
@@ -496,6 +531,10 @@ class AllocateFrameNode : public TIRFrameNode {
*/
class AllocateFrame : public TIRFrame {
public:
+ explicit AllocateFrame(ObjectPtr<AllocateFrameNode> data) :
TIRFrame(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateFrame, TIRFrame,
AllocateFrameNode);
};
@@ -545,6 +584,11 @@ class AllocateConstFrameNode : public TIRFrameNode {
*/
class AllocateConstFrame : public TIRFrame {
public:
+ explicit AllocateConstFrame(ObjectPtr<AllocateConstFrameNode> data)
+ : TIRFrame(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AllocateConstFrame,
TIRFrame,
AllocateConstFrameNode);
};
@@ -588,6 +632,10 @@ class AttrFrameNode : public TIRFrameNode {
*/
class AttrFrame : public TIRFrame {
public:
+ explicit AttrFrame(ObjectPtr<AttrFrameNode> data) :
TIRFrame(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(AttrFrame, TIRFrame,
AttrFrameNode);
};
@@ -624,6 +672,10 @@ class WhileFrameNode : public TIRFrameNode {
*/
class WhileFrame : public TIRFrame {
public:
+ explicit WhileFrame(ObjectPtr<WhileFrameNode> data) :
TIRFrame(ffi::UnsafeInit{}) {
+ TVM_FFI_ICHECK(data != nullptr);
+ data_ = std::move(data);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(WhileFrame, TIRFrame,
WhileFrameNode);
};
@@ -667,6 +719,9 @@ class IfFrameNode : public TIRFrameNode {
*/
class IfFrame : public TIRFrame {
public:
+ explicit IfFrame(ObjectPtr<IfFrameNode> data) : TIRFrame(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, TIRFrame,
IfFrameNode);
};
@@ -705,6 +760,9 @@ class ThenFrameNode : public TIRFrameNode {
*/
class ThenFrame : public TIRFrame {
public:
+ explicit ThenFrame(ObjectPtr<ThenFrameNode> data) : TIRFrame(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, TIRFrame,
ThenFrameNode);
};
@@ -743,6 +801,10 @@ class ElseFrameNode : public TIRFrameNode {
*/
class ElseFrame : public TIRFrame {
public:
+ explicit ElseFrame(ObjectPtr<ElseFrameNode> data) : TIRFrame(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
+
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, TIRFrame,
ElseFrameNode);
};
@@ -769,6 +831,9 @@ class DeclBufferFrameNode : public TIRFrameNode {
class DeclBufferFrame : public TIRFrame {
public:
+ explicit DeclBufferFrame(ObjectPtr<DeclBufferFrameNode> data) :
TIRFrame(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(DeclBufferFrame, TIRFrame,
DeclBufferFrameNode);
};
diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index 976e3183a1..296df34524 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -88,6 +88,7 @@ class DocNode : public Object {
class Doc : public ObjectRef {
protected:
Doc() = default;
+ explicit Doc(ObjectPtr<DocNode> data) : ObjectRef(data) {}
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Doc, ObjectRef, DocNode);
@@ -156,6 +157,8 @@ class ExprDoc : public Doc {
*/
ExprDoc operator[](ffi::Array<Doc> indices) const;
+ explicit ExprDoc(ObjectPtr<ExprDocNode> data) : Doc(data) {
TVM_FFI_ICHECK(data != nullptr); }
+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode);
};
@@ -378,7 +381,7 @@ class IdDoc : public ExprDoc {
class AttrAccessDocNode : public ExprDocNode {
public:
/*! \brief The target expression to be accessed */
- ExprDoc value{nullptr};
+ ExprDoc value{ffi::UnsafeInit()};
/*! \brief The attribute to be accessed */
ffi::String name;
@@ -418,7 +421,7 @@ class AttrAccessDoc : public ExprDoc {
class IndexDocNode : public ExprDocNode {
public:
/*! \brief The container value to be accessed */
- ExprDoc value{nullptr};
+ ExprDoc value{ffi::UnsafeInit()};
/*!
* \brief The indices to access
*
@@ -464,7 +467,7 @@ class IndexDoc : public ExprDoc {
class CallDocNode : public ExprDocNode {
public:
/*! \brief The callee of this function call */
- ExprDoc callee{nullptr};
+ ExprDoc callee{ffi::UnsafeInit()};
/*! \brief The positional arguments */
ffi::Array<ExprDoc> args;
/*! \brief The keys of keyword arguments */
@@ -604,7 +607,7 @@ class LambdaDocNode : public ExprDocNode {
/*! \brief The arguments of this anonymous function */
ffi::Array<IdDoc> args;
/*! \brief The body of this anonymous function */
- ExprDoc body{nullptr};
+ ExprDoc body{ffi::UnsafeInit()};
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
@@ -664,7 +667,7 @@ class TupleDoc : public ExprDoc {
/*!
* \brief Create an empty TupleDoc
*/
- TupleDoc() : TupleDoc(ffi::make_object<TupleDocNode>()) {}
+ TupleDoc() : ExprDoc(ffi::make_object<TupleDocNode>()) {}
/*!
* \brief Constructor of TupleDoc
* \param elements Elements of tuple.
@@ -703,7 +706,7 @@ class ListDoc : public ExprDoc {
/*!
* \brief Create an empty ListDoc
*/
- ListDoc() : ListDoc(ffi::make_object<ListDocNode>()) {}
+ ListDoc() : ExprDoc(ffi::make_object<ListDocNode>()) {}
/*!
* \brief Constructor of ListDoc
* \param elements Elements of list.
@@ -751,7 +754,7 @@ class DictDoc : public ExprDoc {
/*!
* \brief Create an empty dictionary
*/
- DictDoc() : DictDoc(ffi::make_object<DictDocNode>()) {}
+ DictDoc() : ExprDoc(ffi::make_object<DictDocNode>()) {}
/*!
* \brief Constructor of DictDoc
* \param keys Keys of dictionary.
@@ -816,7 +819,7 @@ class SliceDoc : public Doc {
class AssignDocNode : public StmtDocNode {
public:
/*! \brief The left hand side of the assignment */
- ExprDoc lhs{nullptr};
+ ExprDoc lhs{ffi::UnsafeInit()};
/*!
* \brief The right hand side of the assignment.
*
@@ -864,7 +867,7 @@ class AssignDoc : public StmtDoc {
class IfDocNode : public StmtDocNode {
public:
/*! \brief The predicate of the if-then-else statement. */
- ExprDoc predicate{nullptr};
+ ExprDoc predicate{ffi::UnsafeInit()};
/*! \brief The then branch of the if-then-else statement. */
ffi::Array<StmtDoc> then_branch;
/*! \brief The else branch of the if-then-else statement. */
@@ -909,7 +912,7 @@ class IfDoc : public StmtDoc {
class WhileDocNode : public StmtDocNode {
public:
/*! \brief The predicate of the while statement. */
- ExprDoc predicate{nullptr};
+ ExprDoc predicate{ffi::UnsafeInit()};
/*! \brief The body of the while statement. */
ffi::Array<StmtDoc> body;
@@ -953,9 +956,9 @@ class WhileDoc : public StmtDoc {
class ForDocNode : public StmtDocNode {
public:
/*! \brief The left hand side of the assignment of iterating variable. */
- ExprDoc lhs{nullptr};
+ ExprDoc lhs{ffi::UnsafeInit()};
/*! \brief The right hand side of the assignment of iterating variable. */
- ExprDoc rhs{nullptr};
+ ExprDoc rhs{ffi::UnsafeInit()};
/*! \brief The body of the for statement. */
ffi::Array<StmtDoc> body;
@@ -1004,7 +1007,7 @@ class ScopeDocNode : public StmtDocNode {
/*! \brief The name of the scoped variable. */
ffi::Optional<ExprDoc> lhs{std::nullopt};
/*! \brief The value of the scoped variable. */
- ExprDoc rhs{nullptr};
+ ExprDoc rhs{ffi::UnsafeInit()};
/*! \brief The body of the scope doc. */
ffi::Array<StmtDoc> body;
@@ -1054,7 +1057,7 @@ class ScopeDoc : public StmtDoc {
class ExprStmtDocNode : public StmtDocNode {
public:
/*! \brief The expression represented by this doc. */
- ExprDoc expr{nullptr};
+ ExprDoc expr{ffi::UnsafeInit()};
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
@@ -1089,7 +1092,7 @@ class ExprStmtDoc : public StmtDoc {
class AssertDocNode : public StmtDocNode {
public:
/*! \brief The expression to test. */
- ExprDoc test{nullptr};
+ ExprDoc test{ffi::UnsafeInit()};
/*! \brief The optional error message when assertion failed. */
ffi::Optional<ExprDoc> msg{std::nullopt};
@@ -1129,7 +1132,7 @@ class AssertDoc : public StmtDoc {
class ReturnDocNode : public StmtDocNode {
public:
/*! \brief The value to return. */
- ExprDoc value{nullptr};
+ ExprDoc value{ffi::UnsafeInit()};
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
@@ -1164,7 +1167,7 @@ class ReturnDoc : public StmtDoc {
class FunctionDocNode : public StmtDocNode {
public:
/*! \brief The name of function. */
- IdDoc name{nullptr};
+ IdDoc name{ffi::UnsafeInit{}};
/*!
* \brief The arguments of function.
*
@@ -1223,7 +1226,7 @@ class FunctionDoc : public StmtDoc {
class ClassDocNode : public StmtDocNode {
public:
/*! \brief The name of class. */
- IdDoc name{nullptr};
+ IdDoc name{ffi::UnsafeInit{}};
/*! \brief Decorators of class. */
ffi::Array<ExprDoc> decorators;
/*! \brief The body of class. */
diff --git a/include/tvm/script/printer/ir_docsifier.h
b/include/tvm/script/printer/ir_docsifier.h
index 6e6be57f9c..a2fc1097ac 100644
--- a/include/tvm/script/printer/ir_docsifier.h
+++ b/include/tvm/script/printer/ir_docsifier.h
@@ -132,7 +132,7 @@ class IRDocsifierNode : public Object {
ffi::Optional<ffi::String> name;
};
/*! \brief The configuration of the printer */
- PrinterConfig cfg{nullptr};
+ PrinterConfig cfg{ffi::UnsafeInit()};
/*!
* \brief The stack of frames.
* \sa FrameNode
diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h
index ad167ce08b..f468f9cbac 100644
--- a/include/tvm/target/target_kind.h
+++ b/include/tvm/target/target_kind.h
@@ -127,6 +127,9 @@ class TargetKindNode : public Object {
class TargetKind : public ObjectRef {
public:
TargetKind() = default;
+ explicit TargetKind(ObjectPtr<TargetKindNode> data) : ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
/*! \brief Get the attribute map given the attribute name */
template <typename ValueType>
static inline TargetKindAttrMap<ValueType> GetAttrMap(const ffi::String&
attr_name);
diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h
index 8bcad6950f..68b2bbf715 100644
--- a/include/tvm/te/tensor.h
+++ b/include/tvm/te/tensor.h
@@ -50,6 +50,7 @@ class Operation : public ObjectRef {
/*! \brief default constructor */
Operation() {}
explicit Operation(ObjectPtr<Object> n) : ObjectRef(n) {}
+ explicit Operation(ffi::UnsafeInit tag) : ObjectRef(tag) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
diff --git a/include/tvm/tir/block_scope.h b/include/tvm/tir/block_scope.h
index 3fc2515d08..f79a456500 100644
--- a/include/tvm/tir/block_scope.h
+++ b/include/tvm/tir/block_scope.h
@@ -297,6 +297,13 @@ class BlockScopeNode : public Object {
*/
class BlockScope : public ObjectRef {
public:
+ /*!
+ * \brief Constructor from ObjectPtr<BlockScopeNode>.
+ * \param data The object pointer.
+ */
+ explicit BlockScope(ObjectPtr<BlockScopeNode> data) : ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
/*! \brief The constructor creating an empty block scope with on dependency
information */
TVM_DLL BlockScope();
/*!
diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h
index 8cb0053df7..22c4c7d7bd 100644
--- a/include/tvm/tir/schedule/state.h
+++ b/include/tvm/tir/schedule/state.h
@@ -43,7 +43,7 @@ namespace tir {
*/
struct BlockInfo {
/*! \brief Property of a block scope rooted at the block, storing
dependencies in the scope */
- BlockScope scope{nullptr};
+ BlockScope scope{ffi::UnsafeInit()};
// The properties below are information about the current block realization
under its parent scope
/*! \brief Property of a block, indicating the block realization binding is
quasi-affine */
bool affine_binding{false};
diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h
index 578b00fc08..51100c2292 100644
--- a/include/tvm/tir/var.h
+++ b/include/tvm/tir/var.h
@@ -77,7 +77,8 @@ class VarNode : public PrimExprNode {
/*! \brief a named variable in TIR */
class Var : public PrimExpr {
public:
- explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
+ explicit Var(ffi::UnsafeInit tag) : PrimExpr(tag) {}
+ explicit Var(ObjectPtr<VarNode> n) : PrimExpr(n) {}
/*!
* \brief Constructor
* \param name_hint variable name
@@ -143,7 +144,8 @@ class SizeVarNode : public VarNode {
/*! \brief a named variable represents a tensor index size */
class SizeVar : public Var {
public:
- explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
+ explicit SizeVar(ObjectPtr<SizeVarNode> n) : Var(n) {}
+ explicit SizeVar(ffi::UnsafeInit tag) : Var(tag) {}
/*!
* \brief constructor
* \param name_hint variable name
diff --git a/src/contrib/msc/core/printer/msc_doc.h
b/src/contrib/msc/core/printer/msc_doc.h
index ea1cee396b..6433f3de9a 100644
--- a/src/contrib/msc/core/printer/msc_doc.h
+++ b/src/contrib/msc/core/printer/msc_doc.h
@@ -45,7 +45,7 @@ class DeclareDocNode : public ExprDocNode {
/*! \brief The type of the variable */
ffi::Optional<ExprDoc> type;
/*! \brief The variable */
- ExprDoc variable{nullptr};
+ ExprDoc variable{ffi::UnsafeInit{}};
/*! \brief The init arguments for the variable. */
ffi::Array<ExprDoc> init_args;
/*! \brief Whether to use constructor(otherwise initializer) */
@@ -164,7 +164,7 @@ class PointerDoc : public ExprDoc {
class StructDocNode : public StmtDocNode {
public:
/*! \brief The name of class. */
- IdDoc name{nullptr};
+ IdDoc name{ffi::UnsafeInit{}};
/*! \brief Decorators of class. */
ffi::Array<ExprDoc> decorators;
/*! \brief The body of class. */
@@ -207,7 +207,7 @@ class StructDoc : public StmtDoc {
class ConstructorDocNode : public StmtDocNode {
public:
/*! \brief The name of function. */
- IdDoc name{nullptr};
+ IdDoc name{ffi::UnsafeInit{}};
/*!
* \brief The arguments of function.
*
@@ -300,7 +300,7 @@ class SwitchDoc : public StmtDoc {
class LambdaDocNode : public StmtDocNode {
public:
/*! \brief The name of lambda. */
- IdDoc name{nullptr};
+ IdDoc name{ffi::UnsafeInit{}};
/*!
* \brief The arguments of lambda.
*
diff --git a/src/ir/source_map.cc b/src/ir/source_map.cc
index 26fbe07cf6..47727d5297 100644
--- a/src/ir/source_map.cc
+++ b/src/ir/source_map.cc
@@ -46,7 +46,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
.def("__data_from_json__", SourceName::Get);
});
-ObjectPtr<Object> GetSourceNameNode(const ffi::String& name) {
+ObjectPtr<SourceNameNode> GetSourceNameNode(const ffi::String& name) {
// always return pointer as the reference can change as map re-allocate.
// or use another level of indirection by creating a unique_ptr
static std::unordered_map<ffi::String, ObjectPtr<SourceNameNode>> source_map;
@@ -62,7 +62,7 @@ ObjectPtr<Object> GetSourceNameNode(const ffi::String& name) {
}
}
-ObjectPtr<Object> GetSourceNameNodeByStr(const std::string& name) {
+ObjectPtr<SourceNameNode> GetSourceNameNodeByStr(const std::string& name) {
return GetSourceNameNode(name);
}
diff --git a/src/meta_schedule/database/database.cc
b/src/meta_schedule/database/database.cc
index b3c02607bd..8094449bfb 100644
--- a/src/meta_schedule/database/database.cc
+++ b/src/meta_schedule/database/database.cc
@@ -50,7 +50,7 @@ ObjectRef WorkloadNode::AsJSON() const {
}
Workload Workload::FromJSON(const ObjectRef& json_obj) {
- IRModule mod{nullptr};
+ IRModule mod{ffi::UnsafeInit()};
THashCode shash = 0;
try {
const ffi::ArrayObj* json_array = json_obj.as<ffi::ArrayObj>();
@@ -133,7 +133,7 @@ bool TuningRecordNode::IsValid() const {
}
TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload&
workload) {
- tir::Trace trace{nullptr};
+ tir::Trace trace{ffi::UnsafeInit()};
ffi::Optional<ffi::Array<FloatImm>> run_secs;
ffi::Optional<Target> target;
ffi::Optional<ffi::Array<ArgInfo>> args_info;
diff --git a/src/meta_schedule/database/json_database.cc
b/src/meta_schedule/database/json_database.cc
index cef4b6437b..56e179585e 100644
--- a/src/meta_schedule/database/json_database.cc
+++ b/src/meta_schedule/database/json_database.cc
@@ -185,11 +185,11 @@ Database Database::JSONDatabase(ffi::String
path_workload, ffi::String path_tuni
{
std::vector<Any> json_objs = JSONFileReadLines(path_tuning_record,
num_threads, allow_missing);
std::vector<TuningRecord> records;
- records.resize(json_objs.size(), TuningRecord{nullptr});
+ records.resize(json_objs.size(), TuningRecord{ffi::UnsafeInit()});
support::parallel_for_dynamic(
0, json_objs.size(), num_threads, [&](int thread_id, int task_id) {
auto json_obj = json_objs[task_id].cast<ObjectRef>();
- Workload workload{nullptr};
+ Workload workload{ffi::UnsafeInit()};
try {
const ffi::ArrayObj* arr = json_obj.as<ffi::ArrayObj>();
ICHECK_EQ(arr->size(), 2);
diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
index 88b6c2c649..a8ac2f05c4 100644
--- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
+++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
@@ -133,7 +133,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode
{
const GlobalVar& g_var = kv.first;
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
- IRModule lowered{nullptr};
+ IRModule lowered{ffi::UnsafeInit()};
try {
auto pass_list = ffi::Array<tvm::transform::Pass>();
pass_list.push_back(tir::transform::BindTarget(this->target));
diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
index f0047d688a..5b250a6d2b 100644
--- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
+++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
@@ -415,7 +415,7 @@ class RewriteParallelVectorizeUnrollNode : public
PostprocNode {
bool Apply(const Schedule& sch) final {
tir::ParsedAnnotation parsed_root;
- tir::BlockRV root_rv{nullptr};
+ tir::BlockRV root_rv{ffi::UnsafeInit()};
while (tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) {
for (tir::BlockRV block_rv : sch->GetChildBlocks(root_rv)) {
ffi::Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv);
diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc
b/src/meta_schedule/postproc/verify_gpu_code.cc
index 5aaf756d43..7e660dc7cf 100644
--- a/src/meta_schedule/postproc/verify_gpu_code.cc
+++ b/src/meta_schedule/postproc/verify_gpu_code.cc
@@ -114,8 +114,8 @@ Integer Extract(const Target& target, const char* name) {
/*! \brief Verify the correctness of the generated GPU code. */
class VerifyGPUCodeNode : public PostprocNode {
public:
- Target target_{nullptr};
- ffi::Map<ffi::String, PrimExpr> target_constraints_{nullptr};
+ Target target_{ffi::UnsafeInit()};
+ ffi::Map<ffi::String, PrimExpr> target_constraints_{ffi::UnsafeInit()};
int thread_warp_size_ = -1;
void InitializeWithTuneContext(const TuneContext& context) final {
@@ -150,7 +150,7 @@ class VerifyGPUCodeNode : public PostprocNode {
if (!tir::ThreadExtentChecker::Check(prim_func->body,
thread_warp_size_)) {
return false;
}
- IRModule lowered{nullptr};
+ IRModule lowered{ffi::UnsafeInit()};
try {
auto pass_list = ffi::Array<tvm::transform::Pass>();
// Phase 1
diff --git a/src/meta_schedule/schedule/cpu/winograd.cc
b/src/meta_schedule/schedule/cpu/winograd.cc
index e8afb71d6b..6a2b82aa42 100644
--- a/src/meta_schedule/schedule/cpu/winograd.cc
+++ b/src/meta_schedule/schedule/cpu/winograd.cc
@@ -31,7 +31,7 @@ static ffi::Array<tir::LoopRV> ScheduleDataPack(tir::Schedule
sch, tir::BlockRV
using namespace tvm::tir;
ICHECK_EQ(tiled.size(), 2);
ICHECK_EQ(unrolled.size(), 4);
- ffi::Array<ExprRV> factors{nullptr};
+ ffi::Array<ExprRV> factors{ffi::UnsafeInit()};
ffi::Array<LoopRV> loops = sch->GetLoops(block);
ICHECK_EQ(loops.size(), 6);
diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc
b/src/meta_schedule/schedule/cuda/thread_bind.cc
index b71ea9164e..2a042553d6 100644
--- a/src/meta_schedule/schedule/cuda/thread_bind.cc
+++ b/src/meta_schedule/schedule/cuda/thread_bind.cc
@@ -141,11 +141,11 @@ void BindBlockThreadIdx(tir::Schedule sch, tir::BlockRV
block_rv, //
ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx
is not";
throw;
}
- LoopRV loop_rv{nullptr};
+ LoopRV loop_rv{ffi::UnsafeInit()};
{
ffi::Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
if (i_spatial_loop == -1) {
- LoopRV spatial_loop_rv{nullptr};
+ LoopRV spatial_loop_rv{ffi::UnsafeInit()};
if (loop_rvs.empty()) {
spatial_loop_rv = sch->AddUnitLoop(block_rv);
} else {
diff --git a/src/meta_schedule/schedule/cuda/winograd.cc
b/src/meta_schedule/schedule/cuda/winograd.cc
index 759ab9fc72..2b9f4f78df 100644
--- a/src/meta_schedule/schedule/cuda/winograd.cc
+++ b/src/meta_schedule/schedule/cuda/winograd.cc
@@ -35,7 +35,7 @@ static ffi::Array<tir::LoopRV> ScheduleDataPack(tir::Schedule
sch, tir::BlockRV
using namespace tvm::tir;
ICHECK_EQ(tiled.size(), 2);
ICHECK_EQ(unrolled.size(), 4);
- ffi::Array<ExprRV> factors{nullptr};
+ ffi::Array<ExprRV> factors{ffi::UnsafeInit()};
ffi::Array<LoopRV> loops = sch->GetLoops(block);
ICHECK_EQ(loops.size(), 6);
@@ -109,7 +109,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
int64_t max_threads_per_block = 1024;
BlockRV input_tile = GetWinogradProducerAndInlineConst(sch,
data_pack);
BlockRV data_pad = GetWinogradProducerAndInlineConst(sch,
input_tile);
- LoopRV outer{nullptr};
+ LoopRV outer{ffi::UnsafeInit()};
{
ffi::Array<LoopRV> loops = sch->GetLoops(data_pack);
ICHECK_EQ(loops.size(), 6);
@@ -139,7 +139,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
// loops on top of the inverse block: [CO, P, tile_size,
tile_size, alpha, alpha]
int64_t tile_size =
Downcast<IntImm>(sch->Get(inverse)->writes[0]->buffer->shape[2])->value;
- LoopRV outer{nullptr};
+ LoopRV outer{ffi::UnsafeInit()};
{
BlockRV output = sch->GetConsumers(inverse)[0];
ffi::Array<LoopRV> nchw = sch->GetLoops(output);
diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc
b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc
index 219e05254e..d399517791 100644
--- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc
+++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc
@@ -171,7 +171,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
* \return The extent of "threadIdx.x" in the input schedule
*/
tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) {
- tir::ExprRV extent{nullptr};
+ tir::ExprRV extent{ffi::UnsafeInit()};
for (const tir::Instruction& inst : trace->insts) {
if (inst->kind->name == "Bind" && Downcast<ffi::String>(inst->attrs[0])
== "threadIdx.x") {
if (GetLoopRVExtentSource(trace,
Downcast<tir::LoopRV>(inst->inputs[0]), &extent)) {
@@ -198,8 +198,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
// Step 0. Due to technical reason of some primitives (e.g., compute-at),
if the block is doing
// a tuple reduction, fusion is temporarily not supported.
if (sch->Get(block_rv)->writes.size() != 1) {
- return std::make_tuple(false, tir::LoopRV{nullptr},
tir::BlockRV{nullptr},
- tir::LoopRV{nullptr});
+ return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()},
tir::BlockRV{ffi::UnsafeInit()},
+ tir::LoopRV{ffi::UnsafeInit()});
}
// Step 1. Get all the consumers of the input block.
@@ -208,8 +208,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
// Step 2. If the block has no consumer or the first consumer needs
multi-level tiling, it is
// not fusible.
if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(),
sch->GetSRef(consumers[0]))) {
- return std::make_tuple(false, tir::LoopRV{nullptr},
tir::BlockRV{nullptr},
- tir::LoopRV{nullptr});
+ return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()},
tir::BlockRV{ffi::UnsafeInit()},
+ tir::LoopRV{ffi::UnsafeInit()});
}
// Step 3. Calculate the lowest common ancestor of all the consumers.
@@ -221,8 +221,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
const tir::StmtSRef& lca_sref =
tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch,
consumers));
if (consumers.size() > 1 && lca_sref->StmtAs<tir::BlockNode>() != nullptr)
{
- return std::make_tuple(false, tir::LoopRV{nullptr},
tir::BlockRV{nullptr},
- tir::LoopRV{nullptr});
+ return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()},
tir::BlockRV{ffi::UnsafeInit()},
+ tir::LoopRV{ffi::UnsafeInit()});
}
// Step 4. Get the outer loops of the target block, and get the compute-at
position index.
@@ -231,8 +231,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
// Step 5. A negative position index means not fusible, and vice-versa.
if (pos < 0) {
- return std::make_tuple(false, tir::LoopRV{nullptr},
tir::BlockRV{nullptr},
- tir::LoopRV{nullptr});
+ return std::make_tuple(false, tir::LoopRV{ffi::UnsafeInit()},
tir::BlockRV{ffi::UnsafeInit()},
+ tir::LoopRV{ffi::UnsafeInit()});
} else {
return std::make_tuple(true, tgt_block_loops[pos], consumers[0],
tgt_block_loops.back());
}
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index 0bbccbdffe..741f0b6db4 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -77,7 +77,7 @@ class TensorCoreStateNode : public StateNode {
/*! \brief The tensor core intrinsic group. */
TensorCoreIntrinGroup intrin_group;
/*! \brief The auto tensorization maping info. */
- tir::AutoTensorizeMappingInfo mapping_info{nullptr};
+ tir::AutoTensorizeMappingInfo mapping_info{ffi::UnsafeInit()};
/*! \brief The Tensor Core reindex block A for Tensor Core computation */
tir::BlockRV tensor_core_reindex_A;
/*! \brief The Tensor Core reindex block B for Tensor Core computation */
diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc
b/src/meta_schedule/search_strategy/evolutionary_search.cc
index 306a3634d9..456fbbf129 100644
--- a/src/meta_schedule/search_strategy/evolutionary_search.cc
+++ b/src/meta_schedule/search_strategy/evolutionary_search.cc
@@ -112,7 +112,7 @@ class SizedHeap {
};
struct PerThreadData {
- IRModule mod{nullptr};
+ IRModule mod{ffi::UnsafeInit()};
TRandState rand_state{-1};
std::function<int32_t()> trace_sampler = nullptr;
std::function<ffi::Optional<Mutator>()> mutator_sampler = nullptr;
@@ -270,11 +270,11 @@ class EvolutionarySearchNode : public SearchStrategyNode {
* */
IRModuleSet measured_workloads_;
/*! \brief A Database for selecting useful candidates. */
- Database database_{nullptr};
+ Database database_{ffi::UnsafeInit()};
/*! \brief A cost model helping to explore the search space */
- CostModel cost_model_{nullptr};
+ CostModel cost_model_{ffi::UnsafeInit()};
/*! \brief The token registered for the given workload in database. */
- Workload token_{nullptr};
+ Workload token_{ffi::UnsafeInit()};
explicit State(EvolutionarySearchNode* self, int max_trials, int
num_trials_per_iter,
ffi::Array<Schedule> design_space_schedules, Database
database,
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 732a3a083d..ee94b1d2ab 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -360,7 +360,7 @@ struct ThreadedTraceApply {
/*! \brief A helper data structure that stores the fail count for each
postprocessor. */
struct Item {
/*! \brief The postprocessor. */
- Postproc postproc{nullptr};
+ Postproc postproc{ffi::UnsafeInit()};
/*! \brief The thread-safe postprocessor failure counter. */
std::atomic<int> fail_counter{0};
};
diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc
index 11867dee6d..a97c5f784d 100644
--- a/src/relax/ir/py_expr_functor.cc
+++ b/src/relax/ir/py_expr_functor.cc
@@ -177,6 +177,9 @@ class PyExprVisitorNode : public Object, public ExprVisitor
{
*/
class PyExprVisitor : public ObjectRef {
public:
+ explicit PyExprVisitor(ObjectPtr<PyExprVisitorNode> data) : ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
/*!
* \brief Create a PyExprVisitor with customized methods on the python-side.
* \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`.
@@ -461,6 +464,9 @@ class PyExprMutatorNode : public Object, public ExprMutator
{
*/
class PyExprMutator : public ObjectRef {
public:
+ explicit PyExprMutator(ObjectPtr<PyExprMutatorNode> data) : ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
/*!
* \brief Create a PyExprMutator with customized methods on the python-side.
* \param f_visit_expr The packed function of `VisitExpr(const Expr& expr)`.
diff --git a/src/relax/transform/few_shot_tuning.cc
b/src/relax/transform/few_shot_tuning.cc
index 091247272a..7deffaa9f5 100644
--- a/src/relax/transform/few_shot_tuning.cc
+++ b/src/relax/transform/few_shot_tuning.cc
@@ -34,7 +34,7 @@ tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc&
prim_func, const Target&
meta_schedule::Builder builder =
f_get_local_builder().cast<meta_schedule::Builder>();
ICHECK(builder.defined()) << "ValueError: The local builder is not defined!";
// fetch a local runner
- meta_schedule::Runner runner{nullptr};
+ meta_schedule::Runner runner{ffi::UnsafeInit()};
if (benchmark) {
static const auto f_get_local_runner =
tvm::ffi::Function::GetGlobalRequired("meta_schedule.runner.get_local_runner");
diff --git a/src/relax/transform/meta_schedule.cc
b/src/relax/transform/meta_schedule.cc
index 2d24f0785a..295937084d 100644
--- a/src/relax/transform/meta_schedule.cc
+++ b/src/relax/transform/meta_schedule.cc
@@ -81,7 +81,7 @@ Pass MetaScheduleApplyDatabase(ffi::Optional<ffi::String>
work_dir, bool enable_
ICHECK(normalize_mod_func_.has_value()) << "Normalization function is not
found.";
auto pass_func = [=](IRModule mod, PassContext ctx) {
- Database database{nullptr};
+ Database database{ffi::UnsafeInit()};
if (Database::Current().defined()) {
database = Database::Current().value();
} else {
diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h
index 265c58f4af..4c456b861e 100644
--- a/src/runtime/rpc/rpc_session.h
+++ b/src/runtime/rpc/rpc_session.h
@@ -333,6 +333,9 @@ class RPCObjectRefObj : public Object {
*/
class RPCObjectRef : public ObjectRef {
public:
+ explicit RPCObjectRef(ObjectPtr<RPCObjectRefObj> data) : ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RPCObjectRef, ObjectRef,
RPCObjectRefObj);
};
diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc
index 9b0d2b966a..666b3839ea 100644
--- a/src/script/printer/relax/call.cc
+++ b/src/script/printer/relax/call.cc
@@ -264,7 +264,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (ffi::Optional<ExprDoc> doc = PrintRelaxPrint(n, n_p, d)) {
return doc.value();
}
- ExprDoc prefix{nullptr};
+ ExprDoc prefix{ffi::UnsafeInit()};
ffi::Array<ExprDoc> args;
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc> kwargs_values;
diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc
index 587520d72f..1a33d760a9 100644
--- a/src/script/printer/tir/block.cc
+++ b/src/script/printer/tir/block.cc
@@ -83,7 +83,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, AccessPath
block_p, //
LOG(FATAL) << "ValueError: Unknown IterVarType in block signature: "
<< tir::IterVarType2String(iter_var->iter_type);
}
- ExprDoc dom{nullptr};
+ ExprDoc dom{ffi::UnsafeInit()};
if (tir::is_zero(iter_var->dom->min)) {
ExprDoc extent = d->AsDoc<ExprDoc>(iter_var->dom->extent, //
iter_var_p->Attr("dom")->Attr("extent"));
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index ddcf1b64f1..da525aa35f 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -27,7 +27,7 @@ namespace printer {
ExprDoc PrintVarCreation(const tir::Var& var, const AccessPath& var_p, const
IRDocsifier& d) {
Type type = var->type_annotation;
AccessPath type_p = var_p->Attr("type_annotation");
- ExprDoc rhs{nullptr};
+ ExprDoc rhs{ffi::UnsafeInit()};
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc> kwargs_values;
@@ -169,7 +169,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::CommReducer>( //
"", [](tir::CommReducer r, AccessPath p, IRDocsifier d) -> Doc {
ICHECK_EQ(r->lhs.size(), r->rhs.size());
- LambdaDoc lambda{nullptr};
+ ffi::Optional<LambdaDoc> lambda;
{
With<TIRFrame> f(d, r);
int n_vars = r->lhs.size();
@@ -194,7 +194,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
}
ExprDoc id = d->AsDoc<ExprDoc>(r->identity_element,
p->Attr("identity_element"));
- return TIR(d, "comm_reducer")->Call({lambda, id});
+ return TIR(d, "comm_reducer")->Call({lambda.value(), id});
});
LambdaDoc PrintIndexMap(const ObjectRef& map, const ffi::Array<tir::Var>& vs,
@@ -244,7 +244,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
static const OpAttrMap<tir::TScriptDtypePrintLocation> dtype_locations =
Op::GetAttrMap<tir::TScriptDtypePrintLocation>("TScriptDtypePrintLocation");
tir::ScriptDtypePrintLocation dtype_print_location =
tir::ScriptDtypePrintLocation::kNone;
- ExprDoc prefix{nullptr};
+ ffi::Optional<ExprDoc> prefix;
if (auto optional_op = call->op.as<Op>()) {
auto op = optional_op.value();
ffi::String name = op_names.get(op, op->name);
@@ -279,7 +279,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) {
args.push_back(LiteralDoc::DataType(call->dtype,
call_p->Attr("dtype")));
}
- return prefix->Call(args);
+ return prefix.value()->Call(args);
}
} else if (call->op.as<GlobalVarNode>()) {
prefix = d->AsDoc<ExprDoc>(call->op, call_p->Attr("op"));
@@ -299,7 +299,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) {
args.push_back(LiteralDoc::DataType(call->dtype,
call_p->Attr("dtype")));
}
- return prefix->Call(args);
+ return prefix.value()->Call(args);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
diff --git a/src/script/printer/tir/for_loop.cc
b/src/script/printer/tir/for_loop.cc
index 10bb6f756d..742d23f69c 100644
--- a/src/script/printer/tir/for_loop.cc
+++ b/src/script/printer/tir/for_loop.cc
@@ -78,7 +78,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (!loop->annotations.empty()) {
annotations = d->AsDoc<ExprDoc>(loop->annotations,
loop_p->Attr("annotations"));
}
- ExprDoc prefix{nullptr};
+ ExprDoc prefix{ffi::UnsafeInit()};
if (loop->kind == tir::ForKind::kSerial) {
if (loop->annotations.empty()) {
prefix = IdDoc("range");
diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc
index 0cd38d4c6a..797c726c7c 100644
--- a/src/script/printer/tir/ir.cc
+++ b/src/script/printer/tir/ir.cc
@@ -66,7 +66,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<PointerType>("", [](PointerType ty, AccessPath ty_p,
IRDocsifier d) -> Doc {
- ExprDoc element_type{nullptr};
+ ExprDoc element_type{ffi::UnsafeInit()};
if (const auto* prim_type = ty->element_type.as<PrimTypeNode>()) {
element_type = LiteralDoc::DataType(prim_type->dtype, //
ty_p->Attr("element_type")->Attr("dtype"));
diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc
index 228fbbc785..1b0774be36 100644
--- a/src/script/printer/tir/stmt.cc
+++ b/src/script/printer/tir/stmt.cc
@@ -284,7 +284,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
ffi::Array<ExprDoc> args;
ffi::Array<ffi::String> kwargs_keys;
ffi::Array<ExprDoc> kwargs_values;
- ExprDoc data_doc{nullptr};
+ ExprDoc data_doc{ffi::UnsafeInit()};
if (stmt->dtype.is_int()) {
if (stmt->dtype.bits() == 8) {
data_doc = PrintTensor<int8_t>(stmt->data.value());
@@ -377,7 +377,7 @@ ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt,
const AccessPath& at
tir::IterVar iter_var = Downcast<tir::IterVar>(attr_stmt->node);
AccessPath iter_var_p = attr_stmt_p->Attr("node");
- ExprDoc var_doc{nullptr};
+ ExprDoc var_doc{ffi::UnsafeInit()};
if (d->IsVarDefined(iter_var->var)) {
var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
} else if (IsAncestorOfAllVarUse(attr_stmt, iter_var->var, d)) {
diff --git a/src/target/target.cc b/src/target/target.cc
index b2c3e8fe8c..e2013aba72 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -56,10 +56,10 @@ class TargetInternal {
const ffi::Map<ffi::String, ffi::Any>& attrs);
static Any ParseType(const std::string& str, const
TargetKindNode::ValueTypeInfo& info);
static Any ParseType(const Any& obj, const TargetKindNode::ValueTypeInfo&
info);
- static ObjectPtr<Object> FromString(const ffi::String&
tag_or_config_or_target_str);
- static ObjectPtr<Object> FromConfigString(const ffi::String& config_str);
- static ObjectPtr<Object> FromRawString(const ffi::String& target_str);
- static ObjectPtr<Object> FromConfig(ffi::Map<ffi::String, ffi::Any> config);
+ static ObjectPtr<TargetNode> FromString(const ffi::String&
tag_or_config_or_target_str);
+ static ObjectPtr<TargetNode> FromConfigString(const ffi::String& config_str);
+ static ObjectPtr<TargetNode> FromRawString(const ffi::String& target_str);
+ static ObjectPtr<TargetNode> FromConfig(ffi::Map<ffi::String, ffi::Any>
config);
static void ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv);
static Target WithHost(const Target& target, const Target& target_host) {
ObjectPtr<TargetNode> n = ffi::make_object<TargetNode>(*target.get());
@@ -771,10 +771,10 @@ void
TargetInternal::ConstructorDispatcher(ffi::PackedArgs args, ffi::Any* rv) {
LOG(FATAL) << "ValueError: Invalid number of arguments. Expect 1 or 2, but
gets: " << args.size();
}
-ObjectPtr<Object> TargetInternal::FromString(const ffi::String&
tag_or_config_or_target_str) {
+ObjectPtr<TargetNode> TargetInternal::FromString(const ffi::String&
tag_or_config_or_target_str) {
if (ffi::Optional<Target> target =
TargetTag::Get(tag_or_config_or_target_str)) {
Target value = target.value();
- return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<Object>(value);
+ return
ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<TargetNode>(value);
}
if (!tag_or_config_or_target_str.empty() &&
tag_or_config_or_target_str.data()[0] == '{') {
return TargetInternal::FromConfigString(tag_or_config_or_target_str);
@@ -782,7 +782,7 @@ ObjectPtr<Object> TargetInternal::FromString(const
ffi::String& tag_or_config_or
return TargetInternal::FromRawString(tag_or_config_or_target_str);
}
-ObjectPtr<Object> TargetInternal::FromConfigString(const ffi::String&
config_str) {
+ObjectPtr<TargetNode> TargetInternal::FromConfigString(const ffi::String&
config_str) {
const auto loader =
tvm::ffi::Function::GetGlobal("target._load_config_dict");
ICHECK(loader.has_value())
<< "AttributeError: \"target._load_config_dict\" is not registered.
Please check "
@@ -794,7 +794,7 @@ ObjectPtr<Object> TargetInternal::FromConfigString(const
ffi::String& config_str
return TargetInternal::FromConfig({config.value().begin(),
config.value().end()});
}
-ObjectPtr<Object> TargetInternal::FromRawString(const ffi::String& target_str)
{
+ObjectPtr<TargetNode> TargetInternal::FromRawString(const ffi::String&
target_str) {
ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string";
// Split the string by empty spaces
std::vector<std::string> options = SplitString(std::string(target_str), ' ');
@@ -826,7 +826,7 @@ ObjectPtr<Object> TargetInternal::FromRawString(const
ffi::String& target_str) {
return TargetInternal::FromConfig(config);
}
-ObjectPtr<Object> TargetInternal::FromConfig(ffi::Map<ffi::String, ffi::Any>
config) {
+ObjectPtr<TargetNode> TargetInternal::FromConfig(ffi::Map<ffi::String,
ffi::Any> config) {
const ffi::String kKind = "kind";
const ffi::String kTag = "tag";
const ffi::String kKeys = "keys";
diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc
index 871452aeb9..26b55d3bb9 100644
--- a/src/tir/ir/py_functor.cc
+++ b/src/tir/ir/py_functor.cc
@@ -342,6 +342,9 @@ class PyStmtExprVisitorNode : public Object, public
StmtExprVisitor {
*/
class PyStmtExprVisitor : public ObjectRef {
public:
+ explicit PyStmtExprVisitor(ObjectPtr<PyStmtExprVisitorNode> data) :
ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DLL static PyStmtExprVisitor MakePyStmtExprVisitor(ffi::Function
f_visit_stmt, //
ffi::Function
f_visit_expr, //
ffi::Function
f_visit_let_stmt, //
@@ -702,6 +705,9 @@ class PyStmtExprMutatorNode : public Object, public
StmtExprMutator {
/*! \brief Managed reference to PyStmtExprMutatorNode. */
class PyStmtExprMutator : public ObjectRef {
public:
+ explicit PyStmtExprMutator(ObjectPtr<PyStmtExprMutatorNode> data) :
ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
/*!
* \brief Create a PyStmtExprMutator with customized methods on the
python-side.
* \return The PyStmtExprMutator created.
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 8f3372b0ca..910c22aae0 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -761,6 +761,9 @@ class TensorizeInfoNode : public Object {
class TensorizeInfo : public ObjectRef {
public:
+ explicit TensorizeInfo(ObjectPtr<TensorizeInfoNode> data) : ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef,
TensorizeInfoNode);
};
@@ -810,6 +813,10 @@ class AutoTensorizeMappingInfoNode : public Object {
class AutoTensorizeMappingInfo : public ObjectRef {
public:
+ explicit AutoTensorizeMappingInfo(ObjectPtr<AutoTensorizeMappingInfoNode>
data)
+ : ObjectRef(data) {
+ TVM_FFI_ICHECK(data != nullptr);
+ }
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AutoTensorizeMappingInfo,
ObjectRef,
AutoTensorizeMappingInfoNode);
};
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index b333331778..89ece53771 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -604,7 +604,7 @@ void ConcreteScheduleNode::ReorderBlockIterVar(const
BlockRV& block_rv,
}
LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) {
- LoopRV result{nullptr};
+ LoopRV result{ffi::UnsafeInit()};
TVM_TIR_SCHEDULE_BEGIN();
result = CreateRV<LoopRV>(tir::AddUnitLoop(state_, GetSRef(block_rv)));
TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_);
@@ -613,7 +613,7 @@ LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV&
block_rv) {
}
LoopRV ConcreteScheduleNode::AddUnitLoop(const LoopRV& loop_rv) {
- LoopRV result{nullptr};
+ LoopRV result{ffi::UnsafeInit()};
TVM_TIR_SCHEDULE_BEGIN();
result = CreateRV<LoopRV>(tir::AddUnitLoop(state_, GetSRef(loop_rv)));
TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_);
diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc
b/src/tir/transforms/memhammer_tensorcore_rewrite.cc
index c1b303e073..e16c518771 100644
--- a/src/tir/transforms/memhammer_tensorcore_rewrite.cc
+++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc
@@ -334,7 +334,7 @@ class WmmaToGlobalRewriter : public StmtExprMutator {
Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints,
OutputSet* output) const {
Stmt body{nullptr};
- ffi::Optional<For> compute_location{nullptr};
+ ffi::Optional<For> compute_location;
std::tie(body, compute_location) = TileWmmaBlock(stmt);
SeqStmt seq{nullptr};
Buffer cache_buffer;
@@ -543,7 +543,7 @@ class MmaToGlobalRewriter : public StmtExprMutator {
Stmt MmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints,
OutputSet* output) const {
Stmt body{nullptr};
- ffi::Optional<For> compute_location{nullptr};
+ ffi::Optional<For> compute_location;
std::tie(body, compute_location) = TileMmaToGlobalBlock(stmt);
SeqStmt seq{nullptr};
Buffer cache_buffer;