This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 1afbf20129 [Unity][BYOC] Unify the interface between 
`FuseOpsByPattern(.., annotate_codegen=True)`and `MergeCompositeFunctions()` 
(#15491)
1afbf20129 is described below

commit 1afbf20129647a35d108152fc6789bc1d029cda5
Author: Sunghyun Park <[email protected]>
AuthorDate: Tue Aug 8 01:53:17 2023 -0700

    [Unity][BYOC] Unify the interface between `FuseOpsByPattern(.., 
annotate_codegen=True)`and `MergeCompositeFunctions()` (#15491)
    
    * feat: consistent interface
    
    * fix comment
    
    * fix
    
    * bugfix
    
    * fix
---
 include/tvm/relax/utils.h                          |   5 +
 src/relax/transform/fuse_ops.cc                    |  14 +-
 src/relax/transform/merge_composite_functions.cc   |  22 +-
 src/relax/transform/to_non_dataflow.cc             |   1 +
 src/relax/transform/utils.h                        |   7 +
 tests/python/relax/test_codegen_cutlass.py         |   8 +-
 .../relax/test_transform_fuse_ops_by_pattern.py    | 106 +++-
 .../test_transform_merge_composite_functions.py    | 649 ++++++++++-----------
 8 files changed, 443 insertions(+), 369 deletions(-)

diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h
index a1f587e14e..1a6d5d4a52 100644
--- a/include/tvm/relax/utils.h
+++ b/include/tvm/relax/utils.h
@@ -105,6 +105,11 @@ TVM_DLL bool IsImpureCall(const Call& call);
  */
 TVM_DLL Function CopyWithNewVars(Function func);
 
+/*!
+ * \brief Transform all dataflow structure to non-dataflow version.
+ */
+Expr ToNonDataflow(const Expr& e);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index f4e7e00f30..2c042db4fd 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -1199,6 +1199,7 @@ class CompositeFunctionAnnotator : public ExprMutator {
         auto gsymbol = gvar->name_hint + "_" + codegen_name;
         new_func = WithAttrs(new_func,
                              {{attr::kCodegen, codegen_name}, 
{tvm::attr::kGlobalSymbol, gsymbol}});
+        new_func = WithoutAttr(std::move(new_func), 
tvm::relax::attr::kPrimitive);
         builder_->GetContextIRModule()->Remove(GetRef<GlobalVar>(gvar));
         auto new_gvar = builder_->AddFunction(new_func, gsymbol);
         gvar_map_[gvar] = new_gvar;
@@ -1209,8 +1210,10 @@ class CompositeFunctionAnnotator : public ExprMutator {
   }
 
   Expr VisitExpr_(const FunctionNode* func_node) final {
-    auto f_inner = ExprMutator::VisitExpr_(func_node);
+    Function f_inner = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
     auto composite_name = func_node->GetAttr<String>(attr::kComposite);
+
+    f_inner = WithoutAttr(std::move(f_inner), tvm::relax::attr::kPrimitive);
     ICHECK(composite_name);
 
     Array<Var> param_vars;
@@ -1224,17 +1227,10 @@ class CompositeFunctionAnnotator : public ExprMutator {
 
     // pure if the inner func is pure (no need to force purity if it's forced 
for the inner func)
     return Function(param_vars, Call(f_inner, params), 
func_node->ret_struct_info,
-                    Downcast<Function>(f_inner)->is_pure);
+                    f_inner->is_pure);
   }
 
  private:
-  String GetCodegenName(const std::string& composite_name) {
-    auto delim_pos = composite_name.find(".");
-    ICHECK(delim_pos != std::string::npos) << "The pattern name for a 
composite function should "
-                                              "start with a compiler name 
followed by period.";
-    return composite_name.substr(0, delim_pos);
-  }
-
   /*! \brief A map from old global vars to their replacements. */
   std::unordered_map<const GlobalVarNode*, GlobalVar> gvar_map_;
 };
diff --git a/src/relax/transform/merge_composite_functions.cc 
b/src/relax/transform/merge_composite_functions.cc
index 0bc92ba923..365b34c601 100644
--- a/src/relax/transform/merge_composite_functions.cc
+++ b/src/relax/transform/merge_composite_functions.cc
@@ -39,7 +39,7 @@
  * Correct partitioning:
  *
  *     O         O
- *    / \       /            \
+ *    / \       /                  \
  *   O   X --> O    +     +    X
  *    \ /             \ /
  *     O               O
@@ -167,13 +167,6 @@ class CompositeGroupsBuilder : public 
MemoizedExprTranslator<Group*> {
   }
 
  private:
-  String GetCodegenName(const std::string& composite_name) {
-    auto delim_pos = composite_name.find(".");
-    ICHECK(delim_pos != std::string::npos) << "The pattern name for a 
composite function should "
-                                              "start with a compiler name 
followed by period.";
-    return composite_name.substr(0, delim_pos);
-  }
-
   Optional<String> GetCodegenName(const Expr& callee) {
     auto const* gvar = callee.as<GlobalVarNode>();
     if (!gvar) {
@@ -186,7 +179,7 @@ class CompositeGroupsBuilder : public 
MemoizedExprTranslator<Group*> {
       return NullOpt;
     }
 
-    return GetCodegenName(composite_name_opt.value());
+    return relax::GetCodegenName(composite_name_opt.value());
   }
 
   Optional<String> GetCodegenName(Group* group) {
@@ -298,7 +291,7 @@ class CompositeInliner : public ExprMutator {
 
   Function Run(Function func) {
     inlined_functions_ = Map<Function, Function>();
-    auto new_body = VisitExpr(func->body);
+    auto new_body = VisitExpr(ToNonDataflow(func->body));
     auto new_func = Function(func->params, new_body, func->ret_struct_info, 
func->is_pure,
                              func->attrs, func->span);
     return new_func;
@@ -308,10 +301,11 @@ class CompositeInliner : public ExprMutator {
     if (call->op->IsInstance<GlobalVarNode>()) {
       auto gvar = Downcast<GlobalVar>(call->op);
       auto func = Downcast<Function>(mod_->Lookup(gvar));
-
       if (func->GetAttr<String>(attr::kComposite)) {
         if (!inlined_functions_.count(func)) {
-          inlined_functions_.Set(func, CopyWithNewVars(func));
+          auto new_func = CopyWithNewVars(func);
+          new_func = WithoutAttr(new_func, tvm::relax::attr::kPrimitive);
+          inlined_functions_.Set(func, new_func);
         }
         return Call(inlined_functions_[func], call->args);
       }
@@ -340,11 +334,13 @@ IRModule MergeCompositeFunctions(IRModule mod) {
     if (func->GetAttr<String>(attr::kCodegen)) {
       auto new_func = inliner.Run(Downcast<Function>(func));
       new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, gvar->name_hint);
+      new_func = WithoutAttr(std::move(new_func), 
tvm::relax::attr::kPrimitive);
       to_update.emplace_back(gvar, new_func);
     }
   }
+
   for (const auto& [gvar, func] : to_update) {
-    new_mod->Update(gvar, func);
+    new_mod->Update(gvar, Downcast<Function>(func));
   }
   // TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better 
way to handle this.
   return DeadCodeElimination(new_mod, {"main"});
diff --git a/src/relax/transform/to_non_dataflow.cc 
b/src/relax/transform/to_non_dataflow.cc
index db2e9d7ee5..5c790e9a73 100644
--- a/src/relax/transform/to_non_dataflow.cc
+++ b/src/relax/transform/to_non_dataflow.cc
@@ -24,6 +24,7 @@
 #include <tvm/relax/struct_info.h>
 #include <tvm/relax/transform.h>
 #include <tvm/relax/type.h>
+#include <tvm/relax/utils.h>
 #include <tvm/tir/op.h>
 
 namespace tvm {
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index 489a36a5a4..3d40a565bd 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -381,6 +381,13 @@ inline Array<Integer> GetOrderedPositiveAxes(const 
Array<Integer>& axes, int ndi
   return support::AsArray<int64_t, Integer>(ret);
 }
 
+inline String GetCodegenName(const std::string& composite_name) {
+  auto delim_pos = composite_name.find(".");
+  ICHECK(delim_pos != std::string::npos) << "The pattern name for a composite 
function should "
+                                            "start with a compiler name 
followed by period.";
+  return composite_name.substr(0, delim_pos);
+}
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index d19189ff34..952036584f 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -1220,7 +1220,6 @@ def test_attention_rewrite_fp16():
                 R.func_attr(
                     {
                         "Composite": "cutlass.attention_bias",
-                        "Primitive": 1,
                         "WorkspaceSize": T.int64(65536),
                     }
                 )
@@ -1737,6 +1736,13 @@ def test_rms_norm():
     dtype = "float16"
     mod = partition_for_cutlass(Module)
 
+    # TODO(@tvm-team): This is temporary patch.Currently, the remaining packed 
function triggers error since it is not scheduled.
+    # This is because RunCodegen does not support PrimFunc well yet.
+    # i.e., it does remove the global symbol of PrimFunc, which would be no 
longer used,
+    # and thus, the following DCE cannot remove this. Revisit when resolved.
+    with tvm.target.Target("cuda"):
+        mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
+
     mod = relax.transform.RunCodegen()(mod)
 
     inp = np.random.randn(*data_shape).astype(dtype)
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 1352a52674..d95ed68c61 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -78,7 +78,7 @@ class Conv2dReLU_composite_annotated:
             data2: R.Tensor((1, 64, 56, 56), dtype="float32"),
             weight12: R.Tensor((64, 64, 3, 3), dtype="float32"),
         ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
-            R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+            R.func_attr({"Composite": "dnnl.conv2d_relu"})
             with R.dataflow():
                 lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
                     data2,
@@ -394,7 +394,7 @@ class Conv2dx2_partitioned:
             data_1: R.Tensor((16, 32, 32, 16), dtype="float16"),
             weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"),
         ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
-            R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1})
+            R.func_attr({"Composite": "cutlass.conv2d"})
             with R.dataflow():
                 gv_2: R.Tensor((16, 32, 32, 16), dtype="float16") = 
R.nn.conv2d(
                     data_1,
@@ -529,6 +529,108 @@ def test_annotate_codegen():
     )
 
 
+def test_compare_with_merge_composite_path():
+    x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32"))
+    y = relax.Var("y", relax.TensorStructInfo([10, 10], "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("main", [x, y]):
+        with bb.dataflow():
+            lv0 = bb.emit(relax.op.multiply(x, y))
+            gv = bb.emit_output(lv0)
+        bb.emit_func_output(gv)
+    mod = bb.get()
+    mod = relax.transform.FoldDataflowBlockOutput()(mod)
+
+    # Currently, we have two paths for BYOC.
+    # Path1. [FuseOpsByPattern(patterns, annotate_codegen=True), RunCodegen()]
+    # Path2. [FuseOpsByPattern(patterns, annotate_codegen=False), 
MergeCompositeFunctions(), RunCodegen()]
+    # For consistency, both paths should have same interface with RunCodegen().
+    # As each path has different naming convention due to the difference in 
the algorithm,
+    # we compare with expected form of each path rather than directly applying 
structural equality check between two paths.
+    patterns = [("cutlass.multiply", is_op("relax.multiply")(wildcard(), 
wildcard()))]
+    mod1 = relax.transform.FuseOpsByPattern(patterns, bind_constants=True, 
annotate_codegen=True)(
+        mod
+    )
+    assert tvm.relax.analysis.well_formed(mod1)
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def fused_relax_multiply_cutlass(
+            x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), 
dtype="float32")
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            R.func_attr({"Codegen": "cutlass"})
+            # from tvm.script import relax as R
+
+            @R.function
+            def gv(
+                x_1: R.Tensor((10, 10), dtype="float32"),
+                y_1: R.Tensor((10, 10), dtype="float32"),
+            ) -> R.Tensor((10, 10), dtype="float32"):
+                R.func_attr({"Composite": "cutlass.multiply"})
+                with R.dataflow():
+                    gv_1: R.Tensor((10, 10), dtype="float32") = 
R.multiply(x_1, y_1)
+                    R.output(gv_1)
+                return gv_1
+
+            gv1: R.Tensor((10, 10), dtype="float32") = gv(x, y)
+            return gv1
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), 
dtype="float32")
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            cls = Expected1
+            with R.dataflow():
+                gv: R.Tensor((10, 10), dtype="float32") = 
cls.fused_relax_multiply_cutlass(x, y)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod1, Expected1)
+
+    mod2 = relax.transform.FuseOpsByPattern(patterns, bind_constants=True, 
annotate_codegen=False)(
+        mod
+    )
+    mod2 = relax.transform.MergeCompositeFunctions()(mod2)
+    assert tvm.relax.analysis.well_formed(mod2)
+
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def fused_relax_multiply1(
+            x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), 
dtype="float32")
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            R.func_attr({"Codegen": "cutlass"})
+            # from tvm.script import relax as R
+
+            @R.function
+            def gv(
+                x_1: R.Tensor((10, 10), dtype="float32"),
+                y_1: R.Tensor((10, 10), dtype="float32"),
+            ) -> R.Tensor((10, 10), dtype="float32"):
+                R.func_attr({"Composite": "cutlass.multiply"})
+                with R.dataflow():
+                    gv_1: R.Tensor((10, 10), dtype="float32") = 
R.multiply(x_1, y_1)
+                    R.output(gv_1)
+
+                return gv_1
+
+            gv_1: R.Tensor((10, 10), dtype="float32") = gv(x, y)
+            return gv_1
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), 
dtype="float32")
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            cls = Expected2
+            with R.dataflow():
+                gv: R.Tensor((10, 10), dtype="float32") = 
cls.fused_relax_multiply1(x, y)
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod2, Expected2)
+
+
 def test_multiple_entries_multiple_calls_same_extern():
     pat = make_fused_bias_activation_pattern("relax.nn.conv2d", 
with_bias=False, activation=None)
     check(Conv2dx2, [("cutlass.conv2d", pat)], Conv2dx2_partitioned, 
annotate_codegen=True)
diff --git a/tests/python/relax/test_transform_merge_composite_functions.py 
b/tests/python/relax/test_transform_merge_composite_functions.py
index d552266131..d56e1db564 100644
--- a/tests/python/relax/test_transform_merge_composite_functions.py
+++ b/tests/python/relax/test_transform_merge_composite_functions.py
@@ -100,49 +100,46 @@ class Conv2dReLUx2_merged:
     ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
         R.func_attr(
             {
-                "Primitive": 1,
                 "Codegen": "dnnl",
                 "global_symbol": 
"fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1",
             }
         )
