This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch alloc-const in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 5f8eb9574ccb89a78bc7033a76368886a94c1146 Author: tqchen <[email protected]> AuthorDate: Wed Feb 11 18:04:52 2026 -0500 [REFATOR][TIR] Phase out AllocConst This PR phases out alloc const node in the TIR. This node was oroginally introduced to include embedded weights into the allocation. However, the presence of the particular IR couples the data allocation and the weight placement, which is not as desirable especialy when weights get large. A better approach is to have extra annotation on the allocation and store weights separately either outside module or as part of module/function attribute. As a result, we phases out this node which can help us to simplify code logic in the codebase. --- include/tvm/ir/module.h | 10 -- include/tvm/script/ir_builder/tir/frame.h | 53 -------- include/tvm/script/ir_builder/tir/ir.h | 13 -- include/tvm/tir/stmt.h | 75 ----------- include/tvm/tir/stmt_functor.h | 4 - include/tvm/tir/transform.h | 9 -- python/tvm/script/ir_builder/tir/frame.py | 7 -- python/tvm/script/ir_builder/tir/ir.py | 38 ------ python/tvm/tir/__init__.py | 1 - python/tvm/tir/functor.py | 37 ------ python/tvm/tir/stmt.py | 62 --------- python/tvm/tir/transform/transform.py | 11 -- src/s_tir/schedule/primitive/cache_read_write.cc | 13 +- src/s_tir/schedule/transform.cc | 11 +- .../plan_update_buffer_allocation_location.cc | 3 +- src/script/ir_builder/tir/frame.cc | 6 - src/script/ir_builder/tir/ir.cc | 13 -- src/script/printer/tir/stmt.cc | 82 ------------ src/target/llvm/codegen_llvm.cc | 12 -- src/target/llvm/codegen_llvm.h | 1 - src/target/source/codegen_c.cc | 32 ----- src/target/source/codegen_c.h | 1 - src/target/source/codegen_webgpu.cc | 4 - src/target/source/codegen_webgpu.h | 1 - src/te/operation/create_primfunc.cc | 25 +--- src/te/operation/create_primfunc.h | 18 --- src/tir/analysis/estimate_flops.cc | 1 - src/tir/analysis/var_use_def_analysis.cc | 5 - src/tir/analysis/var_use_def_analysis.h | 2 - src/tir/ir/py_functor.cc | 12 -- src/tir/ir/stmt.cc | 65 ---------- src/tir/ir/stmt_functor.cc | 19 --- src/tir/ir/tir_visitor_with_path.cc | 6 - src/tir/ir/tir_visitor_with_path.h | 1 - src/tir/transform/bind_params.cc | 138 -------------------- src/tir/transform/extract_constants.cc | 115 ----------------- src/tir/transform/ir_utils.cc | 5 - src/tir/transform/ir_utils.h | 10 -- .../remove_weight_layout_rewrite_block.cc | 140 --------------------- src/tir/transform/renew_defs.cc | 1 - src/tir/transform/storage_rewrite.cc | 49 +------- tests/python/ir/test_node_reflection.py | 14 --- .../relax/test_meta_schedule_relax_integration.py | 136 +------------------- .../schedule/test_tir_schedule_cache_read_write.py | 115 ----------------- .../s_tir/schedule/test_tir_schedule_compute_at.py | 99 --------------- .../test_tir_transform_extract_constants.py | 68 ---------- tests/python/tvmscript/test_tvmscript_complete.py | 26 ---- .../tvmscript/test_tvmscript_ir_builder_tir.py | 24 ---- tests/python/tvmscript/test_tvmscript_roundtrip.py | 69 ---------- 49 files changed, 17 insertions(+), 1645 deletions(-) diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 2ccddc85dc..17369dbab6 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -316,15 +316,6 @@ namespace attr { */ constexpr const char* kModuleName = "mod_name"; -/* - * \brief All the runtime::Tensors extracted from PrimFunc tir::AllocateConst nodes. The - * node will record the index into this array. See also kConstNameToConstant below, which is - * the analog for Realy Functions. - * - * Type: ffi::Array<runtime::Tensor> - */ -constexpr const char* kConstants = "constants"; - /*! * \brief All the runtime::Modules accumulated during compilation by external codegen. These * modules must be either directly linked or captured in the final compilation artifact. @@ -366,7 +357,6 @@ constexpr const char* kSystemLibPrefix = "system_lib_prefix"; * \brief All the named runtime::Tensors accumulated during compilation by external codegen. * Generally the associated runtime::Module will indicate it requires bindings for these names, * and during module initialization these bindings will be recovered from a ConstLoaderModule. - * See also kConstantsArray above, which is the analog for PrimFuncs. * * Type: ffi::Map<ffi::String, runtime::Tensor> */ diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 1255a67335..224326c0c8 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -526,59 +526,6 @@ class AllocateFrame : public TIRFrame { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AllocateFrame, TIRFrame, AllocateFrameNode); }; -/*! - * \brief A frame represents the allocate constant. - * - * \sa AllocateConstFrame - */ -class AllocateConstFrameNode : public TIRFrameNode { - public: - /*! \brief The data type of the buffer. */ - DataType dtype; - /*! \brief The extents of the allocate. */ - ffi::Array<PrimExpr> extents; - /*! \brief The data associated with the constant. */ - tvm::runtime::Tensor data; - /*! \brief The buffer var */ - tvm::tir::Var buffer_var; - /*! \brief Additional annotations about the allocation. */ - ffi::Map<ffi::String, Any> annotations; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef<AllocateConstFrameNode>() - .def_ro("dtype", &AllocateConstFrameNode::dtype) - .def_ro("extents", &AllocateConstFrameNode::extents) - .def_ro("data", &AllocateConstFrameNode::data) - .def_ro("buffer_var", &AllocateConstFrameNode::buffer_var) - .def_ro("annotations", &AllocateConstFrameNode::annotations); - } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.AllocateConstFrame", - AllocateConstFrameNode, TIRFrameNode); - - public: - /*! - * \brief The method called when exiting RAII scope. - * \sa tvm::support::With - */ - void ExitWithScope() final; -}; - -/*! - * \brief Managed reference to AllocateConstFrameNode. - * - * \sa AllocateConstFrameNode - */ -class AllocateConstFrame : public TIRFrame { - public: - explicit AllocateConstFrame(ObjectPtr<AllocateConstFrameNode> data) - : TIRFrame(ffi::UnsafeInit{}) { - TVM_FFI_ICHECK(data != nullptr); - data_ = std::move(data); - } - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AllocateConstFrame, TIRFrame, - AllocateConstFrameNode); -}; /*! * \brief A frame that represents attribute node. * diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 788eb9615c..62aad84645 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -28,7 +28,6 @@ namespace script { namespace ir_builder { namespace tir { -using tvm::runtime::Tensor; using tvm::tir::Buffer; using tvm::tir::Var; @@ -327,18 +326,6 @@ AllocateFrame Allocate(ffi::Array<PrimExpr> extents, DataType dtype, ffi::String ffi::Optional<PrimExpr> condition = std::nullopt, ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt); -/*! - * \brief The allocate constant node. - * \param data The data associated with the constant. - * \param dtype The data type of the buffer. - * \param extents The extents of the allocate. - * \param annotations Additional annotation hints. - * \return The created AllocateConstFrame. - */ -AllocateConstFrame AllocateConst( - Tensor data, DataType dtype, ffi::Array<PrimExpr> extents, - ffi::Optional<ffi::Map<ffi::String, Any>> annotations = std::nullopt); - /*! * \brief Create an attribute. * \param node The node to annotate the attribute. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index b64dc4beec..86703346cb 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -343,81 +343,6 @@ class Allocate : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateNode); }; -/*! - * \brief Allocate a buffer that can be used in body. - */ -class AllocateConstNode : public StmtNode { - public: - /*! \brief The buffer variable. */ - Var buffer_var; - /*! \brief The optional data associated to the constant. - */ - ffi::Optional<runtime::Tensor> data; - /*! - * \brief If the PrimFunc containing the Stmt is added to IRModule, this is an optional index - * to indicate the index within "constants" attribute, that is a ffi::Array<Tensor> of IRModule. - */ - ffi::Optional<Integer> irmod_storage_idx; - /*! \brief The type of the buffer. */ - DataType dtype; - /*! \brief The extents of the buffer. */ - ffi::Array<PrimExpr> extents; - /*! \brief The body to be executed. */ - Stmt body; - /*! - * \brief Additional annotations about the allocation. - * - * These annotations can be used as auxiliary hint - * to future transformations. - */ - ffi::Map<ffi::String, ffi::Any> annotations; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef<AllocateConstNode>() - .def_ro("buffer_var", &AllocateConstNode::buffer_var, refl::AttachFieldFlag::SEqHashDef()) - .def_ro("data", &AllocateConstNode::data) - .def_ro("irmod_storage_idx", &AllocateConstNode::irmod_storage_idx) - .def_ro("dtype", &AllocateConstNode::dtype) - .def_ro("extents", &AllocateConstNode::extents) - .def_ro("body", &AllocateConstNode::body) - .def_ro("annotations", &AllocateConstNode::annotations); - } - - /*! - * \brief If the buffer size is constant, return the size. - * Otherwise return 0. - * \return The result. - */ - int64_t ConstantAllocationSize() const { return ConstantAllocationSize(extents); } - /*! - * \brief If the buffer size is constant, return the size. - * Otherwise return 0. - * \param extents The extents of the buffer. - * \return The result. - */ - TVM_DLL static int64_t ConstantAllocationSize(const ffi::Array<PrimExpr>& extents); - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.AllocateConst", AllocateConstNode, StmtNode); -}; - -/*! - * \brief Managed reference to AllocateConstNode. - * \sa AllocateConstNode - */ -class AllocateConst : public Stmt { - public: - /* The constructor to create a IRNode with constant data - * depending on the type of ObjectRef, it will either - * create AllocateConstNode with irmod_storage_idx or data - */ - TVM_DLL AllocateConst( - Var buffer_var, DataType dtype, ffi::Array<PrimExpr> extents, ObjectRef data_or_idx, - Stmt body, ffi::Map<ffi::String, ffi::Any> annotations = ffi::Map<ffi::String, ffi::Any>(), - Span span = Span()); - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AllocateConst, Stmt, AllocateConstNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode); -}; - /*! \brief Declare a buffer that can be used in the body */ class DeclBufferNode : public StmtNode { public: diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index d8d06ae1f3..431308f859 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -88,7 +88,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> { virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const DeclBufferNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -112,7 +111,6 @@ class StmtFunctor<R(const Stmt& n, Args... args)> { IR_STMT_FUNCTOR_DISPATCH(ForNode); IR_STMT_FUNCTOR_DISPATCH(WhileNode); IR_STMT_FUNCTOR_DISPATCH(AllocateNode); - IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode); IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); @@ -153,7 +151,6 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> { void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const AllocateNode* op) override; - void VisitStmt_(const AllocateConstNode* op) override; void VisitStmt_(const DeclBufferNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const BufferRealizeNode* op) override; @@ -251,7 +248,6 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> { Stmt VisitStmt_(const ForNode* op) override; Stmt VisitStmt_(const WhileNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; - Stmt VisitStmt_(const AllocateConstNode* op) override; Stmt VisitStmt_(const DeclBufferNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const BufferRealizeNode* op) override; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index bdf8f99aa3..95ef3c18e0 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -443,15 +443,6 @@ TVM_DLL Pass ConvertForLoopsToSerial(); */ TVM_DLL Pass UnifiedStaticMemoryPlanner(); -TVM_DLL Pass BindParams(const ffi::Array<runtime::Tensor>& constants); - -/*! - * \brief Pass to collect tir non-scalar constants into module's 'Constants' attribute. - * - * \return The pass. - */ -TVM_DLL Pass ExtractPrimFuncConstants(); - /*! * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) * \return The pass. diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index ddecd005c8..fd9ae24407 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -74,13 +74,6 @@ class AllocateFrame(TIRFrame): return self.buffer_var -@_register_object("script.ir_builder.tir.AllocateConstFrame") -class AllocateConstFrame(TIRFrame): - def __enter__(self) -> Buffer: - super().__enter__() - return self.buffer_var - - @_register_object("script.ir_builder.tir.AttrFrame") class AttrFrame(TIRFrame): ... diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index bf8a081801..c419592953 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1083,43 +1083,6 @@ def allocate( ) -def allocate_const( - data: List[PrimExpr], - dtype: str, - extents: List[PrimExpr], - annotations=None, -) -> frame.AllocateConstFrame: - """Allocate constant node. - - Parameters - ---------- - data : List[PrimExpr] - The data associated with the constant. - - dtype : str - The data type of the buffer. - - extents : List[PrimExpr] - The extents of the allocate. - - annotations : Optional[Map] - Additional annotations about the allocation. - """ - np_data = np.asarray(data, dtype=dtype) - prod_extent = 1 - for extent in extents: - prod_extent *= extent - prod_shape = 1 - for shape in np_data.shape: - prod_shape *= shape - if prod_extent == prod_shape: - np_data = np_data.reshape(extents) - - return _ffi_api.AllocateConst( # type: ignore[attr-defined] # pylint: disable=no-member - tensor(np_data), dtype, extents, annotations - ) - - def attr(node: Any, attr_key: str, value: Union[PrimExpr, str]) -> frame.AttrFrame: """Create an attribute node. @@ -2186,7 +2149,6 @@ __all__ = float_types + [ "Assert", "realize", "allocate", - "allocate_const", "attr", "While", "If", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index e48b08bbd7..d7053fc862 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -32,7 +32,6 @@ from .stmt import ( BufferStore, BufferRealize, Allocate, - AllocateConst, AttrStmt, DeclBuffer, ) diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py index 1c403ce927..935ddcb2e2 100644 --- a/python/tvm/tir/functor.py +++ b/python/tvm/tir/functor.py @@ -60,7 +60,6 @@ from .expr import ( ) from .stmt import ( Allocate, - AllocateConst, AssertStmt, AttrStmt, SBlock, @@ -165,7 +164,6 @@ class _PyStmtExprVisitor(tvm_ffi.core.Object): f_visit_for: Callable = None, f_visit_while: Callable = None, f_visit_allocate: Callable = None, - f_visit_allocate_const: Callable = None, f_visit_decl_buffer: Callable = None, f_visit_buffer_store: Callable = None, f_visit_buffer_realize: Callable = None, @@ -221,7 +219,6 @@ class _PyStmtExprVisitor(tvm_ffi.core.Object): f_visit_for, f_visit_while, f_visit_allocate, - f_visit_allocate_const, f_visit_decl_buffer, f_visit_buffer_store, f_visit_buffer_realize, @@ -286,7 +283,6 @@ class PyStmtExprVisitor: "visit_for_", "visit_while_", "visit_allocate_", - "visit_allocate_const_", "visit_decl_buffer_", "visit_buffer_store_", "visit_buffer_realize_", @@ -430,19 +426,6 @@ class PyStmtExprVisitor: print("visit_allocate_", op) _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore - def visit_allocate_const_(self, op: AllocateConst) -> None: - """Visit AllocateConst. - Users can customize this function to overwrite VisitStmt_(const AllocateConstNode* op) - on the C++ side. - - Parameters - ---------- - op : AllocateConst - The AllocateConst to be visited. - """ - print("visit_allocate_const_", op) - _ffi_api.PyStmtExprVisitorDefaultVisitStmt(self._outer(), op) # type: ignore - def visit_decl_buffer_(self, op: DeclBuffer) -> None: """Visit DeclBuffer. Users can customize this function to overwrite VisitStmt_(const DeclBufferNode* op) @@ -999,7 +982,6 @@ class _PyStmtExprMutator(tvm_ffi.core.Object): f_visit_for: Callable = None, f_visit_while: Callable = None, f_visit_allocate: Callable = None, - f_visit_allocate_const: Callable = None, f_visit_decl_buffer: Callable = None, f_visit_buffer_store: Callable = None, f_visit_buffer_realize: Callable = None, @@ -1055,7 +1037,6 @@ class _PyStmtExprMutator(tvm_ffi.core.Object): f_visit_for, f_visit_while, f_visit_allocate, - f_visit_allocate_const, f_visit_decl_buffer, f_visit_buffer_store, f_visit_buffer_realize, @@ -1120,7 +1101,6 @@ class PyStmtExprMutator: "visit_for_", "visit_while_", "visit_allocate_", - "visit_allocate_const_", "visit_decl_buffer_", "visit_buffer_store_", "visit_buffer_realize_", @@ -1302,23 +1282,6 @@ class PyStmtExprMutator: """ return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore - def visit_allocate_const_(self, op: AllocateConst) -> Stmt: - """Visit AllocateConst. - Users can customize this function to overwrite VisitStmt_(const AllocateConstNode* op) - on the C++ side. - - Parameters - ---------- - op : AllocateConst - The AllocateConst to be visited. - - Returns - ------- - result : Stmt - The mutated Stmt. - """ - return _ffi_api.PyStmtExprMutatorDefaultVisitStmt(self._outer(), op) # type: ignore - def visit_decl_buffer_(self, op: DeclBuffer) -> Stmt: """Visit DeclBuffer. Users can customize this function to overwrite VisitStmt_(const DeclBufferNode* op) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index cd8f6e92a7..cbaa3f6034 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -360,68 +360,6 @@ class Allocate(Stmt): ) -@tvm_ffi.register_object("tir.AllocateConst") -class AllocateConst(Stmt): - """Allocate constant node. - - Parameters - ---------- - buffer_var : Var - The buffer variable. - - dtype : str - The data type of the buffer. - - extents : list of Expr - The extents of the allocate - - data_or_idx : Union[Tensor, int] - If an Tensor, this is the const data associated with the - constant. If an integer, this is the index into the - "constants" attribute of the `IRModule` that contains the - `AllocateConst`. - - body : Stmt - The body statement. - - annotations : Optional[Mapping[str, Object]] - Additional annotations about the allocation. - - span : Optional[Span] - The location of the stmt in the source code. - """ - - buffer_var: Var - dtype: str - extents: List[PrimExpr] - data: Optional[Tensor] - irmod_storage_idx: Optional[int] - body: Stmt - annotations: Mapping[str, Object] - span: Optional[Span] - - def __init__( - self, - buffer_var: Var, - dtype: str, - extents: List[PrimExpr], - data_or_idx: Union[Tensor, int], - body: Stmt, - annotations: Optional[Mapping[str, Object]] = None, - span: Optional[Span] = None, - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.AllocateConst, # type: ignore - buffer_var, - dtype, - extents, - data_or_idx, - body, - annotations, - span, - ) - - @tvm_ffi.register_object("tir.DeclBuffer") class DeclBuffer(Stmt): """DeclBuffer node. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index f7799ff2b5..676e24031e 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -739,17 +739,6 @@ def ConvertForLoopsToSerial(): return _ffi_api.ConvertForLoopsToSerial() # type: ignore -def ExtractPrimFuncConstants(): - """Collects and unificates tir non-scalar constants to module's attr 'Constants' array. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.ExtractPrimFuncConstants() # type: ignore - - def RenormalizeSplitPattern(): """Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) diff --git a/src/s_tir/schedule/primitive/cache_read_write.cc b/src/s_tir/schedule/primitive/cache_read_write.cc index c356aed3cf..1d80bc8939 100644 --- a/src/s_tir/schedule/primitive/cache_read_write.cc +++ b/src/s_tir/schedule/primitive/cache_read_write.cc @@ -462,12 +462,7 @@ Stmt InsertCacheStage(const Stmt& stmt, int pos, const Stmt& stage) { std::vector<Stmt> nest; Stmt body = stmt; while (true) { - if (auto opt = body.as<AllocateConst>()) { - auto alloc = opt.value(); - body = alloc->body; - alloc.CopyOnWrite()->body = Evaluate(0); - nest.push_back(alloc); - } else if (auto opt = body.as<DeclBuffer>()) { + if (auto opt = body.as<DeclBuffer>()) { auto decl_buffer = opt.value(); body = decl_buffer->body; decl_buffer.CopyOnWrite()->body = Evaluate(0); @@ -640,11 +635,9 @@ class CacheLocDetector : public StmtVisitor { info->loc_sref = scope_sref; auto block_body = scope_sref->StmtAs<SBlockNode>()->body; - // Find the SeqStmtNode within (potentially nested) AllocateConstNodes + // Find the SeqStmtNode within (potentially nested) DeclBufferNodes while (true) { - if (auto* ptr = block_body.as<AllocateConstNode>()) { - block_body = ptr->body; - } else if (auto* ptr = block_body.as<DeclBufferNode>()) { + if (auto* ptr = block_body.as<DeclBufferNode>()) { block_body = ptr->body; } else { break; diff --git a/src/s_tir/schedule/transform.cc b/src/s_tir/schedule/transform.cc index f5f2356dcd..0e1c626b6d 100644 --- a/src/s_tir/schedule/transform.cc +++ b/src/s_tir/schedule/transform.cc @@ -281,15 +281,10 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ } if (const auto* block = sref->StmtAs<SBlockNode>()) { auto body = block->body; - // Peel off AllocateConst nodes at the beginning of the block body. + // Peel off DeclBuffer nodes at the beginning of the block body. std::vector<Stmt> allocs; while (true) { - if (auto opt = body.as<AllocateConst>()) { - auto alloc = opt.value(); - body = alloc->body; - alloc.CopyOnWrite()->body = Evaluate(0); - allocs.push_back(alloc); - } else if (auto opt = body.as<DeclBuffer>()) { + if (auto opt = body.as<DeclBuffer>()) { auto decl_buffer = opt.value(); body = decl_buffer->body; decl_buffer.CopyOnWrite()->body = Evaluate(0); @@ -302,7 +297,7 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ if (const auto* seq = body.as<SeqStmtNode>()) { ObjectPtr<SBlockNode> n = ffi::make_object<SBlockNode>(*block); auto new_seq = RemoveFromSeqStmt(ffi::GetRef<SeqStmt>(seq), ffi::GetRef<Stmt>(last_stmt)); - // Re-attach AllocateConst nodes + // Re-attach DeclBuffer nodes auto new_body = MergeNest(allocs, new_seq); n->body = new_body; *src_stmt = ffi::GetRef<Stmt>(block); diff --git a/src/s_tir/transform/plan_update_buffer_allocation_location.cc b/src/s_tir/transform/plan_update_buffer_allocation_location.cc index a5fb66ec04..98995e2c36 100644 --- a/src/s_tir/transform/plan_update_buffer_allocation_location.cc +++ b/src/s_tir/transform/plan_update_buffer_allocation_location.cc @@ -73,8 +73,7 @@ class BufferAllocateOrderCollector : public StmtExprVisitor { for (const Buffer& buffer : op->alloc_buffers) { buffer_alloc_recorder_.push_back(buffer); } - // Also visit match_buffers to collect constant buffers associated with AllocateConst nodes. - // These buffers only appear in read and match_buffer regions. + // Also visit match_buffers to collect buffers that only appear in read and match_buffer regions. for (const auto& region : op->match_buffers) { if (!find(region->source->buffer)) { buffer_alloc_recorder_.push_back(region->source->buffer); diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 2236e4f8b2..e6c8b5d0b1 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -39,7 +39,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { LaunchThreadFrameNode::RegisterReflection(); RealizeFrameNode::RegisterReflection(); AllocateFrameNode::RegisterReflection(); - AllocateConstFrameNode::RegisterReflection(); AttrFrameNode::RegisterReflection(); WhileFrameNode::RegisterReflection(); IfFrameNode::RegisterReflection(); @@ -155,11 +154,6 @@ void AllocateFrameNode::ExitWithScope() { tvm::tir::Allocate(buffer_var, dtype, extents, condition, AsStmt(stmts), annotations)); } -void AllocateConstFrameNode::ExitWithScope() { - TIRFrameNode::ExitWithScope(); - AddToParent( - tvm::tir::AllocateConst(buffer_var, dtype, extents, data, AsStmt(stmts), annotations)); -} void AttrFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts))); diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 4b5d5d2fb1..9f11743c05 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -521,18 +521,6 @@ AllocateFrame Allocate(ffi::Array<PrimExpr> extents, DataType dtype, ffi::String return AllocateFrame(n); } -AllocateConstFrame AllocateConst(tvm::runtime::Tensor data, DataType dtype, - ffi::Array<PrimExpr> extents, - ffi::Optional<ffi::Map<ffi::String, Any>> annotations) { - ObjectPtr<AllocateConstFrameNode> n = ffi::make_object<AllocateConstFrameNode>(); - n->dtype = dtype; - n->extents = extents; - n->data = data; - n->annotations = annotations.value_or(ffi::Map<ffi::String, Any>()); - n->buffer_var = Var("", tvm::PointerType(tvm::PrimType(dtype))); - return AllocateConstFrame(n); -} - AttrFrame Attr(ffi::Any node, ffi::String attr_key, PrimExpr value) { // convert POD value to PrimExpr if (node.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { @@ -747,7 +735,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("script.ir_builder.tir.LetStmt", LetStmt) .def("script.ir_builder.tir.LegacyLetStmt", LegacyLetStmt) .def("script.ir_builder.tir.Allocate", Allocate) - .def("script.ir_builder.tir.AllocateConst", AllocateConst) .def("script.ir_builder.tir.Realize", Realize) .def("script.ir_builder.tir.Attr", Attr) .def("script.ir_builder.tir.While", While) diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 9ddffe19a8..9df4764dd7 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -251,87 +251,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); }); -template <typename T> -ExprDoc PrintTensor(::tvm::runtime::Tensor arr) { - // FIXME(@junrushao): this is a hack and can be wrong in most of the cases - constexpr int NUM_PRINT = 200; - int ndim = arr->ndim; - int tot_dim = 1; - for (int i = 0; i < ndim; i++) { - tot_dim *= arr->shape[i]; - } - ffi::Array<ExprDoc> result; - T* data_ptr = reinterpret_cast<T*>(arr->data); - runtime::DataType dtype = arr.DataType(); - for (int i = 0; i < tot_dim; i++) { - if (dtype.is_float()) { - result.push_back(LiteralDoc::Float(data_ptr[i], std::nullopt)); - } else { - result.push_back(LiteralDoc::Int(data_ptr[i], std::nullopt)); - } - if (i == NUM_PRINT) { - break; - } - } - return ListDoc(result); -} - -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch<tir::AllocateConst>( - "", [](tir::AllocateConst stmt, AccessPath stmt_p, IRDocsifier d) -> Doc { - bool concise = AllowConciseScoping(d, stmt); - ffi::String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var); - ffi::Array<ExprDoc> args; - ffi::Array<ffi::String> kwargs_keys; - ffi::Array<ExprDoc> kwargs_values; - ExprDoc data_doc{ffi::UnsafeInit()}; - if (stmt->dtype.is_int()) { - if (stmt->dtype.bits() == 8) { - data_doc = PrintTensor<int8_t>(stmt->data.value()); - } else if (stmt->dtype.bits() == 16) { - data_doc = PrintTensor<int16_t>(stmt->data.value()); - } else if (stmt->dtype.bits() == 32) { - data_doc = PrintTensor<int32_t>(stmt->data.value()); - } else if (stmt->dtype.bits() == 64) { - data_doc = PrintTensor<int64_t>(stmt->data.value()); - } else { - LOG(FATAL) << "DataType not supported"; - } - } else if (stmt->dtype.is_uint()) { - if (stmt->dtype.bits() == 8) { - data_doc = PrintTensor<uint8_t>(stmt->data.value()); - } else if (stmt->dtype.bits() == 16) { - data_doc = PrintTensor<uint16_t>(stmt->data.value()); - } else if (stmt->dtype.bits() == 32) { - data_doc = PrintTensor<uint32_t>(stmt->data.value()); - } else if (stmt->dtype.bits() == 64) { - data_doc = PrintTensor<uint64_t>(stmt->data.value()); - } else { - LOG(FATAL) << "DataType not supported"; - } - } else if (stmt->dtype.is_float()) { - if (stmt->dtype.bits() == 16) { - data_doc = PrintTensor<int16_t>(stmt->data.value()); - } else if (stmt->dtype.bits() == 32) { - data_doc = PrintTensor<float>(stmt->data.value()); - } else if (stmt->dtype.bits() == 64) { - data_doc = PrintTensor<double>(stmt->data.value()); - } else { - LOG(FATAL) << "DataType not supported"; - } - } else { - LOG(FATAL) << "DataType not supported"; - } - args.push_back(data_doc); - args.push_back(LiteralDoc::DataType(stmt->dtype, stmt_p->Attr("dtype"))); - args.push_back(d->AsDoc<ExprDoc>(stmt->extents, stmt_p->Attr("extents"))); - ExprDoc rhs = TIR(d, "allocate_const")->Call(args, kwargs_keys, kwargs_values); - With<TIRFrame> f(d, stmt); - ExprDoc lhs = DefineVar(stmt->buffer_var, *f, d); - AsDocBody(stmt->body, stmt_p->Attr("body"), f->get(), d); - return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); - }); - ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, ffi::Optional<ExprDoc> value, // AccessPath p, IRDocsifier d) { ExprDoc buffer = d->AsDoc<ExprDoc>(stmt->buffer, p->Attr("buffer")); @@ -451,7 +370,6 @@ TVM_SCRIPT_REPR(tir::AttrStmtNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::AssertStmtNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::WhileNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::AllocateNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::AllocateConstNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::DeclBufferNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::SeqStmtNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::IfThenElseNode, ReprPrintTIR); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index ed2f927c0d..0a2ae8b09e 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -2078,18 +2078,6 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { builder_->SetInsertPoint(end_block); } -void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) { - EmitDebugLocation(op); - auto data = op->data.value(); - auto array = TensorToLLVMArray(llvm_target_->GetContext(), data); - std::string symbol_name = op->buffer_var->name_hint; - llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable( - *module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name); - - var_map_[op->buffer_var.operator->()] = param_symbol; - this->VisitStmt(op->body); -} - void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { EmitDebugLocation(op); ICHECK_EQ(op->extents.size(), 1) diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index d48e1e2f41..e682581702 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -234,7 +234,6 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>, void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const AllocateNode* op) override; - void VisitStmt_(const AllocateConstNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const LetStmtNode* op) override; diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 52ad781669..732774e318 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -761,38 +761,6 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, } } -void CodeGenC::VisitStmt_(const AllocateConstNode* op) { - std::string symbol_name = AllocVarID(op->buffer_var.get()); - - int64_t num_elements = 1; - const auto& data = op->data.value(); - - for (int64_t dim : data.Shape()) { - num_elements *= dim; - } - - decl_stream << "\n" - << "#ifdef __cplusplus\n" - << "extern \"C\" {\n" - << "#endif\n" - << "static const "; - - PrintType(data.DataType(), decl_stream); - - // Allocate the global static variable - decl_stream << " __attribute__((section(\".rodata.tvm\"), " - << "aligned(" << constants_byte_alignment_->value << "))) " << symbol_name << "[" - << num_elements << "] = {\n"; - TensorDataToC(data, 4, decl_stream); - - decl_stream << "};\n" - << "#ifdef __cplusplus\n" - << "} // extern \"C\"\n" - << "#endif\n"; - var_idmap_[op->buffer_var.operator->()] = symbol_name; - this->PrintStmt(op->body); -} - void CodeGenC::VisitStmt_(const DeclBufferNode* op) { this->PrintStmt(op->body); } void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLINT(*) diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index eae7eb7adb..791fe4a015 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -197,7 +197,6 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>, void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; - void VisitStmt_(const AllocateConstNode* op) override; void VisitStmt_(const DeclBufferNode* op) override; /*! diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index cf8176001a..35c6f8c798 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -694,10 +694,6 @@ void CodeGenWebGPU::VisitStmt_(const AssertStmtNode* op) { PrintStmt(op->body); } -void CodeGenWebGPU::VisitStmt_(const AllocateConstNode* op) { - LOG(FATAL) << "WebGPU: do not support alloc const"; -} - void CodeGenWebGPU::VisitStmt_(const WhileNode* op) { PrintIndent(); stream << "while (true) {\n"; diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h index b8f2f9a79d..b26ff3a4fb 100644 --- a/src/target/source/codegen_webgpu.h +++ b/src/target/source/codegen_webgpu.h @@ -78,7 +78,6 @@ class CodeGenWebGPU final : public CodeGenC { void VisitStmt_(const ForNode* op) final; void VisitStmt_(const AllocateNode* op) final; void VisitStmt_(const AssertStmtNode* op) final; - void VisitStmt_(const AllocateConstNode* op) final; void VisitStmt_(const WhileNode* op) final; private: diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 6499eac91e..f1ac7358cd 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -38,7 +38,6 @@ #include "../../support/array.h" #include "../../tir/ir/data_type_rewriter.h" #include "../../tir/ir/functor_common.h" -#include "../../tir/transform/ir_utils.h" #include "graph.h" namespace tvm { @@ -753,9 +752,8 @@ PrimFunc GenerateAndCompletePrimFunc(const ffi::Array<te::Tensor>& arg_list, return func; } -PrimFunc CreatePrimFuncWithConstants(const ffi::Array<te::Tensor>& arg_list, - const ffi::Array<runtime::Tensor>& constants, - std::optional<DataType> index_dtype_override) { +PrimFunc CreatePrimFunc(const ffi::Array<te::Tensor>& arg_list, + std::optional<DataType> index_dtype_override) { // Information used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(arg_list); // Root body stmts. @@ -776,7 +774,6 @@ PrimFunc CreatePrimFuncWithConstants(const ffi::Array<te::Tensor>& arg_list, // Step 4. Create func and complete prim func. auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info); - func = tir::BindParams(func, constants); if (index_dtype_override.has_value()) { func = IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func)); } @@ -784,11 +781,6 @@ PrimFunc CreatePrimFuncWithConstants(const ffi::Array<te::Tensor>& arg_list, return result; } -PrimFunc CreatePrimFunc(const ffi::Array<te::Tensor>& arg_list, - std::optional<DataType> index_dtype_override) { - return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); -} - TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed("te.CreatePrimFunc", [](ffi::PackedArgs args, ffi::Any* ret) { @@ -830,9 +822,8 @@ PrimFunc GenerateAndCompletePrimFunc(const ffi::Array<ObjectRef>& arg_tir_var_li return func; } -PrimFunc CreatePrimFuncWithConstants(const ffi::Array<ObjectRef>& arg_list, - const ffi::Array<runtime::Tensor>& constants, - std::optional<DataType> index_dtype_override) { +PrimFunc CreatePrimFunc(const ffi::Array<ObjectRef>& arg_list, + std::optional<DataType> index_dtype_override) { ffi::Array<te::Tensor> tensor_arg_list; for (const ObjectRef& x : arg_list) { if (auto tensor_node = x.as<te::TensorNode>()) { @@ -840,7 +831,7 @@ PrimFunc CreatePrimFuncWithConstants(const ffi::Array<ObjectRef>& arg_list, tensor_arg_list.push_back(tensor); } } - // Infomations used in CreatePrimFunc and its sub-functions. + // Information used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(tensor_arg_list); // Root body stmts. ffi::Array<Stmt> root_stmts; @@ -858,7 +849,6 @@ PrimFunc CreatePrimFuncWithConstants(const ffi::Array<ObjectRef>& arg_list, RewriteStageToBlock(op, &info, &root_stmts, &analyzer); } auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info); - func = tir::BindParams(func, constants); if (index_dtype_override.has_value()) { func = IndexDataTypeNormalizer(index_dtype_override.value()).Rewrite(std::move(func)); } @@ -866,10 +856,5 @@ PrimFunc CreatePrimFuncWithConstants(const ffi::Array<ObjectRef>& arg_list, return result; } -PrimFunc CreatePrimFunc(const ffi::Array<ObjectRef>& arg_list, - std::optional<DataType> index_dtype_override) { - return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); -} - } // namespace tir } // namespace tvm diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index f7ad7e0e1e..a5bd9b16ed 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -33,28 +33,10 @@ namespace tir { PrimFunc CreatePrimFunc(const ffi::Array<te::Tensor>& arg_list, std::optional<DataType> index_dtype_override = std::nullopt); -/*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the - * constants array is N, the last N tensors in arg_list will be treated as constant tensors. - * Constant tensors will not be part of the parameters of the created PrimFunc, instead constants - * will be embedded in the body as AllocateConstNode. - */ -PrimFunc CreatePrimFuncWithConstants(const ffi::Array<te::Tensor>& arg_list, - const ffi::Array<runtime::Tensor>& constants, - std::optional<DataType> index_dtype_override = std::nullopt); - /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ PrimFunc CreatePrimFunc(const ffi::Array<ObjectRef>& arg_list, std::optional<DataType> index_dtype_override); -/*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the - * constants array is N, the last N tensors in arg_list will be treated as constant tensors. - * Constant tensors will not be part of the parameters of the created PrimFunc, instead constants - * will be embedded in the body as AllocateConstNode. - */ -PrimFunc CreatePrimFuncWithConstants(const ffi::Array<ObjectRef>& arg_list, - const ffi::Array<runtime::Tensor>& constants, - std::optional<DataType> index_dtype_override); - } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index 6957ee578c..56b9d99265 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -208,7 +208,6 @@ class FlopEstimator : private ExprFunctor<TResult(const PrimExpr& n)>, TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); } TResult VisitExpr_(const StringImmNode* op) override { return TResult(); } TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); } - TResult VisitStmt_(const AllocateConstNode* op) override { return VisitStmt(op->body); } TResult VisitStmt_(const AllocateNode* op) override { return VisitStmt(op->body); } TResult VisitStmt_(const DeclBufferNode* op) override { return VisitStmt(op->body); } TResult VisitStmt_(const EvaluateNode* op) override { return TResult(); } diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tir/analysis/var_use_def_analysis.cc index becae607fb..360dbb6e44 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tir/analysis/var_use_def_analysis.cc @@ -74,11 +74,6 @@ void VarUseDefAnalyzer::VisitStmt_(const AllocateNode* op) { StmtExprVisitor::VisitStmt_(op); } -void VarUseDefAnalyzer::VisitStmt_(const AllocateConstNode* op) { - this->HandleDef(op->buffer_var); - StmtExprVisitor::VisitStmt_(op); -} - void VarUseDefAnalyzer::VisitStmt_(const BufferStoreNode* op) { HandleUse(op->buffer); StmtExprVisitor::VisitStmt_(op); diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tir/analysis/var_use_def_analysis.h index 51323d65d5..3a56afb5e4 100644 --- a/src/tir/analysis/var_use_def_analysis.h +++ b/src/tir/analysis/var_use_def_analysis.h @@ -65,8 +65,6 @@ class VarUseDefAnalyzer : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final; - void VisitStmt_(const AllocateConstNode* op) final; - void VisitStmt_(const BufferStoreNode* op) final; void VisitExpr_(const LetNode* op) final; diff --git a/src/tir/ir/py_functor.cc b/src/tir/ir/py_functor.cc index 990aa87717..9ab37903bb 100644 --- a/src/tir/ir/py_functor.cc +++ b/src/tir/ir/py_functor.cc @@ -182,8 +182,6 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { ffi::Function f_visit_while{nullptr}; /*! \brief The packed function to the `VisitStmt_(const AllocateNode* op)` function. */ ffi::Function f_visit_allocate{nullptr}; - /*! \brief The packed function to the `VisitStmt_(const AllocateConstNode* op)` function. */ - ffi::Function f_visit_allocate_const{nullptr}; /*! \brief The packed function to the `VisitStmt_(const DeclBufferNode* op)` function. */ ffi::Function f_visit_decl_buffer{nullptr}; /*! \brief The packed function to the `VisitStmt_(const BufferStoreNode* op)` function. */ @@ -230,7 +228,6 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { PY_STMT_VISITOR_DISPATCH(ForNode, f_visit_for); PY_STMT_VISITOR_DISPATCH(WhileNode, f_visit_while); PY_STMT_VISITOR_DISPATCH(AllocateNode, f_visit_allocate); - PY_STMT_VISITOR_DISPATCH(AllocateConstNode, f_visit_allocate_const); PY_STMT_VISITOR_DISPATCH(DeclBufferNode, f_visit_decl_buffer); PY_STMT_VISITOR_DISPATCH(BufferStoreNode, f_visit_buffer_store); PY_STMT_VISITOR_DISPATCH(BufferRealizeNode, f_visit_buffer_realize); @@ -323,7 +320,6 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { PY_STMT_VISITOR_DEFAULT_DISPATCH(ForNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(WhileNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(AllocateNode); - PY_STMT_VISITOR_DEFAULT_DISPATCH(AllocateConstNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(DeclBufferNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(BufferStoreNode); PY_STMT_VISITOR_DEFAULT_DISPATCH(BufferRealizeNode); @@ -354,7 +350,6 @@ class PyStmtExprVisitor : public ObjectRef { ffi::Function f_visit_for, // ffi::Function f_visit_while, // ffi::Function f_visit_allocate, // - ffi::Function f_visit_allocate_const, // ffi::Function f_visit_decl_buffer, // ffi::Function f_visit_buffer_store, // ffi::Function f_visit_buffer_realize, // @@ -406,7 +401,6 @@ class PyStmtExprVisitor : public ObjectRef { n->f_visit_for = std::move(f_visit_for); n->f_visit_while = std::move(f_visit_while); n->f_visit_allocate = std::move(f_visit_allocate); - n->f_visit_allocate_const = std::move(f_visit_allocate_const); n->f_visit_decl_buffer = std::move(f_visit_decl_buffer); n->f_visit_buffer_store = std::move(f_visit_buffer_store); n->f_visit_buffer_realize = std::move(f_visit_buffer_realize); @@ -549,8 +543,6 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { ffi::Function f_visit_while{nullptr}; /*! \brief The packed function to the `VisitStmt_(const AllocateNode* op)` function. */ ffi::Function f_visit_allocate{nullptr}; - /*! \brief The packed function to the `VisitStmt_(const AllocateConstNode* op)` function. */ - ffi::Function f_visit_allocate_const{nullptr}; /*! \brief The packed function to the `VisitStmt_(const DeclBufferNode* op)` function. */ ffi::Function f_visit_decl_buffer{nullptr}; /*! \brief The packed function to the `VisitStmt_(const BufferStoreNode* op)` function. */ @@ -597,7 +589,6 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { PY_STMT_MUTATOR_DISPATCH(ForNode, f_visit_for); PY_STMT_MUTATOR_DISPATCH(WhileNode, f_visit_while); PY_STMT_MUTATOR_DISPATCH(AllocateNode, f_visit_allocate); - PY_STMT_MUTATOR_DISPATCH(AllocateConstNode, f_visit_allocate_const); PY_STMT_MUTATOR_DISPATCH(DeclBufferNode, f_visit_decl_buffer); PY_STMT_MUTATOR_DISPATCH(BufferStoreNode, f_visit_buffer_store); PY_STMT_MUTATOR_DISPATCH(BufferRealizeNode, f_visit_buffer_realize); @@ -690,7 +681,6 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { PY_STMT_MUTATOR_DEFAULT_DISPATCH(ForNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(WhileNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(AllocateNode); - PY_STMT_MUTATOR_DEFAULT_DISPATCH(AllocateConstNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(DeclBufferNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(BufferStoreNode); PY_STMT_MUTATOR_DEFAULT_DISPATCH(BufferRealizeNode); @@ -722,7 +712,6 @@ class PyStmtExprMutator : public ObjectRef { ffi::Function f_visit_for, // ffi::Function f_visit_while, // ffi::Function f_visit_allocate, // - ffi::Function f_visit_allocate_const, // ffi::Function f_visit_decl_buffer, // ffi::Function f_visit_buffer_store, // ffi::Function f_visit_buffer_realize, // @@ -774,7 +763,6 @@ class PyStmtExprMutator : public ObjectRef { n->f_visit_for = std::move(f_visit_for); n->f_visit_while = std::move(f_visit_while); n->f_visit_allocate = std::move(f_visit_allocate); - n->f_visit_allocate_const = std::move(f_visit_allocate_const); n->f_visit_decl_buffer = std::move(f_visit_decl_buffer); n->f_visit_buffer_store = std::move(f_visit_buffer_store); n->f_visit_buffer_realize = std::move(f_visit_buffer_realize); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index d332741eea..b7f28e0aaf 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -40,7 +40,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { BufferStoreNode::RegisterReflection(); BufferRealizeNode::RegisterReflection(); AllocateNode::RegisterReflection(); - AllocateConstNode::RegisterReflection(); DeclBufferNode::RegisterReflection(); SeqStmtNode::RegisterReflection(); EvaluateNode::RegisterReflection(); @@ -296,70 +295,6 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -// Const -// The constructor to create a IRNode with constant data -// depending on the type of ObjectRef, it will either -// create AllocateConstNode with irmod_storage_idx or data -AllocateConst::AllocateConst(Var buffer_var, DataType dtype, ffi::Array<PrimExpr> extents, - ObjectRef data_or_idx, Stmt body, - ffi::Map<ffi::String, Any> annotations, Span span) { - ICHECK(IsPointerType(buffer_var->type_annotation, dtype)) - << "The allocated data type (" << dtype - << ") does not match the type annotation of the buffer " << buffer_var << " (" - << buffer_var->type_annotation - << "). The data type should be an element of the pointer type."; - - for (size_t i = 0; i < extents.size(); ++i) { - ICHECK(extents[i].defined()); - ICHECK(extents[i].dtype().is_scalar()); - } - ICHECK(body.defined()); - ICHECK(data_or_idx.defined()); - - ObjectPtr<AllocateConstNode> node = ffi::make_object<AllocateConstNode>(); - node->buffer_var = std::move(buffer_var); - node->dtype = dtype; - node->extents = std::move(extents); - node->body = std::move(body); - node->annotations = annotations; - node->span = std::move(span); - if (data_or_idx->IsInstance<runtime::Tensor::ContainerType>()) { - node->data = ffi::Optional<tvm::runtime::Tensor>(Downcast<runtime::Tensor>(data_or_idx)); - node->irmod_storage_idx = ffi::Optional<Integer>(); - } else if (data_or_idx->IsInstance<IntImmNode>()) { - node->data = ffi::Optional<tvm::runtime::Tensor>(); - node->irmod_storage_idx = ffi::Optional<Integer>(Downcast<Integer>(data_or_idx)); - } else { - LOG(FATAL) << "Data type not supported: " << data_or_idx->GetTypeKey(); - } - data_ = std::move(node); -} - -int64_t AllocateConstNode::ConstantAllocationSize(const ffi::Array<PrimExpr>& extents) { - int64_t result = 1; - for (size_t i = 0; i < extents.size(); ++i) { - if (const IntImmNode* int_size = extents[i].as<IntImmNode>()) { - result *= int_size->value; - if (result > std::numeric_limits<int64_t>::max()) { - return 0; - } - } else { - return 0; - } - } - return static_cast<int64_t>(result); -} -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "tir.AllocateConst", - [](Var buffer_var, DataType dtype, ffi::Array<PrimExpr> extents, ObjectRef data_or_idx, - Stmt body, ffi::Optional<ffi::Map<ffi::String, Any>> annotations, Span span) { - return AllocateConst(buffer_var, dtype, extents, data_or_idx, body, - annotations.value_or({}), span); - }); -} - // DeclBuffer DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) { ObjectPtr<DeclBufferNode> node = ffi::make_object<DeclBufferNode>(); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 2d4ec3a1ca..06db54af5d 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -63,11 +63,6 @@ void StmtVisitor::VisitStmt_(const AllocateNode* op) { this->VisitExpr(op->condition); } -void StmtVisitor::VisitStmt_(const AllocateConstNode* op) { - VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); - this->VisitStmt(op->body); -} - void StmtVisitor::VisitStmt_(const DeclBufferNode* op) { this->VisitStmt(op->body); } void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { @@ -310,20 +305,6 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { } } -Stmt StmtMutator::VisitStmt_(const AllocateConstNode* op) { - ffi::Array<PrimExpr> extents = Internal::Mutate(this, op->extents); - Stmt body = this->VisitStmt(op->body); - - if (extents.same_as(op->extents) && body.same_as(op->body)) { - return ffi::GetRef<Stmt>(op); - } else { - auto n = CopyOnWrite(op); - n->extents = std::move(extents); - n->body = std::move(body); - return Stmt(n); - } -} - Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) { Stmt body = this->VisitStmt(op->body); diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 712d4d88eb..ba79ab856d 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -232,12 +232,6 @@ void TIRVisitorWithPath::VisitStmt_(const AllocateNode* op, AccessPath path) { Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitStmt_(const AllocateConstNode* op, AccessPath path) { - Visit(op->extents, path->Attr("extents")); - auto context = WithDef(op->buffer_var, path->Attr("buffer_var")); - Visit(op->body, path->Attr("body")); -} - void TIRVisitorWithPath::VisitStmt_(const DeclBufferNode* op, AccessPath path) { auto context = WithDef(op->buffer, path->Attr("buffer")); Visit(op->body, path->Attr("body")); diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index 1409fb39f5..df28015495 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -106,7 +106,6 @@ class TIRVisitorWithPath void VisitStmt_(const ForNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const WhileNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const AllocateNode* op, ffi::reflection::AccessPath path) override; - void VisitStmt_(const AllocateConstNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const DeclBufferNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const BufferStoreNode* op, ffi::reflection::AccessPath path) override; void VisitStmt_(const BufferRealizeNode* op, ffi::reflection::AccessPath path) override; diff --git a/src/tir/transform/bind_params.cc b/src/tir/transform/bind_params.cc deleted file mode 100644 index d62f21be1f..0000000000 --- a/src/tir/transform/bind_params.cc +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file storage_rewrite.cc - * \brief Memory access pattern analysis and optimization. - * Re-write data access to enable memory sharing when possible. - */ -#include <tvm/arith/analyzer.h> -#include <tvm/ffi/function.h> -#include <tvm/ir/type.h> -#include <tvm/target/target_info.h> -#include <tvm/tir/analysis.h> -#include <tvm/tir/builtin.h> -#include <tvm/tir/expr.h> -#include <tvm/tir/function.h> -#include <tvm/tir/stmt_functor.h> -#include <tvm/tir/transform.h> - -#include "ir_utils.h" - -namespace tvm { -namespace tir { - -class ParamsCollector : public StmtExprVisitor { - public: - explicit ParamsCollector(const ffi::Map<tir::Var, runtime::Tensor>& constant_map) - : constant_map_(constant_map) {} - std::vector<const tir::VarNode*> CollectParams(tir::Stmt body) { - this->VisitStmt(body); - return constant_list_; - } - - void VisitExpr_(const BufferLoadNode* ln) { - if (constant_map_.find(ln->buffer->data) != constant_map_.end()) { - auto it = std::find(constant_list_.begin(), constant_list_.end(), ln->buffer->data.get()); - if (it == constant_list_.end()) { - constant_list_.push_back(ln->buffer->data.get()); - } - } - StmtExprVisitor::VisitExpr_(ln); - } - - void VisitExpr_(const CallNode* cn) { - if (cn->op.same_as(builtin::tvm_access_ptr())) { - ICHECK_EQ(cn->args.size(), 5U); - const Var& var = Downcast<Var>(cn->args[1]); - const VarNode* buffer = cn->args[1].as<VarNode>(); - auto it = constant_map_.find(var); - if (it != constant_map_.end()) { - auto it = std::find(constant_list_.begin(), constant_list_.end(), buffer); - if (it == constant_list_.end()) { - constant_list_.push_back(buffer); - } - } - } - StmtExprVisitor::VisitExpr_(cn); - } - - private: - std::vector<const tir::VarNode*> constant_list_; - ffi::Map<tir::Var, runtime::Tensor> constant_map_; -}; - -PrimFunc BindParams(PrimFunc f, const ffi::Array<runtime::Tensor>& constants) { - ffi::Map<tir::Var, runtime::Tensor> constant_map; - - // Remove constants from the primfunc signature - size_t num_constants = constants.size(); - size_t start = f->params.size() - num_constants; - ffi::Array<tir::Var> params; - for (unsigned i = 0; i < start; i++) { - params.push_back(f->params[i]); - } - - auto* n = f.CopyOnWrite(); - for (unsigned i = start; i < f->params.size(); i++) { - tir::Var p = n->params[i]; - tir::Var b = n->buffer_map[p]->data; - n->buffer_map.erase(p); - constant_map.Set(b, constants[i - start]); - } - n->params = params; - auto constant_list = ParamsCollector(constant_map).CollectParams(n->body); - - // Allocate constants within the primfunc - for (auto i : constant_list) { - auto var = ffi::GetRef<Var>(i); - int ndim = constant_map[var]->ndim; - ffi::Array<PrimExpr> extents; - - for (int i = 0; i < ndim; i++) { - int shape = constant_map[var]->shape[i]; - extents.push_back(make_const(DataType::Int(32), shape)); - } - DataType dtype = DataType(constant_map[var]->dtype); - - if (n->body->IsInstance<SBlockRealizeNode>()) { - auto* block_realize = n->body.as<SBlockRealizeNode>(); - auto block = block_realize->block; - block.CopyOnWrite()->body = - tir::AllocateConst(var, dtype, extents, constant_map[var], block->body); - n->body = SBlockRealize(block_realize->iter_values, block_realize->predicate, block); - } else { - n->body = tir::AllocateConst(var, dtype, extents, constant_map[var], n->body); - } - } - return f; -} - -namespace transform { - -Pass BindParams(const ffi::Array<runtime::Tensor>& constants) { - auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return BindParams(f, constants); - }; - return CreatePrimFuncPass(pass_func, 0, "tir.BindParams", {}); -} -} // namespace transform - -} // namespace tir -} // namespace tvm diff --git a/src/tir/transform/extract_constants.cc b/src/tir/transform/extract_constants.cc deleted file mode 100644 index be5da45d9f..0000000000 --- a/src/tir/transform/extract_constants.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file extract_constants.cc - * \brief Collects PrimFunc's constant data into mod's 'tvm::attr::kConstantsArray' attrs array, - * sets irmod_storage_idx as index in this array. - * For more information, see the RFC: - * https://github.com/apache/tvm-rfcs/blob/main/rfcs/0022-tir-non-scalar-constants.md - */ -#include <tvm/arith/analyzer.h> -#include <tvm/ffi/function.h> -#include <tvm/ffi/reflection/registry.h> -#include <tvm/ir/transform.h> -#include <tvm/node/structural_equal.h> -#include <tvm/tir/stmt_functor.h> - -#include "ir_utils.h" - -namespace tvm { -namespace tir { - -using ConstArrayType = ffi::Array<runtime::Tensor>; -class Applicator : public tir::StmtMutator { - protected: - // returns index of the a in constant_array_, if not found - appends - size_t DeDup(const runtime::Tensor& a) { - tvm::StructuralEqual eql; - auto it = std::find_if(constant_array_.begin(), constant_array_.end(), - [&eql, a](const runtime::Tensor& v) { return eql(a, v); }); - if (it != constant_array_.end()) { - return it - constant_array_.begin(); - } - constant_array_.push_back(std::move(a)); - return constant_array_.size() - 1; - } - - public: - Stmt Apply(tir::Stmt body, const ConstArrayType& constant_array) { - constant_array_ = constant_array; - return this->VisitStmt(body); - } - - Stmt VisitStmt_(const tir::AllocateConstNode* acn) override { - // Check whether the data already defined within the module's attrs - // and add array index. - ICHECK(acn->data) << "data field should be defined"; - auto node = CopyOnWrite(acn); - node->irmod_storage_idx = ffi::Optional<Integer>(Integer(DeDup(node->data.value()))); - return Stmt(node); - } - - ConstArrayType constant_array_; -}; - -namespace transform { - -tvm::transform::Pass ExtractPrimFuncConstants() { - auto prim_func_pass = [=](PrimFunc foo, IRModule m, tvm::transform::PassContext ctx) { - auto* func = foo.CopyOnWrite(); - if (!m->attrs.defined()) { - m->attrs = DictAttrs(ffi::Map<ffi::String, ffi::Any>()); - } - auto* attrs = m->attrs.CopyOnWrite(); - ConstArrayType constant_array_ = - (attrs->dict.count(tvm::attr::kConstants)) - ? Downcast<ConstArrayType>(attrs->dict[tvm::attr::kConstants]) - : ConstArrayType(); - Applicator a = Applicator(); - func->body = a.Apply(func->body, constant_array_); - const ConstArrayType constant_list = a.constant_array_; - if (constant_list.size()) { - attrs->dict.Set(tvm::attr::kConstants, constant_list); - } - return ffi::GetRef<PrimFunc>(func); - }; - - auto pass_func = [=](IRModule module, tvm::transform::PassContext pc) { - auto m = ffi::GetRef<IRModule>(module.CopyOnWrite()); - for (const auto& kv : m->functions) { - if (auto func = kv.second.as<PrimFunc>()) { - m->Update(kv.first, prim_func_pass(func.value(), m, pc)); - } - } - return m; - }; - - return tvm::transform::CreateModulePass(pass_func, 0, "tir.ExtractPrimFuncConstants", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.ExtractPrimFuncConstants", ExtractPrimFuncConstants); -} - -} // namespace transform - -} // namespace tir -} // namespace tvm diff --git a/src/tir/transform/ir_utils.cc b/src/tir/transform/ir_utils.cc index 0d7a217a0a..15917d891e 100644 --- a/src/tir/transform/ir_utils.cc +++ b/src/tir/transform/ir_utils.cc @@ -76,11 +76,6 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) { ICHECK(is_no_op(n->body)); n->body = body; body = Stmt(n); - } else if (const auto* alloc = s.as<AllocateConstNode>()) { - auto n = ffi::make_object<AllocateConstNode>(*alloc); - ICHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); } else if (const auto* decl_buffer = s.as<DeclBufferNode>()) { auto n = ffi::make_object<DeclBufferNode>(*decl_buffer); ICHECK(is_no_op(n->body)); diff --git a/src/tir/transform/ir_utils.h b/src/tir/transform/ir_utils.h index 0616e23e7e..c8d72e3b14 100644 --- a/src/tir/transform/ir_utils.h +++ b/src/tir/transform/ir_utils.h @@ -314,16 +314,6 @@ std::unordered_map<const VarNode*, FragmentInfo> GetTensorCoreFragmentInfo(const // attr::async_wait_queue_scope annotation. std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const AttrStmtNode* op); -/*! - * \brief Bind a subset of parameter tensors to constants, replacing them by AllocateConst nodes. - * \param f The function to bind constants to. - * \param constants Raw constant data. If the size of this array is N, the last N parameter tensors - * will be removed from the signature and instead AllocateConst nodes will be introduced in the - * function body. - * \return The updated function. - */ -PrimFunc BindParams(PrimFunc f, const ffi::Array<runtime::Tensor>& constants); - /*! \brief The quad used by StorageAlign for (buffer_idx, axis, factor, offset) */ using StorageAlignTuple = ffi::Tuple<int32_t, int32_t, int32_t, int32_t>; /*! \brief A list of StorageAlignTuple, used by StorageAlign */ diff --git a/src/tir/transform/remove_weight_layout_rewrite_block.cc b/src/tir/transform/remove_weight_layout_rewrite_block.cc index 86f1ed6400..1ca0dc21b8 100644 --- a/src/tir/transform/remove_weight_layout_rewrite_block.cc +++ b/src/tir/transform/remove_weight_layout_rewrite_block.cc @@ -116,154 +116,14 @@ class RemoveLayoutRewriteBlock : public StmtMutator { std::unordered_map<const VarNode*, ffi::Array<PrimExpr>> buffer_var_to_rewritten_shape_; }; -// After RemoveLayoutRewriteBlock, the body of a compute update block references a -// non-existant buffer. For example, fused_constant_2_global below is originally a -// cache_read buffer, whose allocation is removed by RemoveLayoutRewriteBlock: -// -// constant fused_constant_2[float32 * 3 * 3 * 64 * 64] -// conv2d_nhwc[nn, yy, xx, ff] += ... * fused_constant_2_global[ry, -// floordiv(rc, 32), -// floordiv(ff, 16), -// rx, -// floormod(rc, 32), -// floormod(ff, 16)])) -// -// When cache_read is reading from AllocateConstant, we need to replace the reference -// to fused_constant_2_global with the corresponding transformed AllocateConstant. -// To do that, we manually rewrite the original constant using the associated index map, -// and let the body of the compute block to load from the rewritten constant. -// -// After this transformation, the example above looks like: -// -// constant fused_constant_2[float32 * 3 * 2 * 4 * 3 * 32 * 16] -// conv2d_nhwc[nn, yy, xx, ff] += ... * fused_constant_2[ry, -// floordiv(rc, 32), -// floordiv(ff, 16), -// rx, -// floormod(rc, 32), -// floormod(ff, 16)])) - -using BufferVarMap = std::unordered_map<const tir::VarNode*, const tir::VarNode*>; - -class AllocateConstRewrite : public StmtExprMutator { - public: - AllocateConstRewrite( - const BufferVarMap& buffer_var_map, - const std::unordered_map<const VarNode*, IndexMap>& buffer_var_to_index_map, - const std::unordered_map<const VarNode*, ffi::Array<PrimExpr>>& buffer_var_to_rewritten_shape, - bool skip_tensor_rewrite) - : buffer_var_map_(buffer_var_map), - buffer_var_to_index_map_(buffer_var_to_index_map), - buffer_var_to_rewritten_shape_(buffer_var_to_rewritten_shape), - skip_tensor_rewrite_(skip_tensor_rewrite) {} - - private: - Stmt VisitStmt_(const SBlockNode* op) final { - SBlock block = Downcast<SBlock>(StmtMutator::VisitStmt_(op)); - auto n = CopyOnWrite(block.get()); - ffi::Array<BufferRegion> new_reads; - for (auto read_region : op->reads) { - if (auto it = new_load_buf_.find(read_region->buffer->data.get()); - it != new_load_buf_.end()) { - new_reads.push_back(BufferRegion(it->second, read_region->region)); - } else { - new_reads.push_back(read_region); - } - } - n->reads = new_reads; - return Stmt(n); - } - - Stmt VisitStmt_(const AllocateConstNode* alloc) final { - if (auto it = buffer_var_to_index_map_.find(alloc->buffer_var.get()); - it != buffer_var_to_index_map_.end()) { - ICHECK(buffer_var_to_rewritten_shape_.count(alloc->buffer_var.get())); - auto new_body = StmtMutator::VisitStmt(alloc->body); - auto rewritten_tensor = RewriteTensor( - alloc->data.value(), it->second, buffer_var_to_rewritten_shape_[alloc->buffer_var.get()]); - ffi::Array<PrimExpr> rewritten_extents; - for (auto s : rewritten_tensor.Shape()) { - rewritten_extents.push_back(PrimExpr(static_cast<int>(s))); - } - return AllocateConst(alloc->buffer_var, alloc->dtype, rewritten_extents, rewritten_tensor, - new_body, alloc->annotations, alloc->span); - } - return StmtMutator::VisitStmt_(alloc); - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - if (auto it = buffer_var_map_.find(op->buffer->data.get()); it != buffer_var_map_.end()) { - auto new_buffer = - Buffer(ffi::GetRef<Var>(it->second), op->buffer->dtype, op->buffer->shape, - op->buffer->strides, op->buffer->elem_offset, it->second->name_hint, - op->buffer->data_alignment, op->buffer->offset_factor, op->buffer->buffer_type); - new_load_buf_[op->buffer->data.get()] = new_buffer; - return BufferLoad(new_buffer, op->indices, op->predicate); - } - return ExprMutator::VisitExpr_(op); - } - - runtime::Tensor RewriteTensor(runtime::Tensor src, const IndexMap& index_map, - const ffi::Array<PrimExpr>& dst_shape) { - if (skip_tensor_rewrite_) { - // Only the shape of the destination array needs to be correct. - std::vector<int64_t> dst_shape_int; - for (auto s : dst_shape) { - ICHECK(s->IsInstance<IntImmNode>()); - dst_shape_int.push_back(s.as<IntImmNode>()->value); - } - return src.CreateView(dst_shape_int, src.DataType()); - } else { - return index_map->MapTensor(src); - } - } - - /*! \brief Maps a buffer store to a load in a layout rewrite block */ - BufferVarMap buffer_var_map_; - /*! \brief Maps a buffer load to an index map associated with the load / store - in a layout rewrite block. */ - std::unordered_map<const VarNode*, IndexMap> buffer_var_to_index_map_; - /*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */ - std::unordered_map<const VarNode*, ffi::Array<PrimExpr>> buffer_var_to_rewritten_shape_; - /*! \brief Maps load buffer variables to newly created buffers */ - std::unordered_map<const VarNode*, Buffer> new_load_buf_; - /*! \brief Whether or not to skip rewriting of Tensor contents */ - bool skip_tensor_rewrite_; -}; - -class CollectAllocateConstBufferVars : public StmtVisitor { - public: - void VisitStmt_(const AllocateConstNode* alloc) final { - StmtVisitor::VisitStmt_(alloc); - constant_buf_var.insert(alloc->buffer_var.get()); - } - - std::unordered_set<const VarNode*> constant_buf_var; -}; - class WeightLayoutRewriteBlockRemover : public StmtMutator { public: static PrimFunc Remove(PrimFunc f, bool skip_tensor_rewrite) { - CollectAllocateConstBufferVars collector; - collector(f->body); - auto [f_, buf_map, buffer_var_to_index_map, buffer_var_to_rewritten_shape] = RemoveLayoutRewriteBlock().Rewrite(f); - BufferVarMap buffer_var_map; - for (const auto& [load_buf, store_buf] : buf_map) { - if (collector.constant_buf_var.find(load_buf->data.get()) != - collector.constant_buf_var.end()) { - buffer_var_map[store_buf->data.get()] = load_buf->data.get(); - } - } - PrimFuncNode* n = f_.CopyOnWrite(); - AllocateConstRewrite rewriter(buffer_var_map, buffer_var_to_index_map, - buffer_var_to_rewritten_shape, skip_tensor_rewrite); - n->body = rewriter(std::move(n->body)); - ffi::Map<tir::Var, Buffer> buffer_map; for (const auto& [param, buffer] : f_->buffer_map) { auto it = buf_map.find(buffer); diff --git a/src/tir/transform/renew_defs.cc b/src/tir/transform/renew_defs.cc index ee72245184..f50eb49832 100644 --- a/src/tir/transform/renew_defs.cc +++ b/src/tir/transform/renew_defs.cc @@ -100,7 +100,6 @@ class RenewDefMutator : public StmtExprMutator { private: STMT_REGENERATE_VAR_DEF(LetStmtNode, var); STMT_REGENERATE_VAR_DEF(AllocateNode, buffer_var); - STMT_REGENERATE_VAR_DEF(AllocateConstNode, buffer_var); STMT_REGENERATE_VAR_DEF(ForNode, loop_var); Stmt VisitStmt_(const SBlockNode* op) final { diff --git a/src/tir/transform/storage_rewrite.cc b/src/tir/transform/storage_rewrite.cc index 830364788c..cb9868d383 100644 --- a/src/tir/transform/storage_rewrite.cc +++ b/src/tir/transform/storage_rewrite.cc @@ -1053,8 +1053,7 @@ struct BufferVarInfo { kPrimFuncParam = (1 << 0), kPrimFuncBufferMap = (1 << 1), kAllocateNode = (1 << 2), - kAllocateConstNode = (1 << 3), - kLetNode = (1 << 4), + kLetNode = (1 << 3), }; // The tir::Var that represents this buffer. @@ -1204,14 +1203,6 @@ class VectorTypeAccessChecker : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } - void VisitStmt_(const AllocateConstNode* op) final { - const ffi::Array<PrimExpr>& extents = op->extents; - PrimExpr extent = extents.size() ? extents[extents.size() - 1] : NullValue<PrimExpr>(); - OnArrayDeclaration(op->buffer_var, op->dtype, extent, BufferVarInfo::kAllocateConstNode); - - StmtExprVisitor::VisitStmt_(op); - } - void VisitExpr_(const LetNode* op) final { HandleLetNode(op->var); StmtExprVisitor::VisitExpr_(op); @@ -1418,7 +1409,7 @@ class VectorTypeRewriter : public StmtExprMutator { VectorTypeRewriter(const std::unordered_map<const VarNode*, BufferVarInfo>& info_map, bool rewrite_params = true, bool rewrite_buffer_map = true, bool rewrite_allocate_node = true, bool rewrite_indices = true, - bool rewrite_let_node = true, bool rewrite_allocate_const_node = true, + bool rewrite_let_node = true, bool rewrite_scalar_read_to_vector_shuffle = true) : rewrite_indices_(rewrite_indices) { int rewrite_mask = 0; @@ -1434,10 +1425,6 @@ class VectorTypeRewriter : public StmtExprMutator { if (rewrite_let_node) { rewrite_mask |= BufferVarInfo::kLetNode; } - if (rewrite_allocate_const_node) { - rewrite_mask |= BufferVarInfo::kAllocateConstNode; - } - // Rewrite any buffer variables whose preferred type isn't their current type. for (const auto& pair : info_map) { const auto& var_info = pair.second; @@ -1619,27 +1606,6 @@ class VectorTypeRewriter : public StmtExprMutator { return Allocate(new_buffer_var, info.new_element_dtype, extents, op->condition, op->body); } - Stmt VisitStmt_(const AllocateConstNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as<AllocateConstNode>(); - - auto it = rewrite_map_.find(op->buffer_var.get()); - if (it == rewrite_map_.end()) { - return stmt; - } - - const auto& info = it->second; - - Var new_buffer_var = info.new_buffer_var; - - int factor = info.new_element_dtype.lanes() / op->dtype.lanes(); - - ffi::Array<PrimExpr> extents = op->extents; - extents.Set(extents.size() - 1, - extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); - return AllocateConst(new_buffer_var, info.new_element_dtype, extents, op->data, op->body); - } - /* Update the parameters and all remaining variable references * * Should be called after calling operator() on the body of the @@ -1713,7 +1679,6 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false bool rewrite_params = true, bool rewrite_buffer_map = true, bool rewrite_allocate_node = true, bool rewrite_indices = true, bool rewrite_let_node = true, - bool rewrite_allocate_const_node = true, bool rewrite_scalar_read_to_vector_shuffle = true) { VectorTypeAccessChecker checker(f->params, f->buffer_map, allow_untyped_pointers, rewrite_scalar_read_to_vector_shuffle); @@ -1721,7 +1686,7 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f, bool allow_untyped_pointers = false VectorTypeRewriter rewriter(checker.info_map_, rewrite_params, rewrite_buffer_map, rewrite_allocate_node, rewrite_indices, rewrite_let_node, - rewrite_allocate_const_node, rewrite_scalar_read_to_vector_shuffle); + rewrite_scalar_read_to_vector_shuffle); PrimFuncNode* n = f.CopyOnWrite(); n->body = rewriter(std::move(n->body)); rewriter.Finalize(&f); @@ -1753,13 +1718,7 @@ Pass StorageRewrite() { n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse, reuse_require_exact_matched_dtype); // Parameters may not be rewritten, but internal allocations may. - // Vectorization of AllocateConst is currently disabled, as it has - // indexing issues for types that include padding (e.g. int8x3 - // padded out to 32 bits) would require either rewriting - // AllocateConst::data, or would require the code generators to - // handle vectorized constants. - return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false, - false); + return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false); }; return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); } diff --git a/tests/python/ir/test_node_reflection.py b/tests/python/ir/test_node_reflection.py index 52b2a29f59..350d59de8c 100644 --- a/tests/python/ir/test_node_reflection.py +++ b/tests/python/ir/test_node_reflection.py @@ -190,19 +190,5 @@ def test_free_var_equal(): tvm.ir.assert_structural_equal(x, z, map_free_vars=True) -def test_alloc_const(): - dev = tvm.cpu(0) - dtype = "float32" - shape = (16,) - buf = tvm.tir.decl_buffer(shape, dtype) - np_data = np.random.rand(*shape).astype(dtype) - data = tvm.runtime.tensor(np_data, device=dev) - body = tvm.tir.Evaluate(0) - alloc_const = tvm.tir.AllocateConst(buf.data, dtype, shape, data, body) - alloc_const2 = tvm.ir.load_json(tvm.ir.save_json(alloc_const)) - tvm.ir.assert_structural_equal(alloc_const, alloc_const2) - np.testing.assert_array_equal(np_data, alloc_const2.data.numpy()) - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_meta_schedule_relax_integration.py b/tests/python/relax/test_meta_schedule_relax_integration.py index 4d27aceed0..3d5a29dce1 100644 --- a/tests/python/relax/test_meta_schedule_relax_integration.py +++ b/tests/python/relax/test_meta_schedule_relax_integration.py @@ -16,15 +16,10 @@ # under the License. """Integration test for MetaSchedule""" -import numpy as np -import pytest -import tempfile import tvm import tvm.testing -from tvm import IRModule from tvm.s_tir import meta_schedule as ms -from tvm import relax, tir -from tvm.ir import transform +from tvm import relax from tvm.script import ir as I from tvm.script import tir as T @@ -51,96 +46,6 @@ class Module0: # fmt: on -# fmt: off [email protected]_module -class Module: - @T.prim_func(private=True) - def conv2d(rxplaceholder: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), DepthwiseConv2d: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")): - T.func_attr({"op_pattern": 4, "tir.noalias": True}) - # with T.sblock("root"): - PaddedInput = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10), T.int64(4)), "int32") - fused_constant = T.allocate_const([-171701247, -1719837685, 1801664104, -634316588, 920159370, -132073802, 2142531563, 1465185701, -1505608067, 1737948828, 1581089391, -1986167320, -1449581822, 35714587, 496324563, -1430879015, -1615680873, 1198514997, 1494683955, 1567376558, 1319924884, -380548171, 296785437, -1546305981, -398644701, -2004794585, -1850413687, 2072643657, 847950121, -544212073, -199532669, -343273682, 953721562, -1930209358, 1573600108, -577689853], "int32", [3, [...] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10), T.int64(4)): - with T.sblock("PaddedInput"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3]) - T.writes(PaddedInput[v_i0, v_i1, v_i2, v_i3]) - PaddedInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2 and v_i2 < T.int64(9), rxplaceholder[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], 0) - for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8), T.int64(4), T.int64(3), T.int64(3)): - with T.sblock("DepthwiseConv2d"): - v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i, j, c, di, dj]) - fused_constant_1 = T.Buffer((3, 3, 4, 1), "int32", data=fused_constant) - T.reads(PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c], fused_constant_1[v_di, v_dj, v_c, T.int64(0)]) - T.writes(DepthwiseConv2d[v_b, v_i, v_j, v_c]) - with T.init(): - DepthwiseConv2d[v_b, v_i, v_j, v_c] = 0 - DepthwiseConv2d[v_b, v_i, v_j, v_c] = DepthwiseConv2d[v_b, v_i, v_j, v_c] + PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c] * fused_constant_1[v_di, v_dj, v_c, T.int64(0)] - - @T.prim_func(private=True) - def conv2d0(rxplaceholder0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), DepthwiseConv2d0: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")): - T.func_attr({"op_pattern": 4, "tir.noalias": True}) - # with T.sblock("root"): - PaddedInput0 = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10), T.int64(4)), "int32") - fused_constant0 = T.allocate_const([2042349344, -2076067063, 1528163722, -1156452837, -2097172051, 1137787079, -601389657, 1907495997, 987801941, 1073738593, -1410339796, -689755358, 90351522, -44886952, -1914103775, -691553659, -1288505112, -1376578817, -2067933148, -1413101824, 1261422027, -156976862, -1185734459, 1608778622, -664209483, 1907479806, 1838595152, 464942526, 877953160, 415131837, -2010736511, 1218242769, -1440127632, 112931, 521745784, -1931145893], "int32", [3, 3 [...] - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10), T.int64(4)): - with T.sblock("PaddedInput"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(rxplaceholder0[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3]) - T.writes(PaddedInput0[v_i0, v_i1, v_i2, v_i3]) - PaddedInput0[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2 and v_i2 < T.int64(9), rxplaceholder0[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], 0) - for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8), T.int64(4), T.int64(3), T.int64(3)): - with T.sblock("DepthwiseConv2d"): - v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i, j, c, di, dj]) - fused_constant0_1 = T.Buffer((3, 3, 4, 1), "int32", data=fused_constant0) - T.reads(PaddedInput0[v_b, v_i + v_di, v_j + v_dj, v_c], fused_constant0_1[v_di, v_dj, v_c, T.int64(0)]) - T.writes(DepthwiseConv2d0[v_b, v_i, v_j, v_c]) - with T.init(): - DepthwiseConv2d0[v_b, v_i, v_j, v_c] = 0 - DepthwiseConv2d0[v_b, v_i, v_j, v_c] = DepthwiseConv2d0[v_b, v_i, v_j, v_c] + PaddedInput0[v_b, v_i + v_di, v_j + v_dj, v_c] * fused_constant0_1[v_di, v_dj, v_c, T.int64(0)] - - @T.prim_func(private=True) - def fused_conv2d_add(data: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32"), T_add: T.Buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32")): - T.func_attr({"tir.noalias": True}) - # with T.sblock("root"): - PaddedInput = T.alloc_buffer((T.int64(1), T.int64(10), T.int64(10), T.int64(4)), "int32") - DepthwiseConv2d = T.alloc_buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(4)), "int32") - fused_nn_conv2d_constant = T.allocate_const([1, 1, 1, 1], "int32", [1, 1, 1, 4]) - fused_constant_2 = T.allocate_const([687940110, -910571705, -901609800, -500525928, 506872399, 1070176297, -305936110, 1625439784, -1565626954, -1705688881, -866370805, -1750740826, 300497007, -626864803, 390295545, 222549121, 319224543, -2003064970, 657992492, 2014175448, 653278589, -768810984, -294555581, -1197167662, 1703154671, -1540759805, -568817430, -1729755444, -275458074, 2078945571, 1683298006, -1029327874, 1315093181, 159010501, 875694807, -223655381], "int32", [3, 3, 4, 1]) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(10), T.int64(10), T.int64(4)): - with T.sblock("PaddedInput"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3]) - T.writes(PaddedInput[v_i0, v_i1, v_i2, v_i3]) - PaddedInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i1 and v_i1 < T.int64(9) and T.int64(1) <= v_i2 and v_i2 < T.int64(9), data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3], 0) - for b, i, j, c, di, dj in T.grid(T.int64(1), T.int64(8), T.int64(8), T.int64(4), T.int64(3), T.int64(3)): - with T.sblock("DepthwiseConv2d"): - v_b, v_i, v_j, v_c, v_di, v_dj = T.axis.remap("SSSSRR", [b, i, j, c, di, dj]) - fused_constant_2_1 = T.Buffer((3, 3, 4, 1), "int32", data=fused_constant_2) - T.reads(PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c], fused_constant_2_1[v_di, v_dj, v_c, T.int64(0)]) - T.writes(DepthwiseConv2d[v_b, v_i, v_j, v_c]) - with T.init(): - DepthwiseConv2d[v_b, v_i, v_j, v_c] = 0 - DepthwiseConv2d[v_b, v_i, v_j, v_c] = DepthwiseConv2d[v_b, v_i, v_j, v_c] + PaddedInput[v_b, v_i + v_di, v_j + v_dj, v_c] * fused_constant_2_1[v_di, v_dj, v_c, T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(8), T.int64(8), T.int64(4)): - with T.sblock("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - fused_nn_conv2d_constant_1 = T.Buffer((1, 1, 1, 4), "int32", data=fused_nn_conv2d_constant) - T.reads(DepthwiseConv2d[v_ax0, v_ax1, v_ax2, v_ax3], fused_nn_conv2d_constant_1[v_ax0, T.int64(0), T.int64(0), v_ax3]) - T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add[v_ax0, v_ax1, v_ax2, v_ax3] = DepthwiseConv2d[v_ax0, v_ax1, v_ax2, v_ax3] + fused_nn_conv2d_constant_1[v_ax0, T.int64(0), T.int64(0), v_ax3] - - @R.function - def main(data: R.Tensor((1, 8, 8, 4), dtype="int32")) -> R.Tensor((1, 8, 8, 4), dtype="int32"): - cls = Module - with R.dataflow(): - lv = R.call_tir(cls.fused_conv2d_add, data, out_sinfo=R.Tensor((1, 8, 8, 4), dtype="int32")) - lv2 = R.call_tir(cls.conv2d, lv, out_sinfo=R.Tensor((1, 8, 8, 4), dtype="int32")) - lv3 = R.call_tir(cls.conv2d0, lv2, out_sinfo=R.Tensor((1, 8, 8, 4), dtype="int32")) - gv: R.Tensor((1, 8, 8, 4), dtype="int32") = lv3 - R.output(gv) - return gv -# fmt: on - def test_extracting_tasks(): target = "llvm -mcpu=core-avx2 -num-cores=1" @@ -166,45 +71,6 @@ def test_extracting_tasks(): ) assert len(extracted_tasks) == count - tir_relax_mod = Module - tir_relax_expectation = {"structural": 3, "ignore-tensor": 2, "anchor-block": 1} - for module_equality, count in tir_relax_expectation.items(): - extracted_tasks = ms.relax_integration.extract_tasks( - tir_relax_mod, - target, - {}, - module_equality=module_equality, - ) - assert len(extracted_tasks) == count - - [email protected]("module_equality", ["structural", "ignore-tensor", "anchor-block"]) -def test_using_anchor_trace(module_equality): - relax_mod = Module - target = "llvm -mcpu=core-avx2 -num-cores=1" - - with tempfile.TemporaryDirectory() as work_dir: - database = ms.relax_integration.tune_relax( - mod=relax_mod, - params={}, - target=target, - work_dir=work_dir, - # for faster tuning - max_trials_global=100, - max_trials_per_task=4, - num_trials_per_iter=4, - strategy="replay-trace", - module_equality=module_equality, - seed=0, - ) - - ms.relax_integration.compile_relax( - database, - mod=relax_mod, - target=target, - params={}, - ) - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py b/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py index cf4efa1964..8b567a8a3f 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py @@ -1420,46 +1420,6 @@ def test_cache_read_fail_invalid_storage_scope(use_block_name): sch.cache_read(block_b, 0, "test_scope") -def test_cache_read_allocate_const(): - @T.prim_func - def before(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")): - B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - B_buf = T.decl_buffer((8), dtype="float32", data=B) - for i in range(8): - with T.sblock("C"): - vi = T.axis.spatial(8, i) - C[vi] = A[vi] + B_buf[vi] - - @T.prim_func - def expected(A: T.Buffer((8), "float32"), C: T.Buffer((8), "float32")): - B_buf_global = T.alloc_buffer((8), dtype="float32") - A_global = T.alloc_buffer((8), dtype="float32") - B = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - B_buf = T.decl_buffer((8), data=B) - for ax0 in range(8): - with T.sblock("A_global"): - v0 = T.axis.spatial(8, ax0) - A_global[v0] = A[v0] - for ax0 in range(8): - with T.sblock("B_buf_global"): - v0 = T.axis.spatial(8, ax0) - B_buf_global[v0] = B_buf[v0] - for i in range(8): - with T.sblock("C"): - vi = T.axis.spatial(8, i) - C[vi] = A_global[vi] + B_buf_global[vi] - - sch = tvm.s_tir.Schedule(before) - block_c = sch.get_sblock("C") - sch.cache_read(block_c, 1, "global") - sch.cache_read(block_c, 0, "global") - - after = sch.mod["main"] - - assert_structural_equal_ignore_global_symbol(expected, after) - verify_trace_roundtrip(sch=sch, mod=before) - - def test_inplace_cache_read(): sch = tvm.s_tir.Schedule(inplace_func, debug_mask="all") block = sch.get_sblock("copy_in") @@ -1624,81 +1584,6 @@ def test_cache_write_fail_invalid_storage_scope(use_block_name): sch.cache_write(block_b, 0, "test_scope") [email protected]("use_decl_buffer", [True, False]) -def test_cache_write_allocate_const(use_decl_buffer): - def apply_decl_buffer(*args, **kwargs): - if use_decl_buffer: - return T.decl_buffer(*args, **kwargs) - else: - return T.Buffer(*args, **kwargs) - - @T.prim_func - def before(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")): - B = T.alloc_buffer([128, 128], dtype="float32") - const1 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - const1_buf = apply_decl_buffer([8], dtype="float32", data=const1) - const2 = T.allocate_const([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - const2_buf = apply_decl_buffer([8], dtype="float32", data=const2) - for i, j in T.grid(128, 128): - for x in range(8): - with T.sblock("B"): - vi, vj, vx = T.axis.remap("SSS", [i, j, x]) - T.reads(A[vi, vj], const1_buf[vx], const2_buf[vx]) - T.writes(B[vi, vj]) - B[vi, vj] = A[vi, vj] * const1_buf[vx] + const2_buf[vx] - for i, j in T.grid(128, 128): - with T.sblock("C"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(B[vi, vj]) - T.writes(C[vi, vj]) - C[vi, vj] = B[vi, vj] + 1.0 - - @T.prim_func - def expected(A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float16")): - B = T.alloc_buffer([128, 128], dtype="float32") - A_global = T.alloc_buffer([128, 128], dtype="float32") - C_global = T.alloc_buffer([128, 128], dtype="float16") - const1 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - const1_buf = apply_decl_buffer([8], dtype="float32", data=const1) - const2 = T.allocate_const([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], "float32", [8]) - const2_buf = apply_decl_buffer([8], dtype="float32", data=const2) - for ax0, ax1 in T.grid(128, 128): - with T.sblock("A_global"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A[v0, v1]) - T.writes(A_global[v0, v1]) - A_global[v0, v1] = A[v0, v1] - for i, j, x in T.grid(128, 128, 8): - with T.sblock("B"): - vi, vj, vx = T.axis.remap("SSS", [i, j, x]) - T.reads(A_global[vi, vj], const1_buf[vx], const2_buf[vx]) - T.writes(B[vi, vj]) - B[vi, vj] = A_global[vi, vj] * const1_buf[vx] + const2_buf[vx] - for i, j in T.grid(128, 128): - with T.sblock("C"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(B[vi, vj]) - T.writes(C_global[vi, vj]) - C_global[vi, vj] = B[vi, vj] + T.float32(1) - for ax0, ax1 in T.grid(128, 128): - with T.sblock("C_global"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(C_global[v0, v1]) - T.writes(C[v0, v1]) - C[v0, v1] = C_global[v0, v1] - - sch = tvm.s_tir.Schedule(before) - block_b = sch.get_sblock("B") - block_c = sch.get_sblock("C") - sch.cache_read(block_b, 0, "global") - sch.cache_write(block_c, 0, "global") - - after = sch.mod["main"] - - assert_structural_equal_ignore_global_symbol(expected, after) - verify_trace_roundtrip(sch=sch, mod=before) - - def test_reindex_cache_read(): sch = tvm.s_tir.Schedule(elementwise, debug_mask="all") sch.reindex_cache_read("C", 0, "shared", lambda i, j: (j, i // 2, i % 2)) diff --git a/tests/python/s_tir/schedule/test_tir_schedule_compute_at.py b/tests/python/s_tir/schedule/test_tir_schedule_compute_at.py index 3b8a13fedf..b67fc41b54 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_compute_at.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_compute_at.py @@ -1745,105 +1745,6 @@ def test_reverse_compute_at_layout_trans(): verify_trace_roundtrip(sch=sch, mod=before) [email protected]("use_decl_buffer", [True, False]) [email protected]("use_reverse_compute_at", [True, False]) -def test_compute_at_allocate_const(use_decl_buffer, use_reverse_compute_at): - def apply_decl_buffer(*args, **kwargs): - if use_decl_buffer: - return T.decl_buffer(*args, **kwargs) - else: - return T.Buffer(*args, **kwargs) - - @T.prim_func - def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): - B = T.alloc_buffer([4]) - - offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) - offset = apply_decl_buffer([4], data=offset_ptr) - for i in range(4): - with T.sblock("compute_B"): - vi = T.axis.remap("S", [i]) - B[vi] = 10.0 * vi + offset[vi] - - for i, j in T.grid(4, 256): - with T.sblock("compute_C"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = B[vi] + 100.0 * vj - - @T.prim_func - def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): - B = T.alloc_buffer([4]) - - offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) - offset = apply_decl_buffer([4], data=offset_ptr) - for i in range(4): - with T.sblock("compute_B"): - vi = T.axis.remap("S", [i]) - B[vi] = 10.0 * vi + offset[vi] - - for j in range(256): - with T.sblock("compute_C"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = B[vi] + 100.0 * vj - - sch = tvm.s_tir.Schedule(before, debug_mask="all") - if use_reverse_compute_at: - block = sch.get_sblock("compute_C") - axis = sch.get_loops("compute_B")[0] - sch.reverse_compute_at(block, axis) - else: - block = sch.get_sblock("compute_B") - axis = sch.get_loops("compute_C")[0] - sch.compute_at(block, axis) - - after = sch.mod["main"] - - assert_structural_equal_ignore_global_symbol(expected, after) - verify_trace_roundtrip(sch=sch, mod=before) - - [email protected]("use_decl_buffer", [True, False]) -def test_compute_inline_allocate_const(use_decl_buffer): - def apply_decl_buffer(*args, **kwargs): - if use_decl_buffer: - return T.decl_buffer(*args, **kwargs) - else: - return T.Buffer(*args, **kwargs) - - @T.prim_func - def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): - B = T.alloc_buffer([4]) - - offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) - offset = apply_decl_buffer([4], data=offset_ptr) - for i in range(4): - with T.sblock("compute_B"): - vi = T.axis.remap("S", [i]) - B[vi] = 10.0 * vi + offset[vi] - - for i, j in T.grid(4, 256): - with T.sblock("compute_C"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = B[vi] + 100.0 * vj - - @T.prim_func - def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")): - offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4]) - offset = apply_decl_buffer([4], data=offset_ptr) - for i, j in T.grid(4, 256): - with T.sblock("compute_C"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = (10.0 * vi + offset[vi]) + 100.0 * vj - - sch = tvm.s_tir.Schedule(before, debug_mask="all") - block = sch.get_sblock("compute_B") - sch.compute_inline(block) - after = sch.mod["main"] - - assert_structural_equal_ignore_global_symbol(expected, after) - verify_trace_roundtrip(sch=sch, mod=before) - - def test_shape_var_as_bound(): # fmt: off @T.prim_func diff --git a/tests/python/tir-transform/test_tir_transform_extract_constants.py b/tests/python/tir-transform/test_tir_transform_extract_constants.py deleted file mode 100644 index cbfb6d39bc..0000000000 --- a/tests/python/tir-transform/test_tir_transform_extract_constants.py +++ /dev/null @@ -1,68 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import numpy as np -import tvm -from tvm import tir -from tvm.script import tir as T -import tvm.testing - - [email protected]_module -class Module4: - @T.prim_func - def constant1(a: T.handle) -> None: - A = T.match_buffer(a, (10), "int32") - B = T.alloc_buffer((10), "int32") - K_data = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) - K = T.Buffer(shape=(10), dtype="int32", data=K_data) - for x in T.serial(0, 10): - B[x] = A[x] + K[x] - - @T.prim_func - def constant2(a: T.handle) -> None: - A = T.match_buffer(a, (10), "int32") - B = T.alloc_buffer((10), "int32") - K_data = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) - K = T.Buffer(shape=(10), dtype="int32", data=K_data) - for x in T.serial(0, 10): - B[x] = A[x] + K[x] - - @T.prim_func - def constant3(a: T.handle) -> None: - A = T.match_buffer(a, (10), "int32") - B = T.alloc_buffer((10), "int32") - K_data = T.allocate_const([1, 2, 3, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) - K = T.Buffer(shape=(10), dtype="int32", data=K_data) - for x in T.serial(0, 10): - B[x] = A[x] + K[x] - - -def test_const_extraction(): - mod = tvm.tir.transform.ExtractPrimFuncConstants()(Module4) - constants = mod.attrs["constants"] - assert len(constants) == 2 - - def _visit(stmt): - if isinstance(stmt, tvm.tir.AllocateConst): - assert np.array_equal(stmt.data.numpy(), constants[int(stmt.irmod_storage_idx)].numpy()) - - for n, f in mod.functions.items(): - tvm.tir.stmt_functor.post_order_visit(f.body, _visit) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_complete.py b/tests/python/tvmscript/test_tvmscript_complete.py index 353bd547e8..2617d26d86 100644 --- a/tests/python/tvmscript/test_tvmscript_complete.py +++ b/tests/python/tvmscript/test_tvmscript_complete.py @@ -336,31 +336,5 @@ def test_complete_alloc_buffer(): ) -def test_access_region_for_decl_buffer(): - @T.prim_func(private=True) - def automatic_access_regions(A: T.Buffer(4, "int32"), C: T.Buffer(4, "int32")): - B_data = T.allocate_const([1, 2, 3, 4], "int32", extents=[4]) - B = T.decl_buffer(4, "int32", data=B_data) - - for i in range(4): - with T.sblock("compute"): - vi = T.axis.remap("S", [i]) - C[vi] = A[vi] + B[vi] - - @T.prim_func(private=True) - def explicit_access_regions(A: T.Buffer(4, "int32"), C: T.Buffer(4, "int32")): - B_data = T.allocate_const([1, 2, 3, 4], "int32", extents=[4]) - B = T.decl_buffer(4, "int32", data=B_data) - - for i in range(4): - with T.sblock("compute"): - vi = T.axis.remap("S", [i]) - T.reads(A[vi], B[vi]) - T.writes(C[vi]) - C[vi] = A[vi] + B[vi] - - tvm.ir.assert_structural_equal(explicit_access_regions, automatic_access_regions) - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index 834078c3d9..9f98e10e7b 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -373,30 +373,6 @@ def test_ir_builder_tir_allocate(): assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) -def test_ir_builder_tir_allocate_const(): - data = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] - with IRBuilder() as ib: - with T.allocate_const(data, "int32", [10]): - T.evaluate(1) - - # the allocate const generated by IRBuilder - ir_actual = ib.get() - - # the expected allocate const - buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("int32"))) - ir_expected = tir.AllocateConst( - buffer_var, - "int32", - [10], - tvm.runtime.tensor(np.asarray(data, "int32")), - tir.Evaluate(1), - annotations={}, - ) - - # Check if the generated ir is expected - assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) - - def test_ir_builder_tir_while(): with IRBuilder() as ib: with T.While(T.int32() > 0): diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 3819e19edd..1ea5bcb24c 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -2769,73 +2769,6 @@ def test_opaque_block(): assert len(root_block.body.body[1].block.iter_vars) == 0 -def module_const(): - @tvm.script.ir_module - class Module4: - # There is an ongoing (python)dict->(c++)Map->(python)dict issue which potentially - # changes order of the items in dict after roundtrip due to map not support order - # of insertion while dict does. Hence func 'def A(a: T.handle, c: T.handle) -> None' - # is commented - # - # test: - # d = {"B": 1, "A": 2} - # m = tvm.runtime.convert(d) - # assert d.keys() == m.keys(), f"Order changed from {list(d.keys())} to {list(m.keys())}" - - """ - @T.prim_func - def A(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (10), "int32") - C = T.match_buffer(c, (10), "int32") - B = T.alloc_buffer((10), "int32") - - K1 = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) - for x in T.serial(0, 10): - B[x] = A[x] + T.load("int32", K1, x) - - for x in T.serial(0, 10): - C[x] = B[x] - """ - - @T.prim_func - def B(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (10), "int32") - C = T.match_buffer(c, (10), "int32") - B = T.alloc_buffer((10), "int32") - - K1_data = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) - K1 = T.Buffer(shape=[10], dtype="int32", data=K1_data) - for x in T.serial(0, 10): - B[x] = A[x] + K1[x] - - K2_data = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) - K2 = T.Buffer(shape=[10], dtype="int32", data=K2_data) - for x in T.serial(0, 10): - B[x] = B[x] + K2[x] - - for x in T.serial(0, 10): - C[x] = B[x] - - return Module4 - - -def constant(): - @T.prim_func - def constant(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (10), "int32") - C = T.match_buffer(c, (10), "int32") - B = T.alloc_buffer((10), "int32") - K_data = T.allocate_const([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "int32", [10]) - K = T.Buffer(shape=[10], dtype="int32", data=K_data) - for x in T.serial(0, 10): - B[x] = A[x] + K[x] - - for x in T.serial(0, 10): - C[x] = B[x] - - return constant - - def rank0(): @T.prim_func def rank0(a: T.handle) -> None: @@ -4180,8 +4113,6 @@ ir_generator = tvm.testing.parameter( opt_conv_tensorcore_mod_host, vthread_func, matmul, - module_const, - constant, rank0, rank0_block, select,
