This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new be66a7e0e4 [TVMScript] Sugar T.env_thread + T.launch_thread (#14217)
be66a7e0e4 is described below
commit be66a7e0e46f422b5eb199cfb70c725fe8a91825
Author: Junru Shao <[email protected]>
AuthorDate: Mon Mar 6 22:58:19 2023 -0800
[TVMScript] Sugar T.env_thread + T.launch_thread (#14217)
This PR introduces a syntactic sugar that combines T.env_thread and
T.launch_thread.
Previously, an AttrStmt that specifies thread extent or virtual thread
is required to be written in two steps:
```python
bx = T.env_thread("blockIdx.x") // creates an IterVar
with T.launch_thread(bx, 128): // specify the iter domain
...
```
With this PR, now this behavior can be merged in a single line:
```python
with T.launch_thread("blockIdx.x", 128) as bx:
...
```
---
include/tvm/script/ir_builder/tir/ir.h | 8 +++
python/tvm/script/ir_builder/tir/frame.py | 4 +-
python/tvm/script/ir_builder/tir/ir.py | 13 ++--
src/script/ir_builder/tir/ir.cc | 17 ++++-
src/script/printer/tir/stmt.cc | 75 +++++++++++++++++------
tests/python/unittest/test_tvmscript_roundtrip.py | 11 ++++
6 files changed, 101 insertions(+), 27 deletions(-)
diff --git a/include/tvm/script/ir_builder/tir/ir.h
b/include/tvm/script/ir_builder/tir/ir.h
index 2b89d0e736..8d8b0b42ba 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -390,6 +390,14 @@ DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType
dtype, String buffer_
*/
LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);
+/*!
+ * \brief Launch a new thread.
+ * \param thread_tag The thread type tag.
+ * \param extent The extent of environment thread.
+ * \return The result LaunchThreadFrame.
+ */
+LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent);
+
/*!
* \brief Bind a var to thread env.
* \param thread_tag The thread type tag.
diff --git a/python/tvm/script/ir_builder/tir/frame.py
b/python/tvm/script/ir_builder/tir/frame.py
index 3e453f2e51..b2229d503b 100644
--- a/python/tvm/script/ir_builder/tir/frame.py
+++ b/python/tvm/script/ir_builder/tir/frame.py
@@ -115,4 +115,6 @@ class DeclBufferFrame(TIRFrame):
@_register_object("script.ir_builder.tir.LaunchThreadFrame")
class LaunchThreadFrame(TIRFrame):
- ...
+ def __enter__(self) -> Var:
+ super().__enter__()
+ return self.iter_var.var
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index 62a0aa8f32..e88597732c 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -31,7 +31,7 @@ import numpy as np # type: ignore
from tvm import tir
from tvm.ir import Range, Type
from tvm.ir.base import deprecated
-from tvm.runtime import convert, ndarray
+from tvm.runtime import String, convert, ndarray
from tvm.target import Target
# pylint: disable=unused-import
@@ -1185,14 +1185,14 @@ def decl_buffer(
def launch_thread(
- iter_var: IterVar, # pylint: disable=redefined-outer-name
+ thread: Union[IterVar, str], # pylint: disable=redefined-outer-name
extent: PrimExpr,
) -> frame.LaunchThreadFrame:
"""Launch a thread.
Parameters
----------
- iter_var : IterVar
+ thread : Union[IterVar, str]
The iteration variable.
extent : PrimExpr
@@ -1213,11 +1213,14 @@ def launch_thread(
T.launch_thread(brow, 1)
"""
- return _ffi_api.LaunchThread(iter_var, extent) # type:
ignore[attr-defined] # pylint: disable=no-member
+
+ if isinstance(thread, str):
+ thread = String(thread)
+ return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined]
# pylint: disable=no-member
def env_thread(thread_tag: str) -> IterVar:
- """Bind a var to thread env"
+ """Bind a var to thread env
Parameters
----------
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index a54f3d926f..487265bff2 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -427,6 +427,10 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
return LaunchThreadFrame(n);
}
+LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) {
+ return LaunchThread(EnvThread(thread_tag), extent);
+}
+
RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
PrimExpr condition) {
ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>();
@@ -658,7 +662,18 @@
TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread")
+ .set_body_typed([](ObjectRef thread_tag_or_var, PrimExpr extent) {
+ if (const auto* var = thread_tag_or_var.as<tvm::tir::VarNode>()) {
+ return LaunchThread(GetRef<tvm::tir::Var>(var), extent);
+ } else if (const auto* str = thread_tag_or_var.as<StringObj>()) {
+ return LaunchThread(GetRef<String>(str), extent);
+ } else {
+ LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: "
+ << thread_tag_or_var->GetTypeKey();
+ throw;
+ }
+ });
TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore);
diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc
index 92ad41edc9..591d1e3bc1 100644
--- a/src/script/printer/tir/stmt.cc
+++ b/src/script/printer/tir/stmt.cc
@@ -45,6 +45,19 @@ bool AllowConciseScoping(const IRDocsifier& d) {
LOG(FATAL) << "NotImplementedError: fragment printing";
}
+bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const
IRDocsifier& d) {
+ if (!d->common_prefix.count(var.get())) {
+ return false;
+ }
+ const std::vector<const Object*>& path = d->common_prefix.at(var.get());
+ for (auto it = path.rbegin(); it != path.rend(); ++it) {
+ if (*it == node.get()) {
+ return true;
+ }
+ }
+ return false;
+}
+
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Evaluate>("", [](tir::Evaluate eval, ObjectPath p,
IRDocsifier d) -> Doc {
ExprDoc value = d->AsDoc<ExprDoc>(eval->value, p->Attr("value"));
@@ -322,6 +335,39 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode*
stmt, Optional<ExprDo
return TIR(d, "realize")->Call(args, kwargs_keys, kwargs_values);
}
+void InsertEnvThread(const tir::IterVar& iter_var, const ObjectPath&
iter_var_p,
+ const IRDocsifier& d) {
+ Frame f = FindLowestVarDef(iter_var->var, d).value();
+ DefineVar(iter_var->var, f, d);
+ ExprDoc rhs = TIR(d, "env_thread")
+ ->Call({LiteralDoc::Str(iter_var->thread_tag, //
+ iter_var_p->Attr("thread_tag"))});
+ ExprDoc lhs = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
+ f->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+}
+
+ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const ObjectPath&
attr_stmt_p,
+ Optional<tir::Var>* define_var, const IRDocsifier&
d) {
+ tir::IterVar iter_var = Downcast<tir::IterVar>(attr_stmt->node);
+ ObjectPath iter_var_p = attr_stmt_p->Attr("node");
+
+ ExprDoc var_doc{nullptr};
+ if (d->IsVarDefined(iter_var->var)) {
+ var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
+ } else if (IsAncestorOfAllVarUse(attr_stmt, iter_var->var, d)) {
+ var_doc = LiteralDoc::Str(iter_var->thread_tag,
iter_var_p->Attr("thread_tag"));
+ *define_var = iter_var->var;
+ } else {
+ InsertEnvThread(iter_var, iter_var_p, d);
+ var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
+ }
+ return TIR(d, "launch_thread")
+ ->Call({
+ var_doc,
+ d->AsDoc<ExprDoc>(attr_stmt->value, attr_stmt_p->Attr("value")),
+ });
+}
+
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::BufferRealize>( //
"", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
@@ -336,7 +382,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::AttrStmt>( //
"", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
+ Optional<ExprDoc> lhs = NullOpt;
Optional<ExprDoc> rhs = NullOpt;
+ Optional<tir::Var> define_var = NullOpt;
tir::Stmt body = stmt->body;
ObjectPath body_p = stmt_p->Attr("body");
if (stmt->attr_key == "realize_scope") {
@@ -347,29 +395,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
/*value=*/d->AsDoc<ExprDoc>(stmt->value,
stmt_p->Attr("value")),
/*p=*/stmt_p->Attr("body"), d);
body = realize->body;
- body_p = body_p->Attr("body");
+ body_p = stmt_p->Attr("body")->Attr("body");
}
}
}
if (stmt->attr_key == "thread_extent" || stmt->attr_key ==
"virtual_thread") {
- if (const auto* iter_var = stmt->node.as<tir::IterVarNode>()) {
- if (!d->IsVarDefined(iter_var->var)) {
- // `DefineVar` is not used here because a more specific name
is desirable
- ObjectPath iter_var_p = stmt_p->Attr("node");
- Frame f = FindLowestVarDef(iter_var->var, d).value();
- DefineVar(iter_var->var, f, d);
- f->stmts.push_back(
- AssignDoc(d->AsDoc<ExprDoc>(iter_var->var,
iter_var_p->Attr("var")),
- TIR(d, "env_thread")
- ->Call({LiteralDoc::Str(iter_var->thread_tag,
-
iter_var_p->Attr("thread_tag"))}), //
- NullOpt));
- }
- rhs = TIR(d, "launch_thread")
- ->Call({
- d->AsDoc<ExprDoc>(iter_var->var,
stmt_p->Attr("node")),
- d->AsDoc<ExprDoc>(stmt->value,
stmt_p->Attr("value")),
- });
+ if (stmt->node->IsInstance<tir::IterVarNode>()) {
+ rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d);
}
}
if (!rhs.defined()) {
@@ -380,8 +412,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
});
}
With<TIRFrame> f(d, stmt);
+ if (define_var.defined()) {
+ lhs = DefineVar(define_var.value(), *f, d);
+ }
AsDocBody(body, body_p, f->get(), d);
- return DoConciseScoping(NullOpt, rhs.value(), &(*f)->stmts, concise);
+ return DoConciseScoping(lhs, rhs.value(), &(*f)->stmts, concise);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index c91b733751..464e8fd342 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -171,6 +171,16 @@ def opt_gemm_lower():
return Module
+def launch_env_thread():
+ @T.prim_func
+ def main(inputs: T.Buffer((64, 2, 4), "float32")) -> None:
+ bx = T.launch_thread("blockIdx.x", 64)
+ for i, j in T.grid(2, 4):
+ T.evaluate(inputs[bx, i, j])
+
+ return main
+
+
def opt_gemm_mod_host():
@tvm.script.ir_module
class Module:
@@ -3563,6 +3573,7 @@ def let_stmt_value():
ir_generator = tvm.testing.parameter(
+ launch_env_thread,
opt_gemm_normalize,
opt_gemm_lower,
opt_gemm_mod_host,