-        with R.dataflow():
 
-            @R.function
-            def lv(
-                data11: R.Tensor((1, 64, 56, 56), dtype="float32"),
-                weight111: R.Tensor((64, 64, 3, 3), dtype="float32"),
-            ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
-                R.func_attr({"Composite": "dnnl.conv2d_relu", "Primitive": 1})
-                with R.dataflow():
-                    lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = 
R.nn.conv2d(
-                        data11,
-                        weight111,
-                        padding=[1, 1, 1, 1],
-                    )
-                    gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = 
R.nn.relu(lv1)
-                    R.output(gv1)
-                return gv1
+        @R.function
+        def lv(
+            data11: R.Tensor((1, 64, 56, 56), dtype="float32"),
+            weight111: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+            R.func_attr({"Composite": "dnnl.conv2d_relu"})
+            with R.dataflow():
+                lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+                    data11,
+                    weight111,
+                    padding=[1, 1, 1, 1],
+                )
+                gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = 
R.nn.relu(lv1)
+                R.output(gv1)
+            return gv1
 
-            lv2: R.Tensor((1, 64, 56, 56), dtype="float32") = lv(data1, 
weight11)
+        lv2: R.Tensor((1, 64, 56, 56), dtype="float32") = lv(data1, weight11)
 
-            @R.function
-            def lv11(
-                conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
-                weight211: R.Tensor((64, 64, 3, 3), dtype="float32"),
-            ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-                R.func_attr({"Composite": "dnnl.conv2d_relu", "Primitive": 1})
-                with R.dataflow():
-                    lv21: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.nn.conv2d(
-                        conv1,
-                        weight211,
-                        padding=[0, 0, 0, 0],
-                    )
-                    gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.nn.relu(lv21)
-                    R.output(gv2)
-                return gv2
+        @R.function
+        def lv11(
+            conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+            weight211: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+            R.func_attr({"Composite": "dnnl.conv2d_relu"})
+            with R.dataflow():
+                lv21: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+                    conv1,
+                    weight211,
+                    padding=[0, 0, 0, 0],
+                )
+                gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.nn.relu(lv21)
+                R.output(gv2)
+            return gv2
 
-            gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv11(lv2, 
weight21)
-            R.output(gv3)
+        gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv11(lv2, weight21)
         return gv3
 
 
@@ -222,82 +219,79 @@ class Diamond_merged:
         R.func_attr(
             {
                 "Codegen": "compiler_A",
-                "Primitive": 1,
                 "global_symbol": 
"fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add",
             }
         )
-        # block 0
-        with R.dataflow():
 
-            @R.function
-            def lv(
-                data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
-                weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
-            ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-                # function attr dict
-                R.func_attr({"Composite": "compiler_A.conv2d", "Primitive": 1})
-                # block 0
-                with R.dataflow():
-                    gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.nn.conv2d(
-                        data1,
-                        weight1,
-                        strides=[1, 1],
-                        padding=[0, 0, 0, 0],
-                        dilation=[1, 1],
-                        groups=1,
-                        data_layout="NCHW",
-                        kernel_layout="OIHW",
-                        out_layout="NCHW",
-                        out_dtype="",
-                    )
-                    R.output(gv4)
-                return gv4
+        # block 0
+        @R.function
+        def lv(
+            data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+            weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+            # function attr dict
+            R.func_attr({"Composite": "compiler_A.conv2d"})
+            # block 0
+            with R.dataflow():
+                gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+                    data1,
+                    weight1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="",
+                )
+                R.output(gv4)
+            return gv4
 
-            lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight)
+        lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight)
 
-            @R.function
-            def lv1(
-                lv11: R.Tensor((1, 64, 54, 54), dtype="float32")
-            ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-                # function attr dict
-                R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
-                # block 0
-                with R.dataflow():
-                    gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.nn.relu(lv11)
-                    R.output(gv1)
-                return gv1
+        @R.function
+        def lv1(
+            lv11: R.Tensor((1, 64, 54, 54), dtype="float32")
+        ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+            # function attr dict
+            R.func_attr({"Composite": "compiler_A.relu"})
+            # block 0
+            with R.dataflow():
+                gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.nn.relu(lv11)
+                R.output(gv1)
+            return gv1
 
-            lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(lv2)
+        lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(lv2)
 
-            @R.function
-            def lv21(
-                lv4: R.Tensor((1, 64, 54, 54), dtype="float32")
-            ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-                # function attr dict
-                R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1})
-                # block 0
-                with R.dataflow():
-                    gv: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.nn.gelu(lv4)
-                    R.output(gv)
-                return gv
+        @R.function
+        def lv21(
+            lv4: R.Tensor((1, 64, 54, 54), dtype="float32")
+        ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+            # function attr dict
+            R.func_attr({"Composite": "compiler_A.gelu"})
+            # block 0
+            with R.dataflow():
+                gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv4)
+                R.output(gv)
+            return gv
 
-            lv41: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2)
+        lv41: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2)
 
