This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 1afbf20129 [Unity][BYOC] Unify the interface between
`FuseOpsByPattern(.., annotate_codegen=True)`and `MergeCompositeFunctions()`
(#15491)
1afbf20129 is described below
commit 1afbf20129647a35d108152fc6789bc1d029cda5
Author: Sunghyun Park <[email protected]>
AuthorDate: Tue Aug 8 01:53:17 2023 -0700
[Unity][BYOC] Unify the interface between `FuseOpsByPattern(..,
annotate_codegen=True)`and `MergeCompositeFunctions()` (#15491)
* feat: consistent interface
* fix comment
* fix
* bugfix
* fix
---
include/tvm/relax/utils.h | 5 +
src/relax/transform/fuse_ops.cc | 14 +-
src/relax/transform/merge_composite_functions.cc | 22 +-
src/relax/transform/to_non_dataflow.cc | 1 +
src/relax/transform/utils.h | 7 +
tests/python/relax/test_codegen_cutlass.py | 8 +-
.../relax/test_transform_fuse_ops_by_pattern.py | 106 +++-
.../test_transform_merge_composite_functions.py | 649 ++++++++++-----------
8 files changed, 443 insertions(+), 369 deletions(-)
diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h
index a1f587e14e..1a6d5d4a52 100644
--- a/include/tvm/relax/utils.h
+++ b/include/tvm/relax/utils.h
@@ -105,6 +105,11 @@ TVM_DLL bool IsImpureCall(const Call& call);
*/
TVM_DLL Function CopyWithNewVars(Function func);
+/*!
+ * \brief Transform all dataflow structure to non-dataflow version.
+ */
+Expr ToNonDataflow(const Expr& e);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index f4e7e00f30..2c042db4fd 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -1199,6 +1199,7 @@ class CompositeFunctionAnnotator : public ExprMutator {
auto gsymbol = gvar->name_hint + "_" + codegen_name;
new_func = WithAttrs(new_func,
{{attr::kCodegen, codegen_name},
{tvm::attr::kGlobalSymbol, gsymbol}});
+ new_func = WithoutAttr(std::move(new_func),
tvm::relax::attr::kPrimitive);
builder_->GetContextIRModule()->Remove(GetRef<GlobalVar>(gvar));
auto new_gvar = builder_->AddFunction(new_func, gsymbol);
gvar_map_[gvar] = new_gvar;
@@ -1209,8 +1210,10 @@ class CompositeFunctionAnnotator : public ExprMutator {
}
Expr VisitExpr_(const FunctionNode* func_node) final {
- auto f_inner = ExprMutator::VisitExpr_(func_node);
+ Function f_inner = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
auto composite_name = func_node->GetAttr<String>(attr::kComposite);
+
+ f_inner = WithoutAttr(std::move(f_inner), tvm::relax::attr::kPrimitive);
ICHECK(composite_name);
Array<Var> param_vars;
@@ -1224,17 +1227,10 @@ class CompositeFunctionAnnotator : public ExprMutator {
// pure if the inner func is pure (no need to force purity if it's forced
for the inner func)
return Function(param_vars, Call(f_inner, params),
func_node->ret_struct_info,
- Downcast<Function>(f_inner)->is_pure);
+ f_inner->is_pure);
}
private:
- String GetCodegenName(const std::string& composite_name) {
- auto delim_pos = composite_name.find(".");
- ICHECK(delim_pos != std::string::npos) << "The pattern name for a
composite function should "
- "start with a compiler name
followed by period.";
- return composite_name.substr(0, delim_pos);
- }
-
/*! \brief A map from old global vars to their replacements. */
std::unordered_map<const GlobalVarNode*, GlobalVar> gvar_map_;
};
diff --git a/src/relax/transform/merge_composite_functions.cc
b/src/relax/transform/merge_composite_functions.cc
index 0bc92ba923..365b34c601 100644
--- a/src/relax/transform/merge_composite_functions.cc
+++ b/src/relax/transform/merge_composite_functions.cc
@@ -39,7 +39,7 @@
* Correct partitioning:
*
* O O
- * / \ / \
+ * / \ / \
* O X --> O + + X
* \ / \ /
* O O
@@ -167,13 +167,6 @@ class CompositeGroupsBuilder : public
MemoizedExprTranslator<Group*> {
}
private:
- String GetCodegenName(const std::string& composite_name) {
- auto delim_pos = composite_name.find(".");
- ICHECK(delim_pos != std::string::npos) << "The pattern name for a
composite function should "
- "start with a compiler name
followed by period.";
- return composite_name.substr(0, delim_pos);
- }
-
Optional<String> GetCodegenName(const Expr& callee) {
auto const* gvar = callee.as<GlobalVarNode>();
if (!gvar) {
@@ -186,7 +179,7 @@ class CompositeGroupsBuilder : public
MemoizedExprTranslator<Group*> {
return NullOpt;
}
- return GetCodegenName(composite_name_opt.value());
+ return relax::GetCodegenName(composite_name_opt.value());
}
Optional<String> GetCodegenName(Group* group) {
@@ -298,7 +291,7 @@ class CompositeInliner : public ExprMutator {
Function Run(Function func) {
inlined_functions_ = Map<Function, Function>();
- auto new_body = VisitExpr(func->body);
+ auto new_body = VisitExpr(ToNonDataflow(func->body));
auto new_func = Function(func->params, new_body, func->ret_struct_info,
func->is_pure,
func->attrs, func->span);
return new_func;
@@ -308,10 +301,11 @@ class CompositeInliner : public ExprMutator {
if (call->op->IsInstance<GlobalVarNode>()) {
auto gvar = Downcast<GlobalVar>(call->op);
auto func = Downcast<Function>(mod_->Lookup(gvar));
-
if (func->GetAttr<String>(attr::kComposite)) {
if (!inlined_functions_.count(func)) {
- inlined_functions_.Set(func, CopyWithNewVars(func));
+ auto new_func = CopyWithNewVars(func);
+ new_func = WithoutAttr(new_func, tvm::relax::attr::kPrimitive);
+ inlined_functions_.Set(func, new_func);
}
return Call(inlined_functions_[func], call->args);
}
@@ -340,11 +334,13 @@ IRModule MergeCompositeFunctions(IRModule mod) {
if (func->GetAttr<String>(attr::kCodegen)) {
auto new_func = inliner.Run(Downcast<Function>(func));
new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, gvar->name_hint);
+ new_func = WithoutAttr(std::move(new_func),
tvm::relax::attr::kPrimitive);
to_update.emplace_back(gvar, new_func);
}
}
+
for (const auto& [gvar, func] : to_update) {
- new_mod->Update(gvar, func);
+ new_mod->Update(gvar, Downcast<Function>(func));
}
// TODO(@tvm-team): Implicit pass dependency. Revisit when we have a better
way to handle this.
return DeadCodeElimination(new_mod, {"main"});
diff --git a/src/relax/transform/to_non_dataflow.cc
b/src/relax/transform/to_non_dataflow.cc
index db2e9d7ee5..5c790e9a73 100644
--- a/src/relax/transform/to_non_dataflow.cc
+++ b/src/relax/transform/to_non_dataflow.cc
@@ -24,6 +24,7 @@
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
#include <tvm/relax/type.h>
+#include <tvm/relax/utils.h>
#include <tvm/tir/op.h>
namespace tvm {
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index 489a36a5a4..3d40a565bd 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -381,6 +381,13 @@ inline Array<Integer> GetOrderedPositiveAxes(const
Array<Integer>& axes, int ndi
return support::AsArray<int64_t, Integer>(ret);
}
+inline String GetCodegenName(const std::string& composite_name) {
+ auto delim_pos = composite_name.find(".");
+ ICHECK(delim_pos != std::string::npos) << "The pattern name for a composite
function should "
+ "start with a compiler name
followed by period.";
+ return composite_name.substr(0, delim_pos);
+}
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index d19189ff34..952036584f 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -1220,7 +1220,6 @@ def test_attention_rewrite_fp16():
R.func_attr(
{
"Composite": "cutlass.attention_bias",
- "Primitive": 1,
"WorkspaceSize": T.int64(65536),
}
)
@@ -1737,6 +1736,13 @@ def test_rms_norm():
dtype = "float16"
mod = partition_for_cutlass(Module)
+ # TODO(@tvm-team): This is temporary patch.Currently, the remaining packed
function triggers error since it is not scheduled.
+ # This is because RunCodegen does not support PrimFunc well yet.
+ # i.e., it does remove the global symbol of PrimFunc, which would be no
longer used,
+ # and thus, the following DCE cannot remove this. Revisit when resolved.
+ with tvm.target.Target("cuda"):
+ mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
+
mod = relax.transform.RunCodegen()(mod)
inp = np.random.randn(*data_shape).astype(dtype)
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 1352a52674..d95ed68c61 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -78,7 +78,7 @@ class Conv2dReLU_composite_annotated:
data2: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight12: R.Tensor((64, 64, 3, 3), dtype="float32"),
) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
- R.func_attr({"Primitive": 1, "Composite": "dnnl.conv2d_relu"})
+ R.func_attr({"Composite": "dnnl.conv2d_relu"})
with R.dataflow():
lv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
data2,
@@ -394,7 +394,7 @@ class Conv2dx2_partitioned:
data_1: R.Tensor((16, 32, 32, 16), dtype="float16"),
weight1_1: R.Tensor((16, 3, 3, 16), dtype="float16"),
) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
- R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1})
+ R.func_attr({"Composite": "cutlass.conv2d"})
with R.dataflow():
gv_2: R.Tensor((16, 32, 32, 16), dtype="float16") =
R.nn.conv2d(
data_1,
@@ -529,6 +529,108 @@ def test_annotate_codegen():
)
+def test_compare_with_merge_composite_path():
+ x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32"))
+ y = relax.Var("y", relax.TensorStructInfo([10, 10], "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("main", [x, y]):
+ with bb.dataflow():
+ lv0 = bb.emit(relax.op.multiply(x, y))
+ gv = bb.emit_output(lv0)
+ bb.emit_func_output(gv)
+ mod = bb.get()
+ mod = relax.transform.FoldDataflowBlockOutput()(mod)
+
+ # Currently, we have two paths for BYOC.
+ # Path1. [FuseOpsByPattern(patterns, annotate_codegen=True), RunCodegen()]
+ # Path2. [FuseOpsByPattern(patterns, annotate_codegen=False),
MergeCompositeFunctions(), RunCodegen()]
+ # For consistency, both paths should have same interface with RunCodegen().
+ # As each path has different naming convention due to the difference in
the algorithm,
+ # we compare with expected form of each path rather than directly applying
structural equality check between two paths.
+ patterns = [("cutlass.multiply", is_op("relax.multiply")(wildcard(),
wildcard()))]
+ mod1 = relax.transform.FuseOpsByPattern(patterns, bind_constants=True,
annotate_codegen=True)(
+ mod
+ )
+ assert tvm.relax.analysis.well_formed(mod1)
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def fused_relax_multiply_cutlass(
+ x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10),
dtype="float32")
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ R.func_attr({"Codegen": "cutlass"})
+ # from tvm.script import relax as R
+
+ @R.function
+ def gv(
+ x_1: R.Tensor((10, 10), dtype="float32"),
+ y_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ R.func_attr({"Composite": "cutlass.multiply"})
+ with R.dataflow():
+ gv_1: R.Tensor((10, 10), dtype="float32") =
R.multiply(x_1, y_1)
+ R.output(gv_1)
+ return gv_1
+
+ gv1: R.Tensor((10, 10), dtype="float32") = gv(x, y)
+ return gv1
+
+ @R.function
+ def main(
+ x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10),
dtype="float32")
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ cls = Expected1
+ with R.dataflow():
+ gv: R.Tensor((10, 10), dtype="float32") =
cls.fused_relax_multiply_cutlass(x, y)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod1, Expected1)
+
+ mod2 = relax.transform.FuseOpsByPattern(patterns, bind_constants=True,
annotate_codegen=False)(
+ mod
+ )
+ mod2 = relax.transform.MergeCompositeFunctions()(mod2)
+ assert tvm.relax.analysis.well_formed(mod2)
+
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def fused_relax_multiply1(
+ x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10),
dtype="float32")
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ R.func_attr({"Codegen": "cutlass"})
+ # from tvm.script import relax as R
+
+ @R.function
+ def gv(
+ x_1: R.Tensor((10, 10), dtype="float32"),
+ y_1: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ R.func_attr({"Composite": "cutlass.multiply"})
+ with R.dataflow():
+ gv_1: R.Tensor((10, 10), dtype="float32") =
R.multiply(x_1, y_1)
+ R.output(gv_1)
+
+ return gv_1
+
+ gv_1: R.Tensor((10, 10), dtype="float32") = gv(x, y)
+ return gv_1
+
+ @R.function
+ def main(
+ x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10),
dtype="float32")
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ cls = Expected2
+ with R.dataflow():
+ gv: R.Tensor((10, 10), dtype="float32") =
cls.fused_relax_multiply1(x, y)
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod2, Expected2)
+
+
def test_multiple_entries_multiple_calls_same_extern():
pat = make_fused_bias_activation_pattern("relax.nn.conv2d",
with_bias=False, activation=None)
check(Conv2dx2, [("cutlass.conv2d", pat)], Conv2dx2_partitioned,
annotate_codegen=True)
diff --git a/tests/python/relax/test_transform_merge_composite_functions.py
b/tests/python/relax/test_transform_merge_composite_functions.py
index d552266131..d56e1db564 100644
--- a/tests/python/relax/test_transform_merge_composite_functions.py
+++ b/tests/python/relax/test_transform_merge_composite_functions.py
@@ -100,49 +100,46 @@ class Conv2dReLUx2_merged:
) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
R.func_attr(
{
- "Primitive": 1,
"Codegen": "dnnl",
"global_symbol":
"fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1",
}
)
- with R.dataflow():
- @R.function
- def lv(
- data11: R.Tensor((1, 64, 56, 56), dtype="float32"),
- weight111: R.Tensor((64, 64, 3, 3), dtype="float32"),
- ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
- R.func_attr({"Composite": "dnnl.conv2d_relu", "Primitive": 1})
- with R.dataflow():
- lv1: R.Tensor((1, 64, 56, 56), dtype="float32") =
R.nn.conv2d(
- data11,
- weight111,
- padding=[1, 1, 1, 1],
- )
- gv1: R.Tensor((1, 64, 56, 56), dtype="float32") =
R.nn.relu(lv1)
- R.output(gv1)
- return gv1
+ @R.function
+ def lv(
+ data11: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight111: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr({"Composite": "dnnl.conv2d_relu"})
+ with R.dataflow():
+ lv1: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+ data11,
+ weight111,
+ padding=[1, 1, 1, 1],
+ )
+ gv1: R.Tensor((1, 64, 56, 56), dtype="float32") =
R.nn.relu(lv1)
+ R.output(gv1)
+ return gv1
- lv2: R.Tensor((1, 64, 56, 56), dtype="float32") = lv(data1,
weight11)
+ lv2: R.Tensor((1, 64, 56, 56), dtype="float32") = lv(data1, weight11)
- @R.function
- def lv11(
- conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
- weight211: R.Tensor((64, 64, 3, 3), dtype="float32"),
- ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- R.func_attr({"Composite": "dnnl.conv2d_relu", "Primitive": 1})
- with R.dataflow():
- lv21: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.conv2d(
- conv1,
- weight211,
- padding=[0, 0, 0, 0],
- )
- gv2: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.relu(lv21)
- R.output(gv2)
- return gv2
+ @R.function
+ def lv11(
+ conv1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight211: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Composite": "dnnl.conv2d_relu"})
+ with R.dataflow():
+ lv21: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+ conv1,
+ weight211,
+ padding=[0, 0, 0, 0],
+ )
+ gv2: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.relu(lv21)
+ R.output(gv2)
+ return gv2
- gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv11(lv2,
weight21)
- R.output(gv3)
+ gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv11(lv2, weight21)
return gv3
@@ -222,82 +219,79 @@ class Diamond_merged:
R.func_attr(
{
"Codegen": "compiler_A",
- "Primitive": 1,
"global_symbol":
"fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add",
}
)
- # block 0
- with R.dataflow():
- @R.function
- def lv(
- data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
- weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
- ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- # function attr dict
- R.func_attr({"Composite": "compiler_A.conv2d", "Primitive": 1})
- # block 0
- with R.dataflow():
- gv4: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.conv2d(
- data1,
- weight1,
- strides=[1, 1],
- padding=[0, 0, 0, 0],
- dilation=[1, 1],
- groups=1,
- data_layout="NCHW",
- kernel_layout="OIHW",
- out_layout="NCHW",
- out_dtype="",
- )
- R.output(gv4)
- return gv4
+ # block 0
+ @R.function
+ def lv(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.conv2d"})
+ # block 0
+ with R.dataflow():
+ gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+ data1,
+ weight1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="",
+ )
+ R.output(gv4)
+ return gv4
- lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight)
+ lv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight)
- @R.function
- def lv1(
- lv11: R.Tensor((1, 64, 54, 54), dtype="float32")
- ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- # function attr dict
- R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
- # block 0
- with R.dataflow():
- gv1: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.relu(lv11)
- R.output(gv1)
- return gv1
+ @R.function
+ def lv1(
+ lv11: R.Tensor((1, 64, 54, 54), dtype="float32")
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.relu"})
+ # block 0
+ with R.dataflow():
+ gv1: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.relu(lv11)
+ R.output(gv1)
+ return gv1
- lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(lv2)
+ lv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(lv2)
- @R.function
- def lv21(
- lv4: R.Tensor((1, 64, 54, 54), dtype="float32")
- ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- # function attr dict
- R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1})
- # block 0
- with R.dataflow():
- gv: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.gelu(lv4)
- R.output(gv)
- return gv
+ @R.function
+ def lv21(
+ lv4: R.Tensor((1, 64, 54, 54), dtype="float32")
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.gelu"})
+ # block 0
+ with R.dataflow():
+ gv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.gelu(lv4)
+ R.output(gv)
+ return gv
- lv41: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2)
+ lv41: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2)
- @R.function
- def lv31(
- lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
- gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
- ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- # function attr dict
- R.func_attr({"Composite": "compiler_A.add", "Primitive": 1})
- # block 0
- with R.dataflow():
- gv3: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.add(lv5, gelu1)
- R.output(gv3)
- return gv3
+ @R.function
+ def lv31(
+ lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.add"})
+ # block 0
+ with R.dataflow():
+ gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5,
gelu1)
+ R.output(gv3)
+ return gv3
- gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv31(lv3, lv41)
- R.output(gv2)
+ gv2: R.Tensor((1, 64, 54, 54), dtype="float32") = lv31(lv3, lv41)
return gv2
@R.function
@@ -408,65 +402,59 @@ class Diamond_cyclic_dep_merged:
):
R.func_attr(
{
- "Primitive": 1,
"Codegen": "compiler_A",
"global_symbol": "fused_relax_nn_conv2d_relax_nn_relu",
}
)
- with R.dataflow():
- @R.function
- def lv(
- data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
- weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
- ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- R.func_attr({"Composite": "compiler_A.conv2d", "Primitive": 1})
- with R.dataflow():
- gv4: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.conv2d(
- data1,
- weight1,
- padding=[0, 0, 0, 0],
- )
- R.output(gv4)
- return gv4
+ @R.function
+ def lv(
+ data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Composite": "compiler_A.conv2d"})
+ with R.dataflow():
+ gv4: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(
+ data1,
+ weight1,
+ padding=[0, 0, 0, 0],
+ )
+ R.output(gv4)
+ return gv4
- gv: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight)
+ gv: R.Tensor((1, 64, 54, 54), dtype="float32") = lv(data, weight)
- @R.function
- def lv1(
- lv11: R.Tensor((1, 64, 54, 54), dtype="float32")
- ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
- with R.dataflow():
- gv1: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.relu(lv11)
- R.output(gv1)
- return gv1
+ @R.function
+ def lv1(
+ lv11: R.Tensor((1, 64, 54, 54), dtype="float32")
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Composite": "compiler_A.relu"})
+ with R.dataflow():
+ gv1: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.relu(lv11)
+ R.output(gv1)
+ return gv1
+
+ gv11: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(gv)
- gv11: R.Tensor((1, 64, 54, 54), dtype="float32") = lv1(gv)
- R.output(gv, gv11)
return (gv, gv11)
@R.function
def fused_relax_nn_gelu1(
lv2: R.Tensor((1, 64, 54, 54), dtype="float32")
) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- R.func_attr(
- {"Primitive": 1, "Codegen": "compiler_B", "global_symbol":
"fused_relax_nn_gelu1"}
- )
- with R.dataflow():
+ R.func_attr({"Codegen": "compiler_B", "global_symbol":
"fused_relax_nn_gelu1"})
- @R.function
- def lv21(
- lv3: R.Tensor((1, 64, 54, 54), dtype="float32")
- ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- R.func_attr({"Composite": "compiler_B.gelu", "Primitive": 1})
- with R.dataflow():
- gv2: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.gelu(lv3)
- R.output(gv2)
- return gv2
+ @R.function
+ def lv21(
+ lv3: R.Tensor((1, 64, 54, 54), dtype="float32")
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Composite": "compiler_B.gelu"})
+ with R.dataflow():
+ gv2: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.nn.gelu(lv3)
+ R.output(gv2)
+ return gv2
- gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2)
- R.output(gv3)
+ gv3: R.Tensor((1, 64, 54, 54), dtype="float32") = lv21(lv2)
return gv3
@R.function
@@ -474,22 +462,20 @@ class Diamond_cyclic_dep_merged:
lv32: R.Tensor((1, 64, 54, 54), dtype="float32"),
lv41: R.Tensor((1, 64, 54, 54), dtype="float32"),
) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- R.func_attr({"Primitive": 1, "Codegen": "compiler_A", "global_symbol":
"fused_relax_add1"})
- with R.dataflow():
+ R.func_attr({"Codegen": "compiler_A", "global_symbol":
"fused_relax_add1"})
- @R.function
- def lv33(
- lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
- gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
- ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
- R.func_attr({"Composite": "compiler_A.add", "Primitive": 1})
- with R.dataflow():
- gv31: R.Tensor((1, 64, 54, 54), dtype="float32") =
R.add(lv5, gelu1)
- R.output(gv31)
- return gv31
+ @R.function
+ def lv33(
+ lv5: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"),
+ ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
+ R.func_attr({"Composite": "compiler_A.add"})
+ with R.dataflow():
+ gv31: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(lv5,
gelu1)
+ R.output(gv31)
+ return gv31
- gv6: R.Tensor((1, 64, 54, 54), dtype="float32") = lv33(lv32, lv41)
- R.output(gv6)
+ gv6: R.Tensor((1, 64, 54, 54), dtype="float32") = lv33(lv32, lv41)
return gv6
@@ -550,53 +536,50 @@ class MultipleProducers_merged:
R.func_attr(
{
"Codegen": "compiler_A",
- "Primitive": 1,
"global_symbol":
"fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add",
}
)
- # block 0
- with R.dataflow():
- @R.function
- def lv(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
- # function attr dict
- R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
- # block 0
- with R.dataflow():
- gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11)
- R.output(gv2)
- return gv2
+ # block 0
+ @R.function
+ def lv(x11: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.relu"})
+ # block 0
+ with R.dataflow():
+ gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x11)
+ R.output(gv2)
+ return gv2
- lv1: R.Tensor((10,), dtype="float32") = lv(x1)
+ lv1: R.Tensor((10,), dtype="float32") = lv(x1)
- @R.function
- def lv11(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
- # function attr dict
- R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1})
- # block 0
- with R.dataflow():
- gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21)
- R.output(gv3)
- return gv3
+ @R.function
+ def lv11(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.gelu"})
+ # block 0
+ with R.dataflow():
+ gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21)
+ R.output(gv3)
+ return gv3
- lv2: R.Tensor((10,), dtype="float32") = lv11(x2)
- lv3: R.Tensor((10,), dtype="float32") = lv(lv1)
- lv4: R.Tensor((10,), dtype="float32") = lv11(lv2)
+ lv2: R.Tensor((10,), dtype="float32") = lv11(x2)
+ lv3: R.Tensor((10,), dtype="float32") = lv(lv1)
+ lv4: R.Tensor((10,), dtype="float32") = lv11(lv2)
- @R.function
- def lv21(
- lv5: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,),
dtype="float32")
- ) -> R.Tensor((10,), dtype="float32"):
- # function attr dict
- R.func_attr({"Composite": "compiler_A.add", "Primitive": 1})
- # block 0
- with R.dataflow():
- gv: R.Tensor((10,), dtype="float32") = R.add(lv5, gelu1)
- R.output(gv)
- return gv
+ @R.function
+ def lv21(
+ lv5: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.add"})
+ # block 0
+ with R.dataflow():
+ gv: R.Tensor((10,), dtype="float32") = R.add(lv5, gelu1)
+ R.output(gv)
+ return gv
- gv1: R.Tensor((10,), dtype="float32") = lv21(lv3, lv4)
- R.output(gv1)
+ gv1: R.Tensor((10,), dtype="float32") = lv21(lv3, lv4)
return gv1
@R.function
@@ -675,24 +658,20 @@ class MultipleProducersCyclic_merged:
x11: R.Tensor((10,), dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
# function attr dict
- R.func_attr(
- {"Codegen": "compiler_A", "Primitive": 1, "global_symbol":
"fused_relax_nn_relu1"}
- )
- # block 0
- with R.dataflow():
+ R.func_attr({"Codegen": "compiler_A", "global_symbol":
"fused_relax_nn_relu1"})
- @R.function
- def lv1(x111: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
- # function attr dict
- R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
- # block 0
- with R.dataflow():
- gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x111)
- R.output(gv2)
- return gv2
+ # block 0
+ @R.function
+ def lv1(x111: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.relu"})
+ # block 0
+ with R.dataflow():
+ gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(x111)
+ R.output(gv2)
+ return gv2
- gv1: R.Tensor((10,), dtype="float32") = lv1(x11)
- R.output(gv1)
+ gv1: R.Tensor((10,), dtype="float32") = lv1(x11)
return gv1
@R.function
@@ -703,39 +682,37 @@ class MultipleProducersCyclic_merged:
R.func_attr(
{
"Codegen": "compiler_A",
- "Primitive": 1,
"global_symbol": "fused_relax_nn_gelu_relax_add",
}
)
# block 0
- with R.dataflow():
- @R.function
- def lv12(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
- # function attr dict
- R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1})
- # block 0
- with R.dataflow():
- gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21)
- R.output(gv3)
- return gv3
+ @R.function
+ def lv12(x21: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.gelu"})
+ # block 0
+ with R.dataflow():
+ gv3: R.Tensor((10,), dtype="float32") = R.nn.gelu(x21)
+ R.output(gv3)
+ return gv3
- lv3: R.Tensor((10,), dtype="float32") = lv12(lv21)
+ lv3: R.Tensor((10,), dtype="float32") = lv12(lv21)
- @R.function
- def lv22(
- lv4: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,),
dtype="float32")
- ) -> R.Tensor((10,), dtype="float32"):
- # function attr dict
- R.func_attr({"Composite": "compiler_A.add", "Primitive": 1})
- # block 0
- with R.dataflow():
- gv4: R.Tensor((10,), dtype="float32") = R.add(lv4, gelu1)
- R.output(gv4)
- return gv4
+ @R.function
+ def lv22(
+ lv4: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ # function attr dict
+ R.func_attr({"Composite": "compiler_A.add"})
+ # block 0
+ with R.dataflow():
+ gv4: R.Tensor((10,), dtype="float32") = R.add(lv4, gelu1)
+ R.output(gv4)
+ return gv4
+
+ gv5: R.Tensor((10,), dtype="float32") = lv22(lv11, lv3)
- gv5: R.Tensor((10,), dtype="float32") = lv22(lv11, lv3)
- R.output(gv5)
return gv5
@@ -800,36 +777,33 @@ class MergeCompilerRegionsExampleRef:
) -> R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((10,),
dtype="float32")):
R.func_attr(
{
- "Primitive": 1,
"Codegen": "compiler_A",
"global_symbol": "fused_relax_add_relax_add_relax_nn_relu",
}
)
- with R.dataflow():
- @R.function
- def lv1(
- x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,),
dtype="float32")
- ) -> R.Tensor((10,), dtype="float32"):
- R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"})
- with R.dataflow():
- gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21)
- R.output(gv)
- return gv
+ @R.function
+ def lv1(
+ x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Composite": "compiler_A.add"})
+ with R.dataflow():
+ gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21)
+ R.output(gv)
+ return gv
- lv2: R.Tensor((10,), dtype="float32") = lv1(x1, x2)
- gv1: R.Tensor((10,), dtype="float32") = lv1(lv2, lv)
+ lv2: R.Tensor((10,), dtype="float32") = lv1(x1, x2)
+ gv1: R.Tensor((10,), dtype="float32") = lv1(lv2, lv)
- @R.function
- def lv11(add2: R.Tensor((10,), dtype="float32")) ->
R.Tensor((10,), dtype="float32"):
- R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"})
- with R.dataflow():
- gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2)
- R.output(gv2)
- return gv2
+ @R.function
+ def lv11(add2: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ R.func_attr({"Composite": "compiler_A.relu"})
+ with R.dataflow():
+ gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2)
+ R.output(gv2)
+ return gv2
- gv11: R.Tensor((10,), dtype="float32") = lv11(gv1)
- R.output(gv1, gv11)
+ gv11: R.Tensor((10,), dtype="float32") = lv11(gv1)
return (gv1, gv11)
@R.function
@@ -838,56 +812,50 @@ class MergeCompilerRegionsExampleRef:
) -> R.Tensor((10,), dtype="float32"):
R.func_attr(
{
- "Primitive": 1,
"Codegen": "compiler_A",
"global_symbol": "fused_relax_add_relax_nn_relu",
}
)
- with R.dataflow():
- @R.function
- def lv21(
- x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,),
dtype="float32")
- ) -> R.Tensor((10,), dtype="float32"):
- R.func_attr({"Primitive": 1, "Composite": "compiler_A.add"})
- with R.dataflow():
- gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21)
- R.output(gv)
- return gv
+ @R.function
+ def lv21(
+ x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tensor((10,), dtype="float32"):
+ R.func_attr({"Composite": "compiler_A.add"})
+ with R.dataflow():
+ gv: R.Tensor((10,), dtype="float32") = R.add(x11, x21)
+ R.output(gv)
+ return gv
- lv22: R.Tensor((10,), dtype="float32") = lv21(lv12, lv3)
+ lv22: R.Tensor((10,), dtype="float32") = lv21(lv12, lv3)
- @R.function
- def lv31(add2: R.Tensor((10,), dtype="float32")) ->
R.Tensor((10,), dtype="float32"):
- R.func_attr({"Primitive": 1, "Composite": "compiler_A.relu"})
- with R.dataflow():
- gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2)
- R.output(gv2)
- return gv2
+ @R.function
+ def lv31(add2: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ R.func_attr({"Composite": "compiler_A.relu"})
+ with R.dataflow():
+ gv2: R.Tensor((10,), dtype="float32") = R.nn.relu(add2)
+ R.output(gv2)
+ return gv2
+
+ gv3: R.Tensor((10,), dtype="float32") = lv31(lv22)
- gv3: R.Tensor((10,), dtype="float32") = lv31(lv22)
- R.output(gv3)
return gv3
@R.function
def fused_relax_nn_gelu1(
x3: R.Tensor((10,), dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
- R.func_attr(
- {"Primitive": 1, "Codegen": "compiler_B", "global_symbol":
"fused_relax_nn_gelu1"}
- )
- with R.dataflow():
+ R.func_attr({"Codegen": "compiler_B", "global_symbol":
"fused_relax_nn_gelu1"})
- @R.function
- def lv4(x31: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
- R.func_attr({"Primitive": 1, "Composite": "compiler_B.gelu"})
- with R.dataflow():
- gv4: R.Tensor((10,), dtype="float32") = R.nn.gelu(x31)
- R.output(gv4)
- return gv4
+ @R.function
+ def lv4(x31: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,),
dtype="float32"):
+ R.func_attr({"Composite": "compiler_B.gelu"})
+ with R.dataflow():
+ gv4: R.Tensor((10,), dtype="float32") = R.nn.gelu(x31)
+ R.output(gv4)
+ return gv4
- gv5: R.Tensor((10,), dtype="float32") = lv4(x3)
- R.output(gv5)
+ gv5: R.Tensor((10,), dtype="float32") = lv4(x3)
return gv5
@R.function
@@ -961,28 +929,24 @@ class ModuleWithNonComposite_ref:
data1: R.Tensor((1, 64, 56, 56), dtype="float32"),
weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
- R.func_attr(
- {"Codegen": "tensorrt", "Primitive": 1, "global_symbol":
"fused_relax_nn_conv2d1"}
- )
- with R.dataflow():
+ R.func_attr({"Codegen": "tensorrt", "global_symbol":
"fused_relax_nn_conv2d1"})
- @R.function
- def lv1(
- data2: R.Tensor((1, 64, 56, 56), dtype="float32"),
- weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
- ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
- R.func_attr({"Composite": "tensorrt.conv2d", "Primitive": 1})
- with R.dataflow():
- gv: R.Tensor((1, 64, 56, 56), dtype="float32") =
R.nn.conv2d(
- data2,
- weight2,
- padding=[1, 1, 1, 1],
- )
- R.output(gv)
- return gv
+ @R.function
+ def lv1(
+ data2: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight2: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr({"Composite": "tensorrt.conv2d"})
+ with R.dataflow():
+ gv: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(
+ data2,
+ weight2,
+ padding=[1, 1, 1, 1],
+ )
+ R.output(gv)
+ return gv
- gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = lv1(data1,
weight1)
- R.output(gv1)
+ gv1: R.Tensor((1, 64, 56, 56), dtype="float32") = lv1(data1, weight1)
return gv1
@@ -1120,41 +1084,38 @@ def test_reshape():
R.func_attr(
{
"Codegen": "tensorrt",
- "Primitive": 1,
"global_symbol": "fused_relax_reshape_relax_matmul",
}
)
- with R.dataflow():
- # from tvm.script import relax as R
-
- @R.function
- def lv_1(
- inp_0_1: R.Tensor((1, 1, 28, 28), dtype="float32"),
param_0_1: R.Shape([1, 784])
- ) -> R.Tensor((1, 784), dtype="float32"):
- R.func_attr({"Composite": "tensorrt.reshape", "Primitive":
1})
- with R.dataflow():
- gv: R.Tensor((1, 784), dtype="float32") =
R.reshape(inp_0_1, param_0_1)
- R.output(gv)
- return gv
-
- lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0,
param_0)
-
- @R.function
- def lv1_1_1(
- lv_2: R.Tensor((1, 784), dtype="float32"),
- lv1_2: R.Tensor((784, 512), dtype="float32"),
- ) -> R.Tensor((1, 512), dtype="float32"):
- R.func_attr({"Composite": "tensorrt.matmul", "Primitive":
1})
- with R.dataflow():
- gv: R.Tensor((1, 512), dtype="float32") = R.matmul(
- lv_2, lv1_2, out_dtype="float32"
- )
- R.output(gv)
- return gv
-
- lv_2: R.Tensor((1, 512), dtype="float32") = lv1_1_1(lv_1, lv1)
- gv: R.Tensor((1, 512), dtype="float32") = lv_2
- R.output(gv)
+ # from tvm.script import relax as R
+
+ @R.function
+ def lv_1(
+ inp_0_1: R.Tensor((1, 1, 28, 28), dtype="float32"), param_0_1:
R.Shape([1, 784])
+ ) -> R.Tensor((1, 784), dtype="float32"):
+ R.func_attr({"Composite": "tensorrt.reshape"})
+ with R.dataflow():
+ gv: R.Tensor((1, 784), dtype="float32") =
R.reshape(inp_0_1, param_0_1)
+ R.output(gv)
+ return gv
+
+ lv_1: R.Tensor((1, 784), dtype="float32") = lv_1(inp_0, param_0)
+
+ @R.function
+ def lv1_1_1(
+ lv_2: R.Tensor((1, 784), dtype="float32"),
+ lv1_2: R.Tensor((784, 512), dtype="float32"),
+ ) -> R.Tensor((1, 512), dtype="float32"):
+ R.func_attr({"Composite": "tensorrt.matmul"})
+ with R.dataflow():
+ gv: R.Tensor((1, 512), dtype="float32") = R.matmul(
+ lv_2, lv1_2, out_dtype="float32"
+ )
+ R.output(gv)
+ return gv
+
+ lv_2: R.Tensor((1, 512), dtype="float32") = lv1_1_1(lv_1, lv1)
+ gv: R.Tensor((1, 512), dtype="float32") = lv_2
return gv
@R.function