This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 80250411e7 [Relax][MetaSchedule] Support CPU weight prepack (#17445)
80250411e7 is described below

commit 80250411e706509fef499e0defe0e625bf6fab28
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Oct 17 04:36:41 2024 +0800

    [Relax][MetaSchedule] Support CPU weight prepack (#17445)
    
    This PR adds support for CPU weight prepacking. To be specific, this PR
    adds a new pass `AttachAttrLayoutFreeBuffers` to attach layout free buffers
    to the weight parameters, so that we can leverage MetaSchedule to optimize
    the prepacking process.
    
    After the pass and tuning, we introduce a new pass 
`SplitLayoutRewritePreproc`
    to split the layout rewrite pass into multiple functions, so that we can 
lift
    the parameters transform pass function with existing pass.
---
 include/tvm/relax/transform.h                      |  21 ++
 python/tvm/relax/frontend/nn/__init__.py           |   2 +
 python/tvm/relax/pipeline.py                       |  50 +++-
 python/tvm/relax/transform/__init__.py             |   2 +
 python/tvm/relax/transform/transform.py            |  29 ++
 src/meta_schedule/postproc/rewrite_layout.cc       |   8 +-
 .../transform/attach_attr_layout_free_buffers.cc   | 113 +++++++
 .../transform/split_layout_rewrite_preproc.cc      | 327 +++++++++++++++++++++
 .../test_meta_schedule_postproc_rewrite_layout.py  |   3 +-
 ...st_transform_attach_attr_layout_free_buffers.py | 311 ++++++++++++++++++++
 .../test_transform_split_layout_rewrite_preproc.py | 220 ++++++++++++++
 11 files changed, 1083 insertions(+), 3 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 5a7b85ac13..eaad44a93a 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -253,6 +253,27 @@ TVM_DLL Pass LegalizeOps(Optional<Map<String, PackedFunc>> 
cmap, bool enable_war
  */
 TVM_DLL Pass RealizeVDevice();
 
+/*!
+ * \brief Attach layout free buffers to the tir::PrimFunc.
+ *
+ * This pass is used to attach layout free buffers to the tir::PrimFunc 
according to
+ * the function usage in the relax function. Currently, the layout free 
buffers are the model
+ * weights and relax constants.
+ *
+ * \note We recommend applying CanonicalizeBindings before this pass.
+ * \return The Pass.
+ */
+TVM_DLL Pass AttachAttrLayoutFreeBuffers();
+
+/*!
+ * \brief Split the layout rewrite preproc block to a separate tir::PrimFunc.
+ *
+ * This pass is used in the prepack weight after meta_schedule tuning.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass SplitLayoutRewritePreproc();
+
 /*!
  * \brief Lift transformation of the parameters of a function.
  *
diff --git a/python/tvm/relax/frontend/nn/__init__.py 
b/python/tvm/relax/frontend/nn/__init__.py
index a8200d8dd6..f490af7062 100644
--- a/python/tvm/relax/frontend/nn/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -23,6 +23,8 @@ from .extern import ExternModule, ObjectModule, SourceModule
 from .modules import (
     GELU,
     Conv1D,
+    Conv2D,
+    Conv3D,
     ConvTranspose1D,
     Embedding,
     GroupNorm,
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
index 582f5111aa..fe3dbc99fc 100644
--- a/python/tvm/relax/pipeline.py
+++ b/python/tvm/relax/pipeline.py
@@ -109,6 +109,7 @@ def static_shape_tuning_pipeline(
     total_trials: int,
     target: Union[str, tvm.target.Target],
     work_dir: str = "tuning_logs",
+    cpu_weight_prepack: bool = False,
 ):
     """Tune the static shape model and store the log to database.
 
@@ -122,18 +123,65 @@ def static_shape_tuning_pipeline(
 
     work_dir : str
         The directory to store the tuning logs.
+
+    cpu_weight_prepack : bool
+        Whether to enable the cpu weight prepack feature.
+
+    Note
+    ----
+    `cpu_weight_prepack` is expected to be `True` when running on CPU for
+    better performance. However, it requires an explicit layout transformation
+    step by calling the corresponding vm function, which changes the interface
+    of deployment. So we disable it by default. Here is an example to enable 
it:
+
+    .. code-block:: python
+
+        mod = relax.pipeline.static_shape_tuning_pipeline(
+            total_trials=1000,
+            target="llvm -num-cores 16",
+            work_dir="tuning_logs",
+            cpu_weight_prepack=True,
+        )(mod)
+
+        ex = relax.build(mod, target=target)
+        vm = relax.VirtualMachine(ex, device=tvm.cpu())
+
+        # Transform the params using the vm function
+        # the name should be f"{func_name}_transform_params"
+        params = vm["main_transform_params"](params["main"])
+
+        input_data = tvm.nd.array(np.random.randn(1, 3, 224, 
224).astype("float32"))
+        out = vm["main"](input_data, *params).numpy()
     """
 
     @tvm.transform.module_pass(opt_level=0)
     def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> 
tvm.ir.IRModule:
+        if cpu_weight_prepack:
+            pre_tuning_layout_rewrite = 
[transform.AttachAttrLayoutFreeBuffers()]
+            post_tuning_layout_rewrite = [
+                transform.SplitLayoutRewritePreproc(),
+                transform.LiftTransformParams(),
+                transform.FoldConstant(),
+            ]
+        else:
+            pre_tuning_layout_rewrite = []
+            post_tuning_layout_rewrite = []
+
         with tvm.target.Target(target):
             mod = tvm.transform.Sequential(
                 [
                     transform.DecomposeOpsForInference(),
                     transform.CanonicalizeBindings(),
                     zero_pipeline(),
-                    transform.MetaScheduleTuneIRMod({}, work_dir, 
total_trials),
+                    *pre_tuning_layout_rewrite,
+                    # Skip tuning if total_trials is 0
+                    (
+                        transform.MetaScheduleTuneIRMod({}, work_dir, 
total_trials)
+                        if total_trials > 0
+                        else tvm.transform.Sequential([])
+                    ),
                     transform.MetaScheduleApplyDatabase(work_dir),
+                    *post_tuning_layout_rewrite,
                 ]
             )(mod)
 
diff --git a/python/tvm/relax/transform/__init__.py 
b/python/tvm/relax/transform/__init__.py
index 1ce864651c..16e4800ca3 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -21,6 +21,7 @@ from .transform import (
     AllocateWorkspace,
     AlterOpImpl,
     AnnotateTIROpPattern,
+    AttachAttrLayoutFreeBuffers,
     AttachGlobalSymbol,
     BindParams,
     BindSymbolicVars,
@@ -73,6 +74,7 @@ from .transform import (
     RewriteDataflowReshape,
     RunCodegen,
     SplitCallTIRByPattern,
+    SplitLayoutRewritePreproc,
     StaticPlanBlockMemory,
     ToMixedPrecision,
     ToNonDataflow,
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 3330d40987..603211b59e 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -970,6 +970,35 @@ def MergeCompositeFunctions() -> tvm.ir.transform.Pass:
     return _ffi_api.MergeCompositeFunctions()  # type: ignore
 
 
+def AttachAttrLayoutFreeBuffers() -> tvm.ir.transform.Pass:
+    """Attach layout free buffers to the tir::PrimFunc.
+
+    This pass is used to attach layout free buffers to the tir::PrimFunc 
according to
+    the function usage in the relax function. Currently, the layout free 
buffers are the model
+    weights and relax constants.
+
+    Note that we recommend applying CanonicalizeBindings before this pass.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass for attaching layout free buffers.
+    """
+    return _ffi_api.AttachAttrLayoutFreeBuffers()  # type: ignore
+
+
+def SplitLayoutRewritePreproc() -> tvm.ir.transform.Pass:
+    """Split the TIR layout rewrite into multiple TIR functions.
+    This pass is used in the prepack weight after meta_schedule tuning.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass for splitting TIR layout rewrite.
+    """
+    return _ffi_api.SplitLayoutRewritePreproc()  # type: ignore
+
+
 def LiftTransformParams(shared_transform: Union[bool, List[str]] = False) -> 
tvm.ir.transform.Pass:
     """Lift transformation of the parameters of a function.
 
diff --git a/src/meta_schedule/postproc/rewrite_layout.cc 
b/src/meta_schedule/postproc/rewrite_layout.cc
index 71ae433871..87fa96f67c 100644
--- a/src/meta_schedule/postproc/rewrite_layout.cc
+++ b/src/meta_schedule/postproc/rewrite_layout.cc
@@ -249,7 +249,13 @@ class RewriteLayoutNode : public PostprocNode {
   void InitializeWithTuneContext(const TuneContext& context) final {}
 
   // Inherited from PostprocNode
-  bool Apply(const tir::Schedule& sch) final { return tir::RewriteLayout(sch); 
}
+  bool Apply(const tir::Schedule& sch) final {
+    try {
+      return tir::RewriteLayout(sch);
+    } catch (const std::runtime_error& e) {
+      return false;
+    }
+  }
 
   Postproc Clone() const {
     ObjectPtr<RewriteLayoutNode> n = make_object<RewriteLayoutNode>(*this);
diff --git a/src/relax/transform/attach_attr_layout_free_buffers.cc 
b/src/relax/transform/attach_attr_layout_free_buffers.cc
new file mode 100644
index 0000000000..64062e2243
--- /dev/null
+++ b/src/relax/transform/attach_attr_layout_free_buffers.cc
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/attach_attr_layout_free_buffers.cc
+ * \brief Attach layout_free_buffers for layout-free buffers.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace relax {
+
+class AttrAttacher : public ExprMutator {
+ public:
+  static IRModule Transform(const IRModule& mod) {
+    AttrAttacher mutator(mod);
+    for (auto [gvar, func] : mod->functions) {
+      if (func->IsInstance<relax::FunctionNode>()) {
+        // clear the layout_free_exprs_ for each function
+        mutator.layout_free_exprs_.clear();
+        mutator.builder_->UpdateFunction(gvar, 
Downcast<BaseFunc>(mutator.VisitExpr(func)));
+      }
+    }
+    return mutator.builder_->GetContextIRModule();
+  }
+
+ private:
+  explicit AttrAttacher(IRModule mod) : ExprMutator(mod), mod_(mod) {}
+
+  using ExprMutator::VisitExpr_;
+  Expr VisitExpr_(const FunctionNode* op) final {
+    if (auto opt_num_input = op->attrs.GetAttr<Integer>(attr::kNumInput)) {
+      ICHECK(layout_free_exprs_.empty()) << "meet a non-global function with 
num_input attr";
+      size_t num_input = opt_num_input.value()->value;
+      for (size_t i = num_input; i < op->params.size(); i++) {
+        layout_free_exprs_.insert(op->params[i].get());
+      }
+    }
+    return ExprMutator::VisitExpr_(op);
+  }
+
+  Expr VisitExpr_(const ConstantNode* op) final {
+    layout_free_exprs_.insert(op);
+    return ExprMutator::VisitExpr_(op);
+  }
+
+  Expr VisitExpr_(const CallNode* op) final {
+    static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+    Call call = Downcast<Call>(ExprMutator::VisitExpr_(op));
+    if (call->op != call_tir_op_) {
+      return call;
+    }
+    GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+    Array<Expr> call_tir_args = Downcast<Tuple>(call->args[1])->fields;
+    // Compute the layout free buffers
+    Array<Integer> layout_free_buffers;
+    for (size_t i = 0; i < call_tir_args.size(); i++) {
+      if (layout_free_exprs_.count(call_tir_args[i].get())) {
+        layout_free_buffers.push_back(Integer(i));
+      }
+    }
+    // Attach the layout free buffers to the tir::PrimFunc
+    tir::PrimFunc func = WithAttr(Downcast<tir::PrimFunc>(mod_->Lookup(gv)), 
"layout_free_buffers",
+                                  layout_free_buffers);
+    // Renew defs
+    func = tir::RenewDefs(func);
+    // Add the updated tir::PrimFunc in the IRModule
+    // Note the blockbuilder would automatically combine the same tir function
+    // So we don't need to worry about the duplicate insertion
+    GlobalVar new_gv = builder_->AddFunction(func, gv->name_hint);
+    // Create a new call node with the updated tir::PrimFunc
+    auto n = make_object<CallNode>(*op);
+    n->args = {new_gv, Tuple(call_tir_args)};
+    return Call(n);
+  }
+
+ private:
+  IRModule mod_;
+  std::unordered_set<const ExprNode*> layout_free_exprs_;
+};
+namespace transform {
+
+Pass AttachAttrLayoutFreeBuffers() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule mod, PassContext pc) { return AttrAttacher::Transform(mod); 
};
+  auto pass = CreateModulePass(pass_func, 0, "_AttachAttrLayoutFreeBuffers", 
{});
+  // Apply DeadCodeElimination to remove unused tir::PrimFunc
+  return tvm::transform::Sequential({pass, DeadCodeElimination()}, 
"AttachAttrLayoutFreeBuffers");
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.AttachAttrLayoutFreeBuffers")
+    .set_body_typed(AttachAttrLayoutFreeBuffers);
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc 
b/src/relax/transform/split_layout_rewrite_preproc.cc
new file mode 100644
index 0000000000..5fee946c26
--- /dev/null
+++ b/src/relax/transform/split_layout_rewrite_preproc.cc
@@ -0,0 +1,327 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/split_tir_layout_rewrite.cc
+ * \brief Use for rewriting the TIRs after meta_schedule layout rewrite post 
process.
+ */
+#include <tvm/ir/transform.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <cstddef>
+
+namespace tvm {
+namespace tir {
+class SplitPrimFuncLayoutRewrite : public StmtMutator {
+ public:
+  explicit SplitPrimFuncLayoutRewrite(const PrimFunc& func) : 
original_func_(func) {}
+  std::tuple<Optional<PrimFunc>, PrimFunc> Transform(const PrimFunc& func) {
+    ICHECK(func->body.as<BlockRealizeNode>()) << "The body of the primfunc 
should be a root block.";
+    const auto& block = func->body.as<BlockRealizeNode>()->block;
+    visit_root_block(block.get());
+    if (layout_rewrite_preproc_stmts_.size() > 0) {
+      return std::make_tuple(create_layout_rewrite_preproc_func(), 
create_compute_func());
+    } else {
+      return std::make_tuple(NullOpt, func);
+    }
+  }
+
+ private:
+  void sort_rewrite_infos() {
+    std::sort(
+        rewrite_infos_.begin(), rewrite_infos_.end(),
+        [](const RewriteInfo& a, const RewriteInfo& b) { return a.buffer_index 
< b.buffer_index; });
+  }
+
+  PrimFunc create_layout_rewrite_preproc_func() const {
+    // Step 1: Check the number of pre_rewrite_buffers and post_rewrite_buffers
+    ICHECK(rewrite_infos_.size() > 0) << "There should be at least one buffer 
rewrite.";
+
+    // Step 2: Create the params for the new PrimFunc
+    Array<Var> params;
+    Map<Var, Buffer> buffer_map;
+
+    for (const auto& info : rewrite_infos_) {
+      params.push_back(Var(info.pre_rewrite_buffer->name, DataType::Handle()));
+      buffer_map.Set(params.back(), info.pre_rewrite_buffer);
+    }
+    for (const auto& info : rewrite_infos_) {
+      params.push_back(Var(info.post_rewrite_buffer->name, 
DataType::Handle()));
+      buffer_map.Set(params.back(), info.post_rewrite_buffer);
+    }
+
+    // Step 3: Create the body for the new PrimFunc
+    ICHECK(layout_rewrite_preproc_stmts_.size() > 0)
+        << "There should be at least one layout rewrite preproc stmt.";
+    Stmt body = layout_rewrite_preproc_stmts_.size() == 1 ? 
layout_rewrite_preproc_stmts_[0]
+                                                          : 
SeqStmt(layout_rewrite_preproc_stmts_);
+    body = BlockRealize(
+        /*iter_values=*/Array<PrimExpr>(),
+        /*predicate=*/const_true(),
+        /*block=*/
+        Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
+              /*name_hint=*/"root", body));
+
+    PrimFunc func = PrimFunc(params, body, VoidType(), buffer_map);
+
+    return RenewDefs(func);
+  }
+
+  PrimFunc create_compute_func() const {
+    // Step 1: Create the params for the new PrimFunc
+    Array<Var> params = original_func_->params;
+    Map<Var, Buffer> buffer_map = original_func_->buffer_map;
+    for (const auto& info : rewrite_infos_) {
+      const Var& param = params[info.buffer_index];
+      ICHECK(buffer_map[param] == info.pre_rewrite_buffer);
+      buffer_map.Set(param, info.post_rewrite_buffer);
+    }
+
+    // Step 2: Create the body for the new PrimFunc
+    Stmt body = compute_stmts_.size() == 1 ? compute_stmts_[0] : 
SeqStmt(compute_stmts_);
+    Block original_block = original_func_->body.as<BlockRealizeNode>()->block;
+    Array<Buffer> alloc_buffers;
+    for (const auto& buffer : original_block->alloc_buffers) {
+      auto it =
+          std::find_if(rewrite_infos_.begin(), rewrite_infos_.end(),
+                       [&](const RewriteInfo& info) { return 
info.post_rewrite_buffer == buffer; });
+      if (it == rewrite_infos_.end()) {
+        alloc_buffers.push_back(buffer);
+      }
+    }
+
+    body = BlockRealize(
+        /*iter_values=*/Array<PrimExpr>(),
+        /*predicate=*/const_true(),
+        /*block=*/
+        Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
+              /*name_hint=*/"root", body,
+              /*init=*/NullOpt,
+              /*alloc_buffers=*/alloc_buffers));
+
+    PrimFunc func = PrimFunc(original_func_->params, body, VoidType(), 
buffer_map);
+    return RenewDefs(func);
+  }
+
+  void visit_root_block(const BlockNode* op) {
+    Stmt body = op->body;
+    if (const auto* seq_stmt = body.as<SeqStmtNode>()) {
+      for (const auto& stmt : seq_stmt->seq) {
+        current_subtree_ = 0;
+        Stmt new_stmt = this->VisitStmt(stmt);
+        ICHECK(current_subtree_ != 0) << "There should be at least a block in 
the subtree.";
+        if (current_subtree_ == 1) {
+          layout_rewrite_preproc_stmts_.push_back(new_stmt);
+        } else {
+          compute_stmts_.push_back(new_stmt);
+        }
+      }
+    } else {
+      current_subtree_ = 0;
+      this->VisitStmt(body);
+      ICHECK(current_subtree_ == -1)
+          << "There should be a compute block if there is only one subtree 
under the root.";
+    }
+  }
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
+    auto it = op->annotations.find(attr::meta_schedule_layout_rewrite_preproc);
+    bool is_layout_rewrite_preproc =
+        it != op->annotations.end() && 
is_one(Downcast<PrimExpr>((*it).second));
+
+    if (current_subtree_ == 0) {
+      current_subtree_ = is_layout_rewrite_preproc ? 1 : -1;
+    } else if (current_subtree_ == 1) {
+      CHECK(is_layout_rewrite_preproc)
+          << "There is a layout rewrite block in the subtree, but meet a 
non-layout rewrite block.";
+    } else {
+      CHECK(!is_layout_rewrite_preproc)
+          << "There is a non-layout rewrite block in the subtree, but meet a 
layout rewrite block.";
+    }
+
+    if (is_layout_rewrite_preproc) {
+      ICHECK(op->reads.size() == 1) << "There should be only one read buffer 
in the layout rewrite";
+      ICHECK(op->writes.size() == 1)
+          << "There should be only one write buffer in the layout rewrite";
+      ICHECK(op->alloc_buffers.empty()) << "There should be no alloc buffer in 
the layout rewrite";
+      ICHECK(op->match_buffers.empty()) << "There should be no match buffer in 
the layout rewrite";
+      const Buffer& preproc_buffer = op->reads[0]->buffer;
+      int buffer_index = -1;
+      for (size_t i = 0; i < original_func_->params.size(); ++i) {
+        const Buffer& buffer = 
original_func_->buffer_map[original_func_->params[i]];
+        if (buffer == preproc_buffer) {
+          buffer_index = i;
+          break;
+        }
+      }
+      ICHECK(buffer_index != -1) << "The preproc buffer is not found in the 
original primfunc.";
+      rewrite_infos_.push_back(
+          RewriteInfo{buffer_index, op->reads[0]->buffer, 
op->writes[0]->buffer});
+
+      auto new_annotations = op->annotations;
+      new_annotations.erase(attr::meta_schedule_layout_rewrite_preproc);
+      auto n = make_object<BlockNode>(*block.get());
+      n->annotations = new_annotations;
+      return Block(n);
+    }
+    return block;
+  }
+
+ public:
+  struct RewriteInfo {
+    int buffer_index;
+    Buffer pre_rewrite_buffer;
+    Buffer post_rewrite_buffer;
+  };
+  std::vector<RewriteInfo> rewrite_infos_;
+
+ private:
+  /*! \brief The stmts that are used for layout rewrite preproc*/
+  Array<Stmt> layout_rewrite_preproc_stmts_;
+  /*! \brief The stmts that are other than layout rewrite preproc*/
+  Array<Stmt> compute_stmts_;
+  /*!
+   \brief Whether the current subtree is a layout rewrite preproc subtree.
+          -1: visited a non-layout rewrite preproc block
+           0: unsure, not visited any block
+           1: visited a layout rewrite preproc block
+  */
+  int current_subtree_;
+  /*! \brief The original primfunc*/
+  PrimFunc original_func_;
+};
+}  // namespace tir
+
+namespace relax {
+class SplitLayoutRewritePreproc : public ExprMutator {
+ public:
+  static IRModule Transform(const IRModule& mod) {
+    SplitLayoutRewritePreproc mutator(mod);
+
+    // Step 1: Split the primfunc into preproc and compute
+    for (auto [gv, func] : mod->functions) {
+      if (func->IsInstance<tir::PrimFuncNode>()) {
+        tir::SplitPrimFuncLayoutRewrite 
tir_rewriter(Downcast<tir::PrimFunc>(func));
+        auto [preproc_func, compute_func] = 
tir_rewriter.Transform(Downcast<tir::PrimFunc>(func));
+        if (preproc_func.defined()) {
+          mutator.split_funcs_.emplace(gv.get(),
+                                       std::make_tuple(preproc_func.value(), 
compute_func));
+          mutator.rewrite_infos_.emplace(gv.get(), 
tir_rewriter.rewrite_infos_);
+        }
+      }
+    }
+
+    for (auto [gv, func] : mod->functions) {
+      if (func->IsInstance<relax::FunctionNode>()) {
+        auto relax_func = Downcast<relax::Function>(func);
+        mutator.builder_->UpdateFunction(gv, 
Downcast<relax::Function>(mutator(relax_func)));
+      }
+    }
+    return mutator.builder_->GetContextIRModule();
+  }
+
+ private:
+  explicit SplitLayoutRewritePreproc(const IRModule& mod) : ExprMutator(mod) {}
+  using ExprMutator::VisitExpr_;
+
+  Expr VisitExpr_(const CallNode* op) final {
+    static const Op& call_tir_op = Op::Get("relax.call_tir");
+    Call call = Downcast<Call>(ExprMutator::VisitExpr_(op));
+
+    // Step 1: Skip call to other than `tir.call_tir`
+    if (!call->op.same_as(call_tir_op)) {
+      return call;
+    }
+
+    // Step 2: Skip if there is no preproc stage
+    const GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+    auto it = split_funcs_.find(gv.get());
+    if (it == split_funcs_.end()) {
+      return call;
+    }
+
+    // Step 3: Get the preproc and compute functions and update the module
+    const auto& [preproc_func, compute_func] = it->second;
+    GlobalVar preproc_gv = builder_->AddFunction(preproc_func, gv->name_hint + 
"_weight_prepack");
+    GlobalVar compute_gv = builder_->AddFunction(compute_func, gv->name_hint + 
"_prepacked");
+    // Step 4. Get rewrite infos
+    auto rewrite_infos_it = rewrite_infos_.find(gv.get());
+    ICHECK(rewrite_infos_it != rewrite_infos_.end())
+        << "Rewrite infos are not found for " << gv->name_hint;
+    const auto& rewrite_infos = rewrite_infos_it->second;
+
+    // Step 5: Emit the preproc call
+    Array<Expr> call_tir_args = Downcast<Tuple>(call->args[1])->fields;
+    Array<Expr> preproc_args;
+    Array<StructInfo> preproc_sinfo_list;
+    for (const auto& info : rewrite_infos) {
+      preproc_args.push_back(call_tir_args[info.buffer_index]);
+      tir::Buffer rewritten_buffer = info.post_rewrite_buffer;
+      for (const auto& shape_expr : rewritten_buffer->shape) {
+        CHECK(shape_expr.as<tir::IntImmNode>()) << "Currently does not support 
rewrite buffer with "
+                                                   "dynamic shape.";
+      }
+      preproc_sinfo_list.push_back(
+          TensorStructInfo(ShapeExpr(rewritten_buffer->shape), 
rewritten_buffer->dtype));
+    }
+    StructInfo preproc_sinfo = preproc_sinfo_list.size() > 1              //
+                                   ? TupleStructInfo(preproc_sinfo_list)  //
+                                   : preproc_sinfo_list[0];
+
+    // Step 6: Call the preproc function
+    Expr preproc_call =
+        builder_->Emit(Call(call_tir_op, {preproc_gv, Tuple(preproc_args)}, 
{}, {preproc_sinfo}));
+    if (rewrite_infos.size() == 1) {
+      call_tir_args.Set(rewrite_infos[0].buffer_index, preproc_call);
+    } else {
+      for (size_t i = 0; i < rewrite_infos.size(); ++i) {
+        call_tir_args.Set(rewrite_infos[i].buffer_index, 
TupleGetItem(preproc_call, i));
+      }
+    }
+    Expr main_call =
+        builder_->Emit(Call(call_tir_op, {compute_gv, Tuple(call_tir_args)}, 
{}, call->sinfo_args));
+
+    return main_call;
+  }
+
+ private:
+  std::unordered_map<const GlobalVarNode*, std::tuple<tir::PrimFunc, 
tir::PrimFunc>> split_funcs_;
+  std::unordered_map<const GlobalVarNode*,
+                     std::vector<tir::SplitPrimFuncLayoutRewrite::RewriteInfo>>
+      rewrite_infos_;
+};
+
+}  // namespace relax
+
+namespace transform {
+Pass SplitLayoutRewritePreproc() {
+  auto pass_func = [](IRModule mod, PassContext pc) {
+    return relax::SplitLayoutRewritePreproc::Transform(mod);
+  };
+  auto pass = CreateModulePass(pass_func, 0, "SplitLayoutRewritePreproc", {});
+  return tvm::transform::Sequential({pass, 
relax::transform::DeadCodeElimination()},
+                                    "SplitLayoutRewritePreproc");
+}
+TVM_REGISTER_GLOBAL("relax.transform.SplitLayoutRewritePreproc")
+    .set_body_typed(SplitLayoutRewritePreproc);
+}  // namespace transform
+}  // namespace tvm
diff --git 
a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py 
b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py
index e2305de2af..8348c57c19 100644
--- a/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py
+++ b/tests/python/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py
@@ -61,7 +61,8 @@ class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
             )
             sch = tvm.tir.Schedule(mod, debug_mask="all")
             sch.enter_postproc()
