This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new a92258ec55 [Unity] Support Simple Dynamic-Shape-Aware in FuseTIR 
(#14515)
a92258ec55 is described below

commit a92258ec559982e12db4fb2fc3d08bceef5d6ec9
Author: Siyuan Feng <[email protected]>
AuthorDate: Thu Apr 6 21:04:37 2023 +0800

    [Unity] Support Simple Dynamic-Shape-Aware in FuseTIR (#14515)
    
    In the last PR https://github.com/apache/tvm/pull/14396, we enable
    dynamic-shape-aware fusion in FuseOps. In this PR, we support the following
    FuseTIR pass for simple cases.
    
    This PR also fixes a minor compilation warning in `json_runtime.h`
---
 src/relax/transform/fuse_tir.cc               | 260 +++++++++++++++++++++-----
 src/runtime/contrib/json/json_runtime.h       |   2 +-
 tests/python/relax/test_transform_fuse_tir.py |  95 ++++++++++
 3 files changed, 305 insertions(+), 52 deletions(-)

diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index b695c5f6c7..5ddda93705 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -31,27 +31,131 @@ namespace tir {
 
 // TODO(Siyuan): move it to somewhere under tir folder
 /*!
- * \brief Substitute a given source buffer with a given target buffer in 
statements or expressions.
+ * \brief Match symbolic vars according to the given PrimExpr, and update the 
var_remap.
+ * Will throw errors if there is a mismatch.
  */
-class FuseTIRBufferSubstitor : private StmtExprMutator {
+class SymbolicMatcher : ExprFunctor<bool(const PrimExpr& n, const PrimExpr& 
other)> {
  public:
-  static Stmt Substitute(const Map<Buffer, Buffer>& buffer_map, Stmt stmt) {
-    return FuseTIRBufferSubstitor(buffer_map)(std::move(stmt));
+  void Match(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
+    CHECK_EQ(lhs.size(), rhs.size());
+    for (size_t i = 0; i < lhs.size(); ++i) {
+      Match(lhs[i], rhs[i]);
+    }
   }
+  void Match(const PrimExpr& lhs, const PrimExpr& rhs) {
+    if (!VisitExpr(lhs, rhs)) {
+      LOG(FATAL) << "Failed to match PrimExpr " << lhs << " with " << rhs;
+    }
+  }
+
+  Map<tir::Var, tir::Var> var_remap;
 
  private:
-  explicit FuseTIRBufferSubstitor(const Map<Buffer, Buffer>& buffer_map) {
+  bool VisitExpr(const PrimExpr& n, const PrimExpr& other) {
+    bool matched = n.same_as(other) || ((n->type_index() == 
other->type_index()) &&
+                                        n.dtype().code() == 
other.dtype().code());
+    return matched && ExprFunctor::VisitExpr(n, other);
+  }
+
+#define TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OpName)               \
+  bool VisitExpr_(const OpName* op, const PrimExpr& other) {     \
+    const auto* rhs = other.as<OpName>();                        \
+    ICHECK(rhs);                                                 \
+    return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \
+  }
+
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AddNode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(SubNode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MulNode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(DivNode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(ModNode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(EQNode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(NENode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LTNode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LENode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GTNode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GENode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AndNode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OrNode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MinNode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MaxNode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorDivNode);
+  TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorModNode);
+
+  bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) {
+    const auto* rhs = other.as<IntImmNode>();
+    return op->value == rhs->value;
+  }
+
+  bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) {
+    const auto* rhs = other.as<FloatImmNode>();
+    return op->value == rhs->value;
+  }
+
+  bool VisitExpr_(const CastNode* op, const PrimExpr& other) {
+    const auto* rhs = other.as<CastNode>();
+    return VisitExpr(op->value, rhs->value);
+  }
+
+  bool VisitExpr_(const VarNode* op, const PrimExpr& other) {
+    const auto* rhs = other.as<VarNode>();
+    auto lhs = GetRef<Var>(op);
+    if (lhs.same_as(other)) return true;
+    if (op->dtype.code() != rhs->dtype.code()) return false;
+    auto it = var_remap.find(lhs);
+    if (it == var_remap.end()) {
+      var_remap.Set(lhs, GetRef<Var>(rhs));
+      return true;
+    } else {
+      return (*it).second.same_as(other);
+    }
+  }
+};
+
+/*!
+ * \brief Substitute a given source buffer with a given target buffer in 
statements or expressions.
+ */
+class FuseTIRBufferSubstitor : private StmtExprMutator {
+ public:
+  explicit FuseTIRBufferSubstitor(const Map<Buffer, Buffer>& buffer_map,
+                                  const Map<Var, Var>& var_map) {
+    buffer_remap_ = buffer_map;
+    var_remap_ = var_map;
     for (const auto& kv : buffer_map) {
       const Buffer& src = kv.first;
       const Buffer& tgt = kv.second;
-      buffer_var_map_[src->data.get()] = tgt;
+      var_remap_.Set(src->data, tgt->data);
     }
   }
 
+  Stmt Substitute(Stmt stmt) { return this->VisitStmt(std::move(stmt)); }
+
+  Buffer SubstituteAllocatedBuffer(Buffer buffer) {
+    ICHECK(buffer_remap_.find(buffer) == buffer_remap_.end());
+    Array<PrimExpr> shape =
+        MutateArray(buffer->shape, [this](const PrimExpr& expr) { return 
this->VisitExpr(expr); });
+    Array<PrimExpr> strides = MutateArray(
+        buffer->strides, [this](const PrimExpr& expr) { return 
this->VisitExpr(expr); });
+    PrimExpr elem_offset = this->VisitExpr(buffer->elem_offset);
+    if (shape.same_as(buffer->shape) && strides.same_as(buffer->strides) &&
+        elem_offset.same_as(buffer->elem_offset)) {
+      return buffer;
+    } else {
+      auto n = make_object<BufferNode>(*buffer.get());
+      n->shape = std::move(shape);
+      n->strides = std::move(strides);
+      n->elem_offset = std::move(elem_offset);
+      Buffer new_buffer(n);
+      this->buffer_remap_.Set(buffer, new_buffer);
+      return new_buffer;
+    }
+  }
+
+ private:
   PrimExpr VisitExpr_(const VarNode* _op) final {
-    auto it = buffer_var_map_.find(_op);
-    if (it != buffer_var_map_.end()) {
-      return it->second->data;
+    auto it = var_remap_.find(GetRef<Var>(_op));
+    if (it != var_remap_.end()) {
+      return (*it).second;
     } else {
       return GetRef<PrimExpr>(_op);
     }
@@ -59,25 +163,25 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
 
   PrimExpr VisitExpr_(const BufferLoadNode* _op) final {
     BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_op));
-    auto it = buffer_var_map_.find(load->buffer->data.get());
-    if (it != buffer_var_map_.end()) {
+    const Buffer& buffer = SubstituteBuffer(load->buffer);
+    if (buffer.same_as(load->buffer)) {
+      return std::move(load);
+    } else {
       auto n = make_object<BufferLoadNode>(*load.get());
-      n->buffer = it->second;
+      n->buffer = buffer;
       return BufferLoad(n);
-    } else {
-      return std::move(load);
     }
   }
 
   Stmt VisitStmt_(const BufferStoreNode* _op) final {
     BufferStore store = 
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_op));
-    auto it = buffer_var_map_.find(store->buffer->data.get());
-    if (it != buffer_var_map_.end()) {
-      auto n = CopyOnWrite(store.get());
-      n->buffer = it->second;
-      return BufferStore(n);
-    } else {
+    const Buffer& buffer = SubstituteBuffer(store->buffer);
+    if (buffer.same_as(store->buffer)) {
       return std::move(store);
+    } else {
+      auto n = make_object<BufferStoreNode>(*store.get());
+      n->buffer = buffer;
+      return BufferStore(n);
     }
   }
 
@@ -85,21 +189,25 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
     Block block = Downcast<Block>(StmtMutator::VisitStmt_(_op));
 
     // Define the mutation functions.
+
     auto f_mutate_match_buffers = [this](const MatchBufferRegion& 
match_buffer) {
-      const Buffer& src_buffer = match_buffer->source->buffer;
-      auto it = buffer_var_map_.find(src_buffer->data.get());
-      if (it != buffer_var_map_.end()) {
-        return MatchBufferRegion(match_buffer->buffer,
-                                 BufferRegion(it->second, 
match_buffer->source->region));
-      } else {
+      const Buffer& src_buffer = 
SubstituteBuffer(match_buffer->source->buffer);
+      const Buffer& tgt_buffer = 
SubstituteAllocatedBuffer(match_buffer->buffer);
+      if (src_buffer.same_as(match_buffer->source->buffer) &&
+          tgt_buffer.same_as(match_buffer->buffer)) {
         return match_buffer;
+      } else {
+        auto n = make_object<MatchBufferRegionNode>(*match_buffer.get());
+        n->buffer = tgt_buffer;
+        n->source = BufferRegion(src_buffer, match_buffer->source->region);
+        return MatchBufferRegion(n);
       }
     };
 
     auto f_mutate_read_write_region = [this](const BufferRegion& 
buffer_region) {
-      auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
-      return it == buffer_var_map_.end() ? buffer_region
-                                         : BufferRegion(it->second, 
buffer_region->region);
+      auto it = buffer_remap_.find(buffer_region->buffer);
+      return it == buffer_remap_.end() ? buffer_region
+                                       : BufferRegion((*it).second, 
buffer_region->region);
     };
 
     // Step 1. Mutate `match_buffers`.
@@ -108,26 +216,34 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
     // Step 2. Mutate the read/write region.
     Array<BufferRegion> reads = MutateArray(block->reads, 
f_mutate_read_write_region);
     Array<BufferRegion> writes = MutateArray(block->writes, 
f_mutate_read_write_region);
+    // Step 3. Mutate the Allocate Buffers.
+    Array<Buffer> alloc_buffers = MutateArray(block->alloc_buffers, 
[this](const Buffer& buffer) {
+      return SubstituteAllocatedBuffer(buffer);
+    });
 
     reads = UnionAccessRegion(reads);
     writes = UnionAccessRegion(writes);
 
     if (reads.same_as(block->reads) &&    //
         writes.same_as(block->writes) &&  //
-        match_buffers.same_as(block->match_buffers)) {
+        match_buffers.same_as(block->match_buffers) &&
+        alloc_buffers.same_as(block->alloc_buffers)) {
       return std::move(block);
     } else {
       auto n = CopyOnWrite(block.get());
       n->reads = std::move(reads);
       n->writes = std::move(writes);
       n->match_buffers = std::move(match_buffers);
+      n->alloc_buffers = std::move(alloc_buffers);
       return Block(n);
     }
   }
 
  private:
-  /*! \brief Mapping from src buffer.data to tgt buffer. */
-  std::unordered_map<const tir::VarNode*, tir::Buffer> buffer_var_map_;
+  /*! \brief Mapping from src buffer to tgt buffer. */
+  Map<tir::Buffer, tir::Buffer> buffer_remap_;
+  /*! \brief Mapping from src tir var to tgt var. */
+  Map<tir::Var, tir::Var> var_remap_;
   /*! \brief The structural equality checker */
   StructuralEqual structural_equal_;
 
@@ -155,6 +271,15 @@ class FuseTIRBufferSubstitor : private StmtExprMutator {
       return ret;
     }
   }
+
+  inline Buffer SubstituteBuffer(const Buffer& buffer) const {
+    auto it = buffer_remap_.find(buffer);
+    if (it != buffer_remap_.end()) {
+      return (*it).second;
+    } else {
+      return buffer;
+    }
+  }
 };
 
 /*! \brief A mutator which detect block name duplication and deduplicate the 
names. */
@@ -298,8 +423,8 @@ class FusedTIRConstructor : public ExprVisitor {
 
     // Step 5. Map input arguments to buffer
     MapInputBuffer(prim_func, call->args[1]);
-    size_t num_output_buffers = GetCallTIROutputSize(call);
-    AllocateIntermediateBuffer(GetRef<Expr>(call), prim_func, 
num_output_buffers);
+    const Array<Array<PrimExpr>>& output_buffer_shapes = 
GetCallTIROutputShapes(call);
+    AllocateIntermediateBuffer(GetRef<Expr>(call), prim_func, 
output_buffer_shapes);
     // Update fused func name
     func_info_.global_name += "_" + gv->name_hint;
   }
@@ -343,14 +468,32 @@ class FusedTIRConstructor : public ExprVisitor {
    * \brief Get the number of outputs for a call_tir node.
    * \return The number of outputs.
    */
-  static size_t GetCallTIROutputSize(const CallNode* call) {
+  static Array<Array<PrimExpr>> GetCallTIROutputShapes(const CallNode* call) {
     static const Op& call_tir_op_ = Op::Get("relax.call_tir");
     ICHECK(call->op.same_as(call_tir_op_));
     ICHECK_EQ(call->sinfo_args.size(), 1);
+    auto get_tensor_shape = [](const TensorStructInfoNode* sinfo) {
+      const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
+      CHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with 
symbolic shape.";
+      return shape_expr->values;
+    };
     if (const auto* tuple_sinfo = 
call->sinfo_args[0].as<TupleStructInfoNode>()) {
-      return tuple_sinfo->fields.size();
+      Array<Array<PrimExpr>> shapes;
+      for (const StructInfo& field : tuple_sinfo->fields) {
+        const auto* tensor_sinfo = field.as<TensorStructInfoNode>();
+        CHECK(tensor_sinfo) << "CallTIR sinfo_args are expected to be 
TensorStructInfo or Tuple of "
+                               "TensorStructInfo, but got "
+                            << call->sinfo_args[0];
+        shapes.push_back(get_tensor_shape(tensor_sinfo));
+      }
+      return shapes;
+    } else if (const auto* tensor_sinfo = 
call->sinfo_args[0].as<TensorStructInfoNode>()) {
+      return {get_tensor_shape(tensor_sinfo)};
     } else {
-      return 1;
+      CHECK(tensor_sinfo) << "CallTIR sinfo_args are expected to be 
TensorStructInfo or Tuple of "
+                             "TensorStructInfo, but got "
+                          << call->sinfo_args[0];
+      throw;
     }
   }
 
@@ -365,17 +508,14 @@ class FusedTIRConstructor : public ExprVisitor {
           for (const tir::Buffer& target_buffer : (*it).second) {
             ICHECK_LT(buffer_idx, buffers.size());
             const tir::Buffer& buffer = buffers[buffer_idx];
-            // TODO(relax-team): Add support for symbolic shape fusion
-            for (const PrimExpr& shape_expr : buffer->shape) {
-              ICHECK(shape_expr.as<IntImmNode>()) << "Only support constant 
shape fusion for now";
-            }
+            func_info_.symbolic_var_matcher.Match(buffer->shape, 
target_buffer->shape);
             func_info_.buffer_subst_map.Set(buffer, target_buffer);
             buffer_idx++;
           }
         }
       }
     }
-    // Make sure every buffers are maped.
+    // Make sure every buffers are mapped.
     ICHECK_EQ(buffer_idx, buffers.size());
   }
 
@@ -408,18 +548,30 @@ class FusedTIRConstructor : public ExprVisitor {
    * intermediate results.
    * \param expr The relax Expr, which can be binding vars or binding values.
    * \param func The old TIR PrimFunc
-   * \param output_size The number of output params. All output params are at 
the end of param list.
+   * \param output_shapes The shape of output params.
    */
-  void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func, 
size_t output_size) {
+  void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func,
+                                  const Array<Array<PrimExpr>>& output_shapes) 
{
     size_t n = func->params.size();
+    size_t output_size = output_shapes.size();
     ICHECK_GE(n, output_size);
     // Allocate intermediate buffer
     Array<tir::Buffer> alloc_buffers;
     for (size_t i = 0; i < output_size; ++i) {
       const tir::Var& param = func->params[n - output_size + i];
       const tir::Buffer& buffer = func->buffer_map.at(param);
-      func_info_.alloc_buffers.push_back(buffer);
-      alloc_buffers.push_back(buffer);
+
+      // Update buffer with new symbolic shape according to the sinfo
+      auto n = make_object<tir::BufferNode>(*buffer.get());
+      n->shape = output_shapes[i];
+      n->name = param->name_hint + "_intermediate";
+      tir::Buffer new_buffer(n);
+      func_info_.alloc_buffers.push_back(new_buffer);
+      alloc_buffers.push_back(new_buffer);
+
+      // Match the shape of the output buffer with the shape
+      func_info_.symbolic_var_matcher.Match(buffer->shape, n->shape);
+      func_info_.buffer_subst_map.Set(buffer, new_buffer);
     }
     // Update expr2buffers
     func_info_.expr2buffers.Set(expr, alloc_buffers);
@@ -438,7 +590,7 @@ class FusedTIRConstructor : public ExprVisitor {
     Array<tir::Var> params;
     Array<tir::Buffer> buffers;
     if (const auto* tensor = struct_info.as<TensorStructInfoNode>()) {
-      // Case 1. the relax param is a DynTensor, we directly create a tir var 
and buffer
+      // Case 1. the relax param is a Tensor, we directly create a tir var and 
buffer
       const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
       ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with 
symbolic shape.";
 
@@ -452,7 +604,7 @@ class FusedTIRConstructor : public ExprVisitor {
       params.push_back(std::move(param));
       buffers.push_back(std::move(buffer));
     } else if (const auto* tuple = struct_info.as<TupleStructInfoNode>()) {
-      // Case 2. the relax param is a Tuple, we recursively visit each field 
until it's a DynTensor
+      // Case 2. the relax param is a Tuple, we recursively visit each field 
until it's a Tensor
       // Enable postfix
       if (index == -1) index = 0;
       for (size_t i = 0; i < tuple->fields.size(); ++i) {
@@ -478,21 +630,25 @@ class FusedTIRConstructor : public ExprVisitor {
   tir::PrimFunc ConstructFunc() {
     Map<String, ObjectRef> attr_map;
     attr_map.Set("tir.noalias", tir::const_true());
+    tir::FuseTIRBufferSubstitor substitor(func_info_.buffer_subst_map,
+                                          
func_info_.symbolic_var_matcher.var_remap);
     ICHECK(func_info_.global_name != "fused");
     // Remove output buffers from func_info_.alloc_buffers
     Array<tir::Buffer> alloc_buffers;
     for (const tir::Buffer& buf : func_info_.alloc_buffers) {
       if (func_info_.output_buffers.count(buf.get()) == 0) {
-        alloc_buffers.push_back(buf);
+        alloc_buffers.push_back(substitor.SubstituteAllocatedBuffer(buf));
       }
     }
     tir::Stmt body = 
tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies));
-    body = 
tir::FuseTIRBufferSubstitor::Substitute(func_info_.buffer_subst_map, body);
+
+    body = substitor.Substitute(body);
     body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt, 
alloc_buffers);
     body = tir::BlockRealize({}, Bool(true), Downcast<tir::Block>(body));
     tir::PrimFunc func(func_info_.params, body, VoidType(), 
func_info_.buffer_map,
                        DictAttrs(attr_map));
-    return func;
+    // Renew function defs to prevent using the same symbolic vars in 
different functions
+    return tir::RenewDefs(func);
   }
 
   /*! \brief Get DynTensor numbers from recursive Tuples. */
@@ -539,6 +695,8 @@ class FusedTIRConstructor : public ExprVisitor {
     std::unordered_set<const tir::BufferNode*> output_buffers;
     /*! \brief The name of the fused function */
     std::string global_name = "fused";
+    /*! \brief The map from symbolic var to its corresponding var in the fused 
function */
+    tir::SymbolicMatcher symbolic_var_matcher;
   };
 
   /*! \brief The IRModule */
diff --git a/src/runtime/contrib/json/json_runtime.h 
b/src/runtime/contrib/json/json_runtime.h
index 51ce2cffd7..5409078e85 100644
--- a/src/runtime/contrib/json/json_runtime.h
+++ b/src/runtime/contrib/json/json_runtime.h
@@ -59,7 +59,7 @@ class JSONRuntimeBase : public ModuleNode {
   const char* type_key() const override { return "json"; }  // May be 
overridden
 
   /*! \brief Get the property of the runtime module .*/
-  int GetPropertyMask() const {
+  int GetPropertyMask() const override {
     return ModulePropertyMask::kBinarySerializable | 
ModulePropertyMask::kRunnable;
   }
 
diff --git a/tests/python/relax/test_transform_fuse_tir.py 
b/tests/python/relax/test_transform_fuse_tir.py
index 356e28d6e9..bdbd9be966 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -698,5 +698,100 @@ def test_skip_call_dps_packed():
     _check(Module, Module)
 
 
+def test_symbolic_shape_aware_fuse():
+    @I.ir_module
+    class Before:
+        @R.function
+        def fused_add_exp_squeeze(
+            x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32")
+        ) -> R.Tensor(["n", "m"], dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            with R.dataflow():
+                lv0 = R.emit_te(topi.add, x, p0)
+                lv1 = R.emit_te(topi.exp, lv0)
+                gv = R.emit_te(topi.squeeze, lv1)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"], 
dtype="float32"):
+            cls = Before
+            with R.dataflow():
+                gv = cls.fused_add_exp_squeeze(x, R.const(1, "float32"))
+                R.output(gv)
+            return gv
+
+    def fused_add_exp_squeeze(x, p0):
+        return topi.squeeze(topi.exp(topi.add(x, p0)))
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"], 
dtype="float32"):
+            with R.dataflow():
+                gv = R.emit_te(fused_add_exp_squeeze, x, R.const(1, "float32"))
+                R.output(gv)
+            return gv
+
+    _check(Before, Expected)
+
+
+def test_symbolic_shape_aware_fuse_with_allocation():
+    def te_mean(x, axis):
+        return topi.divide(topi.sum(x, axis, keepdims=True), 4096)
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def fused_mean_add_tir_sqrt_divide_multiply(
+            x: R.Tensor((1, "n", 4096), dtype="float32"),
+            y: R.Tensor((1, "n", 4096), dtype="float32"),
+            rms_norm_weight: R.Tensor((4096,), dtype="float32"),
+        ) -> R.Tensor((1, "n", 4096), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            with R.dataflow():
+                lv0 = R.emit_te(te_mean, x, axis=2)
+                lv1 = R.emit_te(topi.add, lv0, lv0)
+                lv2 = R.emit_te(topi.sqrt, lv1)
+                lv3 = R.emit_te(topi.divide, y, lv2)
+                gv = R.emit_te(topi.multiply, rms_norm_weight, lv3)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((1, "n", 4096), dtype="float32"),
+            y: R.Tensor((1, "n", 4096), dtype="float32"),
+            rms_norm_weight: R.Tensor((4096,), dtype="float32"),
+        ) -> R.Tensor((1, "n", 4096), dtype="float32"):
+            cls = Before
+            with R.dataflow():
+                gv = cls.fused_mean_add_tir_sqrt_divide_multiply(x, y, 
rms_norm_weight)
+                R.output(gv)
+            return gv
+
+    def fused_mean_add_tir_sqrt_divide_multiply(x, y, rms_norm_weight):
+        lv0 = te_mean(x, axis=2)
+        lv1 = topi.add(lv0, lv0)
+        lv2 = topi.sqrt(lv1)
+        lv3 = topi.divide(y, lv2)
+        return topi.multiply(rms_norm_weight, lv3)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((1, "n", 4096), dtype="float32"),
+            y: R.Tensor((1, "n", 4096), dtype="float32"),
+            rms_norm_weight: R.Tensor((4096,), dtype="float32"),
+        ) -> R.Tensor((1, "n", 4096), dtype="float32"):
+            with R.dataflow():
+                gv = R.emit_te(fused_mean_add_tir_sqrt_divide_multiply, x, y, 
rms_norm_weight)
+                R.output(gv)
+            return gv
+
+    _check(Before, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to