This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2c4af88563 [TVMScript] Sugar Var Definition in TIR Buffer (#14223)
2c4af88563 is described below

commit 2c4af88563a2e877f75daadb90278e80eb40aeeb
Author: Junru Shao <[email protected]>
AuthorDate: Tue Mar 7 05:37:07 2023 -0800

    [TVMScript] Sugar Var Definition in TIR Buffer (#14223)
    
    This PR introduces sugars for TIR buffer declaration.
    
    Sugar 1. Convenient stride definition by string.
    
    ```python
    // Previous
    stride = T.int32()
    A = T.match_buffer(..., strides=(stride,))
    
    // This PR
    A = T.match_buffer(..., strides=("s0", ))
    ```
    
    Sugar 2. Multiple definition of TIR Var from a single buffer are now
    merged into a single line.
    
    ```python
    // Previous
    m = T.int32()
    n = T.int32()
    stride = T.int32()
    stride_1 = T.int32()
    A = T.match_buffer(a, (m, n), strides=(stride, stride_1))
    
    // This PR
    m, n = T.int32(), T.int32()
    A = T.match_buffer(a, (m, n), strides=("s0", "s1"))
    ```
---
 python/tvm/script/ir_builder/tir/ir.py             |  16 ++-
 src/script/ir_builder/tir/ir.cc                    |   4 +-
 .../printer/doc_printer/python_doc_printer.cc      |   6 +-
 src/script/printer/tir/buffer.cc                   | 140 ++++++++++++++++-----
 src/script/printer/tir/expr.cc                     |  51 ++++----
 src/script/printer/tir/function.cc                 |   7 +-
 src/script/printer/tir/utils.h                     |   9 ++
 tests/python/unittest/test_tvmscript_roundtrip.py  |  53 ++++++++
 8 files changed, 222 insertions(+), 64 deletions(-)

diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index e88597732c..7826278431 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -138,6 +138,10 @@ def buffer(
         The declared buffer.
     """
     shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is not None:
+        strides = [Var(s, "int32") if isinstance(s, str) else s for s in 
strides]
+    else:
+        strides = []
     return _ffi_api.Buffer(  # type: ignore[attr-defined] # pylint: 
disable=no-member
         shape,
         dtype,
@@ -304,7 +308,9 @@ def match_buffer(
         else:
             raise ValueError("Shape must be specified when binding input 
param")
     shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
-    if strides is None:
+    if strides is not None:
+        strides = [Var(s, "int32") if isinstance(s, str) else s for s in 
strides]
+    else:
         strides = []
     return _ffi_api.MatchBuffer(  # type: ignore[attr-defined] # pylint: 
disable=no-member
         param,
@@ -472,7 +478,9 @@ def alloc_buffer(
         The allocated buffer.
     """
     shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
-    if strides is None:
+    if strides is not None:
+        strides = [Var(s, "int32") if isinstance(s, str) else s for s in 
strides]
+    else:
         strides = []
     return _ffi_api.AllocBuffer(  # type: ignore[attr-defined] # pylint: 
disable=no-member
         shape,
@@ -1169,6 +1177,10 @@ def decl_buffer(
         The result DeclBufferFrame.
     """
     shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+    if strides is not None:
+        strides = [Var(s, "int32") if isinstance(s, str) else s for s in 
strides]
+    else:
+        strides = []
     return _ffi_api.DeclBuffer(  # type: ignore[attr-defined] # pylint: 
disable=no-member
         shape,
         dtype,
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 487265bff2..8d6c51be3a 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -32,6 +32,8 @@ Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, 
String buffer_name, Opt
                   Optional<Array<PrimExpr>> strides, Optional<PrimExpr> 
elem_offset,
                   String storage_scope, int align, int offset_factor, String 
buffer_type,
                   Optional<Array<IntImm>> axis_separators) {
+  CHECK(buffer_type == "auto" || buffer_type == "default" || 
buffer_type.empty())
+      << "ValueError: `buffer_type` must be `auto` or `default` or empty";
   Var buffer_data;
   if (!data.defined()) {
     DataType storage_dtype = dtype;
@@ -48,7 +50,7 @@ Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype, 
String buffer_name, Opt
   }
   return Buffer(buffer_data, dtype, shape, strides.value_or(Array<PrimExpr>()),
                 elem_offset.value_or(PrimExpr()), buffer_name, align, 
offset_factor,
-                (buffer_type == "auto_broadcast") ? tvm::tir::kAutoBroadcast : 
tvm::tir::kDefault,
+                (buffer_type == "auto" ? tvm::tir::kAutoBroadcast : 
tvm::tir::kDefault),
                 axis_separators.value_or(Array<IntImm>()));
 }
 
diff --git a/src/script/printer/doc_printer/python_doc_printer.cc 
b/src/script/printer/doc_printer/python_doc_printer.cc
index e9a3b3567e..994d048a2e 100644
--- a/src/script/printer/doc_printer/python_doc_printer.cc
+++ b/src/script/printer/doc_printer/python_doc_printer.cc
@@ -548,7 +548,11 @@ void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) 
{
   }
   if (doc->rhs) {
     output_ << " = ";
-    PrintDoc(doc->rhs.value());
+    if (const auto* tuple_doc = doc->rhs.as<TupleDocNode>()) {
+      PrintJoinedDocs(tuple_doc->elements, ", ");
+    } else {
+      PrintDoc(doc->rhs.value());
+    }
   }
   MaybePrintCommentInline(doc);
 }
diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc
index 19f3dc7ef5..ed8b707176 100644
--- a/src/script/printer/tir/buffer.cc
+++ b/src/script/printer/tir/buffer.cc
@@ -24,55 +24,120 @@ namespace tvm {
 namespace script {
 namespace printer {
 
-Map<String, ExprDoc> BufferAttrs(const tir::Buffer& buffer, const ObjectPath& 
p, const Frame& frame,
+Map<String, ExprDoc> BufferAttrs(tir::Buffer buffer, const ObjectPath& 
buffer_p, const Frame& frame,
                                  const IRDocsifier& d) {
+  using tvm::tir::Var;
+  using tvm::tir::VarNode;
   Map<String, ExprDoc> kwargs;
-  auto implicit_var_def = [&](const PrimExpr& e, const ObjectPath& p, const 
String& key) {
-    if (Optional<ExprDoc> doc = d->GetVarDoc(e)) {
-      kwargs.Set(key, doc.value());
-      return false;
-    }
-    if (e->IsInstance<tir::VarNode>()) {
-      d->Define(e, frame, [=]() { return d->AsDoc<IdDoc>(buffer, 
p)->Attr(key); });
+  Array<ExprDoc> var_def_lhs;
+  Array<ExprDoc> var_def_rhs;
+
+  // Step 0. Set up statistics
+  std::unordered_map<const Object*, int> use_count;
+  auto update_use_count = [&](const PrimExpr& e) {
+    tir::PostOrderVisit(e, [&](const ObjectRef& n) {
+      if (const VarNode* var = n.as<VarNode>()) {
+        ++use_count[var];
+      }
+    });
+  };
+  update_use_count(buffer->elem_offset);
+  update_use_count(buffer->data);
+  for (const PrimExpr& e : buffer->strides) {
+    update_use_count(e);
+  }
+  for (const PrimExpr& e : buffer->shape) {
+    update_use_count(e);
+  }
+  auto is_new_var = [&](const PrimExpr& e) {
+    return e->IsInstance<VarNode>() && !d->IsVarDefined(e);
+  };
+  auto add_out_of_line_var_def = [&](const Var& var, const ObjectPath& var_p) {
+    ICHECK(!d->IsVarDefined(var));
+    ExprDoc lhs = DefineVar(var, frame, d);
+    lhs->source_paths.push_back(var_p);
+    var_def_lhs.push_back(lhs);
+    var_def_rhs.push_back(PrintVarCreation(var, var_p, d));
+  };
+  auto try_inline_def = [&](const PrimExpr& e, const ObjectPath& e_p,
+                            std::function<ExprDoc()> inline_f) {
+    ICHECK(is_new_var(e));
+    Var var = Downcast<Var>(e);
+    if (use_count[var.get()] == 1) {
+      d->Define(e, frame, inline_f);
       return true;
+    } else {
+      add_out_of_line_var_def(var, e_p);
+      return false;
     }
-    kwargs.Set(key, d->AsDoc<ExprDoc>(e, p));
-    return false;
   };
-  auto array_out_line_var_def = [&](const Array<PrimExpr>& array, const 
ObjectPath& p,
-                                    const String& key) {
-    int n = array.size();
+  // Step 1. Handle `buffer.shape`
+  {
+    const Array<PrimExpr>& shape = buffer->shape;
+    ObjectPath shape_p = buffer_p->Attr("shape");
+    int n = shape.size();
     Array<ExprDoc> results;
     results.reserve(n);
     for (int i = 0; i < n; ++i) {
-      PrimExpr s = array[i];
-      ObjectPath s_path = p->ArrayIndex(i);
-      // Add out-of-line definition for a new Var in shape
-      results.push_back(d->AsDoc<ExprDoc>(s, s_path));
+      PrimExpr e = shape[i];
+      ObjectPath e_p = shape_p->ArrayIndex(i);
+      if (is_new_var(e)) {
+        add_out_of_line_var_def(Downcast<Var>(e), e_p);
+      }
+      results.push_back(d->AsDoc<ExprDoc>(e, e_p));
     }
-    kwargs.Set(key, TupleDoc(results));
-  };
-  // Step 1. Handle `buffer.shape`
-  array_out_line_var_def(buffer->shape, p->Attr("shape"), "shape");
+    kwargs.Set("shape", TupleDoc(results));
+  }
   // Step 2. Handle `buffer.dtype`
   if (buffer->dtype != d->cfg->buffer_dtype) {
-    kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, p->Attr("dtype")));
+    kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, 
buffer_p->Attr("dtype")));
   }
   // Step 3. Handle `buffer.data`
-  implicit_var_def(buffer->data, p->Attr("data"), "data");
+  if (!is_new_var(buffer->data)) {
+    kwargs.Set("data", d->AsDoc<ExprDoc>(buffer->data, 
buffer_p->Attr("data")));
+  } else {
+    try_inline_def(buffer->data, buffer_p->Attr("data"),
+                   [=]() { return d->AsDoc<ExprDoc>(buffer, 
buffer_p)->Attr("data"); });
+  }
   // Step 4. Handle `buffer.strides`
   if (!buffer->strides.empty()) {
-    array_out_line_var_def(buffer->strides, p->Attr("strides"), "strides");
+    const Array<PrimExpr>& strides = buffer->strides;
+    ObjectPath strides_p = buffer_p->Attr("strides");
+    int n = strides.size();
+    Array<ExprDoc> results;
+    results.reserve(n);
+    for (int i = 0; i < n; ++i) {
+      PrimExpr e = strides[i];
+      ObjectPath e_p = strides_p->ArrayIndex(i);
+      if (is_new_var(e)) {
+        if (try_inline_def(e, e_p, [=]() {
+              return d->AsDoc<ExprDoc>(buffer, buffer_p)
+                  ->Attr("strides")[{LiteralDoc::Int(i, NullOpt)}];
+            })) {
+          results.push_back(LiteralDoc::Str(Downcast<Var>(e)->name_hint, e_p));
+          continue;
+        }
+      }
+      results.push_back(d->AsDoc<ExprDoc>(e, e_p));
+    }
+    kwargs.Set("strides", TupleDoc(results));
   }
   // Step 5. Handle `buffer.elem_offset`
   bool needs_print_factor = false;
   if (const auto* int_imm = buffer->elem_offset.as<IntImmNode>()) {
     if (int_imm->value != 0) {
-      kwargs.Set("elem_offset", d->AsDoc<ExprDoc>(buffer->elem_offset, 
p->Attr("elem_offset")));
+      kwargs.Set("elem_offset",
+                 d->AsDoc<ExprDoc>(buffer->elem_offset,  //
+                                   buffer_p->Attr("elem_offset")));
     }
+  } else if (is_new_var(buffer->elem_offset)) {
+    try_inline_def(buffer->elem_offset, buffer_p->Attr("elem_offset"),
+                   [=]() { return d->AsDoc<ExprDoc>(buffer, 
buffer_p)->Attr("elem_offset"); });
+    needs_print_factor = true;
   } else {
-    needs_print_factor =
-        implicit_var_def(buffer->elem_offset, p->Attr("elem_offset"), 
"elem_offset");
+    kwargs.Set("elem_offset",
+               d->AsDoc<ExprDoc>(buffer->elem_offset,  //
+                                 buffer_p->Attr("elem_offset")));
   }
   // Step 6. Handle `buffer.scope`
   {
@@ -80,25 +145,32 @@ Map<String, ExprDoc> BufferAttrs(const tir::Buffer& 
buffer, const ObjectPath& p,
     if (scope != "global") {
       kwargs.Set(
           "scope",
-          LiteralDoc::Str(scope, 
p->Attr("data")->Attr("type_annotation")->Attr("storage_scope")));
+          LiteralDoc::Str(scope,
+                          
buffer_p->Attr("data")->Attr("type_annotation")->Attr("storage_scope")));
     }
   }
   // Step 7. Handle `buffer.data_alignment`
   if (buffer->data_alignment != runtime::kAllocAlignment) {
-    kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment, 
p->Attr("data_alignment")));
+    kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment, 
buffer_p->Attr("data_alignment")));
   }
   // Step 8. Handle `buffer.offset_factor`
   if (needs_print_factor || buffer->offset_factor != 1) {
-    kwargs.Set("offset_factor", LiteralDoc::Int(buffer->offset_factor, 
p->Attr("offset_factor")));
+    kwargs.Set("offset_factor",
+               LiteralDoc::Int(buffer->offset_factor, 
buffer_p->Attr("offset_factor")));
   }
   // Step 9. Handle `buffer.buffer_type`
   if (buffer->buffer_type != tir::BufferType::kDefault) {
-    kwargs.Set("type", LiteralDoc::Str("auto", p->Attr("buffer_type")));
+    kwargs.Set("buffer_type", LiteralDoc::Str("auto", 
buffer_p->Attr("buffer_type")));
   }
   // Step 10. Handle `buffer.axis_separator`
   if (!buffer->axis_separators.empty()) {
     kwargs.Set("axis_separators",
-               d->AsDoc<ExprDoc>(buffer->axis_separators, 
p->Attr("axis_separators")));
+               d->AsDoc<ExprDoc>(buffer->axis_separators, 
buffer_p->Attr("axis_separators")));
+  }
+  if (var_def_lhs.size() == 1) {
+    frame->stmts.push_back(AssignDoc(var_def_lhs[0], var_def_rhs[0], NullOpt));
+  } else if (var_def_lhs.size() > 1) {
+    frame->stmts.push_back(AssignDoc(TupleDoc(var_def_lhs), 
TupleDoc(var_def_rhs), NullOpt));
   }
   return kwargs;
 }
@@ -111,8 +183,8 @@ ExprDoc BufferCall(const ExprDoc& prefix, const Map<String, 
ExprDoc>& attrs, Arr
       args.push_back(doc.value());
     }
   }
-  for (String s : {"data", "strides", "elem_offset", "scope", "align", 
"offset_factor", "type",
-                   "axis_separators"}) {
+  for (String s : {"data", "strides", "elem_offset", "scope", "align", 
"offset_factor",
+                   "buffer_type", "axis_separators"}) {
     if (Optional<ExprDoc> doc = attrs.Get(s)) {
       kwargs_keys.push_back(s);
       kwargs_values.push_back(doc.value());
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index 02ec269b0e..655f69c32d 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -24,33 +24,38 @@ namespace tvm {
 namespace script {
 namespace printer {
 
+ExprDoc PrintVarCreation(const tir::Var& var, const ObjectPath& var_p, const 
IRDocsifier& d) {
+  Type type = var->type_annotation;
+  ObjectPath type_p = var_p->Attr("type_annotation");
+  ExprDoc rhs{nullptr};
+  if (const auto* ptr_type = type.as<PointerTypeNode>()) {
+    const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
+    ICHECK(prim_type);
+    ExprDoc element_type =
+        LiteralDoc::DataType(prim_type->dtype, 
type_p->Attr("element_type")->Attr("dtype"));
+    rhs = TIR(d, "handle");
+    rhs->source_paths.push_back(var_p->Attr("dtype"));
+    if (ptr_type->storage_scope == "") {
+      rhs = rhs->Call({element_type});
+    } else {
+      rhs = rhs->Call({element_type,
+                       LiteralDoc::Str(ptr_type->storage_scope,  //
+                                       type_p->Attr("storage_scope"))});
+    }
+  } else {
+    rhs = TIR(d, DType2Str(var->dtype));
+    rhs->source_paths.push_back(var_p->Attr("dtype"));
+    rhs = rhs->Call({});
+  }
+  rhs->source_paths.push_back(type_p);
+  return rhs;
+}
+
 Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& 
d) {
   if (!d->IsVarDefined(var)) {
     if (Optional<Frame> opt_f = FindLowestVarDef(var, d)) {
       ExprDoc lhs = DefineVar(var, opt_f.value(), d);
-      Type type = var->type_annotation;
-      ObjectPath type_p = var_p->Attr("type_annotation");
-      ExprDoc rhs{nullptr};
-      if (const auto* ptr_type = type.as<PointerTypeNode>()) {
-        const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
-        ICHECK(prim_type);
-        ExprDoc element_type =
-            LiteralDoc::DataType(prim_type->dtype, 
type_p->Attr("element_type")->Attr("dtype"));
-        rhs = TIR(d, "handle");
-        rhs->source_paths.push_back(var_p->Attr("dtype"));
-        if (ptr_type->storage_scope == "") {
-          rhs = rhs->Call({element_type});
-        } else {
-          rhs = rhs->Call({element_type,
-                           LiteralDoc::Str(ptr_type->storage_scope,  //
-                                           type_p->Attr("storage_scope"))});
-        }
-      } else {
-        rhs = TIR(d, DType2Str(var->dtype));
-        rhs->source_paths.push_back(var_p->Attr("dtype"));
-        rhs = rhs->Call({});
-      }
-      rhs->source_paths.push_back(type_p);
+      ExprDoc rhs = PrintVarCreation(var, var_p, d);
       opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
     } else {
       LOG(WARNING) << "Didn't find variable definition for: " << 
var->name_hint;
diff --git a/src/script/printer/tir/function.cc 
b/src/script/printer/tir/function.cc
index 6a4df34a3a..f40d7818d7 100644
--- a/src/script/printer/tir/function.cc
+++ b/src/script/printer/tir/function.cc
@@ -92,8 +92,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
           tir::Buffer buffer = func->buffer_map[var];
           if (IsSimpleBuffer(buffer) && 
buffer_data_counter.at(buffer->data.get()) == 1) {
             ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var);
-            args.push_back(AssignDoc(DefineBuffer(buffer, *f, d), NullOpt,
-                                     BufferAttn(buffer, buffer_p, *f, d)));
+            IdDoc lhs = DefineBuffer(buffer, *f, d);
+            ExprDoc annotation = BufferAttn(buffer, buffer_p, *f, d);
+            args.push_back(AssignDoc(lhs, NullOpt, annotation));
             buffer_inlined.insert(buffer.get());
             continue;
           }
@@ -117,7 +118,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
           }
           ExprDoc param_doc = args[i]->lhs;
           ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param);
-          ExprDoc lhs = DefineBuffer(buffer, *f, d);  // TODO(@junrushao): 
switch `lhs` and `rhs`
+          ExprDoc lhs = DefineBuffer(buffer, *f, d);
           ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param_doc}, 
buffer_p, *f, d);
           (*f)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
         }
diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h
index 08eb12bfa7..cee5fbd0f0 100644
--- a/src/script/printer/tir/utils.h
+++ b/src/script/printer/tir/utils.h
@@ -201,6 +201,15 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const 
String& method, const Array<
 ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const 
Frame& frame,
                    const IRDocsifier& d);
 
+/*!
+ * \brief Print the creation of a Var
+ * \param var The Var to be printed
+ * \param var_p The object path of the Var
+ * \param d The IRDocsifier
+ * \return The ExprDoc corresponding to the Var creation
+ */
+ExprDoc PrintVarCreation(const tir::Var& var, const ObjectPath& var_p, const 
IRDocsifier& d);
+
 /*! \brief A Var occurrence counter visitor */
 class OccurrenceCounter : public tir::StmtExprVisitor {
  public:
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 464e8fd342..c956f3bb02 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3572,6 +3572,57 @@ def let_stmt_value():
     return func
 
 
+def string_stride():
+    @T.prim_func
+    def main(a: T.handle, b: T.handle):
+        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", 
"tir.noalias": True})
+        n = T.int32()
+        A = T.match_buffer(a, (n,), strides=("A_s0",), buffer_type="auto")
+        B = T.match_buffer(b, (n,), strides=("B_s0",), buffer_type="auto")
+        blockIdx_x = T.launch_thread("blockIdx.x", (n + 63) // 64)
+        threadIdx_x = T.launch_thread("threadIdx.x", 64)
+        if T.likely(blockIdx_x * 64 + threadIdx_x < n):
+            B2 = T.Buffer((B.strides[0] * n,), data=B.data)
+            A2 = T.Buffer((A.strides[0] * n,), data=A.data)
+            B2[(blockIdx_x * 64 + threadIdx_x) * B.strides[0]] = A2[
+                (blockIdx_x * 64 + threadIdx_x) * A.strides[0]
+            ] * T.float32(2)
+
+    return main
+
+
+def merge_shape_var_def():
+    @T.prim_func
+    def main(A: T.handle, B: T.handle):
+        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", 
"tir.noalias": True})
+        m, n = T.int32(), T.int32()
+        A_1 = T.match_buffer(A, (m, n), strides=("A_1_s0", "A_1_s1"), 
buffer_type="auto")
+        B_1 = T.match_buffer(B, (m, n), strides=("B_1_s0", "B_1_s1"), 
buffer_type="auto")
+        for i_outer, j_outer, i_inner in T.grid((m + 9) // 10, (n + 4) // 5, 
10):
+            if T.likely(i_outer * 10 + i_inner < m):
+                for j_inner in range(5):
+                    if T.likely(j_outer * 5 + j_inner < n):
+                        cse_var_2: T.int32 = j_outer * 5 + j_inner
+                        cse_var_1: T.int32 = i_outer * 10 + i_inner
+                        B_2 = T.Buffer(
+                            (B_1.strides[0] * m,),
+                            data=B_1.data,
+                            strides=("B_2_s0",),
+                            buffer_type="auto",
+                        )
+                        A_2 = T.Buffer(
+                            (A_1.strides[0] * m,),
+                            data=A_1.data,
+                            strides=("A_2_s0",),
+                            buffer_type="auto",
+                        )
+                        B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * 
B_1.strides[1]] = A_2[
+                            cse_var_1 * A_1.strides[0] + cse_var_2 * 
A_1.strides[1]
+                        ]
+
+    return main
+
+
 ir_generator = tvm.testing.parameter(
     launch_env_thread,
     opt_gemm_normalize,
@@ -3633,6 +3684,8 @@ ir_generator = tvm.testing.parameter(
     intrinsic_pow,
     let_stmt_var,
     let_stmt_value,
+    string_stride,
+    merge_shape_var_def,
 )
 
 

Reply via email to