-            assert ctx.space_generator.postprocs[0].apply(sch)
+            if not ctx.space_generator.postprocs[0].apply(sch):
+                raise tvm.TVMError("RewriteLayout postproc failed")
             return sch.mod
 
         return inner
diff --git 
a/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py 
b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py
new file mode 100644
index 0000000000..46f7c8aa87
--- /dev/null
+++ b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py
@@ -0,0 +1,311 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm.testing
+
+from tvm import relax, tir
+from tvm.script import relax as R, tir as T, ir as I
+from tvm.relax.transform import CombineParallelMatmul
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder import relax as relax_builder
+
+
+def test_param():
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def matmul(
+            A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+        ):
+            for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+                with T.block("C"):
+                    with T.init():
+                        C[i, j] = T.float32(0)
+                    C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32"), y: R.Tensor((32, 32), 
"float32")):
+            R.func_attr({"num_input": 1})
+            cls = Before
+            with R.dataflow():
+                gv = R.call_tir(cls.matmul, (x, y), out_sinfo=R.Tensor((32, 
32), "float32"))
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def matmul1(
+            A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+        ):
+            T.func_attr({"layout_free_buffers": [1]})
+            for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+                with T.block("C"):
+                    with T.init():
+                        C[i, j] = T.float32(0)
+                    C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32"), y: R.Tensor((32, 32), 
"float32")):
+            R.func_attr({"num_input": 1})
+            cls = Expected
+            with R.dataflow():
+                gv = R.call_tir(cls.matmul1, (x, y), out_sinfo=R.Tensor((32, 
32), "float32"))
+                R.output(gv)
+            return gv
+
+    after = relax.transform.AttachAttrLayoutFreeBuffers()(Before)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_const():
+    const_value = np.ones((32, 32), dtype="float32")
+
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def matmul(
+            A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+        ):
+            for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+                with T.block("C"):
+                    with T.init():
+                        C[i, j] = T.float32(0)
+                    C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")):
+            R.func_attr({"num_input": 1})
+            cls = Before
+            with R.dataflow():
+                gv = R.call_tir(
+                    cls.matmul,
+                    (x, relax.const(const_value)),
+                    out_sinfo=R.Tensor((32, 32), "float32"),
+                )
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def matmul1(
+            A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+        ):
+            T.func_attr({"layout_free_buffers": [1]})
+            for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+                with T.block("C"):
+                    with T.init():
+                        C[i, j] = T.float32(0)
+                    C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+        @R.function
+        def main(x: R.Tensor((32, 32), "float32")):
+            R.func_attr({"num_input": 1})
+            cls = Expected
+            with R.dataflow():
+                gv = R.call_tir(
+                    cls.matmul1,
+                    (x, relax.const(const_value)),
+                    out_sinfo=R.Tensor((32, 32), "float32"),
+                )
+                R.output(gv)
+            return gv
+
+    after = relax.transform.AttachAttrLayoutFreeBuffers()(Before)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_multiple_same_func():
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def matmul(
+            A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+        ):
+            for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+                with T.block("C"):
+                    with T.init():
+                        C[i, j] = T.float32(0)
+                    C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+        @R.function
+        def main(
+            x: R.Tensor((32, 32), "float32"),
+            w1: R.Tensor((32, 32), "float32"),
+            w2: R.Tensor((32, 32), "float32"),
+        ):
+            R.func_attr({"num_input": 1})
+            cls = Before
+            with R.dataflow():
+                lv1 = R.call_tir(
+                    cls.matmul,
+                    (x, w1),
+                    out_sinfo=R.Tensor((32, 32), "float32"),
+                )
+                gv = R.call_tir(
+                    cls.matmul,
+                    (lv1, w2),
+                    out_sinfo=R.Tensor((32, 32), "float32"),
+                )
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def matmul1(
+            A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+        ):
+            T.func_attr({"layout_free_buffers": [1]})
+            for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+                with T.block("C"):
+                    with T.init():
+                        C[i, j] = T.float32(0)
+                    C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+        @R.function
+        def main(
+            x: R.Tensor((32, 32), "float32"),
+            w1: R.Tensor((32, 32), "float32"),
+            w2: R.Tensor((32, 32), "float32"),
+        ):
+            R.func_attr({"num_input": 1})
+            cls = Expected
+            with R.dataflow():
+                lv1 = R.call_tir(
+                    cls.matmul1,
+                    (x, w1),
+                    out_sinfo=R.Tensor((32, 32), "float32"),
+                )
+                gv = R.call_tir(
+                    cls.matmul1,
+                    (lv1, w2),
+                    out_sinfo=R.Tensor((32, 32), "float32"),
+                )
+                R.output(gv)
+            return gv
+
+    after = relax.transform.AttachAttrLayoutFreeBuffers()(Before)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
+def test_multiple_same_func_with_different_free_buffers():
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def matmul(
+            A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+        ):
+            for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+                with T.block("C"):
+                    with T.init():
+                        C[i, j] = T.float32(0)
+                    C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+        @R.function
+        def main(
+            x: R.Tensor((32, 32), "float32"),
+            w1: R.Tensor((32, 32), "float32"),
+            w2: R.Tensor((32, 32), "float32"),
+        ):
+            R.func_attr({"num_input": 1})
+            cls = Before
+            with R.dataflow():
+                lv1 = R.call_tir(
+                    cls.matmul,
+                    (x, w1),
+                    out_sinfo=R.Tensor((32, 32), "float32"),
+                )
+                gv = R.call_tir(
+                    cls.matmul,
+                    (w2, lv1),
+                    out_sinfo=R.Tensor((32, 32), "float32"),
+                )
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def matmul1(
+            A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+        ):
+            T.func_attr({"layout_free_buffers": [1]})
+            for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+                with T.block("C"):
+                    with T.init():
+                        C[i, j] = T.float32(0)
+                    C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+        @T.prim_func(private=True)
+        def matmul2(
+            A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+            C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
+        ):
+            T.func_attr({"layout_free_buffers": [0]})
+            for i, j, k in T.grid(T.int64(32), T.int64(32), T.int64(32)):
+                with T.block("C"):
+                    with T.init():
+                        C[i, j] = T.float32(0)
+                    C[i, j] = C[i, j] + A[i, k] * B[k, j]
+
+        @R.function
+        def main(
+            x: R.Tensor((32, 32), "float32"),
+            w1: R.Tensor((32, 32), "float32"),
+            w2: R.Tensor((32, 32), "float32"),
+        ):
+            R.func_attr({"num_input": 1})
+            cls = Expected
+            with R.dataflow():
+                lv1 = R.call_tir(
+                    cls.matmul1,
+                    (x, w1),
+                    out_sinfo=R.Tensor((32, 32), "float32"),
+                )
+                gv = R.call_tir(
+                    cls.matmul2,
+                    (w2, lv1),
+                    out_sinfo=R.Tensor((32, 32), "float32"),
+                )
+                R.output(gv)
+            return gv
+
+    after = relax.transform.AttachAttrLayoutFreeBuffers()(Before)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py 
b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py
new file mode 100644
index 0000000000..e6b4c8ec4e
--- /dev/null
+++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py
@@ -0,0 +1,220 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def test_single_buffer():
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def tir_func(
+            X: T.Buffer((224, 224), "float32"),
+            W: T.Buffer((224, 224), "float32"),
+            Out: T.Buffer((224, 224), "float32"),
+        ):
+            T.func_attr({"layout_free_buffers": [1]})
+            W_rewrite = T.alloc_buffer((4, 4, 56, 56))
+            for i, j in T.grid(224, 224):
+                with T.block("W_rewrite"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    T.block_attr({"meta_schedule.layout_rewrite_preproc": 
T.bool(True)})
+                    W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj]
+            for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
+                with T.block("Out"):
+                    vi = T.axis.spatial(224, i0 * 56 + i1)
+                    vj = T.axis.spatial(224, j0 * 56 + j1)
+                    Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi 
% 56, vj % 56]
+
+        @R.function
+        def forward(
+            x: R.Tensor((224, 224), dtype="float32"),
+            w: R.Tensor((224, 224), dtype="float32"),
+        ) -> R.Tensor((224, 224), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            cls = Before
+            with R.dataflow():
+                gv = R.call_tir(
+                    cls.tir_func, (x, w), out_sinfo=R.Tensor((224, 224), 
dtype="float32")
+                )
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class After:
+        @T.prim_func(private=True)
+        def tir_func_prepacked(
+            X: T.Buffer((224, 224), "float32"),
+            W_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+            Out: T.Buffer((224, 224), "float32"),
+        ):
+            for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
+                with T.block("Out"):
+                    vi = T.axis.spatial(224, i0 * 56 + i1)
+                    vj = T.axis.spatial(224, j0 * 56 + j1)
+                    Out[vi, vj] = X[vi, vj] + W_rewrite[vi // 56, vj // 56, vi 
% 56, vj % 56]
+
+        @T.prim_func(private=True)
+        def tir_func_weight_prepack(
+            W: T.Buffer((224, 224), "float32"),
+            W_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+        ):
+            for i, j in T.grid(224, 224):
+                with T.block("W_rewrite"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    W_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W[vi, vj]
+
+        @R.function
+        def forward(
+            x: R.Tensor((224, 224), dtype="float32"),
+            w: R.Tensor((224, 224), dtype="float32"),
+        ) -> R.Tensor((224, 224), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            cls = After
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.tir_func_weight_prepack, (w,), out_sinfo=R.Tensor((4, 
4, 56, 56), "float32")
+                )
+                lv1 = R.call_tir(
+                    cls.tir_func_prepacked, (x, lv), out_sinfo=R.Tensor((224, 
224), "float32")
+                )
+                gv: R.Tensor((224, 224), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    mod = relax.transform.SplitLayoutRewritePreproc()(Before)
+    tvm.ir.assert_structural_equal(mod, After)
+
+
+def test_multiple_buffers():
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def tir_func(
+            X: T.Buffer((224, 224), "float32"),
+            W1: T.Buffer((224, 224), "float32"),
+            W2: T.Buffer((224, 224), "float32"),
+            Out: T.Buffer((224, 224), "float32"),
+        ):
+            W1_rewrite = T.alloc_buffer((4, 4, 56, 56))
+            W2_rewrite = T.alloc_buffer((4, 4, 56, 56))
+            for i, j in T.grid(224, 224):
+                with T.block("W1_rewrite"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    T.block_attr({"meta_schedule.layout_rewrite_preproc": 
T.bool(True)})
+                    W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi, 
vj]
+            for i, j in T.grid(224, 224):
+                with T.block("W2_rewrite"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    T.block_attr({"meta_schedule.layout_rewrite_preproc": 
T.bool(True)})
+                    W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi, 
vj]
+            for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
+                with T.block("Out"):
+                    vi = T.axis.spatial(224, i0 * 56 + i1)
+                    vj = T.axis.spatial(224, j0 * 56 + j1)
+                    Out[vi, vj] = (
+                        X[vi, vj]
+                        + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
+                        + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
+                    )
+
+        @R.function
+        def forward(
+            x: R.Tensor((224, 224), dtype="float32"),
+            w1: R.Tensor((224, 224), dtype="float32"),
+            w2: R.Tensor((224, 224), dtype="float32"),
+        ) -> R.Tensor((224, 224), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            cls = Before
+            with R.dataflow():
+                gv = R.call_tir(
+                    cls.tir_func, (x, w1, w2), out_sinfo=R.Tensor((224, 224), 
dtype="float32")
+                )
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class After:
+        @T.prim_func(private=True)
+        def tir_func_prepacked(
+            X: T.Buffer((224, 224), "float32"),
+            W1_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+            W2_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+            Out: T.Buffer((224, 224), "float32"),
+        ):
+            for i0, j0, i1, j1 in T.grid(4, 4, 56, 56):
+                with T.block("Out"):
+                    vi = T.axis.spatial(224, i0 * 56 + i1)
+                    vj = T.axis.spatial(224, j0 * 56 + j1)
+                    Out[vi, vj] = (
+                        X[vi, vj]
+                        + W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
+                        + W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56]
+                    )
+
+        @T.prim_func(private=True)
+        def tir_func_weight_prepack(
+            W1: T.Buffer((224, 224), "float32"),
+            W2: T.Buffer((224, 224), "float32"),
+            W1_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+            W2_rewrite: T.Buffer((4, 4, 56, 56), "float32"),
+        ):
+            for i, j in T.grid(224, 224):
+                with T.block("W1_rewrite"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    W1_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W1[vi, 
vj]
+            for i, j in T.grid(224, 224):
+                with T.block("W2_rewrite"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    W2_rewrite[vi // 56, vj // 56, vi % 56, vj % 56] = W2[vi, 
vj]
+
+        @R.function
+        def forward(
+            x: R.Tensor((224, 224), dtype="float32"),
+            w1: R.Tensor((224, 224), dtype="float32"),
+            w2: R.Tensor((224, 224), dtype="float32"),
+        ) -> R.Tensor((224, 224), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            cls = After
+            with R.dataflow():
+                lv0 = R.call_tir(
+                    cls.tir_func_weight_prepack,
+                    (w1, w2),
+                    out_sinfo=[
+                        R.Tensor((4, 4, 56, 56), "float32"),
+                        R.Tensor((4, 4, 56, 56), "float32"),
+                    ],
+                )
+                lv1 = R.call_tir(
+                    cls.tir_func_prepacked,
+                    (x, lv0[0], lv0[1]),
+                    out_sinfo=R.Tensor((224, 224), "float32"),
+                )
+                gv: R.Tensor((224, 224), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    mod = relax.transform.SplitLayoutRewritePreproc()(Before)
+    tvm.ir.assert_structural_equal(mod, After)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to