-            @R.function
-            def lv31(
-                lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
-                gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
-            ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-                # function attr dict
-                R.func_attr({"Composite": "compiler_A.add", "Primitive": 1})
-                # block 0
-                with R.dataflow():
-                    gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.add(lv5, gelu1)
-                    R.output(gv3)
-                return gv3
+        @R.function
+        def lv31(
+            lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
+            gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
+        ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+            # function attr dict
+            R.func_attr({"Composite": "compiler_A.add"})
+            # block 0
+            with R.dataflow():
+                gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, 
gelu1)
+                R.output(gv3)
+            return gv3
 
-            gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv31(lv3, lv41)
-            R.output(gv2)
+        gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv31(lv3, lv41)
         return gv2
 
     @R.function
@@ -408,65 +402,59 @@ class Diamond_cyclic_dep_merged:
     ):
         R.func_attr(
             {
-                "Primitive": 1,
                 "Codegen": "compiler_A",
                 "global_symbol": "fused_relax_nn_conv2d_relax_nn_relu",
             }
         )
-        with R.dataflow():
 
-            @R.function
-            def lv(
-                data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
-                weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
-            ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-                R.func_attr({"Composite": "compiler_A.conv2d", "Primitive": 1})
-                with R.dataflow():
-                    gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.nn.conv2d(
-                        data1,
-                        weight1,
-                        padding=[0, 0, 0, 0],
-                    )
-                    R.output(gv4)
-                return gv4
+        @R.function
+        def lv(
+            data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+            weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+            R.func_attr({"Composite": "compiler_A.conv2d"})
+            with R.dataflow():
+                gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+                    data1,
+                    weight1,
+                    padding=[0, 0, 0, 0],
+                )
+                R.output(gv4)
+            return gv4
 
-            gv: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight)
+        gv: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight)
 
