This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 80250411e7 [Relax][MetaSchedule] Support CPU weight prepack (#17445)
80250411e7 is described below
commit 80250411e706509fef499e0defe0e625bf6fab28
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Oct 17 04:36:41 2024 +0800
[Relax][MetaSchedule] Support CPU weight prepack (#17445)
This PR adds support for CPU weight prepacking. To be specific, this PR
adds a new pass `AttachAttrLayoutFreeBuffers` to attach layout free buffers
to the weight parameters, so that we can leverage MetaSchedule to optimize
the prepacking process.
After the pass and tuning, we introduce a new pass
`SplitLayoutRewritePreproc`
to split the layout rewrite pass into multiple functions, so that we can
lift
the parameters transform pass function with existing pass.
---
include/tvm/relax/transform.h | 21 ++
python/tvm/relax/frontend/nn/__init__.py | 2 +
python/tvm/relax/pipeline.py | 50 +++-
python/tvm/relax/transform/__init__.py | 2 +
python/tvm/relax/transform/transform.py | 29 ++
src/meta_schedule/postproc/rewrite_layout.cc | 8 +-
.../transform/attach_attr_layout_free_buffers.cc | 113 +++++++
.../transform/split_layout_rewrite_preproc.cc | 327 +++++++++++++++++++++
.../test_meta_schedule_postproc_rewrite_layout.py | 3 +-
...st_transform_attach_attr_layout_free_buffers.py | 311 ++++++++++++++++++++
.../test_transform_split_layout_rewrite_preproc.py | 220 ++++++++++++++
11 files changed, 1083 insertions(+), 3 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 5a7b85ac13..eaad44a93a 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -253,6 +253,27 @@ TVM_DLL Pass LegalizeOps(Optional<Map<String, PackedFunc>>
cmap, bool enable_war
*/
TVM_DLL Pass RealizeVDevice();
+/*!
+ * \brief Attach layout free buffers to the tir::PrimFunc.
+ *
+ * This pass is used to attach layout free buffers to the tir::PrimFunc
according to
+ * the function usage in the relax function. Currently, the layout free
buffers are the model
+ * weights and relax constants.
+ *
+ * \note We recommend applying CanonicalizeBindings before this pass.
+ * \return The Pass.
+ */
+TVM_DLL Pass AttachAttrLayoutFreeBuffers();
+
+/*!
+ * \brief Split the layout rewrite preproc block to a separate tir::PrimFunc.
+ *
+ * This pass is used in the prepack weight after meta_schedule tuning.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass SplitLayoutRewritePreproc();
+
/*!
* \brief Lift transformation of the parameters of a function.
*
diff --git a/python/tvm/relax/frontend/nn/__init__.py
b/python/tvm/relax/frontend/nn/__init__.py
index a8200d8dd6..f490af7062 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -23,6 +23,8 @@ from .extern import ExternModule, ObjectModule, SourceModule
from .modules import (
GELU,
Conv1D,
+ Conv2D,
+ Conv3D,
ConvTranspose1D,
Embedding,
GroupNorm,
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
index 582f5111aa..fe3dbc99fc 100644
--- a/python/tvm/relax/pipeline.py
+++ b/python/tvm/relax/pipeline.py
@@ -109,6 +109,7 @@ def static_shape_tuning_pipeline(
total_trials: int,
target: Union[str, tvm.target.Target],
work_dir: str = "tuning_logs",
+ cpu_weight_prepack: bool = False,
):
"""Tune the static shape model and store the log to database.
@@ -122,18 +123,65 @@ def static_shape_tuning_pipeline(
work_dir : str
The directory to store the tuning logs.
+
+ cpu_weight_prepack : bool
+ Whether to enable the cpu weight prepack feature.
+
+ Note
+ ----
+ `cpu_weight_prepack` is expected to be `True` when running on CPU for
+ better performance. However, it requires an explicit layout transformation
+ step by calling the corresponding vm function, which changes the interface
+ of deployment. So we disable it by default. Here is an example to enable
it:
+
+ .. code-block:: python
+
+ mod = relax.pipeline.static_shape_tuning_pipeline(
+ total_trials=1000,
+ target="llvm -num-cores 16",
+ work_dir="tuning_logs",
+ cpu_weight_prepack=True,
+ )(mod)
+
+ ex = relax.build(mod, target=target)
+ vm = relax.VirtualMachine(ex, device=tvm.cpu())
+
+ # Transform the params using the vm function
+ # the name should be f"{func_name}_transform_params"
+ params = vm["main_transform_params"](params["main"])
+
+ input_data = tvm.nd.array(np.random.randn(1, 3, 224,
224).astype("float32"))
+ out = vm["main"](input_data, *params).numpy()
"""
@tvm.transform.module_pass(opt_level=0)
def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) ->
tvm.ir.IRModule:
+ if cpu_weight_prepack:
+ pre_tuning_layout_rewrite =
[transform.AttachAttrLayoutFreeBuffers()]
+ post_tuning_layout_rewrite = [
+ transform.SplitLayoutRewritePreproc(),
+ transform.LiftTransformParams(),
+ transform.FoldConstant(),
+ ]
+ else:
+ pre_tuning_layout_rewrite = []
+ post_tuning_layout_rewrite = []
+
with tvm.target.Target(target):
mod = tvm.transform.Sequential(
[
transform.DecomposeOpsForInference(),
transform.CanonicalizeBindings(),
zero_pipeline(),
- transform.MetaScheduleTuneIRMod({}, work_dir,
total_trials),
+ *pre_tuning_layout_rewrite,
+ # Skip tuning if total_trials is 0
+ (
+ transform.MetaScheduleTuneIRMod({}, work_dir,
total_trials)
+ if total_trials > 0
+ else tvm.transform.Sequential([])
+ ),
transform.MetaScheduleApplyDatabase(work_dir),
+ *post_tuning_layout_rewrite,
]
)(mod)
diff --git a/python/tvm/relax/transform/__init__.py
b/python/tvm/relax/transform/__init__.py
index 1ce864651c..16e4800ca3 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -21,6 +21,7 @@ from .transform import (
AllocateWorkspace,
AlterOpImpl,
AnnotateTIROpPattern,
+ AttachAttrLayoutFreeBuffers,
AttachGlobalSymbol,
BindParams,
BindSymbolicVars,
@@ -73,6 +74,7 @@ from .transform import (
RewriteDataflowReshape,
RunCodegen,
SplitCallTIRByPattern,
+ SplitLayoutRewritePreproc,
StaticPlanBlockMemory,
ToMixedPrecision,
ToNonDataflow,
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 3330d40987..603211b59e 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -970,6 +970,35 @@ def MergeCompositeFunctions() -> tvm.ir.transform.Pass:
return _ffi_api.MergeCompositeFunctions() # type: ignore
+def AttachAttrLayoutFreeBuffers() -> tvm.ir.transform.Pass:
+ """Attach layout free buffers to the tir::PrimFunc.
+
+ This pass is used to attach layout free buffers to the tir::PrimFunc
according to
+ the function usage in the relax function. Currently, the layout free
buffers are the model
+ weights and relax constants.
+
+ Note that we recommend applying CanonicalizeBindings before this pass.
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass for attaching layout free buffers.
+ """
+ return _ffi_api.AttachAttrLayoutFreeBuffers() # type: ignore
+
+
+def SplitLayoutRewritePreproc() -> tvm.ir.transform.Pass:
+ """Split the TIR layout rewrite into multiple TIR functions.
+ This pass is used in the prepack weight after meta_schedule tuning.
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass for splitting TIR layout rewrite.
+ """
+ return _ffi_api.SplitLayoutRewritePreproc() # type: ignore
+
+
def LiftTransformParams(shared_transform: Union[bool, List[str]] = False) ->
tvm.ir.transform.Pass:
"""Lift transformation of the parameters of a function.
diff --git a/src/meta_schedule/postproc/rewrite_layout.cc
b/src/meta_schedule/postproc/rewrite_layout.cc
index 71ae433871..87fa96f67c 100644
--- a/src/meta_schedule/postproc/rewrite_layout.cc
+++ b/src/meta_schedule/postproc/rewrite_layout.cc
@@ -249,7 +249,13 @@ class RewriteLayoutNode : public PostprocNode {
void InitializeWithTuneContext(const TuneContext& context) final {}
// Inherited from PostprocNode
- bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch);
}
+ bool Apply(const tir::Schedule& sch) final {
+ try {
+ return tir::RewriteLayout(sch);
+ } catch (const std::runtime_error& e) {
+ return false;
+ }
+ }
Postproc Clone() const {
ObjectPtr<RewriteLayoutNode> n = make_object<RewriteLayoutNode>(*this);
diff --git a/src/relax/transform/attach_attr_layout_free_buffers.cc
b/src/relax/transform/attach_attr_layout_free_buffers.cc
new file mode 100644
index 0000000000..64062e2243
--- /dev/null
+++ b/src/relax/transform/attach_attr_layout_free_buffers.cc
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/attach_attr_layout_free_buffers.cc
+ * \brief Attach layout_free_buffers for layout-free buffers.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace relax {
+
+class AttrAttacher : public ExprMutator {
+ public:
+ static IRModule Transform(const IRModule& mod) {
+ AttrAttacher mutator(mod);
+ for (auto [gvar, func] : mod->functions) {
+ if (func->IsInstance<relax::FunctionNode>()) {
+ // clear the layout_free_exprs_ for each function
+ mutator.layout_free_exprs_.clear();
+ mutator.builder_->UpdateFunction(gvar,
Downcast<BaseFunc>(mutator.VisitExpr(func)));
+ }
+ }
+ return mutator.builder_->GetContextIRModule();
+ }
+
+ private:
+ explicit AttrAttacher(IRModule mod) : ExprMutator(mod), mod_(mod) {}
+
+ using ExprMutator::VisitExpr_;
+ Expr VisitExpr_(const FunctionNode* op) final {
+ if (auto opt_num_input = op->attrs.GetAttr<Integer>(attr::kNumInput)) {
+ ICHECK(layout_free_exprs_.empty()) << "meet a non-global function with
num_input attr";
+ size_t num_input = opt_num_input.value()->value;
+ for (size_t i = num_input; i < op->params.size(); i++) {
+ layout_free_exprs_.insert(op->params[i].get());
+ }
+ }
+ return ExprMutator::VisitExpr_(op);
+ }
+
+ Expr VisitExpr_(const ConstantNode* op) final {
+ layout_free_exprs_.insert(op);
+ return ExprMutator::VisitExpr_(op);
+ }
+
+ Expr VisitExpr_(const CallNode* op) final {
+ static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+ Call call = Downcast<Call>(ExprMutator::VisitExpr_(op));
+ if (call->op != call_tir_op_) {
+ return call;
+ }
+ GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+ Array<Expr> call_tir_args = Downcast<Tuple>(call->args[1])->fields;
+ // Compute the layout free buffers
+ Array<Integer> layout_free_buffers;
+ for (size_t i = 0; i < call_tir_args.size(); i++) {
+ if (layout_free_exprs_.count(call_tir_args[i].get())) {
+ layout_free_buffers.push_back(Integer(i));
+ }
+ }
+ // Attach the layout free buffers to the tir::PrimFunc
+ tir::PrimFunc func = WithAttr(Downcast<tir::PrimFunc>(mod_->Lookup(gv)),
"layout_free_buffers",
+ layout_free_buffers);
+ // Renew defs
+ func = tir::RenewDefs(func);
+ // Add the updated tir::PrimFunc in the IRModule
+ // Note the blockbuilder would automatically combine the same tir function
+ // So we don't need to worry about the duplicate insertion
+ GlobalVar new_gv = builder_->AddFunction(func, gv->name_hint);
+ // Create a new call node with the updated tir::PrimFunc
+ auto n = make_object<CallNode>(*op);
+ n->args = {new_gv, Tuple(call_tir_args)};
+ return Call(n);
+ }
+
+ private:
+ IRModule mod_;
+ std::unordered_set<const ExprNode*> layout_free_exprs_;
+};
+namespace transform {
+
+Pass AttachAttrLayoutFreeBuffers() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule mod, PassContext pc) { return AttrAttacher::Transform(mod);
};
+ auto pass = CreateModulePass(pass_func, 0, "_AttachAttrLayoutFreeBuffers",
{});
+ // Apply DeadCodeElimination to remove unused tir::PrimFunc
+ return tvm::transform::Sequential({pass, DeadCodeElimination()},
"AttachAttrLayoutFreeBuffers");
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.AttachAttrLayoutFreeBuffers")
+ .set_body_typed(AttachAttrLayoutFreeBuffers);
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc
b/src/relax/transform/split_layout_rewrite_preproc.cc
new file mode 100644
index 0000000000..5fee946c26
--- /dev/null
+++ b/src/relax/transform/split_layout_rewrite_preproc.cc
@@ -0,0 +1,327 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/split_tir_layout_rewrite.cc
+ * \brief Use for rewriting the TIRs after meta_schedule layout rewrite post
process.
+ */
+#include <tvm/ir/transform.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <cstddef>
+
+namespace tvm {
+namespace tir {
+class SplitPrimFuncLayoutRewrite : public StmtMutator {
+ public:
+ explicit SplitPrimFuncLayoutRewrite(const PrimFunc& func) :
original_func_(func) {}
+ std::tuple<Optional<PrimFunc>, PrimFunc> Transform(const PrimFunc& func) {
+ ICHECK(func->body.as<BlockRealizeNode>()) << "The body of the primfunc
should be a root block.";
+ const auto& block = func->body.as<BlockRealizeNode>()->block;
+ visit_root_block(block.get());
+ if (layout_rewrite_preproc_stmts_.size() > 0) {
+ return std::make_tuple(create_layout_rewrite_preproc_func(),
create_compute_func());
+ } else {
+ return std::make_tuple(NullOpt, func);
+ }
+ }
+
+ private:
+ void sort_rewrite_infos() {
+ std::sort(
+ rewrite_infos_.begin(), rewrite_infos_.end(),
+ [](const RewriteInfo& a, const RewriteInfo& b) { return a.buffer_index
< b.buffer_index; });
+ }
+
+ PrimFunc create_layout_rewrite_preproc_func() const {
+ // Step 1: Check the number of pre_rewrite_buffers and post_rewrite_buffers
+ ICHECK(rewrite_infos_.size() > 0) << "There should be at least one buffer
rewrite.";
+
+ // Step 2: Create the params for the new PrimFunc
+ Array<Var> params;
+ Map<Var, Buffer> buffer_map;
+
+ for (const auto& info : rewrite_infos_) {
+ params.push_back(Var(info.pre_rewrite_buffer->name, DataType::Handle()));
+ buffer_map.Set(params.back(), info.pre_rewrite_buffer);
+ }
+ for (const auto& info : rewrite_infos_) {
+ params.push_back(Var(info.post_rewrite_buffer->name,
DataType::Handle()));
+ buffer_map.Set(params.back(), info.post_rewrite_buffer);
+ }
+
+ // Step 3: Create the body for the new PrimFunc
+ ICHECK(layout_rewrite_preproc_stmts_.size() > 0)
+ << "There should be at least one layout rewrite preproc stmt.";
+ Stmt body = layout_rewrite_preproc_stmts_.size() == 1 ?
layout_rewrite_preproc_stmts_[0]
+ :
SeqStmt(layout_rewrite_preproc_stmts_);
+ body = BlockRealize(
+ /*iter_values=*/Array<PrimExpr>(),
+ /*predicate=*/const_true(),
+ /*block=*/
+ Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
+ /*name_hint=*/"root", body));
+
+ PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map);
+
+ return RenewDefs(func);
+ }
+
+ PrimFunc create_compute_func() const {
+ // Step 1: Create the params for the new PrimFunc
+ Array<Var> params = original_func_->params;
+ Map<Var, Buffer> buffer_map = original_func_->buffer_map;
+ for (const auto& info : rewrite_infos_) {
+ const Var& param = params[info.buffer_index];
+ ICHECK(buffer_map[param] == info.pre_rewrite_buffer);
+ buffer_map.Set(param, info.post_rewrite_buffer);
+ }
+
+ // Step 2: Create the body for the new PrimFunc
+ Stmt body = compute_stmts_.size() == 1 ? compute_stmts_[0] :
SeqStmt(compute_stmts_);
+ Block original_block = original_func_->body.as<BlockRealizeNode>()->block;
+ Array<Buffer> alloc_buffers;
+ for (const auto& buffer : original_block->alloc_buffers) {
+ auto it =
+ std::find_if(rewrite_infos_.begin(), rewrite_infos_.end(),
+ [&](const RewriteInfo& info) { return
info.post_rewrite_buffer == buffer; });
+ if (it == rewrite_infos_.end()) {
+ alloc_buffers.push_back(buffer);
+ }
+ }
+
+ body = BlockRealize(
+ /*iter_values=*/Array<PrimExpr>(),
+ /*predicate=*/const_true(),
+ /*block=*/
+ Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
+ /*name_hint=*/"root", body,
+ /*init=*/NullOpt,
+ /*alloc_buffers=*/alloc_buffers));
+
+ PrimFunc func = PrimFunc(original_func_->params, body, VoidType(),
buffer_map);
+ return RenewDefs(func);
+ }
+
+ void visit_root_block(const BlockNode* op) {
+ Stmt body = op->body;
+ if (const auto* seq_stmt = body.as<SeqStmtNode>()) {
+ for (const auto& stmt : seq_stmt->seq) {
+ current_subtree_ = 0;
+ Stmt new_stmt = this->VisitStmt(stmt);
+ ICHECK(current_subtree_ != 0) << "There should be at least a block in
the subtree.";
+ if (current_subtree_ == 1) {
+ layout_rewrite_preproc_stmts_.push_back(new_stmt);
+ } else {
+ compute_stmts_.push_back(new_stmt);
+ }
+ }
+ } else {
+ current_subtree_ = 0;
+ this->VisitStmt(body);
+ ICHECK(current_subtree_ == -1)
+ << "There should be a compute block if there is only one subtree
under the root.";
+ }
+ }
+ Stmt VisitStmt_(const BlockNode* op) final {
+ Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
+ auto it = op->annotations.find(attr::meta_schedule_layout_rewrite_preproc);
+ bool is_layout_rewrite_preproc =
+ it != op->annotations.end() &&
is_one(Downcast<PrimExpr>((*it).second));
+
+ if (current_subtree_ == 0) {
+ current_subtree_ = is_layout_rewrite_preproc ? 1 : -1;
+ } else if (current_subtree_ == 1) {
+ CHECK(is_layout_rewrite_preproc)
+ << "There is a layout rewrite block in the subtree, but meet a
non-layout rewrite block.";
+ } else {
+ CHECK(!is_layout_rewrite_preproc)
+ << "There is a non-layout rewrite block in the subtree, but meet a
layout rewrite block.";
+ }
+
+ if (is_layout_rewrite_preproc) {
+ ICHECK(op->reads.size() == 1) << "There should be only one read buffer
in the layout rewrite";
+ ICHECK(op->writes.size() == 1)
+ << "There should be only one write buffer in the layout rewrite";
+ ICHECK(op->alloc_buffers.empty()) << "There should be no alloc buffer in
the layout rewrite";
+ ICHECK(op->match_buffers.empty()) << "There should be no match buffer in
the layout rewrite";
+ const Buffer& preproc_buffer = op->reads[0]->buffer;
+ int buffer_index = -1;
+ for (size_t i = 0; i < original_func_->params.size(); ++i) {
+ const Buffer& buffer =
original_func_->buffer_map[original_func_->params[i]];
+ if (buffer == preproc_buffer) {
+ buffer_index = i;
+ break;
+ }
+ }
+ ICHECK(buffer_index != -1) << "The preproc buffer is not found in the
original primfunc.";
+ rewrite_infos_.push_back(
+ RewriteInfo{buffer_index, op->reads[0]->buffer,
op->writes[0]->buffer});
+
+ auto new_annotations = op->annotations;
+ new_annotations.erase(attr::meta_schedule_layout_rewrite_preproc);
+ auto n = make_object<BlockNode>(*block.get());
+ n->annotations = new_annotations;
+ return Block(n);
+ }
+ return block;
+ }
+
+ public:
+ struct RewriteInfo {
+ int buffer_index;
+ Buffer pre_rewrite_buffer;
+ Buffer post_rewrite_buffer;
+ };
+ std::vector<RewriteInfo> rewrite_infos_;
+
+ private:
+ /*! \brief The stmts that are used for layout rewrite preproc*/
+ Array<Stmt> layout_rewrite_preproc_stmts_;
+ /*! \brief The stmts that are other than layout rewrite preproc*/
+ Array<Stmt> compute_stmts_;
+ /*!
+ \brief Whether the current subtree is a layout rewrite preproc subtree.
+ -1: visited a non-layout rewrite preproc block
+ 0: unsure, not visited any block
+ 1: visited a layout rewrite preproc block
+ */
+ int current_subtree_;
+ /*! \brief The original primfunc*/
+ PrimFunc original_func_;
+};
+} // namespace tir
+
+namespace relax {
+class SplitLayoutRewritePreproc : public ExprMutator {
+ public:
+ static IRModule Transform(const IRModule& mod) {
+ SplitLayoutRewritePreproc mutator(mod);
+
+ // Step 1: Split the primfunc into preproc and compute
+ for (auto [gv, func] : mod->functions) {
+ if (func->IsInstance<tir::PrimFuncNode>()) {
+ tir::SplitPrimFuncLayoutRewrite
tir_rewriter(Downcast<tir::PrimFunc>(func));
+ auto [preproc_func, compute_func] =
tir_rewriter.Transform(Downcast<tir::PrimFunc>(func));
+ if (preproc_func.defined()) {
+ mutator.split_funcs_.emplace(gv.get(),
+ std::make_tuple(preproc_func.value(),
compute_func));
+ mutator.rewrite_infos_.emplace(gv.get(),
tir_rewriter.rewrite_infos_);
+ }
+ }
+ }
+
+ for (auto [gv, func] : mod->functions) {
+ if (func->IsInstance<relax::FunctionNode>()) {
+ auto relax_func = Downcast<relax::Function>(func);
+ mutator.builder_->UpdateFunction(gv,
Downcast<relax::Function>(mutator(relax_func)));
+ }
+ }
+ return mutator.builder_->GetContextIRModule();
+ }
+
+ private:
+ explicit SplitLayoutRewritePreproc(const IRModule& mod) : ExprMutator(mod) {}
+ using ExprMutator::VisitExpr_;
+
+ Expr VisitExpr_(const CallNode* op) final {
+ static const Op& call_tir_op = Op::Get("relax.call_tir");
+ Call call = Downcast<Call>(ExprMutator::VisitExpr_(op));
+
+ // Step 1: Skip call to other than `tir.call_tir`
+ if (!call->op.same_as(call_tir_op)) {
+ return call;
+ }
+
+ // Step 2: Skip if there is no preproc stage
+ const GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+ auto it = split_funcs_.find(gv.get());
+ if (it == split_funcs_.end()) {
+ return call;
+ }
+
+ // Step 3: Get the preproc and compute functions and update the module
+ const auto& [preproc_func, compute_func] = it->second;
+ GlobalVar preproc_gv = builder_->AddFunction(preproc_func, gv->name_hint +
"_weight_prepack");
+ GlobalVar compute_gv = builder_->AddFunction(compute_func, gv->name_hint +
"_prepacked");
+ // Step 4. Get rewrite infos
+ auto rewrite_infos_it = rewrite_infos_.find(gv.get());
+ ICHECK(rewrite_infos_it != rewrite_infos_.end())
+ << "Rewrite infos are not found for " << gv->name_hint;
+ const auto& rewrite_infos = rewrite_infos_it->second;
+
+ // Step 5: Emit the preproc call
+ Array<Expr> call_tir_args = Downcast<Tuple>(call->args[1])->fields;
+ Array<Expr> preproc_args;
+ Array<StructInfo> preproc_sinfo_list;
+ for (const auto& info : rewrite_infos) {
+ preproc_args.push_back(call_tir_args[info.buffer_index]);
+ tir::Buffer rewritten_buffer = info.post_rewrite_buffer;
+ for (const auto& shape_expr : rewritten_buffer->shape) {
+ CHECK(shape_expr.as<tir::IntImmNode>()) << "Currently does not support
rewrite buffer with "
+ "dynamic shape.";
+ }
+ preproc_sinfo_list.push_back(
+ TensorStructInfo(ShapeExpr(rewritten_buffer->shape),
rewritten_buffer->dtype));
+ }
+ StructInfo preproc_sinfo = preproc_sinfo_list.size() > 1 //
+ ? TupleStructInfo(preproc_sinfo_list) //
+ : preproc_sinfo_list[0];
+
+ // Step 6: Call the preproc function
+ Expr preproc_call =
+ builder_->Emit(Call(call_tir_op, {preproc_gv, Tuple(preproc_args)},
{}, {preproc_sinfo}));
+ if (rewrite_infos.size() == 1) {
+ call_tir_args.Set(rewrite_infos[0].buffer_index, preproc_call);
+ } else {
+ for (size_t i = 0; i < rewrite_infos.size(); ++i) {
+ call_tir_args.Set(rewrite_infos[i].buffer_index,
TupleGetItem(preproc_call, i));
+ }
+ }
+ Expr main_call =
+ builder_->Emit(Call(call_tir_op, {compute_gv, Tuple(call_tir_args)},
{}, call->sinfo_args));
+
+ return main_call;
+ }
+
+ private:
+ std::unordered_map<const GlobalVarNode*, std::tuple<tir::PrimFunc,
tir::PrimFunc>> split_funcs_;
+ std::unordered_map<const GlobalVarNode*,
+ std::vector<tir::SplitPrimFuncLayoutRewrite::RewriteInfo>>
+ rewrite_infos_;
+};
+
+} // namespace relax
+
+namespace transform {
+Pass SplitLayoutRewritePreproc() {
+ auto pass_func = [](IRModule mod, PassContext pc) {
+ return relax::SplitLayoutRewritePreproc::Transform(mod);
+ };
+ auto pass = CreateModulePass(pass_func, 0, "SplitLayoutRewritePreproc", {});
+ return tvm::transform::Sequential({pass,
relax::transform::DeadCodeElimination()},
+ "SplitLayoutRewritePreproc");
+}
+TVM_REGISTER_GLOBAL("relax.transform.SplitLayoutRewritePreproc")
+ .set_body_typed(SplitLayoutRewritePreproc);
+} // namespace transform
+} // namespace tvm
diff --git
a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py
b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py
index e2305de2af..8348c57c19 100644
--- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py
+++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py
@@ -61,7 +61,8 @@ class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
)
sch = tvm.tir.Schedule(mod, debug_mask="all")
sch.enter_postproc()
- assert ctx.space_generator.postprocs[0].apply(sch)
+ if not ctx.space_generator.postprocs[0].apply(sch):
+ raise tvm.TVMError("RewriteLayout postproc failed")
return sch.mod
return inner
diff --git
a/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py
b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py
new file mode 100644
index 0000000000..46f7c8aa87
--- /dev/null
+++ b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py
@@ -0,0 +1,311 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm.testing
+
+from tvm import relax, tir
+from tvm.script import relax as R, tir as T, ir as I
+from tvm.relax.transform import CombineParallelMatmul
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder import relax as relax_builder
+
+
+def test_param():
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def matmul(
+ A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ ):
+ for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+ with T.block("C"):
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32"), y: R.Tensor((32, 32),
"float32")):
+ R.func_attr({"num_input": 1})
+ cls = Before
+ with R.dataflow():
+ gv = R.call_tir(cls.matmul, (x, y), out_sinfo=R.Tensor((32,
32), "float32"))
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def matmul1(
+ A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ ):
+ T.func_attr({"layout_free_buffers": [1]})
+ for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+ with T.block("C"):
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32"), y: R.Tensor((32, 32),
"float32")):
+ R.func_attr({"num_input": 1})
+ cls = Expected
+ with R.dataflow():
+ gv = R.call_tir(cls.matmul1, (x, y), out_sinfo=R.Tensor((32,
32), "float32"))
+ R.output(gv)
+ return gv
+
+ after = relax.transform.AttachAttrLayoutFreeBuffers()(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_const():
+ const_value = np.ones((32, 32), dtype="float32")
+
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def matmul(
+ A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ ):
+ for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+ with T.block("C"):
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")):
+ R.func_attr({"num_input": 1})
+ cls = Before
+ with R.dataflow():
+ gv = R.call_tir(
+ cls.matmul,
+ (x, relax.const(const_value)),
+ out_sinfo=R.Tensor((32, 32), "float32"),
+ )
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def matmul1(
+ A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ ):
+ T.func_attr({"layout_free_buffers": [1]})
+ for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+ with T.block("C"):
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+ @R.function
+ def main(x: R.Tensor((32, 32), "float32")):
+ R.func_attr({"num_input": 1})
+ cls = Expected
+ with R.dataflow():
+ gv = R.call_tir(
+ cls.matmul1,
+ (x, relax.const(const_value)),
+ out_sinfo=R.Tensor((32, 32), "float32"),
+ )
+ R.output(gv)
+ return gv
+
+ after = relax.transform.AttachAttrLayoutFreeBuffers()(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_multiple_same_func():
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def matmul(
+ A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ ):
+ for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+ with T.block("C"):
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+ @R.function
+ def main(
+ x: R.Tensor((32, 32), "float32"),
+ w1: R.Tensor((32, 32), "float32"),
+ w2: R.Tensor((32, 32), "float32"),
+ ):
+ R.func_attr({"num_input": 1})
+ cls = Before
+ with R.dataflow():
+ lv1 = R.call_tir(
+ cls.matmul,
+ (x, w1),
+ out_sinfo=R.Tensor((32, 32), "float32"),
+ )
+ gv = R.call_tir(
+ cls.matmul,
+ (lv1, w2),
+ out_sinfo=R.Tensor((32, 32), "float32"),
+ )
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def matmul1(
+ A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ ):
+ T.func_attr({"layout_free_buffers": [1]})
+ for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+ with T.block("C"):
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+ @R.function
+ def main(
+ x: R.Tensor((32, 32), "float32"),
+ w1: R.Tensor((32, 32), "float32"),
+ w2: R.Tensor((32, 32), "float32"),
+ ):
+ R.func_attr({"num_input": 1})
+ cls = Expected
+ with R.dataflow():
+ lv1 = R.call_tir(
+ cls.matmul1,
+ (x, w1),
+ out_sinfo=R.Tensor((32, 32), "float32"),
+ )
+ gv = R.call_tir(
+ cls.matmul1,
+ (lv1, w2),
+ out_sinfo=R.Tensor((32, 32), "float32"),
+ )
+ R.output(gv)
+ return gv
+
+ after = relax.transform.AttachAttrLayoutFreeBuffers()(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_multiple_same_func_with_different_free_buffers():
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def matmul(
+ A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ ):
+ for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+ with T.block("C"):
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+ @R.function
+ def main(
+ x: R.Tensor((32, 32), "float32"),
+ w1: R.Tensor((32, 32), "float32"),
+ w2: R.Tensor((32, 32), "float32"),
+ ):
+ R.func_attr({"num_input": 1})
+ cls = Before
+ with R.dataflow():
+ lv1 = R.call_tir(
+ cls.matmul,
+ (x, w1),
+ out_sinfo=R.Tensor((32, 32), "float32"),
+ )
+ gv = R.call_tir(
+ cls.matmul,
+ (w2, lv1),
+ out_sinfo=R.Tensor((32, 32), "float32"),
+ )
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def matmul1(
+ A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ ):
+ T.func_attr({"layout_free_buffers": [1]})
+ for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+ with T.block("C"):
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+ @T.prim_func(private=True)
+ def matmul2(
+ A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+ ):
+ T.func_attr({"layout_free_buffers": [0]})
+ for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+ with T.block("C"):
+ with T.init():
+ C[i, j] = T.float32(0)
+ C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+ @R.function
+ def main(
+ x: R.Tensor((32, 32), "float32"),
+ w1: R.Tensor((32, 32), "float32"),
+ w2: R.Tensor((32, 32), "float32"),
+ ):
+ R.func_attr({"num_input": 1})
+ cls = Expected
+ with R.dataflow():
+ lv1 = R.call_tir(
+ cls.matmul1,
+ (x, w1),
+ out_sinfo=R.Tensor((32, 32), "float32"),
+ )
+ gv = R.call_tir(
+ cls.matmul2,
+ (w2, lv1),
+ out_sinfo=R.Tensor((32, 32), "float32"),
+ )
+ R.output(gv)
+ return gv
+
+ after = relax.transform.AttachAttrLayoutFreeBuffers()(Before)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py
b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py
new file mode 100644
index 0000000000..e6b4c8ec4e
--- /dev/null
+++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py
@@ -0,0 +1,220 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def test_single_buffer():
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def tir_func(
+ X: T.Buffer((224, 224), "float32"),
+ W: T.Buffer((224, 224), "float32"),
+ Out: T.Buffer((224, 224), "float32"),
+ ):
+ T.func_attr({"layout_free_buffers": [1]})
+ W_rewrite = T.alloc_buffer((4, 4, 56, 56))
+ for i, j in T.grid(224, 224):
+ with T.block("W_rewrite"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.block_attr({"meta_schedule.layout_rewrite_preproc":
T.bool(True)})
+ W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj]
+ for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
+ with T.block("Out"):
+ vi = T.axis.spatial(224, i0 * 56 + i1)
+ vj = T.axis.spatial(224, j0 * 56 + j1)
+ Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi
% 56, vj % 56]
+
+ @R.function
+ def forward(
+ x: R.Tensor((224, 224), dtype="float32"),
+ w: R.Tensor((224, 224), dtype="float32"),
+ ) -> R.Tensor((224, 224), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ cls = Before
+ with R.dataflow():
+ gv = R.call_tir(
+ cls.tir_func, (x, w), out_sinfo=R.Tensor((224, 224),
dtype="float32")
+ )
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class After:
+ @T.prim_func(private=True)
+ def tir_func_prepacked(
+ X: T.Buffer((224, 224), "float32"),
+ W_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+ Out: T.Buffer((224, 224), "float32"),
+ ):
+ for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
+ with T.block("Out"):
+ vi = T.axis.spatial(224, i0 * 56 + i1)
+ vj = T.axis.spatial(224, j0 * 56 + j1)
+ Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi
% 56, vj % 56]
+
+ @T.prim_func(private=True)
+ def tir_func_weight_prepack(
+ W: T.Buffer((224, 224), "float32"),
+ W_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+ ):
+ for i, j in T.grid(224, 224):
+ with T.block("W_rewrite"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj]
+
+ @R.function
+ def forward(
+ x: R.Tensor((224, 224), dtype="float32"),
+ w: R.Tensor((224, 224), dtype="float32"),
+ ) -> R.Tensor((224, 224), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ cls = After
+ with R.dataflow():
+ lv = R.call_tir(
+ cls.tir_func_weight_prepack, (w,), out_sinfo=R.Tensor((4,
4, 56, 56), "float32")
+ )
+ lv1 = R.call_tir(
+ cls.tir_func_prepacked, (x, lv), out_sinfo=R.Tensor((224,
224), "float32")
+ )
+ gv: R.Tensor((224, 224), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ mod = relax.transform.SplitLayoutRewritePreproc()(Before)
+ tvm.ir.assert_structural_equal(mod, After)
+
+
+def test_multiple_buffers():
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def tir_func(
+ X: T.Buffer((224, 224), "float32"),
+ W1: T.Buffer((224, 224), "float32"),
+ W2: T.Buffer((224, 224), "float32"),
+ Out: T.Buffer((224, 224), "float32"),
+ ):
+ W1_rewrite = T.alloc_buffer((4, 4, 56, 56))
+ W2_rewrite = T.alloc_buffer((4, 4, 56, 56))
+ for i, j in T.grid(224, 224):
+ with T.block("W1_rewrite"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.block_attr({"meta_schedule.layout_rewrite_preproc":
T.bool(True)})
+ W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi,
vj]
+ for i, j in T.grid(224, 224):
+ with T.block("W2_rewrite"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.block_attr({"meta_schedule.layout_rewrite_preproc":
T.bool(True)})
+ W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi,
vj]
+ for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
+ with T.block("Out"):
+ vi = T.axis.spatial(224, i0 * 56 + i1)
+ vj = T.axis.spatial(224, j0 * 56 + j1)
+ Out[vi, vj] = (
+ X[vi, vj]
+ + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
+ + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
+ )
+
+ @R.function
+ def forward(
+ x: R.Tensor((224, 224), dtype="float32"),
+ w1: R.Tensor((224, 224), dtype="float32"),
+ w2: R.Tensor((224, 224), dtype="float32"),
+ ) -> R.Tensor((224, 224), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ cls = Before
+ with R.dataflow():
+ gv = R.call_tir(
+ cls.tir_func, (x, w1, w2), out_sinfo=R.Tensor((224, 224),
dtype="float32")
+ )
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class After:
+ @T.prim_func(private=True)
+ def tir_func_prepacked(
+ X: T.Buffer((224, 224), "float32"),
+ W1_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+ W2_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+ Out: T.Buffer((224, 224), "float32"),
+ ):
+ for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
+ with T.block("Out"):
+ vi = T.axis.spatial(224, i0 * 56 + i1)
+ vj = T.axis.spatial(224, j0 * 56 + j1)
+ Out[vi, vj] = (
+ X[vi, vj]
+ + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
+ + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
+ )
+
+ @T.prim_func(private=True)
+ def tir_func_weight_prepack(
+ W1: T.Buffer((224, 224), "float32"),
+ W2: T.Buffer((224, 224), "float32"),
+ W1_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+ W2_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+ ):
+ for i, j in T.grid(224, 224):
+ with T.block("W1_rewrite"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi,
vj]
+ for i, j in T.grid(224, 224):
+ with T.block("W2_rewrite"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi,
vj]
+
+ @R.function
+ def forward(
+ x: R.Tensor((224, 224), dtype="float32"),
+ w1: R.Tensor((224, 224), dtype="float32"),
+ w2: R.Tensor((224, 224), dtype="float32"),
+ ) -> R.Tensor((224, 224), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ cls = After
+ with R.dataflow():
+ lv0 = R.call_tir(
+ cls.tir_func_weight_prepack,
+ (w1, w2),
+ out_sinfo=[
+ R.Tensor((4, 4, 56, 56), "float32"),
+ R.Tensor((4, 4, 56, 56), "float32"),
+ ],
+ )
+ lv1 = R.call_tir(
+ cls.tir_func_prepacked,
+ (x, lv0[0], lv0[1]),
+ out_sinfo=R.Tensor((224, 224), "float32"),
+ )
+ gv: R.Tensor((224, 224), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ mod = relax.transform.SplitLayoutRewritePreproc()(Before)
+ tvm.ir.assert_structural_equal(mod, After)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()