This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9fab56c4c1 [TVMScript] Use op attribute to control whether to print 
dtype in TVMScript (#14111)
9fab56c4c1 is described below

commit 9fab56c4c1b2115ee52b93250597b688ca682d23
Author: LiangW <[email protected]>
AuthorDate: Sun Feb 26 03:57:27 2023 +0800

    [TVMScript] Use op attribute to control whether to print dtype in TVMScript 
(#14111)
    
    This PR adds an op attribute `TScriptDtypePrintLocation`, and modifies the 
dtype printing logic of the builtin op to check this attribute. So that user 
defined operators can use it to specify how there dtype argument are printed by 
appending attributes instead of appending members to 
`dtype_first_arg`/`dtype_last_arg`.
---
 include/tvm/tir/op_attr_types.h | 21 +++++++++++++
 src/script/printer/tir/expr.cc  | 31 ++++++-------------
 src/tir/ir/stmt.cc              |  4 ++-
 src/tir/op/builtin.cc           | 67 ++++++++++++++++++++++++++++++-----------
 4 files changed, 82 insertions(+), 41 deletions(-)

diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h
index 858d89c2d5..b2a644f954 100644
--- a/include/tvm/tir/op_attr_types.h
+++ b/include/tvm/tir/op_attr_types.h
@@ -61,6 +61,27 @@ using FLegalize = 
runtime::TypedPackedFunc<PrimExpr(PrimExpr)>;
  */
 using TScriptPrinterName = String;
 
+/*!
+ * \brief Specifies that TVMScript printer prints the dtype as the first/last 
argument.
+          If not specified, dtype will not be printed.
+ */
+enum class ScriptDtypePrintLocation : int {
+  /*!
+   * \brief Do not print dtype as an argument.
+   */
+  kNone = 0,
+  /*!
+   * \brief Print dtype as the first argument.
+   */
+  kFirst = 1,
+  /*!
+   * \brief FPrint dtype as the last argument.
+   */
+  kLast = 2,
+};
+
+using TScriptDtypePrintLocation = Integer;
+
 /*!
  * \brief The effect type of the call.
  */
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index d860eeb2a7..f1435c4870 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -222,26 +222,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<tir::Call>("", [](tir::Call call, ObjectPath call_p, 
IRDocsifier d) -> Doc {
       static const OpAttrMap<tir::TScriptPrinterName>& op_names =
           Op::GetAttrMap<tir::TScriptPrinterName>("TScriptPrinterName");
-      static const std::unordered_set<const Object*> dtype_first_arg = {
-          tir::builtin::reinterpret().get(),
-          tir::builtin::call_extern().get(),
-          tir::builtin::call_llvm_intrin().get(),       //
-          tir::builtin::call_llvm_pure_intrin().get(),  //
-          tir::builtin::call_pure_extern().get(),       //
-          tir::builtin::ptx_mma().get(),
-          tir::builtin::ptx_mma_sp().get(),
-          tir::builtin::ptx_ldmatrix().get(),
-          tir::builtin::ptx_cp_async().get(),
-          tir::builtin::mma_store().get(),
-          tir::builtin::mma_fill().get(),
-          tir::builtin::vectorlow().get(),
-          tir::builtin::vectorhigh().get(),
-          tir::builtin::vectorcombine().get(),
-          Op::Get("tir.type_annotation").get(),
-      };
-      static const std::unordered_set<const Object*> dtype_last_arg = {
-          tir::builtin::tvm_struct_get().get(),
-      };
+      static const OpAttrMap<tir::TScriptDtypePrintLocation> dtype_locations =
+          
Op::GetAttrMap<tir::TScriptDtypePrintLocation>("TScriptDtypePrintLocation");
+      tir::ScriptDtypePrintLocation dtype_print_location = 
tir::ScriptDtypePrintLocation::kNone;
       ExprDoc prefix{nullptr};
       if (const auto* op = call->op.as<OpNode>()) {
         String name = op_names.get(GetRef<Op>(op), op->name);
@@ -249,6 +232,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
           LOG(WARNING) << "No TScriptPrinterName attribute for " << op->name;
         }
         prefix = TIR(d, name);
+        if (dtype_locations.count(GetRef<Op>(op))) {
+          dtype_print_location = static_cast<tir::ScriptDtypePrintLocation>(
+              dtype_locations[GetRef<Op>(op)].IntValue());
+        }
       } else if (const auto* gv = call->op.as<GlobalVarNode>()) {
         prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op"));
       } else {
@@ -257,13 +244,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
       Array<ExprDoc> args;
       int n_args = call->args.size();
       args.reserve(n_args + 1);
-      if (dtype_first_arg.count(call->op.get())) {
+      if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) {
         args.push_back(LiteralDoc::DataType(call->dtype, 
call_p->Attr("dtype")));
       }
       for (int i = 0; i < n_args; ++i) {
         args.push_back(d->AsDoc<ExprDoc>(call->args[i], 
call_p->Attr("args")->ArrayIndex(i)));
       }
-      if (dtype_last_arg.count(call->op.get())) {
+      if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) {
         args.push_back(LiteralDoc::DataType(call->dtype, 
call_p->Attr("dtype")));
       }
       return prefix->Call(args);
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 1652786cb5..fd2a98554d 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -692,7 +692,9 @@ PrimExpr TypeAnnotation(DataType dtype, Span span) {
 }
 
 TVM_TIR_REGISTER_OP("type_annotation")
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure));
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
 
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 680202751f..e240b7b701 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -40,6 +40,8 @@ namespace builtin {
 
 TIR_DEFINE_BUILTIN_FUNC(reinterpret)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst))
     .set_num_inputs(1);
 
 TIR_DEFINE_BUILTIN_FUNC(ret)
@@ -120,16 +122,24 @@ TIR_DEFINE_BUILTIN_FUNC(fma)
     .set_attr<TVectorizable>("TVectorizable", true);
 
 TIR_DEFINE_BUILTIN_FUNC(call_extern)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
 
 TIR_DEFINE_BUILTIN_FUNC(call_pure_extern)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure));
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
 
 TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
 
 TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure));
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
 
 TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure));