-            @R.function
-            def lv1(
-                lv11: R.Tensor((1, 64, 54, 54), dtype="float32")
-            ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-                R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
-                with R.dataflow():
-                    gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.nn.relu(lv11)
-                    R.output(gv1)
-                return gv1
+        @R.function
+        def lv1(
+            lv11: R.Tensor((1, 64, 54, 54), dtype="float32")
+        ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+            R.func_attr({"Composite": "compiler_A.relu"})
+            with R.dataflow():
+                gv1: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.nn.relu(lv11)
+                R.output(gv1)
+            return gv1
+
+        gv11: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(gv)
 
-            gv11: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(gv)
-            R.output(gv, gv11)
         return (gv, gv11)
 
     @R.function
     def fused_relax_nn_gelu1(
         lv2: R.Tensor((1, 64, 54, 54), dtype="float32")
     ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-        R.func_attr(
-            {"Primitive": 1, "Codegen": "compiler_B", "global_symbol": 
"fused_relax_nn_gelu1"}
-        )
-        with R.dataflow():
+        R.func_attr({"Codegen": "compiler_B", "global_symbol": 
"fused_relax_nn_gelu1"})
 
-            @R.function
-            def lv21(
-                lv3: R.Tensor((1, 64, 54, 54), dtype="float32")
-            ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-                R.func_attr({"Composite": "compiler_B.gelu", "Primitive": 1})
-                with R.dataflow():
-                    gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.nn.gelu(lv3)
-                    R.output(gv2)
-                return gv2
+        @R.function
+        def lv21(
+            lv3: R.Tensor((1, 64, 54, 54), dtype="float32")
+        ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+            R.func_attr({"Composite": "compiler_B.gelu"})
+            with R.dataflow():
+                gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.nn.gelu(lv3)
+                R.output(gv2)
+            return gv2
 
-            gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2)
-            R.output(gv3)
+        gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2)
         return gv3
 
     @R.function
@@ -474,22 +462,20 @@ class Diamond_cyclic_dep_merged:
         lv32: R.Tensor((1, 64, 54, 54), dtype="float32"),
         lv41: R.Tensor((1, 64, 54, 54), dtype="float32"),
     ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-        R.func_attr({"Primitive": 1, "Codegen": "compiler_A", "global_symbol": 
"fused_relax_add1"})
-        with R.dataflow():
+        R.func_attr({"Codegen": "compiler_A", "global_symbol": 
"fused_relax_add1"})
 
-            @R.function
-            def lv33(
-                lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
-                gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
-            ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
-                R.func_attr({"Composite": "compiler_A.add", "Primitive": 1})
-                with R.dataflow():
-                    gv31: R.Tensor((1, 64, 54, 54), dtype="float32") = 
R.add(lv5, gelu1)
-                    R.output(gv31)
-                return gv31
+        @R.function
+        def lv33(
+            lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
+            gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
+        ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+            R.func_attr({"Composite": "compiler_A.add"})
+            with R.dataflow():
+                gv31: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5, 
gelu1)
+                R.output(gv31)
+            return gv31
 
-            gv6: R.Tensor((1, 64, 54, 54), dtype="float32") = lv33(lv32, lv41)
-            R.output(gv6)
+        gv6: R.Tensor((1, 64, 54, 54), dtype="float32") = lv33(lv32, lv41)
         return gv6
 
 
@@ -550,53 +536,50 @@ class MultipleProducers_merged:
         R.func_attr(
             {
                 "Codegen": "compiler_A",
-                "Primitive": 1,
                 "global_symbol": 
"fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add",
             }
         )
-        # block 0
-        with R.dataflow():
 
-            @R.function
-            def lv(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
-                # function attr dict
-                R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
-                # block 0
-                with R.dataflow():
-                    gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11)
-                    R.output(gv2)
-                return gv2
+        # block 0
+        @R.function
+        def lv(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            # function attr dict
+            R.func_attr({"Composite": "compiler_A.relu"})
+            # block 0
+            with R.dataflow():
+                gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11)
+                R.output(gv2)
+            return gv2
 
-            lv1: R.Tensor((10,), dtype="float32") = lv(x1)
+        lv1: R.Tensor((10,), dtype="float32") = lv(x1)
 
-            @R.function
-            def lv11(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
-                # function attr dict
-                R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1})
-                # block 0
-                with R.dataflow():
-                    gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21)
-                    R.output(gv3)
-                return gv3
+        @R.function
+        def lv11(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            # function attr dict
+            R.func_attr({"Composite": "compiler_A.gelu"})
+            # block 0
+            with R.dataflow():
+                gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21)
+                R.output(gv3)
+            return gv3
 
-            lv2: R.Tensor((10,), dtype="float32") = lv11(x2)
-            lv3: R.Tensor((10,), dtype="float32") = lv(lv1)
-            lv4: R.Tensor((10,), dtype="float32") = lv11(lv2)
+        lv2: R.Tensor((10,), dtype="float32") = lv11(x2)
+        lv3: R.Tensor((10,), dtype="float32") = lv(lv1)
+        lv4: R.Tensor((10,), dtype="float32") = lv11(lv2)
 
-            @R.function
-            def lv21(
-                lv5: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), 
dtype="float32")
-            ) -> R.Tensor((10,), dtype="float32"):
-                # function attr dict
-                R.func_attr({"Composite": "compiler_A.add", "Primitive": 1})
-                # block 0
-                with R.dataflow():
-                    gv: R.Tensor((10,), dtype="float32") = R.add(lv5, gelu1)
-                    R.output(gv)
-                return gv
+        @R.function
+        def lv21(
+            lv5: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), 
dtype="float32")
+        ) -> R.Tensor((10,), dtype="float32"):
+            # function attr dict
+            R.func_attr({"Composite": "compiler_A.add"})
+            # block 0
+            with R.dataflow():
+                gv: R.Tensor((10,), dtype="float32") = R.add(lv5, gelu1)
+                R.output(gv)
+            return gv
 
