This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b1e1566f82 [REFACTOR][IR] Cleanup attrs.h: drop NullValue,
AttrsNodeReflAdapter, legacy BaseAttrsNode methods (#19607)
b1e1566f82 is described below
commit b1e1566f8264240a5e5a8cae836a212e9b43687f
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue May 26 10:03:17 2026 -0400
[REFACTOR][IR] Cleanup attrs.h: drop NullValue, AttrsNodeReflAdapter,
legacy BaseAttrsNode methods (#19607)
## Overview
This PR cleans up `include/tvm/ir/attrs.h` by removing four deprecated
abstractions:
1. `NullValue<T>()` sentinel helpers (replaced by `ffi::Optional<T>`)
2. `AttrsNodeReflAdapter<DerivedType>` shim template (Attrs structs now
inherit `BaseAttrsNode` directly)
3. `BaseAttrsNode::InitBySeq` / `InitByPackedArgs` legacy initialization
methods
4. `DictAttrsNode::InitByPackedArgs` override
It also migrates 9 pass-config classes from
`Attrs`/`AttrsNodeReflAdapter` to `ffi::Object`, since they are pass
configuration objects, not IR attributes.
## Changes
**Commit A — Replace NullValue<T>() call sites** (`[REFACTOR][IR]
Replace NullValue<T>() call sites with default construction`)
- 11 source files: replace `NullValue<T>()` with `T()`, `std::nullopt`,
or `DataType::Void()`
- `manipulate.h`/`manipulate.cc`: `FlipAttrs::axis` changed from
`Integer` to `ffi::Optional<int64_t>`
**Commit B — Drop NullValue, AttrsNodeReflAdapter, legacy BaseAttrsNode
methods** (`[REFACTOR][IR] Drop NullValue declaration,
AttrsNodeReflAdapter, BaseAttrsNode legacy methods`)
- `include/tvm/ir/attrs.h`: removes `NullValue<T>`, `InitBySeq`,
`InitByPackedArgs`, `AttrsNodeReflAdapter<T>`
- `src/ir/attrs.cc`: removes `DictAttrsNode::InitByPackedArgs`
definition
- `AttrsWithDefaultValues<T>()` broadened to accept any `ffi::ObjectRef`
subtype (needed for Commit D)
- Removes unused includes: `reflection/accessor.h`, `<functional>`,
`<vector>`
**Commit C — Subclass BaseAttrsNode directly** (`[REFACTOR][IR] Subclass
BaseAttrsNode directly, drop AttrsNodeReflAdapter`)
- 17 attrs headers in `include/tvm/relax/attrs/` +
`include/tvm/target/virtual_device.h`
- All `struct FooAttrs : public AttrsNodeReflAdapter<FooAttrs>` →
`struct FooAttrs : public BaseAttrsNode`
**Commit D — Migrate pass-config classes to ffi::Object** (`[REFACTOR]
Migrate pass-config classes to subclass ffi::Object`)
- 9 pass-config classes in `src/s_tir/`, `src/tirx/`,
`src/relax/backend/contrib/`
- `XConfigNode : public ffi::Object` (was
`AttrsNodeReflAdapter<XConfigNode>`)
- `XConfig : public ffi::ObjectRef` (was `Attrs`)
- Python bindings updated: 7 classes changed from `_ir.Attrs` to
`_ffi.Object`
## Design Decisions
**`AttrFieldInfo` / `OpNode::arguments` kept**: Pre-flight check
revealed `GetArgStructInfo()` in `op_common.h` and `op_common.cc`
actively reads `op->arguments` (names, counts). These were not dead
metadata — deleting them would break Relax op argument validation. They
are kept as-is.
**Commit E (trim attrs.h includes) reduced in scope**: Removing
`structural_equal.h`, `structural_hash.h`, and `<unordered_map>` from
`attrs.h` caused 47 downstream files to fail compilation. Rather than
adding explicit includes to 47 files, only clearly-unused includes
(`reflection/accessor.h`, `<functional>`, `<vector>`) were removed in
Commit B.
## Testing
- Build: clean compile with `-DUSE_CUDA=OFF -DUSE_LLVM=ON`
- Tests passing:
- `tests/python/ir/` (93 passed)
- `tests/python/relax/test_analysis.py`, `test_blockbuilder_core.py`,
`test_op_manipulate.py`, `test_transform.py` (209 passed)
- `tests/python/s_tir/transform/test_s_tir_transform_loop_partition.py`,
`test_s_tir_transform_unify_thread_binding.py` (30 passed)
- `tests/python/tirx-transform/test_tir_transform_unroll_loop.py`,
`test_tir_transform_simplify.py`, `test_tir_transform_remove_no_op.py`
(108 passed, 6 xfailed)
- Pre-existing failures (unrelated to this PR):
`test_s_tir_transform_lower_opaque_block`,
`test_s_tir_transform_compact_buffer_region::TestLetBinding::test_compact`,
`test_tir_transform_vectorize::test_vectorize_llvm_pure_intrin_fail`
---
include/tvm/ir/attrs.h | 93 ++++------------------
include/tvm/relax/attrs/ccl.h | 6 +-
include/tvm/relax/attrs/create.h | 4 +-
include/tvm/relax/attrs/datatype.h | 4 +-
include/tvm/relax/attrs/distributed.h | 2 +-
include/tvm/relax/attrs/image.h | 6 +-
include/tvm/relax/attrs/index.h | 4 +-
include/tvm/relax/attrs/linear_algebra.h | 4 +-
include/tvm/relax/attrs/manipulate.h | 41 +++++-----
include/tvm/relax/attrs/nn.h | 52 ++++++------
include/tvm/relax/attrs/op.h | 10 +--
include/tvm/relax/attrs/qdq.h | 2 +-
include/tvm/relax/attrs/sampling.h | 2 +-
include/tvm/relax/attrs/search.h | 4 +-
include/tvm/relax/attrs/sorting.h | 10 +--
include/tvm/relax/attrs/statistical.h | 4 +-
include/tvm/relax/attrs/vision.h | 13 ++-
include/tvm/target/virtual_device.h | 2 +-
python/tvm/relax/op/manipulate.py | 2 +-
python/tvm/s_tir/transform/transform.py | 5 +-
python/tvm/tirx/transform/transform.py | 11 ++-
src/ir/attrs.cc | 8 --
src/relax/backend/contrib/clml/codegen.cc | 8 +-
src/relax/backend/contrib/tensorrt/codegen.cc | 8 +-
src/relax/op/tensor/manipulate.cc | 10 +--
src/relax/op/tensor/manipulate.h | 2 +-
src/s_tir/schedule/concrete_schedule.cc | 4 +-
src/s_tir/schedule/traced_schedule.cc | 4 +-
src/s_tir/transform/hoist_expression.cc | 12 +--
src/s_tir/transform/inject_double_buffer.cc | 8 +-
src/s_tir/transform/loop_partition.cc | 8 +-
.../transform/lower_cross_thread_reduction.cc | 2 +-
src/s_tir/transform/storage_access.h | 2 +-
src/s_tir/transform/unify_thread_binding.cc | 3 +-
src/tirx/analysis/stmt_finding.cc | 2 +-
src/tirx/script/builder/frame.cc | 2 +-
src/tirx/transform/remove_no_op.cc | 9 ++-
src/tirx/transform/simplify.cc | 14 ++--
src/tirx/transform/unroll_loop.cc | 9 ++-
tests/cpp/ir_functor_test.cc | 2 +-
40 files changed, 163 insertions(+), 235 deletions(-)
diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index fa3dfa5b3e..287a263517 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -23,7 +23,7 @@
* This module enables declaration of named attributes
* which support default value setup and bound checking.
*
- * \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD
+ * \sa BaseAttrsNode, AttrsWithDefaultValues
*/
#ifndef TVM_IR_ATTRS_H_
#define TVM_IR_ATTRS_H_
@@ -32,36 +32,17 @@
#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/extra/structural_hash.h>
#include <tvm/ffi/function.h>
-#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/cow.h>
#include <tvm/ir/expr.h>
-#include <functional>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <utility>
-#include <vector>
namespace tvm {
-/*!
- * \brief Create a NodeRef type that represents null.
- * \tparam TNodeRef the type to be created.
- * \return A instance that will represent None.
- */
-template <typename TObjectRef>
-inline TObjectRef NullValue() {
- static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for
nullable types");
- return TObjectRef(ffi::ObjectPtr<typename
TObjectRef::ContainerType>(nullptr));
-}
-
-template <>
-inline DataType NullValue<DataType>() {
- return DataType(DataType::kHandle, 0, 0);
-}
-
/*!
* \brief Information about attribute fields in string representations.
*/
@@ -103,22 +84,6 @@ class BaseAttrsNode : public ffi::Object {
public:
/*! \brief virtual destructor */
virtual ~BaseAttrsNode() {}
- /*!
- * \brief Initialize the attributes by sequence of arguments
- * \param args The positional arguments in the form
- * [key0, value0, key1, value1, ..., key_n, value_n]
- */
- template <typename... Args>
- inline void InitBySeq(Args&&... args);
- /*!
- * \brief Initialize the attributes by arguments.
- * \param kwargs The key value pairs for initialization.
- * [key0, value0, key1, value1, ..., key_n, value_n]
- * \param allow_unknown Whether allow additional unknown fields.
- * \note This function throws when the required field is not present.
- */
- TVM_DLL virtual void InitByPackedArgs(const ffi::PackedArgs& kwargs,
- bool allow_unknown = false) = 0;
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
TVM_FFI_DECLARE_OBJECT_INFO("ir.Attrs", BaseAttrsNode, ffi::Object);
@@ -149,8 +114,6 @@ class DictAttrsNode : public BaseAttrsNode {
rfl::ObjectDef<DictAttrsNode>().def_ro("__dict__", &DictAttrsNode::dict);
}
- void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final;
-
// type info
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.DictAttrs", DictAttrsNode,
BaseAttrsNode);
};
@@ -380,48 +343,20 @@ inline TFunc WithoutAttr(TFunc input, const std::string&
attr_key) {
}
/*!
- * \brief Adapter for AttrsNode with the new reflection API.
- *
- * We will phaseout the old AttrsNode in future in favor of the new reflection
API.
- * This adapter allows us to gradually migrate to the new reflection API.
- *
- * \tparam DerivedType The final attribute type.
+ * \brief Create an object with all default values, using the reflection
defaults.
+ * \tparam TObj the ObjectRef type to be created.
+ * \return An instance with all reflection-defined default values applied.
*/
-template <typename DerivedType>
-class AttrsNodeReflAdapter : public BaseAttrsNode {
- public:
- void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final
{
- TVM_FFI_THROW(InternalError) << "`" << DerivedType::_type_key
- << "` uses new reflection mechanism for init";
- }
-
- private:
- DerivedType* self() const {
- return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
- }
-};
-
-/*!
- * \brief Create an Attr object with all default values.
- * \tparam TAttrNode the type to be created.
- * \return A instance that will represent None.
- */
-template <typename TAttrs>
-inline TAttrs AttrsWithDefaultValues() {
- static_assert(std::is_base_of_v<Attrs, TAttrs>, "Can only take attr nodes");
- using ContainerType = typename TAttrs::ContainerType;
- if constexpr (std::is_base_of_v<AttrsNodeReflAdapter<ContainerType>,
ContainerType>) {
- static auto finit_object =
ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs");
- AnyView packed_args[1];
- packed_args[0] = ContainerType::RuntimeTypeIndex();
- ffi::Any rv;
- finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv);
- return rv.cast<TAttrs>();
- } else {
- auto n = ffi::make_object<ContainerType>();
- n->InitByPackedArgs(ffi::PackedArgs(nullptr, 0), false);
- return TAttrs(n);
- }
+template <typename TObj>
+inline TObj AttrsWithDefaultValues() {
+ static_assert(std::is_base_of_v<ffi::ObjectRef, TObj>, "Can only create
ObjectRef-derived types");
+ using ContainerType = typename TObj::ContainerType;
+ static auto finit_object =
ffi::Function::GetGlobalRequired("ffi.MakeObjectFromPackedArgs");
+ AnyView packed_args[1];
+ packed_args[0] = ContainerType::RuntimeTypeIndex();
+ ffi::Any rv;
+ finit_object.CallPacked(ffi::PackedArgs(packed_args, 1), &rv);
+ return rv.cast<TObj>();
}
} // namespace tvm
diff --git a/include/tvm/relax/attrs/ccl.h b/include/tvm/relax/attrs/ccl.h
index 09d40b4ed9..7e0624706b 100644
--- a/include/tvm/relax/attrs/ccl.h
+++ b/include/tvm/relax/attrs/ccl.h
@@ -31,7 +31,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in allreduce operators */
-struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter<AllReduceAttrs> {
+struct AllReduceAttrs : public tvm::BaseAttrsNode {
ffi::String op_type;
bool in_group;
@@ -49,7 +49,7 @@ struct AllReduceAttrs : public
tvm::AttrsNodeReflAdapter<AllReduceAttrs> {
}; // struct AllReduceAttrs
/*! \brief Attributes used in allgather operators */
-struct AllGatherAttrs : public tvm::AttrsNodeReflAdapter<AllGatherAttrs> {
+struct AllGatherAttrs : public tvm::BaseAttrsNode {
int num_workers;
bool in_group;
@@ -67,7 +67,7 @@ struct AllGatherAttrs : public
tvm::AttrsNodeReflAdapter<AllGatherAttrs> {
}; // struct AllGatherAttrs
/*! \brief Attributes used in scatter operators */
-struct ScatterCollectiveAttrs : public
tvm::AttrsNodeReflAdapter<ScatterCollectiveAttrs> {
+struct ScatterCollectiveAttrs : public tvm::BaseAttrsNode {
int num_workers;
int axis;
diff --git a/include/tvm/relax/attrs/create.h b/include/tvm/relax/attrs/create.h
index c631fd3b4e..9a9e453263 100644
--- a/include/tvm/relax/attrs/create.h
+++ b/include/tvm/relax/attrs/create.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in full/full_like, ones/ones_like, and
zeros/zeros_like operators */
-struct InitAttrs : public AttrsNodeReflAdapter<InitAttrs> {
+struct InitAttrs : public BaseAttrsNode {
DataType dtype;
static void RegisterReflection() {
@@ -42,7 +42,7 @@ struct InitAttrs : public AttrsNodeReflAdapter<InitAttrs> {
}; // struct InitAttrs
/*! \brief Attributes used in tril and triu operator */
-struct TriluAttrs : public AttrsNodeReflAdapter<TriluAttrs> {
+struct TriluAttrs : public BaseAttrsNode {
int k;
static void RegisterReflection() {
diff --git a/include/tvm/relax/attrs/datatype.h
b/include/tvm/relax/attrs/datatype.h
index dd07e3b548..a187059703 100644
--- a/include/tvm/relax/attrs/datatype.h
+++ b/include/tvm/relax/attrs/datatype.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in astype operator */
-struct AstypeAttrs : public AttrsNodeReflAdapter<AstypeAttrs> {
+struct AstypeAttrs : public BaseAttrsNode {
DataType dtype;
static void RegisterReflection() {
@@ -41,7 +41,7 @@ struct AstypeAttrs : public AttrsNodeReflAdapter<AstypeAttrs>
{
}; // struct AstypeAttrs.
/*! \brief Attributes used in wrap_param operator */
-struct WrapParamAttrs : public AttrsNodeReflAdapter<WrapParamAttrs> {
+struct WrapParamAttrs : public BaseAttrsNode {
DataType dtype;
static void RegisterReflection() {
diff --git a/include/tvm/relax/attrs/distributed.h
b/include/tvm/relax/attrs/distributed.h
index 356a248ba2..cce508ef1d 100644
--- a/include/tvm/relax/attrs/distributed.h
+++ b/include/tvm/relax/attrs/distributed.h
@@ -32,7 +32,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes for redistribute and annotate_sharding operator */
-struct DistributionAttrs : public AttrsNodeReflAdapter<DistributionAttrs> {
+struct DistributionAttrs : public BaseAttrsNode {
distributed::DeviceMesh device_mesh;
distributed::Placement placement;
diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h
index 52aac58dcd..8cc5e36734 100644
--- a/include/tvm/relax/attrs/image.h
+++ b/include/tvm/relax/attrs/image.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in image resize2d operator */
-struct Resize2DAttrs : public AttrsNodeReflAdapter<Resize2DAttrs> {
+struct Resize2DAttrs : public BaseAttrsNode {
ffi::Array<FloatImm> roi;
ffi::String layout;
ffi::String method;
@@ -79,7 +79,7 @@ struct Resize2DAttrs : public
AttrsNodeReflAdapter<Resize2DAttrs> {
}; // struct Resize2dAttrs
/*! \brief Attributes used in image resize3d operator */
-struct Resize3DAttrs : public AttrsNodeReflAdapter<Resize3DAttrs> {
+struct Resize3DAttrs : public BaseAttrsNode {
ffi::Array<FloatImm> roi;
ffi::String layout;
ffi::String method;
@@ -128,7 +128,7 @@ struct Resize3DAttrs : public
AttrsNodeReflAdapter<Resize3DAttrs> {
}; // struct Resize3DAttrs
/*! \brief Attributes used in image grid_sample operator */
-struct GridSampleAttrs : public AttrsNodeReflAdapter<GridSampleAttrs> {
+struct GridSampleAttrs : public BaseAttrsNode {
ffi::String method;
ffi::String layout;
ffi::String padding_mode;
diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h
index 0ea7c06bac..7b4c446bb8 100644
--- a/include/tvm/relax/attrs/index.h
+++ b/include/tvm/relax/attrs/index.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in take operator */
-struct TakeAttrs : public AttrsNodeReflAdapter<TakeAttrs> {
+struct TakeAttrs : public BaseAttrsNode {
ffi::Optional<int64_t> axis;
ffi::String mode;
@@ -45,7 +45,7 @@ struct TakeAttrs : public AttrsNodeReflAdapter<TakeAttrs> {
}; // struct TakeAttrs
/*! \brief Attributes used in strided_slice operator */
-struct StridedSliceAttrs : public AttrsNodeReflAdapter<StridedSliceAttrs> {
+struct StridedSliceAttrs : public BaseAttrsNode {
bool assume_inbound;
static void RegisterReflection() {
diff --git a/include/tvm/relax/attrs/linear_algebra.h
b/include/tvm/relax/attrs/linear_algebra.h
index f95d817f1e..2627dafcf6 100644
--- a/include/tvm/relax/attrs/linear_algebra.h
+++ b/include/tvm/relax/attrs/linear_algebra.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes for matmul operator */
-struct MatmulAttrs : public AttrsNodeReflAdapter<MatmulAttrs> {
+struct MatmulAttrs : public BaseAttrsNode {
DataType out_dtype;
static void RegisterReflection() {
@@ -42,7 +42,7 @@ struct MatmulAttrs : public AttrsNodeReflAdapter<MatmulAttrs>
{
}; // struct MatmulAttrs
/*! \brief Attributes used in einsum operator */
-struct EinsumAttrs : public AttrsNodeReflAdapter<EinsumAttrs> {
+struct EinsumAttrs : public BaseAttrsNode {
ffi::String subscripts;
static void RegisterReflection() {
diff --git a/include/tvm/relax/attrs/manipulate.h
b/include/tvm/relax/attrs/manipulate.h
index f2ba7af0d9..71fb7b0b95 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -31,7 +31,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in concat operators */
-struct ConcatAttrs : public AttrsNodeReflAdapter<ConcatAttrs> {
+struct ConcatAttrs : public BaseAttrsNode {
ffi::Optional<int64_t> axis;
static void RegisterReflection() {
@@ -44,7 +44,7 @@ struct ConcatAttrs : public AttrsNodeReflAdapter<ConcatAttrs>
{
}; // struct ConcatAttrs
/*! \brief Attributes used in expand_dims operators */
-struct ExpandDimsAttrs : public AttrsNodeReflAdapter<ExpandDimsAttrs> {
+struct ExpandDimsAttrs : public BaseAttrsNode {
ffi::Array<Integer> axis;
static void RegisterReflection() {
@@ -59,7 +59,7 @@ struct ExpandDimsAttrs : public
AttrsNodeReflAdapter<ExpandDimsAttrs> {
}; // struct ExpandDimsAttrs
/*! \brief Attributes used in layout_transform operator */
-struct LayoutTransformAttrs : public
AttrsNodeReflAdapter<LayoutTransformAttrs> {
+struct LayoutTransformAttrs : public BaseAttrsNode {
tirx::IndexMap index_map;
// pad_value is chosen to be of PrimValue type, as it represents constant
TIR POD expression. This
// needs to be revisited in case PrimValue is evolved to represent symbolic
expression in future.
@@ -97,7 +97,7 @@ struct LayoutTransformAttrs : public
AttrsNodeReflAdapter<LayoutTransformAttrs>
}; // struct LayoutTransformAttrs
/*! \brief Attributes used in permute_dims operator */
-struct PermuteDimsAttrs : public AttrsNodeReflAdapter<PermuteDimsAttrs> {
+struct PermuteDimsAttrs : public BaseAttrsNode {
ffi::Optional<ffi::Array<Integer>> axes;
static void RegisterReflection() {
@@ -110,7 +110,7 @@ struct PermuteDimsAttrs : public
AttrsNodeReflAdapter<PermuteDimsAttrs> {
}; // struct PermuteDimsAttrs
/*! \brief Attributes used in split operator */
-struct SplitAttrs : public AttrsNodeReflAdapter<SplitAttrs> {
+struct SplitAttrs : public BaseAttrsNode {
ffi::ObjectRef indices_or_sections;
int axis;
@@ -125,7 +125,7 @@ struct SplitAttrs : public AttrsNodeReflAdapter<SplitAttrs>
{
}; // struct SplitAttrs
/*! \brief Attributes used in squeeze operators */
-struct SqueezeAttrs : public AttrsNodeReflAdapter<SqueezeAttrs> {
+struct SqueezeAttrs : public BaseAttrsNode {
ffi::Optional<ffi::Array<Integer>> axis;
static void RegisterReflection() {
@@ -140,7 +140,7 @@ struct SqueezeAttrs : public
AttrsNodeReflAdapter<SqueezeAttrs> {
}; // struct SqueezeAttrs
/*! \brief Attributes used in stack operators */
-struct StackAttrs : public AttrsNodeReflAdapter<StackAttrs> {
+struct StackAttrs : public BaseAttrsNode {
ffi::Optional<Integer> axis;
static void RegisterReflection() {
@@ -156,7 +156,7 @@ struct StackAttrs : public AttrsNodeReflAdapter<StackAttrs>
{
}; // struct StackAttrs
/*! \brief Attributes used in repeat operators */
-struct RepeatAttrs : public AttrsNodeReflAdapter<RepeatAttrs> {
+struct RepeatAttrs : public BaseAttrsNode {
int repeats;
ffi::Optional<int64_t> axis;
@@ -173,7 +173,7 @@ struct RepeatAttrs : public
AttrsNodeReflAdapter<RepeatAttrs> {
}; // struct RepeatAttrs
/*! \brief Attributes used in tile operators */
-struct TileAttrs : public AttrsNodeReflAdapter<TileAttrs> {
+struct TileAttrs : public BaseAttrsNode {
ffi::Array<Integer> repeats;
static void RegisterReflection() {
@@ -185,20 +185,19 @@ struct TileAttrs : public AttrsNodeReflAdapter<TileAttrs>
{
}; // struct TileAttrs
/*! \brief Attributes used in flip operators */
-struct FlipAttrs : public AttrsNodeReflAdapter<FlipAttrs> {
- Integer axis;
+struct FlipAttrs : public BaseAttrsNode {
+ int64_t axis;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<FlipAttrs>().def_ro("axis", &FlipAttrs::axis,
- "The axis along which to flip over.",
-
refl::DefaultValue(NullValue<Integer>()));
+ "The axis along which to flip over.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.FlipAttrs", FlipAttrs,
BaseAttrsNode);
}; // struct FlipAttrs
/*! \brief Attributes used in gather_elements operators */
-struct GatherElementsAttrs : public AttrsNodeReflAdapter<GatherElementsAttrs> {
+struct GatherElementsAttrs : public BaseAttrsNode {
Integer axis;
static void RegisterReflection() {
@@ -212,7 +211,7 @@ struct GatherElementsAttrs : public
AttrsNodeReflAdapter<GatherElementsAttrs> {
}; // struct GatherElementsAttrs
/*! \brief Attributes used in gather_nd operators */
-struct GatherNDAttrs : public AttrsNodeReflAdapter<GatherNDAttrs> {
+struct GatherNDAttrs : public BaseAttrsNode {
Integer batch_dims;
static void RegisterReflection() {
@@ -224,7 +223,7 @@ struct GatherNDAttrs : public
AttrsNodeReflAdapter<GatherNDAttrs> {
}; // struct GatherNDAttrs
/*! \brief Attributes used in index_put operator */
-struct IndexPutAttrs : public AttrsNodeReflAdapter<IndexPutAttrs> {
+struct IndexPutAttrs : public BaseAttrsNode {
bool accumulate;
static void RegisterReflection() {
@@ -240,7 +239,7 @@ struct IndexPutAttrs : public
AttrsNodeReflAdapter<IndexPutAttrs> {
}; // struct IndexPutAttrs
/*! \brief Attribute used in meshgrid operator */
-struct MeshgridAttrs : public AttrsNodeReflAdapter<MeshgridAttrs> {
+struct MeshgridAttrs : public BaseAttrsNode {
ffi::Optional<ffi::String> indexing;
static void RegisterReflection() {
@@ -252,7 +251,7 @@ struct MeshgridAttrs : public
AttrsNodeReflAdapter<MeshgridAttrs> {
};
/*! \brief Attributes used in scatter_elements operators */
-struct ScatterElementsAttrs : public
AttrsNodeReflAdapter<ScatterElementsAttrs> {
+struct ScatterElementsAttrs : public BaseAttrsNode {
Integer axis;
ffi::String reduction;
@@ -271,7 +270,7 @@ struct ScatterElementsAttrs : public
AttrsNodeReflAdapter<ScatterElementsAttrs>
}; // struct ScatterElementsAttrs
/*! \brief Attributes used in scatter_nd operators */
-struct ScatterNDAttrs : public AttrsNodeReflAdapter<ScatterNDAttrs> {
+struct ScatterNDAttrs : public BaseAttrsNode {
ffi::String reduction;
static void RegisterReflection() {
@@ -286,7 +285,7 @@ struct ScatterNDAttrs : public
AttrsNodeReflAdapter<ScatterNDAttrs> {
}; // struct ScatterNDAttrs
/*! \brief Attributes used in slice_scatter operator */
-struct SliceScatterAttrs : public AttrsNodeReflAdapter<SliceScatterAttrs> {
+struct SliceScatterAttrs : public BaseAttrsNode {
int axis;
static void RegisterReflection() {
@@ -300,7 +299,7 @@ struct SliceScatterAttrs : public
AttrsNodeReflAdapter<SliceScatterAttrs> {
}; // struct SliceScatterAttrs
/*! \brief Attributes used in one_hot operator */
-struct OneHotAttrs : public AttrsNodeReflAdapter<OneHotAttrs> {
+struct OneHotAttrs : public BaseAttrsNode {
int depth;
int axis;
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 45abeb9d5b..bfc85dfd5a 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in Conv1d operator */
-struct Conv1DAttrs : public AttrsNodeReflAdapter<Conv1DAttrs> {
+struct Conv1DAttrs : public BaseAttrsNode {
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
ffi::Array<int64_t> dilation;
@@ -74,7 +74,7 @@ struct Conv1DAttrs : public AttrsNodeReflAdapter<Conv1DAttrs>
{
}; // struct Conv1dAttrs
/*! \brief Attributes used in Conv2d operator */
-struct Conv2DAttrs : public AttrsNodeReflAdapter<Conv2DAttrs> {
+struct Conv2DAttrs : public BaseAttrsNode {
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
ffi::Array<int64_t> dilation;
@@ -120,7 +120,7 @@ struct Conv2DAttrs : public
AttrsNodeReflAdapter<Conv2DAttrs> {
}; // struct Conv2dAttrs
/*! \brief Attributes used in Conv3d operator */
-struct Conv3DAttrs : public AttrsNodeReflAdapter<Conv3DAttrs> {
+struct Conv3DAttrs : public BaseAttrsNode {
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
ffi::Array<int64_t> dilation;
@@ -168,7 +168,7 @@ struct Conv3DAttrs : public
AttrsNodeReflAdapter<Conv3DAttrs> {
}; // struct Conv3dAttrs
/*! \brief Attributes used in Conv1DTranspose operator */
-struct Conv1DTransposeAttrs : public
AttrsNodeReflAdapter<Conv1DTransposeAttrs> {
+struct Conv1DTransposeAttrs : public BaseAttrsNode {
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
ffi::Array<int64_t> output_padding;
@@ -217,7 +217,7 @@ struct Conv1DTransposeAttrs : public
AttrsNodeReflAdapter<Conv1DTransposeAttrs>
}; // struct Conv1DTransposeAttrs
/*! \brief Attributes used in Conv2d operator */
-struct Conv2DTransposeAttrs : public
AttrsNodeReflAdapter<Conv2DTransposeAttrs> {
+struct Conv2DTransposeAttrs : public BaseAttrsNode {
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
ffi::Array<int64_t> output_padding;
@@ -268,7 +268,7 @@ struct Conv2DTransposeAttrs : public
AttrsNodeReflAdapter<Conv2DTransposeAttrs>
}; // struct Conv2DTransposeAttrs
/*! \brief Attributes used in Conv3dTranspose operator */
-struct Conv3DTransposeAttrs : public
AttrsNodeReflAdapter<Conv3DTransposeAttrs> {
+struct Conv3DTransposeAttrs : public BaseAttrsNode {
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
ffi::Array<int64_t> output_padding;
@@ -321,7 +321,7 @@ struct Conv3DTransposeAttrs : public
AttrsNodeReflAdapter<Conv3DTransposeAttrs>
}; // struct Conv3DTransposeAttrs
/*! \brief Attributes used in max_pool1d and avg_pool1d operator */
-struct Pool1DAttrs : public AttrsNodeReflAdapter<Pool1DAttrs> {
+struct Pool1DAttrs : public BaseAttrsNode {
ffi::Array<int64_t> pool_size;
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
@@ -362,7 +362,7 @@ struct Pool1DAttrs : public
AttrsNodeReflAdapter<Pool1DAttrs> {
}; // struct Pool1dAttrs
/*! \brief Attributes used in max_pool2d and avg_pool2d operator */
-struct Pool2DAttrs : public AttrsNodeReflAdapter<Pool2DAttrs> {
+struct Pool2DAttrs : public BaseAttrsNode {
ffi::Array<int64_t> pool_size;
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
@@ -405,7 +405,7 @@ struct Pool2DAttrs : public
AttrsNodeReflAdapter<Pool2DAttrs> {
}; // struct Pool2dAttrs
/*! \brief Attributes used in max_pool3d and avg_pool3d operator */
-struct Pool3DAttrs : public AttrsNodeReflAdapter<Pool3DAttrs> {
+struct Pool3DAttrs : public BaseAttrsNode {
ffi::Array<int64_t> pool_size;
ffi::Array<int64_t> strides;
ffi::Array<int64_t> padding;
@@ -448,7 +448,7 @@ struct Pool3DAttrs : public
AttrsNodeReflAdapter<Pool3DAttrs> {
}; // struct Pool3dAttrs
/*! \brief Attributes for 1d adaptive pool operator */
-struct AdaptivePool1DAttrs : public AttrsNodeReflAdapter<AdaptivePool1DAttrs> {
+struct AdaptivePool1DAttrs : public BaseAttrsNode {
ffi::Optional<ffi::Array<int64_t>> output_size;
ffi::String layout;
ffi::String out_layout;
@@ -473,7 +473,7 @@ struct AdaptivePool1DAttrs : public
AttrsNodeReflAdapter<AdaptivePool1DAttrs> {
}; // struct AdaptivePool1DAttrs
/*! \brief Attributes for 2d adaptive pool operator */
-struct AdaptivePool2DAttrs : public AttrsNodeReflAdapter<AdaptivePool2DAttrs> {
+struct AdaptivePool2DAttrs : public BaseAttrsNode {
ffi::Optional<ffi::Array<int64_t>> output_size;
ffi::String layout;
ffi::String out_layout;
@@ -498,7 +498,7 @@ struct AdaptivePool2DAttrs : public
AttrsNodeReflAdapter<AdaptivePool2DAttrs> {
}; // struct AdaptivePool2DAttrs
/*! \brief Attributes for 3d adaptive pool operator */
-struct AdaptivePool3DAttrs : public AttrsNodeReflAdapter<AdaptivePool3DAttrs> {
+struct AdaptivePool3DAttrs : public BaseAttrsNode {
ffi::Optional<ffi::Array<int64_t>> output_size;
ffi::String layout;
ffi::String out_layout;
@@ -523,7 +523,7 @@ struct AdaptivePool3DAttrs : public
AttrsNodeReflAdapter<AdaptivePool3DAttrs> {
}; // struct AdaptivePool3DAttrs
/*! \brief Attributes used in softmax operators */
-struct SoftmaxAttrs : public AttrsNodeReflAdapter<SoftmaxAttrs> {
+struct SoftmaxAttrs : public BaseAttrsNode {
int axis;
static void RegisterReflection() {
@@ -535,7 +535,7 @@ struct SoftmaxAttrs : public
AttrsNodeReflAdapter<SoftmaxAttrs> {
};
/*! \brief Attributes used in softmax operators */
-struct LeakyReluAttrs : public AttrsNodeReflAdapter<LeakyReluAttrs> {
+struct LeakyReluAttrs : public BaseAttrsNode {
double alpha;
static void RegisterReflection() {
@@ -547,7 +547,7 @@ struct LeakyReluAttrs : public
AttrsNodeReflAdapter<LeakyReluAttrs> {
};
/*! \brief Attributes used in softplus operators */
-struct SoftplusAttrs : public AttrsNodeReflAdapter<SoftplusAttrs> {
+struct SoftplusAttrs : public BaseAttrsNode {
double beta;
double threshold;
@@ -563,7 +563,7 @@ struct SoftplusAttrs : public
AttrsNodeReflAdapter<SoftplusAttrs> {
};
/*! \brief Attributes used in PReLU operator */
-struct PReluAttrs : public AttrsNodeReflAdapter<PReluAttrs> {
+struct PReluAttrs : public BaseAttrsNode {
int axis;
static void RegisterReflection() {
@@ -575,7 +575,7 @@ struct PReluAttrs : public AttrsNodeReflAdapter<PReluAttrs>
{
};
/*! \brief Attributes used in batch_norm operator */
-struct BatchNormAttrs : public AttrsNodeReflAdapter<BatchNormAttrs> {
+struct BatchNormAttrs : public BaseAttrsNode {
int axis;
double epsilon;
bool center;
@@ -602,7 +602,7 @@ struct BatchNormAttrs : public
AttrsNodeReflAdapter<BatchNormAttrs> {
}; // struct BatchNormAttrs
/*! \brief Attributes used in layer_norm operator */
-struct LayerNormAttrs : public AttrsNodeReflAdapter<LayerNormAttrs> {
+struct LayerNormAttrs : public BaseAttrsNode {
ffi::Array<Integer> axes;
double epsilon;
bool center;
@@ -624,7 +624,7 @@ struct LayerNormAttrs : public
AttrsNodeReflAdapter<LayerNormAttrs> {
}; // struct LayerNormAttrs
/*! \brief Attributes used in group_norm operator */
-struct GroupNormAttrs : public AttrsNodeReflAdapter<GroupNormAttrs> {
+struct GroupNormAttrs : public BaseAttrsNode {
int num_groups;
int channel_axis;
ffi::Array<Integer> axes;
@@ -653,7 +653,7 @@ struct GroupNormAttrs : public
AttrsNodeReflAdapter<GroupNormAttrs> {
}; // struct GroupNormAttrs
/*! \brief Attributes used in instance_norm operator */
-struct InstanceNormAttrs : public AttrsNodeReflAdapter<InstanceNormAttrs> {
+struct InstanceNormAttrs : public BaseAttrsNode {
int channel_axis;
ffi::Array<Integer> axes;
double epsilon;
@@ -679,7 +679,7 @@ struct InstanceNormAttrs : public
AttrsNodeReflAdapter<InstanceNormAttrs> {
}; // struct InstanceNormAttrs
/*! \brief Attributes used in rms_norm operator */
-struct RMSNormAttrs : public AttrsNodeReflAdapter<RMSNormAttrs> {
+struct RMSNormAttrs : public BaseAttrsNode {
ffi::Array<Integer> axes;
double epsilon;
@@ -695,7 +695,7 @@ struct RMSNormAttrs : public
AttrsNodeReflAdapter<RMSNormAttrs> {
}; // struct RMSNormAttrs
/*! \brief Attributes used in nll_loss operator */
-struct NLLLossAttrs : public AttrsNodeReflAdapter<NLLLossAttrs> {
+struct NLLLossAttrs : public BaseAttrsNode {
ffi::String reduction;
int ignore_index;
@@ -712,7 +712,7 @@ struct NLLLossAttrs : public
AttrsNodeReflAdapter<NLLLossAttrs> {
}; // struct NLLLossAttrs
/*! \brief Attributes used in dropout operator */
-struct DropoutAttrs : public AttrsNodeReflAdapter<DropoutAttrs> {
+struct DropoutAttrs : public BaseAttrsNode {
double rate;
static void RegisterReflection() {
@@ -725,7 +725,7 @@ struct DropoutAttrs : public
AttrsNodeReflAdapter<DropoutAttrs> {
}; // struct DropoutAttrs
/*! \brief Attributes used in Attention operator */
-struct AttentionAttrs : public AttrsNodeReflAdapter<AttentionAttrs> {
+struct AttentionAttrs : public BaseAttrsNode {
ffi::Optional<FloatImm> scale;
ffi::Optional<ffi::String> causal_mask;
ffi::Optional<IntImm> window_size;
@@ -745,7 +745,7 @@ struct AttentionAttrs : public
AttrsNodeReflAdapter<AttentionAttrs> {
}; // struct AttentionAttrs
/*! \brief Attributes used for the padding operator */
-struct PadAttrs : public AttrsNodeReflAdapter<PadAttrs> {
+struct PadAttrs : public BaseAttrsNode {
ffi::Array<Integer> pad_width;
double pad_value = 0.0;
tvm::ffi::String pad_mode;
@@ -768,7 +768,7 @@ struct PadAttrs : public AttrsNodeReflAdapter<PadAttrs> {
};
/*! \brief Attributes used for the pixel shuffle operator */
-struct PixelShuffleAttrs : public AttrsNodeReflAdapter<PixelShuffleAttrs> {
+struct PixelShuffleAttrs : public BaseAttrsNode {
int upscale_factor;
static void RegisterReflection() {
diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h
index 54640901ff..79e00d590a 100644
--- a/include/tvm/relax/attrs/op.h
+++ b/include/tvm/relax/attrs/op.h
@@ -31,7 +31,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in call_tir_with_grad */
-struct CallTIRWithGradAttrs : public
AttrsNodeReflAdapter<CallTIRWithGradAttrs> {
+struct CallTIRWithGradAttrs : public BaseAttrsNode {
ffi::String te_grad_name;
ffi::Map<ffi::String, Any> te_grad_kwargs;
@@ -49,7 +49,7 @@ struct CallTIRWithGradAttrs : public
AttrsNodeReflAdapter<CallTIRWithGradAttrs>
}; // struct CallTIRAttrs
/*! \brief Attributes used in call_tir_inplace */
-struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter<CallTIRInplaceAttrs> {
+struct CallTIRInplaceAttrs : public BaseAttrsNode {
/*!
* \brief Indices that describe which input corresponds to which output.
*
@@ -69,7 +69,7 @@ struct CallTIRInplaceAttrs : public
AttrsNodeReflAdapter<CallTIRInplaceAttrs> {
}; // struct CallTIRInplaceAttrs
/*! \brief Attributes used in call_inplace_packed */
-struct CallInplacePackedAttrs : public
AttrsNodeReflAdapter<CallInplacePackedAttrs> {
+struct CallInplacePackedAttrs : public BaseAttrsNode {
/*!
* \brief Indices that describe which input corresponds to which output.
*
@@ -89,7 +89,7 @@ struct CallInplacePackedAttrs : public
AttrsNodeReflAdapter<CallInplacePackedAtt
}; // struct CallInplacePackedAttrs
/*! \brief Attributes used in to_vdevice */
-struct ToVDeviceAttrs : public AttrsNodeReflAdapter<ToVDeviceAttrs> {
+struct ToVDeviceAttrs : public BaseAttrsNode {
VDevice dst_vdevice;
static void RegisterReflection() {
@@ -101,7 +101,7 @@ struct ToVDeviceAttrs : public
AttrsNodeReflAdapter<ToVDeviceAttrs> {
}; // struct ToVDeviceAttrs
/*! \brief Attributes used in hint_on_device */
-struct HintOnDeviceAttrs : public AttrsNodeReflAdapter<HintOnDeviceAttrs> {
+struct HintOnDeviceAttrs : public BaseAttrsNode {
int32_t device_type;
int32_t index;
MemoryScope memory_scope;
diff --git a/include/tvm/relax/attrs/qdq.h b/include/tvm/relax/attrs/qdq.h
index ffb554994f..08bc054dc5 100644
--- a/include/tvm/relax/attrs/qdq.h
+++ b/include/tvm/relax/attrs/qdq.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes for relax.quantize/relax.dequantize operator */
-struct QuantizeAttrs : public AttrsNodeReflAdapter<QuantizeAttrs> {
+struct QuantizeAttrs : public BaseAttrsNode {
DataType out_dtype;
int axis;
diff --git a/include/tvm/relax/attrs/sampling.h
b/include/tvm/relax/attrs/sampling.h
index 53fd3a1404..2d7421cc20 100644
--- a/include/tvm/relax/attrs/sampling.h
+++ b/include/tvm/relax/attrs/sampling.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in multinomial_from_uniform operator */
-struct MultinomialFromUniformAttrs : public
AttrsNodeReflAdapter<MultinomialFromUniformAttrs> {
+struct MultinomialFromUniformAttrs : public BaseAttrsNode {
DataType dtype;
static void RegisterReflection() {
diff --git a/include/tvm/relax/attrs/search.h b/include/tvm/relax/attrs/search.h
index 32327c160d..015e5d8edc 100644
--- a/include/tvm/relax/attrs/search.h
+++ b/include/tvm/relax/attrs/search.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes for search operators */
-struct ArgmaxArgminAttrs : public AttrsNodeReflAdapter<ArgmaxArgminAttrs> {
+struct ArgmaxArgminAttrs : public BaseAttrsNode {
ffi::Optional<int64_t> axis;
bool keepdims;
@@ -49,7 +49,7 @@ struct ArgmaxArgminAttrs : public
AttrsNodeReflAdapter<ArgmaxArgminAttrs> {
}; // struct ArgmaxArgminAttrs
/*! \brief Attributes for bucketize operator */
-struct BucketizeAttrs : public tvm::AttrsNodeReflAdapter<BucketizeAttrs> {
+struct BucketizeAttrs : public tvm::BaseAttrsNode {
bool out_int32;
bool right;
diff --git a/include/tvm/relax/attrs/sorting.h
b/include/tvm/relax/attrs/sorting.h
index 354b770472..e32d47239f 100644
--- a/include/tvm/relax/attrs/sorting.h
+++ b/include/tvm/relax/attrs/sorting.h
@@ -31,7 +31,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in sort operator */
-struct SortAttrs : public AttrsNodeReflAdapter<SortAttrs> {
+struct SortAttrs : public BaseAttrsNode {
int axis;
bool descending;
@@ -51,7 +51,7 @@ struct SortAttrs : public AttrsNodeReflAdapter<SortAttrs> {
}; // struct SortAttrs
/*! \brief Attributes used in argsort operator */
-struct ArgsortAttrs : public AttrsNodeReflAdapter<ArgsortAttrs> {
+struct ArgsortAttrs : public BaseAttrsNode {
int axis;
bool descending;
DataType dtype;
@@ -68,13 +68,13 @@ struct ArgsortAttrs : public
AttrsNodeReflAdapter<ArgsortAttrs> {
"If it is not specified, it defaults to the ascending order.",
refl::DefaultValue(false))
.def_ro("dtype", &ArgsortAttrs::dtype, "DType of the output indices.",
- refl::DefaultValue(NullValue<DataType>()));
+ refl::DefaultValue(DataType::Void()));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ArgsortAttrs", ArgsortAttrs,
BaseAttrsNode);
}; // struct ArgsortAttrs
/*! \brief Attributes used in topk operator */
-struct TopKAttrs : public AttrsNodeReflAdapter<TopKAttrs> {
+struct TopKAttrs : public BaseAttrsNode {
int k;
int axis;
bool largest;
@@ -98,7 +98,7 @@ struct TopKAttrs : public AttrsNodeReflAdapter<TopKAttrs> {
"By default, return the largest k elements.",
refl::DefaultValue(true))
.def_ro("dtype", &TopKAttrs::dtype, "Data type of the output indices.",
- refl::DefaultValue(NullValue<DataType>()));
+ refl::DefaultValue(DataType::Void()));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.TopKAttrs", TopKAttrs,
BaseAttrsNode);
}; // struct TopKAttrs
diff --git a/include/tvm/relax/attrs/statistical.h
b/include/tvm/relax/attrs/statistical.h
index 433524116d..367869f1ab 100644
--- a/include/tvm/relax/attrs/statistical.h
+++ b/include/tvm/relax/attrs/statistical.h
@@ -30,7 +30,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes for statistical operators */
-struct StatisticalAttrs : public AttrsNodeReflAdapter<StatisticalAttrs> {
+struct StatisticalAttrs : public BaseAttrsNode {
ffi::Optional<ffi::Array<Integer>> axis;
bool keepdims;
@@ -49,7 +49,7 @@ struct StatisticalAttrs : public
AttrsNodeReflAdapter<StatisticalAttrs> {
}; // struct StatisticalAttrs
/*! \brief Attributes used in scan operators like cumsum, cumprod */
-struct ScanopAttrs : public AttrsNodeReflAdapter<ScanopAttrs> {
+struct ScanopAttrs : public BaseAttrsNode {
ffi::Optional<int64_t> axis;
DataType dtype;
Bool exclusive = Bool(false);
diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
index 55ed162674..37ec77cbbf 100644
--- a/include/tvm/relax/attrs/vision.h
+++ b/include/tvm/relax/attrs/vision.h
@@ -32,8 +32,7 @@ namespace tvm {
namespace relax {
/*! \brief Attributes used in AllClassNonMaximumSuppression operator */
-struct AllClassNonMaximumSuppressionAttrs
- : public AttrsNodeReflAdapter<AllClassNonMaximumSuppressionAttrs> {
+struct AllClassNonMaximumSuppressionAttrs : public BaseAttrsNode {
ffi::String output_format;
static void RegisterReflection() {
@@ -48,7 +47,7 @@ struct AllClassNonMaximumSuppressionAttrs
}; // struct AllClassNonMaximumSuppressionAttrs
/*! \brief Attributes used in ROIAlign operator */
-struct ROIAlignAttrs : public AttrsNodeReflAdapter<ROIAlignAttrs> {
+struct ROIAlignAttrs : public BaseAttrsNode {
ffi::Array<int64_t> pooled_size;
double spatial_scale;
int sample_ratio;
@@ -73,7 +72,7 @@ struct ROIAlignAttrs : public
AttrsNodeReflAdapter<ROIAlignAttrs> {
}; // struct ROIAlignAttrs
/*! \brief Attributes used in ROIPool operator */
-struct ROIPoolAttrs : public AttrsNodeReflAdapter<ROIPoolAttrs> {
+struct ROIPoolAttrs : public BaseAttrsNode {
ffi::Array<int64_t> pooled_size;
double spatial_scale;
ffi::String layout;
@@ -90,7 +89,7 @@ struct ROIPoolAttrs : public
AttrsNodeReflAdapter<ROIPoolAttrs> {
}; // struct ROIPoolAttrs
/*! \brief Attributes used in GetValidCounts operator */
-struct GetValidCountsAttrs : public AttrsNodeReflAdapter<GetValidCountsAttrs> {
+struct GetValidCountsAttrs : public BaseAttrsNode {
double score_threshold;
int id_index;
int score_index;
@@ -110,7 +109,7 @@ struct GetValidCountsAttrs : public
AttrsNodeReflAdapter<GetValidCountsAttrs> {
}; // struct GetValidCountsAttrs
/*! \brief Attributes used in NonMaximumSuppression operator */
-struct NonMaximumSuppressionAttrs : public
AttrsNodeReflAdapter<NonMaximumSuppressionAttrs> {
+struct NonMaximumSuppressionAttrs : public BaseAttrsNode {
int max_output_size;
double iou_threshold;
bool force_suppress;
@@ -154,7 +153,7 @@ struct NonMaximumSuppressionAttrs : public
AttrsNodeReflAdapter<NonMaximumSuppre
}; // struct NonMaximumSuppressionAttrs
/*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box
decode). */
-struct MultiboxTransformLocAttrs : public
AttrsNodeReflAdapter<MultiboxTransformLocAttrs> {
+struct MultiboxTransformLocAttrs : public BaseAttrsNode {
bool clip;
double threshold;
ffi::Array<double> variances;
diff --git a/include/tvm/target/virtual_device.h
b/include/tvm/target/virtual_device.h
index 5ff282adb6..79475262c4 100644
--- a/include/tvm/target/virtual_device.h
+++ b/include/tvm/target/virtual_device.h
@@ -169,7 +169,7 @@ constexpr int kInvalidDeviceType = -1;
* These operations are needed during device planning.
*/
-class VirtualDeviceNode : public AttrsNodeReflAdapter<VirtualDeviceNode> {
+class VirtualDeviceNode : public BaseAttrsNode {
private:
/*!
* \brief The \p DLDeviceType (represented as an int) of the virtual device.
If \p target is
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index 3ce70fc545..21fd7b565c 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -441,7 +441,7 @@ def flip(data, axis):
The input data to the operator.
axis: int
- axis to flip on
+ The axis along which to flip over.
Returns
-------
diff --git a/python/tvm/s_tir/transform/transform.py
b/python/tvm/s_tir/transform/transform.py
index e8d14171b3..af4ec493cc 100644
--- a/python/tvm/s_tir/transform/transform.py
+++ b/python/tvm/s_tir/transform/transform.py
@@ -18,7 +18,6 @@
# pylint: disable=invalid-name, unsupported-binary-operation
from ... import ffi as _ffi
-from ... import ir as _ir
from . import _ffi_api
@@ -213,7 +212,7 @@ def AnnotateIrregularLoop():
@_ffi.register_object("s_tir.transform.LoopPartitionConfig")
-class LoopPartitionConfig(_ir.Attrs):
+class LoopPartitionConfig(_ffi.Object):
"""Config for loop partition pass"""
@@ -240,7 +239,7 @@ def InjectVirtualThread():
@_ffi.register_object("s_tir.transform.InjectDoubleBufferConfig")
-class InjectDoubleBufferConfig(_ir.Attrs):
+class InjectDoubleBufferConfig(_ffi.Object):
"""Config for inject double buffer pass"""
diff --git a/python/tvm/tirx/transform/transform.py
b/python/tvm/tirx/transform/transform.py
index 8082d864c1..fbf07b5f48 100644
--- a/python/tvm/tirx/transform/transform.py
+++ b/python/tvm/tirx/transform/transform.py
@@ -21,7 +21,6 @@ import enum
from collections.abc import Callable
from ... import ffi as _ffi
-from ... import ir as _ir
from . import _ffi_api
from . import function_pass as _fpass
@@ -107,7 +106,7 @@ def PointerValueTypeRewrite():
@_ffi.register_object("tirx.transform.UnrollLoopConfig")
-class UnrollLoopConfig(_ir.Attrs):
+class UnrollLoopConfig(_ffi.Object):
"""Config for unroll loop pass"""
@@ -125,7 +124,7 @@ def UnrollLoop():
@_ffi.register_object("tirx.transform.RemoveNoOpConfig")
-class RemoveNoOpConfig(_ir.Attrs):
+class RemoveNoOpConfig(_ffi.Object):
"""Config for remove no op pass"""
@@ -212,7 +211,7 @@ def CommonSubexprElim():
@_ffi.register_object("tirx.transform.SimplifyConfig")
-class SimplifyConfig(_ir.Attrs):
+class SimplifyConfig(_ffi.Object):
"""Config for simplify pass"""
@@ -429,7 +428,7 @@ def VerifyMemory():
@_ffi.register_object("s_tir.transform.HoistIfThenElseConfig")
-class HoistIfThenElseConfig(_ir.Attrs):
+class HoistIfThenElseConfig(_ffi.Object):
"""Config for hoist if then else pass"""
@@ -483,7 +482,7 @@ class HoistedLetBindings(enum.Flag):
@_ffi.register_object("s_tir.transform.HoistExpressionConfig")
-class HoistExpressionConfig(_ir.Attrs):
+class HoistExpressionConfig(_ffi.Object):
"""Config for hoist expression pass"""
diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc
index cfe269e4eb..e7d9b90828 100644
--- a/src/ir/attrs.cc
+++ b/src/ir/attrs.cc
@@ -53,14 +53,6 @@ DictAttrs WithoutAttr(DictAttrs attrs, const std::string&
key) {
return attrs;
}
-void DictAttrsNode::InitByPackedArgs(const ffi::PackedArgs& args, bool
allow_unknown) {
- for (int i = 0; i < args.size(); i += 2) {
- ffi::String key = args[i].cast<ffi::String>();
- ffi::AnyView val = args[i + 1];
- dict.Set(key, val);
- }
-}
-
DictAttrs::DictAttrs(ffi::Map<ffi::String, Any> dict) {
ffi::ObjectPtr<DictAttrsNode> n = ffi::make_object<DictAttrsNode>();
n->dict = std::move(dict);
diff --git a/src/relax/backend/contrib/clml/codegen.cc
b/src/relax/backend/contrib/clml/codegen.cc
index eaa57f8315..dd71e8a68a 100644
--- a/src/relax/backend/contrib/clml/codegen.cc
+++ b/src/relax/backend/contrib/clml/codegen.cc
@@ -41,7 +41,7 @@ namespace relax {
namespace contrib {
/*! \brief Attributes to store the compiler options for OpenCLML. */
-struct OpenCLMLCompilerConfigNode : public
AttrsNodeReflAdapter<OpenCLMLCompilerConfigNode> {
+struct OpenCLMLCompilerConfigNode : public ffi::Object {
Integer clml_version;
static void RegisterReflection() {
@@ -51,12 +51,12 @@ struct OpenCLMLCompilerConfigNode : public
AttrsNodeReflAdapter<OpenCLMLCompiler
"OpenCLML version as (major, minor, patch).",
refl::DefaultValue(Integer(3)));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ext.attrs.OpenCLMLCompilerConfig",
- OpenCLMLCompilerConfigNode, BaseAttrsNode);
+ OpenCLMLCompilerConfigNode, ffi::Object);
};
-class OpenCLMLCompilerConfig : public Attrs {
+class OpenCLMLCompilerConfig : public ffi::ObjectRef {
public:
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(OpenCLMLCompilerConfig, Attrs,
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(OpenCLMLCompilerConfig,
ffi::ObjectRef,
OpenCLMLCompilerConfigNode);
};
diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc
b/src/relax/backend/contrib/tensorrt/codegen.cc
index 2be214ed94..38b2dc405f 100644
--- a/src/relax/backend/contrib/tensorrt/codegen.cc
+++ b/src/relax/backend/contrib/tensorrt/codegen.cc
@@ -46,7 +46,7 @@ namespace relax {
namespace contrib {
/*! \brief Attributes to store the compiler options for TensorRT. */
-struct TensorRTCompilerConfigNode : public
AttrsNodeReflAdapter<TensorRTCompilerConfigNode> {
+struct TensorRTCompilerConfigNode : public ffi::Object {
ffi::Array<Integer> tensorrt_version;
bool use_implicit_batch;
size_t max_workspace_size;
@@ -72,12 +72,12 @@ struct TensorRTCompilerConfigNode : public
AttrsNodeReflAdapter<TensorRTCompiler
refl::DefaultValue(false));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ext.attrs.TensorRTCompilerConfig",
- TensorRTCompilerConfigNode, BaseAttrsNode);
+ TensorRTCompilerConfigNode, ffi::Object);
};
-class TensorRTCompilerConfig : public Attrs {
+class TensorRTCompilerConfig : public ffi::ObjectRef {
public:
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TensorRTCompilerConfig, Attrs,
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TensorRTCompilerConfig,
ffi::ObjectRef,
TensorRTCompilerConfigNode);
};
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index 461faf3fba..f6fc45deaa 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2025,9 +2025,9 @@ TVM_REGISTER_OP("relax.tile")
/* relax.flip */
-Expr flip(Expr data, Integer axis) {
+Expr flip(Expr data, int64_t axis) {
auto attrs = ffi::make_object<FlipAttrs>();
- attrs->axis = std::move(axis);
+ attrs->axis = axis;
static const Op& op = Op::Get("relax.flip");
return Call(op, {std::move(data)}, Attrs{attrs}, {});
}
@@ -2043,7 +2043,7 @@ StructInfo InferStructInfoFlip(const Call& call, const
BlockBuilder& ctx) {
}
TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
const auto* attrs = call->attrs.as<FlipAttrs>();
- int axis = attrs->axis.IntValue();
+ int axis = static_cast<int>(attrs->axis);
if (!data_sinfo->IsUnknownNdim()) {
int ndim = data_sinfo->ndim;
if (axis < -ndim || axis >= ndim) {
@@ -2073,7 +2073,7 @@ InferLayoutOutput InferLayoutFlip(
existing_layout = LayoutDecision(InitialLayout(ndim));
}
- int axis = attrs->axis.IntValue();
+ int axis = static_cast<int>(attrs->axis);
if (axis < 0) {
axis += ndim;
}
@@ -2082,7 +2082,7 @@ InferLayoutOutput InferLayoutFlip(
TVM_FFI_ICHECK_GE(new_axis, 0) << "Failed to find transformed axis";
ffi::ObjectPtr<FlipAttrs> new_attrs = ffi::make_object<FlipAttrs>(*attrs);
- new_attrs->axis = Integer(new_axis);
+ new_attrs->axis = static_cast<int64_t>(new_axis);
return InferLayoutOutput({existing_layout}, {existing_layout},
Attrs(new_attrs));
}
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 260d27f1ef..a6efffff46 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -179,7 +179,7 @@ Expr tile(Expr data, ffi::Array<Integer> repeats);
* \param axis The axis to flip on
* \return The computed result.
*/
-Expr flip(Expr data, Integer axis);
+Expr flip(Expr data, int64_t axis);
/*!
* \brief Gather elements from a tensor using indices.
diff --git a/src/s_tir/schedule/concrete_schedule.cc
b/src/s_tir/schedule/concrete_schedule.cc
index 21f5454040..94402c44d7 100644
--- a/src/s_tir/schedule/concrete_schedule.cc
+++ b/src/s_tir/schedule/concrete_schedule.cc
@@ -35,7 +35,7 @@ Schedule Schedule::Concrete(IRModule mod,
LinearCongruentialEngine::TRandState s
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
n->Seed(seed);
- GlobalVar gv = NullValue<GlobalVar>();
+ GlobalVar gv;
if (FindEntryFunc(mod, &gv) != nullptr) {
n->func_working_on_ = gv;
} else {
@@ -316,7 +316,7 @@ SBlockRV ConcreteScheduleNode::GetSBlock(const ffi::String&
name,
IRModule mod_;
ffi::Array<SBlock> blocks_;
};
- GlobalVar gv = NullValue<GlobalVar>();
+ GlobalVar gv;
if (func_name.has_value()) {
gv = state_->mod->GetGlobalVar(func_name.value());
} else if (func_working_on_.has_value()) {
diff --git a/src/s_tir/schedule/traced_schedule.cc
b/src/s_tir/schedule/traced_schedule.cc
index e12cdd69de..6357e1ae19 100644
--- a/src/s_tir/schedule/traced_schedule.cc
+++ b/src/s_tir/schedule/traced_schedule.cc
@@ -31,7 +31,7 @@ Schedule Schedule::Traced(IRModule mod,
LinearCongruentialEngine::TRandState see
n->analyzer_ = std::make_unique<arith::Analyzer>();
n->trace_ = Trace();
n->Seed(seed);
- GlobalVar gv = NullValue<GlobalVar>();
+ GlobalVar gv;
if (FindEntryFunc(mod, &gv) != nullptr) {
n->func_working_on_ = gv;
} else {
@@ -118,7 +118,7 @@ LoopRV TracedScheduleNode::SampleComputeLocation(const
SBlockRV& block_rv,
SBlockRV TracedScheduleNode::GetSBlock(const ffi::String& name,
const ffi::Optional<ffi::String>&
func_name) {
- GlobalVar gv = NullValue<GlobalVar>();
+ GlobalVar gv;
if (func_name.has_value()) {
gv = state_->mod->GetGlobalVar(func_name.value());
} else if (func_working_on_.defined()) {
diff --git a/src/s_tir/transform/hoist_expression.cc
b/src/s_tir/transform/hoist_expression.cc
index ac3987b6a0..dbe389e84a 100644
--- a/src/s_tir/transform/hoist_expression.cc
+++ b/src/s_tir/transform/hoist_expression.cc
@@ -58,7 +58,7 @@ enum class HoistedLetBindings : int {
kLetExpr = (1 << 2),
};
-struct HoistExpressionConfigNode : public
AttrsNodeReflAdapter<HoistExpressionConfigNode> {
+struct HoistExpressionConfigNode : public ffi::Object {
int hoisted_conditionals;
int hoisted_let_bindings;
@@ -87,7 +87,7 @@ struct HoistExpressionConfigNode : public
AttrsNodeReflAdapter<HoistExpressionCo
HoistExpressionConfigNode, ffi::Object);
};
-class HoistExpressionConfig : public Attrs {
+class HoistExpressionConfig : public ffi::ObjectRef {
public:
HoistExpressionConfig(int hoisted_conditionals, int hoisted_let_bindings) {
auto node = ffi::make_object<HoistExpressionConfigNode>();
@@ -95,7 +95,7 @@ class HoistExpressionConfig : public Attrs {
node->hoisted_let_bindings = hoisted_let_bindings;
data_ = std::move(node);
}
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(HoistExpressionConfig, Attrs,
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(HoistExpressionConfig,
ffi::ObjectRef,
HoistExpressionConfigNode);
};
@@ -103,7 +103,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
HoistExpressionConfigNode::RegisterReflection(); }
TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.HoistExpression",
HoistExpressionConfig);
-struct HoistIfThenElseConfigNode : public
AttrsNodeReflAdapter<HoistIfThenElseConfigNode> {
+struct HoistIfThenElseConfigNode : public ffi::Object {
bool support_block_scope_hoisting;
static void RegisterReflection() {
@@ -116,9 +116,9 @@ struct HoistIfThenElseConfigNode : public
AttrsNodeReflAdapter<HoistIfThenElseCo
HoistIfThenElseConfigNode, ffi::Object);
};
-class HoistIfThenElseConfig : public Attrs {
+class HoistIfThenElseConfig : public ffi::ObjectRef {
public:
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(HoistIfThenElseConfig, Attrs,
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(HoistIfThenElseConfig,
ffi::ObjectRef,
HoistIfThenElseConfigNode);
};
diff --git a/src/s_tir/transform/inject_double_buffer.cc
b/src/s_tir/transform/inject_double_buffer.cc
index b476f0dca6..0c934ddbcd 100644
--- a/src/s_tir/transform/inject_double_buffer.cc
+++ b/src/s_tir/transform/inject_double_buffer.cc
@@ -36,7 +36,7 @@ namespace tvm {
namespace s_tir {
using namespace tvm::tirx;
-struct InjectDoubleBufferConfigNode : public
AttrsNodeReflAdapter<InjectDoubleBufferConfigNode> {
+struct InjectDoubleBufferConfigNode : public ffi::Object {
int split_loop;
static void RegisterReflection() {
@@ -46,12 +46,12 @@ struct InjectDoubleBufferConfigNode : public
AttrsNodeReflAdapter<InjectDoubleBu
refl::DefaultValue(1));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.transform.InjectDoubleBufferConfig",
- InjectDoubleBufferConfigNode,
BaseAttrsNode);
+ InjectDoubleBufferConfigNode, ffi::Object);
};
-class InjectDoubleBufferConfig : public Attrs {
+class InjectDoubleBufferConfig : public ffi::ObjectRef {
public:
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(InjectDoubleBufferConfig,
Attrs,
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(InjectDoubleBufferConfig,
ffi::ObjectRef,
InjectDoubleBufferConfigNode);
};
diff --git a/src/s_tir/transform/loop_partition.cc
b/src/s_tir/transform/loop_partition.cc
index e5f03c29f5..bf2dca776c 100644
--- a/src/s_tir/transform/loop_partition.cc
+++ b/src/s_tir/transform/loop_partition.cc
@@ -45,7 +45,7 @@ namespace tvm {
namespace s_tir {
using namespace tvm::tirx;
-struct LoopPartitionConfigNode : public
AttrsNodeReflAdapter<LoopPartitionConfigNode> {
+struct LoopPartitionConfigNode : public ffi::Object {
bool partition_const_loop;
bool no_unroll_loop_with_extent_one;
bool unroll_loop_with_partition_hint_no_interval;
@@ -64,14 +64,14 @@ struct LoopPartitionConfigNode : public
AttrsNodeReflAdapter<LoopPartitionConfig
refl::DefaultValue(false));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.transform.LoopPartitionConfig",
LoopPartitionConfigNode,
- BaseAttrsNode);
+ ffi::Object);
};
TVM_FFI_STATIC_INIT_BLOCK() { LoopPartitionConfigNode::RegisterReflection(); }
-class LoopPartitionConfig : public Attrs {
+class LoopPartitionConfig : public ffi::ObjectRef {
public:
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LoopPartitionConfig, Attrs,
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(LoopPartitionConfig,
ffi::ObjectRef,
LoopPartitionConfigNode);
};
diff --git a/src/s_tir/transform/lower_cross_thread_reduction.cc
b/src/s_tir/transform/lower_cross_thread_reduction.cc
index 5bb0f6b767..ba7dd69625 100644
--- a/src/s_tir/transform/lower_cross_thread_reduction.cc
+++ b/src/s_tir/transform/lower_cross_thread_reduction.cc
@@ -881,7 +881,7 @@ class CrossThreadReductionTransformer : public StmtMutator {
/*kind=*/ForKind::kThreadBinding, //
/*body=*/body, //
/*thread_binding=*/
- IterVar(NullValue<Range>(), Var("", loop_vars[i]->dtype),
IterVarType::kThreadIndex,
+ IterVar(Range(), Var("", loop_vars[i]->dtype),
IterVarType::kThreadIndex,
"threadIdx." + dim_index),
/*annotations=*/{},
/*step=*/std::nullopt);
diff --git a/src/s_tir/transform/storage_access.h
b/src/s_tir/transform/storage_access.h
index 2aa3850774..d85dc5a3c3 100644
--- a/src/s_tir/transform/storage_access.h
+++ b/src/s_tir/transform/storage_access.h
@@ -59,7 +59,7 @@ class StorageAccessVisitor : public StmtExprVisitor {
/*! \brief The thread index that access this entry */
ffi::Array<IterVar> threads;
/*! \brief The buffer variable, if any */
- Var buffer = NullValue<Var>();
+ Var buffer = Var(ffi::ObjectPtr<VarNode>(nullptr));
/*! \brief The access data type */
DataType dtype;
/*! \brief The touched access range
diff --git a/src/s_tir/transform/unify_thread_binding.cc
b/src/s_tir/transform/unify_thread_binding.cc
index 3ee465223a..85333b6efc 100644
--- a/src/s_tir/transform/unify_thread_binding.cc
+++ b/src/s_tir/transform/unify_thread_binding.cc
@@ -159,8 +159,7 @@ class ThreadBindingUnifier : public StmtExprMutator {
// necessary for unit tests.
result = For(thread_binding->var, thread_binding->dom->min,
thread_binding->dom->extent,
ForKind::kThreadBinding, result,
- IterVar(NullValue<Range>(), Var(""),
IterVarType::kThreadIndex,
- thread_binding->thread_tag),
+ IterVar(Range(), Var(""), IterVarType::kThreadIndex,
thread_binding->thread_tag),
{}, std::nullopt);
launch_threads_.pop_back();
}
diff --git a/src/tirx/analysis/stmt_finding.cc
b/src/tirx/analysis/stmt_finding.cc
index 0ba6146213..6dc3d07b4f 100644
--- a/src/tirx/analysis/stmt_finding.cc
+++ b/src/tirx/analysis/stmt_finding.cc
@@ -24,7 +24,7 @@ namespace tvm {
namespace tirx {
const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar*
result_g_var) {
- GlobalVar result = NullValue<GlobalVar>();
+ GlobalVar result;
// Priority 1: PrimFunc marked as `tirx::attr::kIsEntryFunc`
int num_prim_func = 0;
const tirx::PrimFuncNode* main_func = nullptr;
diff --git a/src/tirx/script/builder/frame.cc b/src/tirx/script/builder/frame.cc
index 5e971d7361..e57b794cf3 100644
--- a/src/tirx/script/builder/frame.cc
+++ b/src/tirx/script/builder/frame.cc
@@ -145,7 +145,7 @@ void PrimFuncFrameNode::ExitWithScope() {
/*body=*/body,
/*ret_type=*/ret_type.value_or(TupleType::Empty()),
/*buffer_map=*/effective_buffer_map,
- /*attrs=*/attrs.defined() ? DictAttrs(attrs) : NullValue<DictAttrs>(),
+ /*attrs=*/attrs.defined() ? DictAttrs(attrs) : DictAttrs(),
/*span=*/tvm::Span());
func = tvm::tirx::ScriptComplete(func, effective_root_alloc_buffers, s_tir);
IRBuilder builder = IRBuilder::Current();
diff --git a/src/tirx/transform/remove_no_op.cc
b/src/tirx/transform/remove_no_op.cc
index fcc7519334..4bdb5c083c 100644
--- a/src/tirx/transform/remove_no_op.cc
+++ b/src/tirx/transform/remove_no_op.cc
@@ -44,7 +44,7 @@
namespace tvm {
namespace tirx {
-struct RemoveNoOpConfigNode : public
AttrsNodeReflAdapter<RemoveNoOpConfigNode> {
+struct RemoveNoOpConfigNode : public ffi::Object {
bool use_dataflow_analysis;
int64_t max_simplification_steps;
bool ignore_profiler_call;
@@ -65,12 +65,13 @@ struct RemoveNoOpConfigNode : public
AttrsNodeReflAdapter<RemoveNoOpConfigNode>
"If true, profiler calls are rendered as no-ops.",
refl::DefaultValue(false));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.RemoveNoOpConfig",
RemoveNoOpConfigNode,
- BaseAttrsNode);
+ ffi::Object);
};
-class RemoveNoOpConfig : public Attrs {
+class RemoveNoOpConfig : public ffi::ObjectRef {
public:
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RemoveNoOpConfig, Attrs,
RemoveNoOpConfigNode);
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(RemoveNoOpConfig,
ffi::ObjectRef,
+ RemoveNoOpConfigNode);
};
TVM_FFI_STATIC_INIT_BLOCK() { RemoveNoOpConfigNode::RegisterReflection(); }
diff --git a/src/tirx/transform/simplify.cc b/src/tirx/transform/simplify.cc
index f193fb502d..bf80ad00a4 100644
--- a/src/tirx/transform/simplify.cc
+++ b/src/tirx/transform/simplify.cc
@@ -44,7 +44,7 @@ namespace arith {
using namespace tirx;
-struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
+struct SimplifyConfigNode : public ffi::Object {
bool transitively_prove_inequalities;
bool propagate_knowns_to_prove_conditional;
bool propagate_knowns_to_simplify_expressions;
@@ -78,7 +78,7 @@ struct SimplifyConfigNode : public
AttrsNodeReflAdapter<SimplifyConfigNode> {
refl::DefaultValue(false));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.SimplifyConfig",
SimplifyConfigNode,
- BaseAttrsNode);
+ ffi::Object);
RewriteSimplifier::Extension GetEnabledExtensions() const {
RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
@@ -97,11 +97,15 @@ struct SimplifyConfigNode : public
AttrsNodeReflAdapter<SimplifyConfigNode> {
}
};
-class SimplifyConfig : public Attrs {
+class SimplifyConfig : public ffi::ObjectRef {
public:
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, Attrs,
SimplifyConfigNode);
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig,
ffi::ObjectRef, SimplifyConfigNode);
};
+static SimplifyConfig MakeDefaultSimplifyConfig() {
+ return AttrsWithDefaultValues<SimplifyConfig>();
+}
+
TVM_FFI_STATIC_INIT_BLOCK() { SimplifyConfigNode::RegisterReflection(); }
TVM_REGISTER_PASS_CONFIG_OPTION("tirx.Simplify", SimplifyConfig);
@@ -110,7 +114,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
public:
static PrimFunc Apply(PrimFunc func, Analyzer* analyzer,
ffi::Optional<SimplifyConfig> config_opt =
std::nullopt) {
- auto config =
config_opt.value_or(AttrsWithDefaultValues<arith::SimplifyConfig>());
+ auto config = config_opt.value_or(MakeDefaultSimplifyConfig());
analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions());
std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
diff --git a/src/tirx/transform/unroll_loop.cc
b/src/tirx/transform/unroll_loop.cc
index 3aea9ddd04..faf1ec2d67 100644
--- a/src/tirx/transform/unroll_loop.cc
+++ b/src/tirx/transform/unroll_loop.cc
@@ -39,7 +39,7 @@
namespace tvm {
namespace tirx {
-struct UnrollLoopConfigNode : public
AttrsNodeReflAdapter<UnrollLoopConfigNode> {
+struct UnrollLoopConfigNode : public ffi::Object {
int auto_max_step;
int auto_max_depth;
int auto_max_extent;
@@ -64,12 +64,13 @@ struct UnrollLoopConfigNode : public
AttrsNodeReflAdapter<UnrollLoopConfigNode>
"Whether to always unroll local access",
refl::DefaultValue(false));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.UnrollLoopConfig",
UnrollLoopConfigNode,
- BaseAttrsNode);
+ ffi::Object);
};
-class UnrollLoopConfig : public Attrs {
+class UnrollLoopConfig : public ffi::ObjectRef {
public:
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(UnrollLoopConfig, Attrs,
UnrollLoopConfigNode);
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(UnrollLoopConfig,
ffi::ObjectRef,
+ UnrollLoopConfigNode);
};
TVM_FFI_STATIC_INIT_BLOCK() { UnrollLoopConfigNode::RegisterReflection(); }
diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc
index 0743c3db68..e7a1715cc7 100644
--- a/tests/cpp/ir_functor_test.cc
+++ b/tests/cpp/ir_functor_test.cc
@@ -338,7 +338,7 @@ TEST(IRF, Substitute) {
/*dtype=*/DataType::Float(32),
/*shape=*/{n},
/*strides=*/{},
- /*elem_offset=*/NullValue<PrimExpr>(),
+ /*elem_offset=*/PrimExpr(),
/*name=*/"buf",
/*data_alignment=*/1,
/*offset_factor=*/1,