This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new be66a7e0e4 [TVMScript] Sugar T.env_thread + T.launch_thread (#14217)
be66a7e0e4 is described below

commit be66a7e0e46f422b5eb199cfb70c725fe8a91825
Author: Junru Shao <[email protected]>
AuthorDate: Mon Mar 6 22:58:19 2023 -0800

    [TVMScript] Sugar T.env_thread + T.launch_thread (#14217)
    
    This PR introduces a syntactic sugar that combines T.env_thread and
    T.launch_thread.
    
    Previously, an AttrStmt that specifies thread extent or virtual thread
    is required to be written in two steps:
    
    ```python
    bx = T.env_thread("blockIdx.x")  // creates an IterVar
    with T.launch_thread(bx, 128):   // specify the iter domain
      ...
    ```
    
    With this PR, now this behavior can be merged in a single line:
    
    ```python
    with T.launch_thread("blockIdx.x", 128) as bx:
      ...
    ```
---
 include/tvm/script/ir_builder/tir/ir.h            |  8 +++
 python/tvm/script/ir_builder/tir/frame.py         |  4 +-
 python/tvm/script/ir_builder/tir/ir.py            | 13 ++--
 src/script/ir_builder/tir/ir.cc                   | 17 ++++-
 src/script/printer/tir/stmt.cc                    | 75 +++++++++++++++++------
 tests/python/unittest/test_tvmscript_roundtrip.py | 11 ++++
 6 files changed, 101 insertions(+), 27 deletions(-)

diff --git a/include/tvm/script/ir_builder/tir/ir.h 
b/include/tvm/script/ir_builder/tir/ir.h
index 2b89d0e736..8d8b0b42ba 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -390,6 +390,14 @@ DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType 
dtype, String buffer_
  */
 LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);
 
+/*!
+ * \brief Launch a new thread.
+ * \param thread_tag The thread type tag.
+ * \param extent The extent of environment thread.
+ * \return The result LaunchThreadFrame.
+ */
+LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent);
+
 /*!
  * \brief Bind a var to thread env.
  * \param thread_tag The thread type tag.
diff --git a/python/tvm/script/ir_builder/tir/frame.py 
b/python/tvm/script/ir_builder/tir/frame.py
index 3e453f2e51..b2229d503b 100644
--- a/python/tvm/script/ir_builder/tir/frame.py
+++ b/python/tvm/script/ir_builder/tir/frame.py
@@ -115,4 +115,6 @@ class DeclBufferFrame(TIRFrame):
 
 @_register_object("script.ir_builder.tir.LaunchThreadFrame")
 class LaunchThreadFrame(TIRFrame):
-    ...
+    def __enter__(self) -> Var:
+        super().__enter__()
+        return self.iter_var.var
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index 62a0aa8f32..e88597732c 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -31,7 +31,7 @@ import numpy as np  # type: ignore
 from tvm import tir
 from tvm.ir import Range, Type
 from tvm.ir.base import deprecated
-from tvm.runtime import convert, ndarray
+from tvm.runtime import String, convert, ndarray
 from tvm.target import Target
 
 # pylint: disable=unused-import
@@ -1185,14 +1185,14 @@ def decl_buffer(
 
 
 def launch_thread(
-    iter_var: IterVar,  # pylint: disable=redefined-outer-name
+    thread: Union[IterVar, str],  # pylint: disable=redefined-outer-name
     extent: PrimExpr,
 ) -> frame.LaunchThreadFrame:
     """Launch a thread.
 
     Parameters
     ----------
-    iter_var : IterVar
+    thread : Union[IterVar, str]
         The iteration variable.
 
     extent : PrimExpr
@@ -1213,11 +1213,14 @@ def launch_thread(
     T.launch_thread(brow, 1)
 
     """
-    return _ffi_api.LaunchThread(iter_var, extent)  # type: 
ignore[attr-defined] # pylint: disable=no-member
+
+    if isinstance(thread, str):
+        thread = String(thread)
+    return _ffi_api.LaunchThread(thread, extent)  # type: ignore[attr-defined] 
# pylint: disable=no-member
 
 
 def env_thread(thread_tag: str) -> IterVar:
-    """Bind a var to thread env"
+    """Bind a var to thread env
 
     Parameters
     ----------
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index a54f3d926f..487265bff2 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -427,6 +427,10 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
   return LaunchThreadFrame(n);
 }
 
+LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) {
+  return LaunchThread(EnvThread(thread_tag), extent);
+}
+
 RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
                      PrimExpr condition) {
   ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>();
@@ -658,7 +662,18 @@ 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else);
 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread")
+    .set_body_typed([](ObjectRef thread_tag_or_var, PrimExpr extent) {
+      if (const auto* var = thread_tag_or_var.as<tvm::tir::VarNode>()) {
+        return LaunchThread(GetRef<tvm::tir::Var>(var), extent);
+      } else if (const auto* str = thread_tag_or_var.as<StringObj>()) {
+        return LaunchThread(GetRef<String>(str), extent);
+      } else {
+        LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: "
+                   << thread_tag_or_var->GetTypeKey();
+        throw;
+      }
+    });
 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread);
 
 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore);
diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc
index 92ad41edc9..591d1e3bc1 100644
--- a/src/script/printer/tir/stmt.cc
+++ b/src/script/printer/tir/stmt.cc
@@ -45,6 +45,19 @@ bool AllowConciseScoping(const IRDocsifier& d) {
   LOG(FATAL) << "NotImplementedError: fragment printing";
 }
 
+bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const 
IRDocsifier& d) {
+  if (!d->common_prefix.count(var.get())) {
+    return false;
+  }
+  const std::vector<const Object*>& path = d->common_prefix.at(var.get());
+  for (auto it = path.rbegin(); it != path.rend(); ++it) {
+    if (*it == node.get()) {
+      return true;
+    }
+  }
+  return false;
+}
+
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<tir::Evaluate>("", [](tir::Evaluate eval, ObjectPath p, 
IRDocsifier d) -> Doc {
       ExprDoc value = d->AsDoc<ExprDoc>(eval->value, p->Attr("value"));
@@ -322,6 +335,39 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* 
stmt, Optional<ExprDo
   return TIR(d, "realize")->Call(args, kwargs_keys, kwargs_values);
 }
 
+void InsertEnvThread(const tir::IterVar& iter_var, const ObjectPath& 
iter_var_p,
+                     const IRDocsifier& d) {
+  Frame f = FindLowestVarDef(iter_var->var, d).value();
+  DefineVar(iter_var->var, f, d);
+  ExprDoc rhs = TIR(d, "env_thread")
+                    ->Call({LiteralDoc::Str(iter_var->thread_tag,  //
+                                            iter_var_p->Attr("thread_tag"))});
+  ExprDoc lhs = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
+  f->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+}
+
+ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const ObjectPath& 
attr_stmt_p,
+                            Optional<tir::Var>* define_var, const IRDocsifier& 
d) {
+  tir::IterVar iter_var = Downcast<tir::IterVar>(attr_stmt->node);
+  ObjectPath iter_var_p = attr_stmt_p->Attr("node");
+
+  ExprDoc var_doc{nullptr};
+  if (d->IsVarDefined(iter_var->var)) {
+    var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
+  } else if (IsAncestorOfAllVarUse(attr_stmt, iter_var->var, d)) {
+    var_doc = LiteralDoc::Str(iter_var->thread_tag, 
iter_var_p->Attr("thread_tag"));
+    *define_var = iter_var->var;
+  } else {
+    InsertEnvThread(iter_var, iter_var_p, d);
+    var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
+  }
+  return TIR(d, "launch_thread")
+      ->Call({
+          var_doc,
+          d->AsDoc<ExprDoc>(attr_stmt->value, attr_stmt_p->Attr("value")),
+      });
+}
+
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<tir::BufferRealize>(  //
         "", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
@@ -336,7 +382,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<tir::AttrStmt>(  //
         "", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
           bool concise = AllowConciseScoping(d);
+          Optional<ExprDoc> lhs = NullOpt;
           Optional<ExprDoc> rhs = NullOpt;
+          Optional<tir::Var> define_var = NullOpt;
           tir::Stmt body = stmt->body;
           ObjectPath body_p = stmt_p->Attr("body");
           if (stmt->attr_key == "realize_scope") {
@@ -347,29 +395,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
                     /*value=*/d->AsDoc<ExprDoc>(stmt->value, 
stmt_p->Attr("value")),
                     /*p=*/stmt_p->Attr("body"), d);
                 body = realize->body;