-            gv1: R.Tensor((10,), dtype="float32") = lv21(lv3, lv4)
-            R.output(gv1)
+        gv1: R.Tensor((10,), dtype="float32") = lv21(lv3, lv4)
         return gv1
 
     @R.function
@@ -675,24 +658,20 @@ class MultipleProducersCyclic_merged:
         x11: R.Tensor((10,), dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
         # function attr dict
-        R.func_attr(
-            {"Codegen": "compiler_A", "Primitive": 1, "global_symbol": 
"fused_relax_nn_relu1"}
-        )
-        # block 0
-        with R.dataflow():
+        R.func_attr({"Codegen": "compiler_A", "global_symbol": 
"fused_relax_nn_relu1"})
 
-            @R.function
-            def lv1(x111: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
-                # function attr dict
-                R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
-                # block 0
-                with R.dataflow():
-                    gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x111)
-                    R.output(gv2)
-                return gv2
+        # block 0
+        @R.function
+        def lv1(x111: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            # function attr dict
+            R.func_attr({"Composite": "compiler_A.relu"})
+            # block 0
+            with R.dataflow():
+                gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x111)
+                R.output(gv2)
+            return gv2
 
-            gv1: R.Tensor((10,), dtype="float32") = lv1(x11)
-            R.output(gv1)
+        gv1: R.Tensor((10,), dtype="float32") = lv1(x11)
         return gv1
 
     @R.function
@@ -703,39 +682,37 @@ class MultipleProducersCyclic_merged:
         R.func_attr(
             {
                 "Codegen": "compiler_A",
-                "Primitive": 1,
                 "global_symbol": "fused_relax_nn_gelu_relax_add",
             }
         )
         # block 0
-        with R.dataflow():
 
-            @R.function
-            def lv12(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
-                # function attr dict
-                R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1})
-                # block 0
-                with R.dataflow():
-                    gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21)
-                    R.output(gv3)
-                return gv3
+        @R.function
+        def lv12(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            # function attr dict
+            R.func_attr({"Composite": "compiler_A.gelu"})
+            # block 0
+            with R.dataflow():
+                gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21)
+                R.output(gv3)
+            return gv3
 
