This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new a92258ec55 [Unity] Support Simple Dynamic-Shape-Aware in FuseTIR
(#14515)
a92258ec55 is described below
commit a92258ec559982e12db4fb2fc3d08bceef5d6ec9
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Apr 6 21:04:37 2023 +0800
[Unity] Support Simple Dynamic-Shape-Aware in FuseTIR (#14515)
In the last PR https://github.com/apache/tvm/pull/14396, we enable
dynamic-shape-aware fusion in FuseOps. In this PR, we support the following
FuseTIR pass for simple cases.
This PR also fixes a minor compilation warning in `json_runtime.h`
---
src/relax/transform/fuse_tir.cc | 260 +++++++++++++++++++++-----
src/runtime/contrib/json/json_runtime.h | 2 +-
tests/python/relax/test_transform_fuse_tir.py | 95 ++++++++++
3 files changed, 305 insertions(+), 52 deletions(-)
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index b695c5f6c7..5ddda93705 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -31,27 +31,131 @@ namespace tir {
// TODO(Siyuan): move it to somewhere under tir folder
/*!
- * \brief Substitute a given source buffer with a given target buffer in
statements or expressions.
+ * \brief Match symbolic vars according to the given PrimExpr, and update the
var_remap.
+ * Will throw errors if there is a mismatch.
*/
-class FuseTIRBufferSubstitor : private StmtExprMutator {
+class SymbolicMatcher : ExprFunctor<bool(const PrimExpr& n, const PrimExpr&
other)> {
public:
- static Stmt Substitute(const Map<Buffer, Buffer>& buffer_map, Stmt stmt) {
- return FuseTIRBufferSubstitor(buffer_map)(std::move(stmt));
+ void Match(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
+ CHECK_EQ(lhs.size(), rhs.size());
+ for (size_t i = 0; i < lhs.size(); ++i) {
+ Match(lhs[i], rhs[i]);
+ }
}
+ void Match(const PrimExpr& lhs, const PrimExpr& rhs) {
+ if (!VisitExpr(lhs, rhs)) {
+ LOG(FATAL) << "Failed to match PrimExpr " << lhs << " with " << rhs;
+ }
+ }
+
+ Map<tir::Var, tir::Var> var_remap;
private:
- explicit FuseTIRBufferSubstitor(const Map<Buffer, Buffer>& buffer_map) {
+ bool VisitExpr(const PrimExpr& n, const PrimExpr& other) {
+ bool matched = n.same_as(other) || ((n->type_index() ==
other->type_index()) &&
+ n.dtype().code() ==
other.dtype().code());
+ return matched && ExprFunctor::VisitExpr(n, other);
+ }
+
+#define TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OpName) \
+ bool VisitExpr_(const OpName* op, const PrimExpr& other) { \
+ const auto* rhs = other.as<OpName>(); \
+ ICHECK(rhs); \
+ return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \
+ }
+
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AddNode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(SubNode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MulNode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(DivNode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(ModNode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(EQNode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(NENode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LTNode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LENode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GTNode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GENode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AndNode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OrNode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MinNode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MaxNode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorDivNode);
+ TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorModNode);
+
+ bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) {
+ const auto* rhs = other.as<IntImmNode>();
+ return op->value == rhs->value;
+ }
+
+ bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) {
+ const auto* rhs = other.as<FloatImmNode>();
+ return op->value == rhs->value;
+ }
+
+ bool VisitExpr_(const CastNode* op, const PrimExpr& other) {
+ const auto* rhs = other.as<CastNode>();
+ return VisitExpr(op->value, rhs->value);
+ }
+
+ bool VisitExpr_(const VarNode* op, const PrimExpr& other) {
+ const auto* rhs = other.as<VarNode>();
+ auto lhs = GetRef<Var>(op);
+ if (lhs.same_as(other)) return true;
+ if (op->dtype.code() != rhs->dtype.code()) return false;
+ auto it = var_remap.find(lhs);
+ if (it == var_remap.end()) {
+ var_remap.Set(lhs, GetRef<Var>(rhs));
+ return true;
+ } else {
+ return (*it).second.same_as(other);
+ }
+ }
+};
+
+/*!
+ * \brief Substitute a given source buffer with a given target buffer in
statements or expressions.
+ */
+class FuseTIRBufferSubstitor : private StmtExprMutator {
+ public:
+ explicit FuseTIRBufferSubstitor(const Map<Buffer, Buffer>& buffer_map,
+ const Map<Var, Var>& var_map) {
+ buffer_remap_ = buffer_map;
+ var_remap_ = var_map;
for (const auto& kv : buffer_map) {
const Buffer& src = kv.first;
const Buffer& tgt = kv.second;
- buffer_var_map_[src->data.get()] = tgt;
+ var_remap_.Set(src->data, tgt->data);
}
}
+ Stmt Substitute(Stmt stmt) { return this->VisitStmt(std::move(stmt)); }
+
+ Buffer SubstituteAllocatedBuffer(Buffer buffer) {
+ ICHECK(buffer_remap_.find(buffer) == buffer_remap_.end());
+ Array<PrimExpr> shape =
+ MutateArray(buffer->shape, [this](const PrimExpr& expr) { return
this->VisitExpr(expr); });
+ Array<PrimExpr> strides = MutateArray(
+ buffer->strides, [this](const PrimExpr& expr) { return
this->VisitExpr(expr); });
+ PrimExpr elem_offset = this->VisitExpr(buffer->elem_offset);
+ if (shape.same_as(buffer->shape) && strides.same_as(buffer->strides) &&
+ elem_offset.same_as(buffer->elem_offset)) {
+ return buffer;
+ } else {
+ auto n = make_object<BufferNode>(*buffer.get());
+ n->shape = std::move(shape);
+ n->strides = std::move(strides);
+ n->elem_offset = std::move(elem_offset);
+ Buffer new_buffer(n);
+ this->buffer_remap_.Set(buffer, new_buffer);
+ return new_buffer;
+ }
+ }
+
+ private:
PrimExpr VisitExpr_(const VarNode* _op) final {
- auto it = buffer_var_map_.find(_op);
- if (it != buffer_var_map_.end()) {
- return it->second->data;
+ auto it = var_remap_.find(GetRef<Var>(_op));
+ if (it != var_remap_.end()) {
+ return (*it).second;
} else {
return GetRef<PrimExpr>(_op);
}
@@ -59,25 +163,25 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
PrimExpr VisitExpr_(const BufferLoadNode* _op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_op));
- auto it = buffer_var_map_.find(load->buffer->data.get());
- if (it != buffer_var_map_.end()) {
+ const Buffer& buffer = SubstituteBuffer(load->buffer);
+ if (buffer.same_as(load->buffer)) {
+ return std::move(load);
+ } else {
auto n = make_object<BufferLoadNode>(*load.get());
- n->buffer = it->second;
+ n->buffer = buffer;
return BufferLoad(n);
- } else {
- return std::move(load);
}
}
Stmt VisitStmt_(const BufferStoreNode* _op) final {
BufferStore store =
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_op));
- auto it = buffer_var_map_.find(store->buffer->data.get());
- if (it != buffer_var_map_.end()) {
- auto n = CopyOnWrite(store.get());
- n->buffer = it->second;
- return BufferStore(n);
- } else {
+ const Buffer& buffer = SubstituteBuffer(store->buffer);
+ if (buffer.same_as(store->buffer)) {
return std::move(store);
+ } else {
+ auto n = make_object<BufferStoreNode>(*store.get());
+ n->buffer = buffer;
+ return BufferStore(n);
}
}
@@ -85,21 +189,25 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
Block block = Downcast<Block>(StmtMutator::VisitStmt_(_op));
// Define the mutation functions.
+
auto f_mutate_match_buffers = [this](const MatchBufferRegion&
match_buffer) {
- const Buffer& src_buffer = match_buffer->source->buffer;
- auto it = buffer_var_map_.find(src_buffer->data.get());
- if (it != buffer_var_map_.end()) {
- return MatchBufferRegion(match_buffer->buffer,
- BufferRegion(it->second,
match_buffer->source->region));
- } else {
+ const Buffer& src_buffer =
SubstituteBuffer(match_buffer->source->buffer);
+ const Buffer& tgt_buffer =
SubstituteAllocatedBuffer(match_buffer->buffer);
+ if (src_buffer.same_as(match_buffer->source->buffer) &&
+ tgt_buffer.same_as(match_buffer->buffer)) {
return match_buffer;
+ } else {
+ auto n = make_object<MatchBufferRegionNode>(*match_buffer.get());
+ n->buffer = tgt_buffer;
+ n->source = BufferRegion(src_buffer, match_buffer->source->region);
+ return MatchBufferRegion(n);
}
};
auto f_mutate_read_write_region = [this](const BufferRegion&
buffer_region) {
- auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
- return it == buffer_var_map_.end() ? buffer_region
- : BufferRegion(it->second,
buffer_region->region);
+ auto it = buffer_remap_.find(buffer_region->buffer);
+ return it == buffer_remap_.end() ? buffer_region
+ : BufferRegion((*it).second,
buffer_region->region);
};
// Step 1. Mutate `match_buffers`.
@@ -108,26 +216,34 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
// Step 2. Mutate the read/write region.
Array<BufferRegion> reads = MutateArray(block->reads,
f_mutate_read_write_region);
Array<BufferRegion> writes = MutateArray(block->writes,
f_mutate_read_write_region);
+ // Step 3. Mutate the Allocate Buffers.
+ Array<Buffer> alloc_buffers = MutateArray(block->alloc_buffers,
[this](const Buffer& buffer) {
+ return SubstituteAllocatedBuffer(buffer);
+ });
reads = UnionAccessRegion(reads);
writes = UnionAccessRegion(writes);
if (reads.same_as(block->reads) && //
writes.same_as(block->writes) && //
- match_buffers.same_as(block->match_buffers)) {
+ match_buffers.same_as(block->match_buffers) &&
+ alloc_buffers.same_as(block->alloc_buffers)) {
return std::move(block);
} else {
auto n = CopyOnWrite(block.get());
n->reads = std::move(reads);
n->writes = std::move(writes);
n->match_buffers = std::move(match_buffers);
+ n->alloc_buffers = std::move(alloc_buffers);
return Block(n);
}
}
private:
- /*! \brief Mapping from src buffer.data to tgt buffer. */
- std::unordered_map<const tir::VarNode*, tir::Buffer> buffer_var_map_;
+ /*! \brief Mapping from src buffer to tgt buffer. */
+ Map<tir::Buffer, tir::Buffer> buffer_remap_;
+ /*! \brief Mapping from src tir var to tgt var. */
+ Map<tir::Var, tir::Var> var_remap_;
/*! \brief The structural equality checker */
StructuralEqual structural_equal_;
@@ -155,6 +271,15 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
return ret;
}
}
+
+ inline Buffer SubstituteBuffer(const Buffer& buffer) const {
+ auto it = buffer_remap_.find(buffer);
+ if (it != buffer_remap_.end()) {
+ return (*it).second;
+ } else {
+ return buffer;
+ }
+ }
};
/*! \brief A mutator which detect block name duplication and deduplicate the
names. */
@@ -298,8 +423,8 @@ class FusedTIRConstructor : public ExprVisitor {
// Step 5. Map input arguments to buffer
MapInputBuffer(prim_func, call->args[1]);
- size_t num_output_buffers = GetCallTIROutputSize(call);
- AllocateIntermediateBuffer(GetRef<Expr>(call), prim_func,
num_output_buffers);
+ const Array<Array<PrimExpr>>& output_buffer_shapes =
GetCallTIROutputShapes(call);
+ AllocateIntermediateBuffer(GetRef<Expr>(call), prim_func,
output_buffer_shapes);
// Update fused func name
func_info_.global_name += "_" + gv->name_hint;
}
@@ -343,14 +468,32 @@ class FusedTIRConstructor : public ExprVisitor {
* \brief Get the number of outputs for a call_tir node.
* \return The number of outputs.
*/
- static size_t GetCallTIROutputSize(const CallNode* call) {
+ static Array<Array<PrimExpr>> GetCallTIROutputShapes(const CallNode* call) {
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
ICHECK(call->op.same_as(call_tir_op_));
ICHECK_EQ(call->sinfo_args.size(), 1);
+ auto get_tensor_shape = [](const TensorStructInfoNode* sinfo) {
+ const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
+ CHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with
symbolic shape.";
+ return shape_expr->values;
+ };
if (const auto* tuple_sinfo =
call->sinfo_args[0].as<TupleStructInfoNode>()) {
- return tuple_sinfo->fields.size();
+ Array<Array<PrimExpr>> shapes;
+ for (const StructInfo& field : tuple_sinfo->fields) {
+ const auto* tensor_sinfo = field.as<TensorStructInfoNode>();
+ CHECK(tensor_sinfo) << "CallTIR sinfo_args are expected to be
TensorStructInfo or Tuple of "
+ "TensorStructInfo, but got "
+ << call->sinfo_args[0];
+ shapes.push_back(get_tensor_shape(tensor_sinfo));
+ }
+ return shapes;
+ } else if (const auto* tensor_sinfo =
call->sinfo_args[0].as<TensorStructInfoNode>()) {
+ return {get_tensor_shape(tensor_sinfo)};
} else {
- return 1;
+ CHECK(tensor_sinfo) << "CallTIR sinfo_args are expected to be
TensorStructInfo or Tuple of "
+ "TensorStructInfo, but got "
+ << call->sinfo_args[0];
+ throw;
}
}
@@ -365,17 +508,14 @@ class FusedTIRConstructor : public ExprVisitor {
for (const tir::Buffer& target_buffer : (*it).second) {
ICHECK_LT(buffer_idx, buffers.size());
const tir::Buffer& buffer = buffers[buffer_idx];
- // TODO(relax-team): Add support for symbolic shape fusion
- for (const PrimExpr& shape_expr : buffer->shape) {
- ICHECK(shape_expr.as<IntImmNode>()) << "Only support constant
shape fusion for now";
- }
+ func_info_.symbolic_var_matcher.Match(buffer->shape,
target_buffer->shape);
func_info_.buffer_subst_map.Set(buffer, target_buffer);
buffer_idx++;
}
}
}
}
- // Make sure every buffers are maped.
+ // Make sure every buffers are mapped.
ICHECK_EQ(buffer_idx, buffers.size());
}
@@ -408,18 +548,30 @@ class FusedTIRConstructor : public ExprVisitor {
* intermediate results.
* \param expr The relax Expr, which can be binding vars or binding values.
* \param func The old TIR PrimFunc
- * \param output_size The number of output params. All output params are at
the end of param list.
+ * \param output_shapes The shape of output params.
*/
- void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func,
size_t output_size) {
+ void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func,
+ const Array<Array<PrimExpr>>& output_shapes)
{
size_t n = func->params.size();
+ size_t output_size = output_shapes.size();
ICHECK_GE(n, output_size);
// Allocate intermediate buffer
Array<tir::Buffer> alloc_buffers;
for (size_t i = 0; i < output_size; ++i) {
const tir::Var& param = func->params[n - output_size + i];
const tir::Buffer& buffer = func->buffer_map.at(param);
- func_info_.alloc_buffers.push_back(buffer);
- alloc_buffers.push_back(buffer);
+
+ // Update buffer with new symbolic shape according to the sinfo
+ auto n = make_object<tir::BufferNode>(*buffer.get());
+ n->shape = output_shapes[i];
+ n->name = param->name_hint + "_intermediate";
+ tir::Buffer new_buffer(n);
+ func_info_.alloc_buffers.push_back(new_buffer);
+ alloc_buffers.push_back(new_buffer);
+
+ // Match the shape of the output buffer with the shape
+ func_info_.symbolic_var_matcher.Match(buffer->shape, n->shape);
+ func_info_.buffer_subst_map.Set(buffer, new_buffer);
}
// Update expr2buffers
func_info_.expr2buffers.Set(expr, alloc_buffers);
@@ -438,7 +590,7 @@ class FusedTIRConstructor : public ExprVisitor {
Array<tir::Var> params;
Array<tir::Buffer> buffers;
if (const auto* tensor = struct_info.as<TensorStructInfoNode>()) {
- // Case 1. the relax param is a DynTensor, we directly create a tir var
and buffer
+ // Case 1. the relax param is a Tensor, we directly create a tir var and
buffer
const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with
symbolic shape.";
@@ -452,7 +604,7 @@ class FusedTIRConstructor : public ExprVisitor {
params.push_back(std::move(param));
buffers.push_back(std::move(buffer));
} else if (const auto* tuple = struct_info.as<TupleStructInfoNode>()) {
- // Case 2. the relax param is a Tuple, we recursively visit each field
until it's a DynTensor
+ // Case 2. the relax param is a Tuple, we recursively visit each field
until it's a Tensor
// Enable postfix
if (index == -1) index = 0;
for (size_t i = 0; i < tuple->fields.size(); ++i) {
@@ -478,21 +630,25 @@ class FusedTIRConstructor : public ExprVisitor {
tir::PrimFunc ConstructFunc() {
Map<String, ObjectRef> attr_map;
attr_map.Set("tir.noalias", tir::const_true());
+ tir::FuseTIRBufferSubstitor substitor(func_info_.buffer_subst_map,
+
func_info_.symbolic_var_matcher.var_remap);
ICHECK(func_info_.global_name != "fused");
// Remove output buffers from func_info_.alloc_buffers
Array<tir::Buffer> alloc_buffers;
for (const tir::Buffer& buf : func_info_.alloc_buffers) {
if (func_info_.output_buffers.count(buf.get()) == 0) {
- alloc_buffers.push_back(buf);
+ alloc_buffers.push_back(substitor.SubstituteAllocatedBuffer(buf));
}
}
tir::Stmt body =
tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies));
- body =
tir::FuseTIRBufferSubstitor::Substitute(func_info_.buffer_subst_map, body);
+
+ body = substitor.Substitute(body);
body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt,
alloc_buffers);
body = tir::BlockRealize({}, Bool(true), Downcast<tir::Block>(body));
tir::PrimFunc func(func_info_.params, body, VoidType(),
func_info_.buffer_map,
DictAttrs(attr_map));
- return func;
+ // Renew function defs to prevent using the same symbolic vars in
different functions
+ return tir::RenewDefs(func);
}
/*! \brief Get DynTensor numbers from recursive Tuples. */
@@ -539,6 +695,8 @@ class FusedTIRConstructor : public ExprVisitor {
std::unordered_set<const tir::BufferNode*> output_buffers;
/*! \brief The name of the fused function */
std::string global_name = "fused";
+ /*! \brief The map from symbolic var to its corresponding var in the fused
function */
+ tir::SymbolicMatcher symbolic_var_matcher;
};
/*! \brief The IRModule */
diff --git a/src/runtime/contrib/json/json_runtime.h
b/src/runtime/contrib/json/json_runtime.h
index 51ce2cffd7..5409078e85 100644
--- a/src/runtime/contrib/json/json_runtime.h
+++ b/src/runtime/contrib/json/json_runtime.h
@@ -59,7 +59,7 @@ class JSONRuntimeBase : public ModuleNode {
const char* type_key() const override { return "json"; } // May be
overridden
/*! \brief Get the property of the runtime module .*/
- int GetPropertyMask() const {
+ int GetPropertyMask() const override {
return ModulePropertyMask::kBinarySerializable |
ModulePropertyMask::kRunnable;
}
diff --git a/tests/python/relax/test_transform_fuse_tir.py
b/tests/python/relax/test_transform_fuse_tir.py
index 356e28d6e9..bdbd9be966 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -698,5 +698,100 @@ def test_skip_call_dps_packed():
_check(Module, Module)
+def test_symbolic_shape_aware_fuse():
+ @I.ir_module
+ class Before:
+ @R.function
+ def fused_add_exp_squeeze(
+ x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32")
+ ) -> R.Tensor(["n", "m"], dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ with R.dataflow():
+ lv0 = R.emit_te(topi.add, x, p0)
+ lv1 = R.emit_te(topi.exp, lv0)
+ gv = R.emit_te(topi.squeeze, lv1)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"],
dtype="float32"):
+ cls = Before
+ with R.dataflow():
+ gv = cls.fused_add_exp_squeeze(x, R.const(1, "float32"))
+ R.output(gv)
+ return gv
+
+ def fused_add_exp_squeeze(x, p0):
+ return topi.squeeze(topi.exp(topi.add(x, p0)))
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"],
dtype="float32"):
+ with R.dataflow():
+ gv = R.emit_te(fused_add_exp_squeeze, x, R.const(1, "float32"))
+ R.output(gv)
+ return gv
+
+ _check(Before, Expected)
+
+
+def test_symbolic_shape_aware_fuse_with_allocation():
+ def te_mean(x, axis):
+ return topi.divide(topi.sum(x, axis, keepdims=True), 4096)
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def fused_mean_add_tir_sqrt_divide_multiply(
+ x: R.Tensor((1, "n", 4096), dtype="float32"),
+ y: R.Tensor((1, "n", 4096), dtype="float32"),
+ rms_norm_weight: R.Tensor((4096,), dtype="float32"),
+ ) -> R.Tensor((1, "n", 4096), dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ with R.dataflow():
+ lv0 = R.emit_te(te_mean, x, axis=2)
+ lv1 = R.emit_te(topi.add, lv0, lv0)
+ lv2 = R.emit_te(topi.sqrt, lv1)
+ lv3 = R.emit_te(topi.divide, y, lv2)
+ gv = R.emit_te(topi.multiply, rms_norm_weight, lv3)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor((1, "n", 4096), dtype="float32"),
+ y: R.Tensor((1, "n", 4096), dtype="float32"),
+ rms_norm_weight: R.Tensor((4096,), dtype="float32"),
+ ) -> R.Tensor((1, "n", 4096), dtype="float32"):
+ cls = Before
+ with R.dataflow():
+ gv = cls.fused_mean_add_tir_sqrt_divide_multiply(x, y,
rms_norm_weight)
+ R.output(gv)
+ return gv
+
+ def fused_mean_add_tir_sqrt_divide_multiply(x, y, rms_norm_weight):
+ lv0 = te_mean(x, axis=2)
+ lv1 = topi.add(lv0, lv0)
+ lv2 = topi.sqrt(lv1)
+ lv3 = topi.divide(y, lv2)
+ return topi.multiply(rms_norm_weight, lv3)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((1, "n", 4096), dtype="float32"),
+ y: R.Tensor((1, "n", 4096), dtype="float32"),
+ rms_norm_weight: R.Tensor((4096,), dtype="float32"),
+ ) -> R.Tensor((1, "n", 4096), dtype="float32"):
+ with R.dataflow():
+ gv = R.emit_te(fused_mean_add_tir_sqrt_divide_multiply, x, y,
rms_norm_weight)
+ R.output(gv)
+ return gv
+
+ _check(Before, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()