-                body_p = body_p->Attr("body");
+                body_p = stmt_p->Attr("body")->Attr("body");
               }
             }
           }
           if (stmt->attr_key == "thread_extent" || stmt->attr_key == 
"virtual_thread") {
-            if (const auto* iter_var = stmt->node.as<tir::IterVarNode>()) {
-              if (!d->IsVarDefined(iter_var->var)) {
-                // `DefineVar` is not used here because a more specific name 
is desirable
-                ObjectPath iter_var_p = stmt_p->Attr("node");
-                Frame f = FindLowestVarDef(iter_var->var, d).value();
-                DefineVar(iter_var->var, f, d);
-                f->stmts.push_back(
-                    AssignDoc(d->AsDoc<ExprDoc>(iter_var->var, 
iter_var_p->Attr("var")),
-                              TIR(d, "env_thread")
-                                  ->Call({LiteralDoc::Str(iter_var->thread_tag,
-                                                          
iter_var_p->Attr("thread_tag"))}),  //
-                              NullOpt));
-              }
-              rhs = TIR(d, "launch_thread")
-                        ->Call({
-                            d->AsDoc<ExprDoc>(iter_var->var, 
stmt_p->Attr("node")),
-                            d->AsDoc<ExprDoc>(stmt->value, 
stmt_p->Attr("value")),
-                        });
+            if (stmt->node->IsInstance<tir::IterVarNode>()) {
+              rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d);
             }
           }
           if (!rhs.defined()) {
@@ -380,8 +412,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
             });
           }
           With<TIRFrame> f(d, stmt);
+          if (define_var.defined()) {
+            lhs = DefineVar(define_var.value(), *f, d);
+          }
           AsDocBody(body, body_p, f->get(), d);
-          return DoConciseScoping(NullOpt, rhs.value(), &(*f)->stmts, concise);
+          return DoConciseScoping(lhs, rhs.value(), &(*f)->stmts, concise);
         });
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index c91b733751..464e8fd342 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -171,6 +171,16 @@ def opt_gemm_lower():
     return Module
 
 
+def launch_env_thread():
+    @T.prim_func
+    def main(inputs: T.Buffer((64, 2, 4), "float32")) -> None:
+        bx = T.launch_thread("blockIdx.x", 64)
+        for i, j in T.grid(2, 4):
+            T.evaluate(inputs[bx, i, j])
+
+    return main
+
+
 def opt_gemm_mod_host():
     @tvm.script.ir_module
     class Module:
@@ -3563,6 +3573,7 @@ def let_stmt_value():
 
 
 ir_generator = tvm.testing.parameter(
+    launch_env_thread,
     opt_gemm_normalize,
     opt_gemm_lower,
     opt_gemm_mod_host,

Reply via email to