-            lv3: R.Tensor((10,), dtype="float32") = lv12(lv21)
+        lv3: R.Tensor((10,), dtype="float32") = lv12(lv21)
 
-            @R.function
-            def lv22(
-                lv4: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), 
dtype="float32")
-            ) -> R.Tensor((10,), dtype="float32"):
-                # function attr dict
-                R.func_attr({"Composite": "compiler_A.add", "Primitive": 1})
-                # block 0
-                with R.dataflow():
-                    gv4: R.Tensor((10,), dtype="float32") = R.add(lv4, gelu1)
-                    R.output(gv4)
-                return gv4
+        @R.function
+        def lv22(
+            lv4: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), 
dtype="float32")
+        ) -> R.Tensor((10,), dtype="float32"):
+            # function attr dict
+            R.func_attr({"Composite": "compiler_A.add"})
+            # block 0
+            with R.dataflow():
+                gv4: R.Tensor((10,), dtype="float32") = R.add(lv4, gelu1)
+                R.output(gv4)
+            return gv4
+
+        gv5: R.Tensor((10,), dtype="float32") = lv22(lv11, lv3)
 
-            gv5: R.Tensor((10,), dtype="float32") = lv22(lv11, lv3)
-            R.output(gv5)
         return gv5
 
 
@@ -800,36 +777,33 @@ class MergeCompilerRegionsExampleRef:
     ) -> R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((10,), 
dtype="float32")):
         R.func_attr(
             {
-                "Primitive": 1,
                 "Codegen": "compiler_A",
                 "global_symbol": "fused_relax_add_relax_add_relax_nn_relu",
             }
         )
