masahi commented on code in PR #12489:
URL: https://github.com/apache/tvm/pull/12489#discussion_r949074983


##########
src/tir/transforms/bind_params.cc:
##########
@@ -84,44 +84,57 @@ class ParamsCollector : public StmtExprVisitor {
   Map<tir::Var, runtime::NDArray> constant_map_;
 };
 
-namespace transform {
+PrimFunc BindParams(PrimFunc f, const Array<runtime::NDArray>& constants) {
+  Map<tir::Var, runtime::NDArray> constant_map;
 
-Pass BindParams(const Array<runtime::NDArray>& constants) {
-  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
-    Map<tir::Var, runtime::NDArray> constant_map;
-
-    // Remove constants from the primfunc signature
-    size_t num_constants = constants.size();
-    size_t start = f->params.size() - num_constants;
-    Array<tir::Var> params;
-    for (unsigned i = 0; i < start; i++) {
-      params.push_back(f->params[i]);
-    }
+  // Remove constants from the primfunc signature
+  size_t num_constants = constants.size();
+  size_t start = f->params.size() - num_constants;
+  Array<tir::Var> params;
+  for (unsigned i = 0; i < start; i++) {
+    params.push_back(f->params[i]);
+  }
 
-    auto* n = f.CopyOnWrite();
-    for (unsigned i = start; i < f->params.size(); i++) {
-      tir::Var p = n->params[i];
-      tir::Var b = n->buffer_map[p]->data;
-      n->buffer_map.erase(p);
-      constant_map.Set(b, constants[i - start]);
+  auto* n = f.CopyOnWrite();
+  for (unsigned i = start; i < f->params.size(); i++) {
+    tir::Var p = n->params[i];
+    tir::Var b = n->buffer_map[p]->data;
+    n->buffer_map.erase(p);
+    constant_map.Set(b, constants[i - start]);
+  }
+  n->params = params;
+  auto constant_list = ParamsCollector(constant_map).CollectParams(n->body);
+
+  // Allocate constants within the primfunc
+  for (auto i : constant_list) {
+    auto var = GetRef<Var>(i);
+    int ndim = constant_map[var]->ndim;
+    Array<PrimExpr> extents;
+
+    for (int i = 0; i < ndim; i++) {
+      int shape = constant_map[var]->shape[i];
+      extents.push_back(make_const(DataType::Int(32), shape));
     }
-    n->params = params;
-    auto constant_list = ParamsCollector(constant_map).CollectParams(n->body);
-
-    // Allocate constants within the primfunc
-    for (auto i : constant_list) {
-      auto var = GetRef<Var>(i);
-      int ndim = constant_map[var]->ndim;
-      Array<PrimExpr> extents;
-
-      for (int i = 0; i < ndim; i++) {
-        int shape = constant_map[var]->shape[i];
-        extents.push_back(make_const(DataType::Int(32), shape));
-      }
-      DataType dtype = DataType(constant_map[var]->dtype);
+    DataType dtype = DataType(constant_map[var]->dtype);
+
+    if (n->body->IsInstance<BlockRealizeNode>()) {
+      auto* block_realize = n->body.as<BlockRealizeNode>();
+      auto block = block_realize->block;
+      block.CopyOnWrite()->body =
+          tir::AllocateConst(var, dtype, extents, constant_map[var], 
block->body);
+      n->body = BlockRealize(block_realize->iter_values, 
block_realize->predicate, block);

Review Comment:
   Please note this change. This places `AllocateConst` at the beginning of the 
body of `BlockRealize`. I found that putting `BlockRealize` as the body of 
`AllocateConst` leads to many issues since many places in TIR code assumes that 
the body of a primfunc starts with `BlockRealize`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to