@@ -154,7 +164,9 @@ 
TIR_DEFINE_BUILTIN_FUNC(tvm_tuple).set_attr<TCallEffectKind>("TCallEffectKind",
 
 TIR_DEFINE_BUILTIN_FUNC(tvm_struct_get)
     .set_num_inputs(3)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kReadState));
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kReadState))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kLast));
 
 TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set)
     .set_num_inputs(4)
@@ -249,19 +261,28 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment)
 TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr<TCallEffectKind>("TCallEffectKind",
-                                                           
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(ptx_mma)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
+
 TIR_DEFINE_BUILTIN_FUNC(ptx_ldg32).set_num_inputs(4).set_attr<TCallEffectKind>(
     "TCallEffectKind", Integer(CallEffectKind::kPure));
 
 TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
 
 TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
 
 TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
 
 TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
@@ -269,20 +290,30 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group)
 TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(mma_store).set_attr<TCallEffectKind>("TCallEffectKind",
-                                                             
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(mma_store)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
 
-TIR_DEFINE_BUILTIN_FUNC(mma_fill).set_attr<TCallEffectKind>("TCallEffectKind",
-                                                            
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(mma_fill)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
 
 TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure));
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
 
-TIR_DEFINE_BUILTIN_FUNC(vectorlow).set_attr<TCallEffectKind>("TCallEffectKind",
-                                                             
Integer(CallEffectKind::kPure));
+TIR_DEFINE_BUILTIN_FUNC(vectorlow)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
 
 TIR_DEFINE_BUILTIN_FUNC(vectorcombine)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure));
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
Integer(ScriptDtypePrintLocation::kFirst));
 
 TIR_DEFINE_BUILTIN_FUNC(atomic_add)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));

Reply via email to