-        with R.dataflow():
 
-            @R.function
-            def lv1(
-                x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), 
dtype="float32")
-            ) -> R.Tensor((10,), dtype="float32"):
-                R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"})
-                with R.dataflow():
-                    gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21)
-                    R.output(gv)
-                return gv
+        @R.function
+        def lv1(
+            x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), 
dtype="float32")
+        ) -> R.Tensor((10,), dtype="float32"):
+            R.func_attr({"Composite": "compiler_A.add"})
+            with R.dataflow():
+                gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21)
+                R.output(gv)
+            return gv
 
-            lv2: R.Tensor((10,), dtype="float32") = lv1(x1, x2)
-            gv1: R.Tensor((10,), dtype="float32") = lv1(lv2, lv)
+        lv2: R.Tensor((10,), dtype="float32") = lv1(x1, x2)
+        gv1: R.Tensor((10,), dtype="float32") = lv1(lv2, lv)
 
-            @R.function
-            def lv11(add2: R.Tensor((10,), dtype="float32")) -> 
R.Tensor((10,), dtype="float32"):
-                R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"})
-                with R.dataflow():
-                    gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2)
-                    R.output(gv2)
-                return gv2
+        @R.function
+        def lv11(add2: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            R.func_attr({"Composite": "compiler_A.relu"})
+            with R.dataflow():
+                gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2)
+                R.output(gv2)
+            return gv2
 
-            gv11: R.Tensor((10,), dtype="float32") = lv11(gv1)
-            R.output(gv1, gv11)
+        gv11: R.Tensor((10,), dtype="float32") = lv11(gv1)
         return (gv1, gv11)
 
     @R.function
@@ -838,56 +812,50 @@ class MergeCompilerRegionsExampleRef:
     ) -> R.Tensor((10,), dtype="float32"):
         R.func_attr(
             {
-                "Primitive": 1,
                 "Codegen": "compiler_A",
                 "global_symbol": "fused_relax_add_relax_nn_relu",
             }
         )
-        with R.dataflow():
 
-            @R.function
-            def lv21(
-                x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), 
dtype="float32")
-            ) -> R.Tensor((10,), dtype="float32"):
-                R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"})
-                with R.dataflow():
-                    gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21)
-                    R.output(gv)
-                return gv
+        @R.function
+        def lv21(
+            x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), 
dtype="float32")
+        ) -> R.Tensor((10,), dtype="float32"):
+            R.func_attr({"Composite": "compiler_A.add"})
+            with R.dataflow():
+                gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21)
+                R.output(gv)
+            return gv
 
-            lv22: R.Tensor((10,), dtype="float32") = lv21(lv12, lv3)
+        lv22: R.Tensor((10,), dtype="float32") = lv21(lv12, lv3)
 
-            @R.function
-            def lv31(add2: R.Tensor((10,), dtype="float32")) -> 
R.Tensor((10,), dtype="float32"):
-                R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"})
-                with R.dataflow():
-                    gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2)
-                    R.output(gv2)
-                return gv2
+        @R.function
+        def lv31(add2: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            R.func_attr({"Composite": "compiler_A.relu"})
+            with R.dataflow():
+                gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2)
+                R.output(gv2)
+            return gv2
+
+        gv3: R.Tensor((10,), dtype="float32") = lv31(lv22)
 
-            gv3: R.Tensor((10,), dtype="float32") = lv31(lv22)
-            R.output(gv3)
         return gv3
 
     @R.function
     def fused_relax_nn_gelu1(
         x3: R.Tensor((10,), dtype="float32")
     ) -> R.Tensor((10,), dtype="float32"):
-        R.func_attr(
-            {"Primitive": 1, "Codegen": "compiler_B", "global_symbol": 
"fused_relax_nn_gelu1"}
-        )
-        with R.dataflow():
+        R.func_attr({"Codegen": "compiler_B", "global_symbol": 
"fused_relax_nn_gelu1"})
 
