This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2c4af88563 [TVMScript] Sugar Var Definition in TIR Buffer (#14223)
2c4af88563 is described below
commit 2c4af88563a2e877f75daadb90278e80eb40aeeb
Author: Junru Shao <[email protected]>
AuthorDate: Tue Mar 7 05:37:07 2023 -0800
[TVMScript] Sugar Var Definition in TIR Buffer (#14223)
This PR introduces sugars for TIR buffer declaration.
Sugar 1. Convenient stride definition by string.
```python
// Previous
stride = T.int32()
A = T.match_buffer(..., strides=(stride,))
// This PR
A = T.match_buffer(..., strides=("s0", ))
```
Sugar 2. Multiple definition of TIR Var from a single buffer are now
merged into a single line.
```python
// Previous
m = T.int32()
n = T.int32()
stride = T.int32()
stride_1 = T.int32()
A = T.match_buffer(a, (m, n), strides=(stride, stride_1))
// This PR
m, n = T.int32(), T.int32()
A = T.match_buffer(a, (m, n), strides=("s0", "s1"))
```
---
python/tvm/script/ir_builder/tir/ir.py | 16 ++-
src/script/ir_builder/tir/ir.cc | 4 +-
.../printer/doc_printer/python_doc_printer.cc | 6 +-
src/script/printer/tir/buffer.cc | 140 ++++++++++++++++-----
src/script/printer/tir/expr.cc | 51 ++++----
src/script/printer/tir/function.cc | 7 +-
src/script/printer/tir/utils.h | 9 ++
tests/python/unittest/test_tvmscript_roundtrip.py | 53 ++++++++
8 files changed, 222 insertions(+), 64 deletions(-)
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index e88597732c..7826278431 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -138,6 +138,10 @@ def buffer(
The declared buffer.
"""
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ if strides is not None:
+ strides = [Var(s, "int32") if isinstance(s, str) else s for s in
strides]
+ else:
+ strides = []
return _ffi_api.Buffer( # type: ignore[attr-defined] # pylint:
disable=no-member
shape,
dtype,
@@ -304,7 +308,9 @@ def match_buffer(
else:
raise ValueError("Shape must be specified when binding input
param")
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
- if strides is None:
+ if strides is not None:
+ strides = [Var(s, "int32") if isinstance(s, str) else s for s in
strides]
+ else:
strides = []
return _ffi_api.MatchBuffer( # type: ignore[attr-defined] # pylint:
disable=no-member
param,
@@ -472,7 +478,9 @@ def alloc_buffer(
The allocated buffer.
"""
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
- if strides is None:
+ if strides is not None:
+ strides = [Var(s, "int32") if isinstance(s, str) else s for s in
strides]
+ else:
strides = []
return _ffi_api.AllocBuffer( # type: ignore[attr-defined] # pylint:
disable=no-member
shape,
@@ -1169,6 +1177,10 @@ def decl_buffer(
The result DeclBufferFrame.
"""
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
+ if strides is not None:
+ strides = [Var(s, "int32") if isinstance(s, str) else s for s in
strides]
+ else:
+ strides = []
return _ffi_api.DeclBuffer( # type: ignore[attr-defined] # pylint:
disable=no-member
shape,
dtype,
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 487265bff2..8d6c51be3a 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -32,6 +32,8 @@ Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype,
String buffer_name, Opt
Optional<Array<PrimExpr>> strides, Optional<PrimExpr>
elem_offset,
String storage_scope, int align, int offset_factor, String
buffer_type,
Optional<Array<IntImm>> axis_separators) {
+ CHECK(buffer_type == "auto" || buffer_type == "default" ||
buffer_type.empty())
+ << "ValueError: `buffer_type` must be `auto` or `default` or empty";
Var buffer_data;
if (!data.defined()) {
DataType storage_dtype = dtype;
@@ -48,7 +50,7 @@ Buffer BufferDecl(Array<PrimExpr> shape, DataType dtype,
String buffer_name, Opt
}
return Buffer(buffer_data, dtype, shape, strides.value_or(Array<PrimExpr>()),
elem_offset.value_or(PrimExpr()), buffer_name, align,
offset_factor,
- (buffer_type == "auto_broadcast") ? tvm::tir::kAutoBroadcast :
tvm::tir::kDefault,
+ (buffer_type == "auto" ? tvm::tir::kAutoBroadcast :
tvm::tir::kDefault),
axis_separators.value_or(Array<IntImm>()));
}
diff --git a/src/script/printer/doc_printer/python_doc_printer.cc
b/src/script/printer/doc_printer/python_doc_printer.cc
index e9a3b3567e..994d048a2e 100644
--- a/src/script/printer/doc_printer/python_doc_printer.cc
+++ b/src/script/printer/doc_printer/python_doc_printer.cc
@@ -548,7 +548,11 @@ void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc)
{
}
if (doc->rhs) {
output_ << " = ";
- PrintDoc(doc->rhs.value());
+ if (const auto* tuple_doc = doc->rhs.as<TupleDocNode>()) {
+ PrintJoinedDocs(tuple_doc->elements, ", ");
+ } else {
+ PrintDoc(doc->rhs.value());
+ }
}
MaybePrintCommentInline(doc);
}
diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc
index 19f3dc7ef5..ed8b707176 100644
--- a/src/script/printer/tir/buffer.cc
+++ b/src/script/printer/tir/buffer.cc
@@ -24,55 +24,120 @@ namespace tvm {
namespace script {
namespace printer {
-Map<String, ExprDoc> BufferAttrs(const tir::Buffer& buffer, const ObjectPath&
p, const Frame& frame,
+Map<String, ExprDoc> BufferAttrs(tir::Buffer buffer, const ObjectPath&
buffer_p, const Frame& frame,
const IRDocsifier& d) {
+ using tvm::tir::Var;
+ using tvm::tir::VarNode;
Map<String, ExprDoc> kwargs;
- auto implicit_var_def = [&](const PrimExpr& e, const ObjectPath& p, const
String& key) {
- if (Optional<ExprDoc> doc = d->GetVarDoc(e)) {
- kwargs.Set(key, doc.value());
- return false;
- }
- if (e->IsInstance<tir::VarNode>()) {
- d->Define(e, frame, [=]() { return d->AsDoc<IdDoc>(buffer,
p)->Attr(key); });
+ Array<ExprDoc> var_def_lhs;
+ Array<ExprDoc> var_def_rhs;
+
+ // Step 0. Set up statistics
+ std::unordered_map<const Object*, int> use_count;
+ auto update_use_count = [&](const PrimExpr& e) {
+ tir::PostOrderVisit(e, [&](const ObjectRef& n) {
+ if (const VarNode* var = n.as<VarNode>()) {
+ ++use_count[var];
+ }
+ });
+ };
+ update_use_count(buffer->elem_offset);
+ update_use_count(buffer->data);
+ for (const PrimExpr& e : buffer->strides) {
+ update_use_count(e);
+ }
+ for (const PrimExpr& e : buffer->shape) {
+ update_use_count(e);
+ }
+ auto is_new_var = [&](const PrimExpr& e) {
+ return e->IsInstance<VarNode>() && !d->IsVarDefined(e);
+ };
+ auto add_out_of_line_var_def = [&](const Var& var, const ObjectPath& var_p) {
+ ICHECK(!d->IsVarDefined(var));
+ ExprDoc lhs = DefineVar(var, frame, d);
+ lhs->source_paths.push_back(var_p);
+ var_def_lhs.push_back(lhs);
+ var_def_rhs.push_back(PrintVarCreation(var, var_p, d));
+ };
+ auto try_inline_def = [&](const PrimExpr& e, const ObjectPath& e_p,
+ std::function<ExprDoc()> inline_f) {
+ ICHECK(is_new_var(e));
+ Var var = Downcast<Var>(e);
+ if (use_count[var.get()] == 1) {
+ d->Define(e, frame, inline_f);
return true;
+ } else {
+ add_out_of_line_var_def(var, e_p);
+ return false;
}
- kwargs.Set(key, d->AsDoc<ExprDoc>(e, p));
- return false;
};
- auto array_out_line_var_def = [&](const Array<PrimExpr>& array, const
ObjectPath& p,
- const String& key) {
- int n = array.size();
+ // Step 1. Handle `buffer.shape`
+ {
+ const Array<PrimExpr>& shape = buffer->shape;
+ ObjectPath shape_p = buffer_p->Attr("shape");
+ int n = shape.size();
Array<ExprDoc> results;
results.reserve(n);
for (int i = 0; i < n; ++i) {
- PrimExpr s = array[i];
- ObjectPath s_path = p->ArrayIndex(i);
- // Add out-of-line definition for a new Var in shape
- results.push_back(d->AsDoc<ExprDoc>(s, s_path));
+ PrimExpr e = shape[i];
+ ObjectPath e_p = shape_p->ArrayIndex(i);
+ if (is_new_var(e)) {
+ add_out_of_line_var_def(Downcast<Var>(e), e_p);
+ }
+ results.push_back(d->AsDoc<ExprDoc>(e, e_p));
}
- kwargs.Set(key, TupleDoc(results));
- };
- // Step 1. Handle `buffer.shape`
- array_out_line_var_def(buffer->shape, p->Attr("shape"), "shape");
+ kwargs.Set("shape", TupleDoc(results));
+ }
// Step 2. Handle `buffer.dtype`
if (buffer->dtype != d->cfg->buffer_dtype) {
- kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, p->Attr("dtype")));
+ kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype,
buffer_p->Attr("dtype")));
}
// Step 3. Handle `buffer.data`
- implicit_var_def(buffer->data, p->Attr("data"), "data");
+ if (!is_new_var(buffer->data)) {
+ kwargs.Set("data", d->AsDoc<ExprDoc>(buffer->data,
buffer_p->Attr("data")));
+ } else {
+ try_inline_def(buffer->data, buffer_p->Attr("data"),
+ [=]() { return d->AsDoc<ExprDoc>(buffer,
buffer_p)->Attr("data"); });
+ }
// Step 4. Handle `buffer.strides`
if (!buffer->strides.empty()) {
- array_out_line_var_def(buffer->strides, p->Attr("strides"), "strides");
+ const Array<PrimExpr>& strides = buffer->strides;
+ ObjectPath strides_p = buffer_p->Attr("strides");
+ int n = strides.size();
+ Array<ExprDoc> results;
+ results.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ PrimExpr e = strides[i];
+ ObjectPath e_p = strides_p->ArrayIndex(i);
+ if (is_new_var(e)) {
+ if (try_inline_def(e, e_p, [=]() {
+ return d->AsDoc<ExprDoc>(buffer, buffer_p)
+ ->Attr("strides")[{LiteralDoc::Int(i, NullOpt)}];
+ })) {
+ results.push_back(LiteralDoc::Str(Downcast<Var>(e)->name_hint, e_p));
+ continue;
+ }
+ }
+ results.push_back(d->AsDoc<ExprDoc>(e, e_p));
+ }
+ kwargs.Set("strides", TupleDoc(results));
}
// Step 5. Handle `buffer.elem_offset`
bool needs_print_factor = false;
if (const auto* int_imm = buffer->elem_offset.as<IntImmNode>()) {
if (int_imm->value != 0) {
- kwargs.Set("elem_offset", d->AsDoc<ExprDoc>(buffer->elem_offset,
p->Attr("elem_offset")));
+ kwargs.Set("elem_offset",
+ d->AsDoc<ExprDoc>(buffer->elem_offset, //
+ buffer_p->Attr("elem_offset")));
}
+ } else if (is_new_var(buffer->elem_offset)) {
+ try_inline_def(buffer->elem_offset, buffer_p->Attr("elem_offset"),
+ [=]() { return d->AsDoc<ExprDoc>(buffer,
buffer_p)->Attr("elem_offset"); });
+ needs_print_factor = true;
} else {
- needs_print_factor =
- implicit_var_def(buffer->elem_offset, p->Attr("elem_offset"),
"elem_offset");
+ kwargs.Set("elem_offset",
+ d->AsDoc<ExprDoc>(buffer->elem_offset, //
+ buffer_p->Attr("elem_offset")));
}
// Step 6. Handle `buffer.scope`
{
@@ -80,25 +145,32 @@ Map<String, ExprDoc> BufferAttrs(const tir::Buffer&
buffer, const ObjectPath& p,
if (scope != "global") {
kwargs.Set(
"scope",
- LiteralDoc::Str(scope,
p->Attr("data")->Attr("type_annotation")->Attr("storage_scope")));
+ LiteralDoc::Str(scope,
+
buffer_p->Attr("data")->Attr("type_annotation")->Attr("storage_scope")));
}
}
// Step 7. Handle `buffer.data_alignment`
if (buffer->data_alignment != runtime::kAllocAlignment) {
- kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment,
p->Attr("data_alignment")));
+ kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment,
buffer_p->Attr("data_alignment")));
}
// Step 8. Handle `buffer.offset_factor`
if (needs_print_factor || buffer->offset_factor != 1) {
- kwargs.Set("offset_factor", LiteralDoc::Int(buffer->offset_factor,
p->Attr("offset_factor")));
+ kwargs.Set("offset_factor",
+ LiteralDoc::Int(buffer->offset_factor,
buffer_p->Attr("offset_factor")));
}
// Step 9. Handle `buffer.buffer_type`
if (buffer->buffer_type != tir::BufferType::kDefault) {
- kwargs.Set("type", LiteralDoc::Str("auto", p->Attr("buffer_type")));
+ kwargs.Set("buffer_type", LiteralDoc::Str("auto",
buffer_p->Attr("buffer_type")));
}
// Step 10. Handle `buffer.axis_separator`
if (!buffer->axis_separators.empty()) {
kwargs.Set("axis_separators",
- d->AsDoc<ExprDoc>(buffer->axis_separators,
p->Attr("axis_separators")));
+ d->AsDoc<ExprDoc>(buffer->axis_separators,
buffer_p->Attr("axis_separators")));
+ }
+ if (var_def_lhs.size() == 1) {
+ frame->stmts.push_back(AssignDoc(var_def_lhs[0], var_def_rhs[0], NullOpt));
+ } else if (var_def_lhs.size() > 1) {
+ frame->stmts.push_back(AssignDoc(TupleDoc(var_def_lhs),
TupleDoc(var_def_rhs), NullOpt));
}
return kwargs;
}
@@ -111,8 +183,8 @@ ExprDoc BufferCall(const ExprDoc& prefix, const Map<String,
ExprDoc>& attrs, Arr
args.push_back(doc.value());
}
}
- for (String s : {"data", "strides", "elem_offset", "scope", "align",
"offset_factor", "type",
- "axis_separators"}) {
+ for (String s : {"data", "strides", "elem_offset", "scope", "align",
"offset_factor",
+ "buffer_type", "axis_separators"}) {
if (Optional<ExprDoc> doc = attrs.Get(s)) {
kwargs_keys.push_back(s);
kwargs_values.push_back(doc.value());
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index 02ec269b0e..655f69c32d 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -24,33 +24,38 @@ namespace tvm {
namespace script {
namespace printer {
+ExprDoc PrintVarCreation(const tir::Var& var, const ObjectPath& var_p, const
IRDocsifier& d) {
+ Type type = var->type_annotation;
+ ObjectPath type_p = var_p->Attr("type_annotation");
+ ExprDoc rhs{nullptr};
+ if (const auto* ptr_type = type.as<PointerTypeNode>()) {
+ const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
+ ICHECK(prim_type);
+ ExprDoc element_type =
+ LiteralDoc::DataType(prim_type->dtype,
type_p->Attr("element_type")->Attr("dtype"));
+ rhs = TIR(d, "handle");
+ rhs->source_paths.push_back(var_p->Attr("dtype"));
+ if (ptr_type->storage_scope == "") {
+ rhs = rhs->Call({element_type});
+ } else {
+ rhs = rhs->Call({element_type,
+ LiteralDoc::Str(ptr_type->storage_scope, //
+ type_p->Attr("storage_scope"))});
+ }
+ } else {
+ rhs = TIR(d, DType2Str(var->dtype));
+ rhs->source_paths.push_back(var_p->Attr("dtype"));
+ rhs = rhs->Call({});
+ }
+ rhs->source_paths.push_back(type_p);
+ return rhs;
+}
+
Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier&
d) {
if (!d->IsVarDefined(var)) {
if (Optional<Frame> opt_f = FindLowestVarDef(var, d)) {
ExprDoc lhs = DefineVar(var, opt_f.value(), d);
- Type type = var->type_annotation;
- ObjectPath type_p = var_p->Attr("type_annotation");
- ExprDoc rhs{nullptr};
- if (const auto* ptr_type = type.as<PointerTypeNode>()) {
- const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
- ICHECK(prim_type);
- ExprDoc element_type =
- LiteralDoc::DataType(prim_type->dtype,
type_p->Attr("element_type")->Attr("dtype"));
- rhs = TIR(d, "handle");
- rhs->source_paths.push_back(var_p->Attr("dtype"));
- if (ptr_type->storage_scope == "") {
- rhs = rhs->Call({element_type});
- } else {
- rhs = rhs->Call({element_type,
- LiteralDoc::Str(ptr_type->storage_scope, //
- type_p->Attr("storage_scope"))});
- }
- } else {
- rhs = TIR(d, DType2Str(var->dtype));
- rhs->source_paths.push_back(var_p->Attr("dtype"));
- rhs = rhs->Call({});
- }
- rhs->source_paths.push_back(type_p);
+ ExprDoc rhs = PrintVarCreation(var, var_p, d);
opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
} else {
LOG(WARNING) << "Didn't find variable definition for: " <<
var->name_hint;
diff --git a/src/script/printer/tir/function.cc
b/src/script/printer/tir/function.cc
index 6a4df34a3a..f40d7818d7 100644
--- a/src/script/printer/tir/function.cc
+++ b/src/script/printer/tir/function.cc
@@ -92,8 +92,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
tir::Buffer buffer = func->buffer_map[var];
if (IsSimpleBuffer(buffer) &&
buffer_data_counter.at(buffer->data.get()) == 1) {
ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var);
- args.push_back(AssignDoc(DefineBuffer(buffer, *f, d), NullOpt,
- BufferAttn(buffer, buffer_p, *f, d)));
+ IdDoc lhs = DefineBuffer(buffer, *f, d);
+ ExprDoc annotation = BufferAttn(buffer, buffer_p, *f, d);
+ args.push_back(AssignDoc(lhs, NullOpt, annotation));
buffer_inlined.insert(buffer.get());
continue;
}
@@ -117,7 +118,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
ExprDoc param_doc = args[i]->lhs;
ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param);
- ExprDoc lhs = DefineBuffer(buffer, *f, d); // TODO(@junrushao):
switch `lhs` and `rhs`
+ ExprDoc lhs = DefineBuffer(buffer, *f, d);
ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param_doc},
buffer_p, *f, d);
(*f)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
}
diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h
index 08eb12bfa7..cee5fbd0f0 100644
--- a/src/script/printer/tir/utils.h
+++ b/src/script/printer/tir/utils.h
@@ -201,6 +201,15 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const
String& method, const Array<
ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const
Frame& frame,
const IRDocsifier& d);
+/*!
+ * \brief Print the creation of a Var
+ * \param var The Var to be printed
+ * \param var_p The object path of the Var
+ * \param d The IRDocsifier
+ * \return The ExprDoc corresponding to the Var creation
+ */
+ExprDoc PrintVarCreation(const tir::Var& var, const ObjectPath& var_p, const
IRDocsifier& d);
+
/*! \brief A Var occurrence counter visitor */
class OccurrenceCounter : public tir::StmtExprVisitor {
public:
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 464e8fd342..c956f3bb02 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3572,6 +3572,57 @@ def let_stmt_value():
return func
+def string_stride():
+ @T.prim_func
+ def main(a: T.handle, b: T.handle):
+ T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main",
"tir.noalias": True})
+ n = T.int32()
+ A = T.match_buffer(a, (n,), strides=("A_s0",), buffer_type="auto")
+ B = T.match_buffer(b, (n,), strides=("B_s0",), buffer_type="auto")
+ blockIdx_x = T.launch_thread("blockIdx.x", (n + 63) // 64)
+ threadIdx_x = T.launch_thread("threadIdx.x", 64)
+ if T.likely(blockIdx_x * 64 + threadIdx_x < n):
+ B2 = T.Buffer((B.strides[0] * n,), data=B.data)
+ A2 = T.Buffer((A.strides[0] * n,), data=A.data)
+ B2[(blockIdx_x * 64 + threadIdx_x) * B.strides[0]] = A2[
+ (blockIdx_x * 64 + threadIdx_x) * A.strides[0]
+ ] * T.float32(2)
+
+ return main
+
+
+def merge_shape_var_def():
+ @T.prim_func
+ def main(A: T.handle, B: T.handle):
+ T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main",
"tir.noalias": True})
+ m, n = T.int32(), T.int32()
+ A_1 = T.match_buffer(A, (m, n), strides=("A_1_s0", "A_1_s1"),
buffer_type="auto")
+ B_1 = T.match_buffer(B, (m, n), strides=("B_1_s0", "B_1_s1"),
buffer_type="auto")
+ for i_outer, j_outer, i_inner in T.grid((m + 9) // 10, (n + 4) // 5,
10):
+ if T.likely(i_outer * 10 + i_inner < m):
+ for j_inner in range(5):
+ if T.likely(j_outer * 5 + j_inner < n):
+ cse_var_2: T.int32 = j_outer * 5 + j_inner
+ cse_var_1: T.int32 = i_outer * 10 + i_inner
+ B_2 = T.Buffer(
+ (B_1.strides[0] * m,),
+ data=B_1.data,
+ strides=("B_2_s0",),
+ buffer_type="auto",
+ )
+ A_2 = T.Buffer(
+ (A_1.strides[0] * m,),
+ data=A_1.data,
+ strides=("A_2_s0",),
+ buffer_type="auto",
+ )
+ B_2[cse_var_1 * B_1.strides[0] + cse_var_2 *
B_1.strides[1]] = A_2[
+ cse_var_1 * A_1.strides[0] + cse_var_2 *
A_1.strides[1]
+ ]
+
+ return main
+
+
ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
@@ -3633,6 +3684,8 @@ ir_generator = tvm.testing.parameter(
intrinsic_pow,
let_stmt_var,
let_stmt_value,
+ string_stride,
+ merge_shape_var_def,
)