-            @R.function
-            def lv4(x31: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
-                R.func_attr({"Primitive": 1, "Composite": "compiler_B.gelu"})
-                with R.dataflow():
-                    gv4: R.Tensor((10,), dtype="float32") = R.nn.gelu(x31)
-                    R.output(gv4)
-                return gv4
+        @R.function
+        def lv4(x31: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
+            R.func_attr({"Composite": "compiler_B.gelu"})
+            with R.dataflow():
+                gv4: R.Tensor((10,), dtype="float32") = R.nn.gelu(x31)
+                R.output(gv4)
+            return gv4
 
-            gv5: R.Tensor((10,), dtype="float32") = lv4(x3)
-            R.output(gv5)
+        gv5: R.Tensor((10,), dtype="float32") = lv4(x3)
         return gv5
 
     @R.function
@@ -961,28 +929,24 @@ class ModuleWithNonComposite_ref:
         data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
         weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
     ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
-        R.func_attr(
-            {"Codegen": "tensorrt", "Primitive": 1, "global_symbol": 
"fused_relax_nn_conv2d1"}
-        )
-        with R.dataflow():
+        R.func_attr({"Codegen": "tensorrt", "global_symbol": 
"fused_relax_nn_conv2d1"})
 
-            @R.function
-            def lv1(
-                data2: R.Tensor((1, 64, 56, 56), dtype="float32"),
-                weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
-            ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
-                R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": 1})
-                with R.dataflow():
-                    gv: R.Tensor((1, 64, 56, 56), dtype="float32") = 
R.nn.conv2d(
-                        data2,
-                        weight2,
-                        padding=[1, 1, 1, 1],
-                    )
-                    R.output(gv)
-                return gv
+        @R.function
+        def lv1(
+            data2: R.Tensor((1, 64, 56, 56), dtype="float32"),
+            weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
+        ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+            R.func_attr({"Composite": "tensorrt.conv2d"})
+            with R.dataflow():
+                gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+                    data2,
+                    weight2,
+                    padding=[1, 1, 1, 1],
+                )
+                R.output(gv)
+            return gv
 
-            gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = lv1(data1, 
weight1)
-            R.output(gv1)
+        gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = lv1(data1, weight1)
         return gv1
 
 
@@ -1120,41 +1084,38 @@ def test_reshape():
             R.func_attr(
                 {
                     "Codegen": "tensorrt",
-                    "Primitive": 1,
                     "global_symbol": "fused_relax_reshape_relax_matmul",
                 }
             )
-            with R.dataflow():
-                # from tvm.script import relax as R
-
-                @R.function
-                def lv_1(
-                    inp_0_1: R.Tensor((1, 1, 28, 28), dtype="float32"), 
param_0_1: R.Shape([1, 784])
-                ) -> R.Tensor((1, 784), dtype="float32"):
-                    R.func_attr({"Composite": "tensorrt.reshape", "Primitive": 
1})
-                    with R.dataflow():
-                        gv: R.Tensor((1, 784), dtype="float32") = 
R.reshape(inp_0_1, param_0_1)
-                        R.output(gv)
-                    return gv
-
-                lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0, 
param_0)
-
-                @R.function
-                def lv1_1_1(
-                    lv_2: R.Tensor((1, 784), dtype="float32"),
-                    lv1_2: R.Tensor((784, 512), dtype="float32"),
-                ) -> R.Tensor((1, 512), dtype="float32"):
-                    R.func_attr({"Composite": "tensorrt.matmul", "Primitive": 
1})
-                    with R.dataflow():
-                        gv: R.Tensor((1, 512), dtype="float32") = R.matmul(
-                            lv_2, lv1_2, out_dtype="float32"
-                        )
-                        R.output(gv)
-                    return gv
-
-                lv_2: R.Tensor((1, 512), dtype="float32") = lv1_1_1(lv_1, lv1)
-                gv: R.Tensor((1, 512), dtype="float32") = lv_2
-                R.output(gv)
+            # from tvm.script import relax as R
+
+            @R.function
+            def lv_1(
+                inp_0_1: R.Tensor((1, 1, 28, 28), dtype="float32"), param_0_1: 
R.Shape([1, 784])
+            ) -> R.Tensor((1, 784), dtype="float32"):
+                R.func_attr({"Composite": "tensorrt.reshape"})
+                with R.dataflow():
+                    gv: R.Tensor((1, 784), dtype="float32") = 
R.reshape(inp_0_1, param_0_1)
+                    R.output(gv)
+                return gv
+
+            lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0, param_0)
+
+            @R.function
+            def lv1_1_1(
+                lv_2: R.Tensor((1, 784), dtype="float32"),
+                lv1_2: R.Tensor((784, 512), dtype="float32"),
+            ) -> R.Tensor((1, 512), dtype="float32"):
+                R.func_attr({"Composite": "tensorrt.matmul"})
+                with R.dataflow():
+                    gv: R.Tensor((1, 512), dtype="float32") = R.matmul(
+                        lv_2, lv1_2, out_dtype="float32"
+                    )
+                    R.output(gv)
+                return gv
+
+            lv_2: R.Tensor((1, 512), dtype="float32") = lv1_1_1(lv_1, lv1)
+            gv: R.Tensor((1, 512), dtype="float32") = lv_2
             return gv
 
         @R.function

Reply via email to