This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9f4d0fa5d3 [Unity][Transform] Allow static Relax arguments to dynamic 
PrimFunc (#15883)
9f4d0fa5d3 is described below

commit 9f4d0fa5d37dfed3193dae29177c033435ef7130
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Oct 9 08:14:25 2023 -0500

    [Unity][Transform] Allow static Relax arguments to dynamic PrimFunc (#15883)
    
    * [Unity][Transform] Allow static Relax arguments to dynamic PrimFunc
    
    Prior to this commit, the `relax.transform.FuseTIR` transform required
    that the shapes arguments passed into a `PrimFunc` be structurally
    equivalent to the shapes of the parameters, and that any replacement
    of symbolic `tir.Var` be with a symbolic `tir.Var` in the fused
    function.
    
    This commit updates the `SymbolicMatcher` to instead extract a
    `Map<tir::Var, PrimExpr>`.  As a result, a Relax tensor with
    statically-known shape can be passed into a TIR PrimFunc with dynamic
    shape.  The resulting fused TIR function is in terms of the
    statically-known shape, and no longer contains the symbolic variable.
---
 src/relax/transform/fuse_tir.cc               |  84 +++---
 tests/python/relax/test_transform_fuse_tir.py | 397 ++++++++++++++++++++++++++
 2 files changed, 440 insertions(+), 41 deletions(-)

diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 98fce9215f..2fb3f1d8ce 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -39,31 +39,37 @@ namespace tir {
  */
 class SymbolicMatcher : ExprFunctor<bool(const PrimExpr& n, const PrimExpr& 
other)> {
  public:
-  explicit SymbolicMatcher(Map<tir::Var, tir::Var>* var_remap) : 
var_remap_(var_remap) {}
+  explicit SymbolicMatcher(Map<tir::Var, PrimExpr>* var_remap) : 
var_remap_(var_remap) {}
 
-  void Match(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
-    CHECK_EQ(lhs.size(), rhs.size());
-    for (size_t i = 0; i < lhs.size(); ++i) {
-      Match(lhs[i], rhs[i]);
+  void Match(const Array<PrimExpr>& params, const Array<PrimExpr>& args) {
+    CHECK_EQ(params.size(), args.size());
+    for (size_t i = 0; i < params.size(); ++i) {
+      Match(params[i], args[i]);
     }
   }
-  void Match(const PrimExpr& lhs, const PrimExpr& rhs) {
-    if (!VisitExpr(lhs, rhs)) {
-      LOG(FATAL) << "Failed to match PrimExpr " << lhs << " with " << rhs;
+  void Match(const PrimExpr& param, const PrimExpr& arg) {
+    if (!VisitExpr(param, arg)) {
+      LOG(FATAL) << "Failed to match PrimExpr " << param << " with " << arg;
     }
   }
 
  private:
-  bool VisitExpr(const PrimExpr& n, const PrimExpr& other) {
-    bool matched = n.same_as(other) || ((n->type_index() == 
other->type_index()) &&
-                                        n.dtype().code() == 
other.dtype().code());
-    return matched && ExprFunctor::VisitExpr(n, other);
+  bool VisitExpr(const PrimExpr& node, const PrimExpr& other) {
+    if (node.same_as(other)) {
+      return true;
+    } else if (node.dtype().code() != other.dtype().code()) {
+      return false;
+    } else {
+      return ExprFunctor::VisitExpr(node, other);
+    }
   }
 
 #define TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OpName)               \
   bool VisitExpr_(const OpName* op, const PrimExpr& other) {     \
     const auto* rhs = other.as<OpName>();                        \
-    ICHECK(rhs);                                                 \
+    if (!rhs) {                                                  \
+      return false;                                              \
+    }                                                            \
     return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \
   }
 
@@ -87,34 +93,35 @@ class SymbolicMatcher : ExprFunctor<bool(const PrimExpr& n, 
const PrimExpr& othe
 
   bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) {
     const auto* rhs = other.as<IntImmNode>();
-    return op->value == rhs->value;
+    return rhs && (op->value == rhs->value);
   }
 
   bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) {
     const auto* rhs = other.as<FloatImmNode>();
-    return op->value == rhs->value;
+    return rhs && (op->value == rhs->value);
   }
 
   bool VisitExpr_(const CastNode* op, const PrimExpr& other) {
     const auto* rhs = other.as<CastNode>();
-    return VisitExpr(op->value, rhs->value);
+    return rhs && VisitExpr(op->value, rhs->value);
   }
 
-  bool VisitExpr_(const VarNode* op, const PrimExpr& other) {
-    const auto* rhs = other.as<VarNode>();
+  bool VisitExpr_(const VarNode* op, const PrimExpr& rhs) {
     auto lhs = GetRef<Var>(op);
-    if (lhs.same_as(other)) return true;
-    if (op->dtype.code() != rhs->dtype.code()) return false;
-    auto it = var_remap_->find(lhs);
-    if (it == var_remap_->end()) {
-      var_remap_->Set(lhs, GetRef<Var>(rhs));
+
+    if (lhs.same_as(rhs)) {
       return true;
+    } else if (op->dtype.code() != rhs->dtype.code()) {
+      return false;
+    } else if (auto it = var_remap_->find(lhs); it != var_remap_->end()) {
+      return VisitExpr((*it).second, rhs);
     } else {
-      return (*it).second.same_as(other);
+      var_remap_->Set(lhs, rhs);
+      return true;
     }
   }
 
-  Map<tir::Var, tir::Var>* var_remap_;
+  Map<tir::Var, PrimExpr>* var_remap_;
 };
 
 /*!
@@ -123,7 +130,7 @@ class SymbolicMatcher : ExprFunctor<bool(const PrimExpr& n, 
const PrimExpr& othe
 class FuseTIRBufferSubstitutor : private StmtExprMutator {
  public:
   explicit FuseTIRBufferSubstitutor(const Map<Buffer, Buffer>& buffer_map,
-                                    const Map<Var, Var>& var_map) {
+                                    const Map<Var, PrimExpr>& var_map) {
     buffer_remap_ = buffer_map;
     var_remap_ = var_map;
     for (const auto& [src, tgt] : buffer_map) {
@@ -156,8 +163,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator {
 
  private:
   PrimExpr VisitExpr_(const VarNode* _op) final {
-    auto it = var_remap_.find(GetRef<Var>(_op));
-    if (it != var_remap_.end()) {
+    if (auto it = var_remap_.find(GetRef<Var>(_op)); it != var_remap_.end()) {
       return (*it).second;
     } else {
       return GetRef<PrimExpr>(_op);
@@ -246,7 +252,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator {
   /*! \brief Mapping from src buffer to tgt buffer. */
   Map<tir::Buffer, tir::Buffer> buffer_remap_;
   /*! \brief Mapping from src tir var to tgt var. */
-  Map<tir::Var, tir::Var> var_remap_;
+  Map<tir::Var, PrimExpr> var_remap_;
 
   Array<tir::BufferRegion> UnionAccessRegion(const Array<BufferRegion>& 
regions) const {
     // For now we only allow Buffer access the same elements.
@@ -474,6 +480,7 @@ class FusedTIRConstructor : public ExprVisitor {
     // Step 5. Map input arguments to buffer
     MapInputBuffer(prim_func, call->args[1]);
     const Array<Array<PrimExpr>>& output_buffer_shapes = 
GetCallTIROutputShapes(call);
+
     AllocateIntermediateBuffer(GetRef<Expr>(call), prim_func, 
output_buffer_shapes);
 
     // Step 6. Update tir_vars
@@ -481,17 +488,12 @@ class FusedTIRConstructor : public ExprVisitor {
       ICHECK(call->args.size() == 3);
       const Expr& tir_vars = call->args[2];
       if (const auto* shape_expr = tir_vars.as<ShapeExprNode>()) {
-        const Array<tir::Var> vars = shape_expr->values.Map([](const PrimExpr& 
expr) {
-          if (!expr->IsInstance<tir::VarNode>()) {
-            LOG(FATAL) << "Expected a single var, but got: " << expr;
-          }
-          return Downcast<tir::Var>(expr);
-        });
+        const auto& args = shape_expr->values;
         size_t num_params = prim_func->params.size();
-        ICHECK_GE(num_params, vars.size());
-        for (size_t i = 0; i < vars.size(); ++i) {
-          const tir::Var& param = prim_func->params[num_params - vars.size() + 
i];
-          func_info_.symbolic_var_matcher.Match(param, vars[i]);
+        ICHECK_GE(num_params, args.size());
+        for (size_t i = 0; i < args.size(); ++i) {
+          const tir::Var& param = prim_func->params[num_params - args.size() + 
i];
+          func_info_.symbolic_var_matcher.Match(param, args[i]);
         }
       } else {
         LOG(FATAL) << "TIR vars should be a shape expr, but got: " << 
tir_vars->GetTypeKey();
@@ -805,8 +807,8 @@ class FusedTIRConstructor : public ExprVisitor {
      * function
      */
     Map<tir::Buffer, tir::Buffer> buffer_subst_map;
-    /*! \brief The map from symbolic var to its corresponding var in the fused 
function */
-    Map<tir::Var, tir::Var> symbolic_var_remap;
+    /*! \brief The map from symbolic var to its value in the fused function */
+    Map<tir::Var, PrimExpr> symbolic_var_remap;
     /*! \brief The `buffer_map` in the fused function*/
     Map<tir::Var, tir::Buffer> buffer_map;
     /*! \brief The output buffers in the function buffer_map*/
diff --git a/tests/python/relax/test_transform_fuse_tir.py 
b/tests/python/relax/test_transform_fuse_tir.py
index 6932b1c89d..556b673e61 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -1206,5 +1206,402 @@ def test_extern_func():
     _check(mod, mod)
 
 
+def test_symbolic_var_in_buffer_shape():
+    """A PrimFunc may have dynamic buffer shapes
+
+    Symbolic variables in a PrimFunc may be present in the buffer
+    shape without a corresponding parameter.  These symbolic variables
+    are inferred from the buffer's shape.  (Or, at runtime, they are
+    typically determined from the DLTensor's known shape.)
+    """
+
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def foo(
+            X_handle: T.handle,
+            Y: T.Buffer((T.int64(2048), T.int64(128)), "float32"),
+            rotary_handle: T.handle,
+            m: T.int64,
+        ):
+            sequence_length = T.int64()
+
+            X = T.match_buffer(
+                X_handle, [T.int64(1), sequence_length, T.int64(32), 
T.int64(128)], "float32"
+            )
+            rotary = T.match_buffer(
+                rotary_handle, [T.int64(1), sequence_length, T.int64(32), 
T.int64(128)], "float32"
+            )
+
+            for i0, i1, i2, i3 in T.grid(T.int64(1), sequence_length, 
T.int64(32), T.int64(128)):
+                with T.block("rotary"):
+                    v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                    rotary[v0, v1, v2, v3] = Y[m + v1 - 1, v3] * X[v0, v1, v2, 
v3]
+
+        @R.function
+        def fused(
+            x: R.Tensor((1, "sequence_length", 32, 128), dtype="float32"),
+            y: R.Tensor((2048, 128), dtype="float32"),
+            len: R.Shape(["m"]),
+        ) -> R.Tensor((1, "sequence_length", 32, 128), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            sequence_length = T.int64()
+            m = T.int64()
+            cls = Before
+            with R.dataflow():
+                lv1 = R.emit_te(topi.add, x, x)
+                gv = R.call_tir(
+                    cls.foo,
+                    [lv1, y],
+                    out_sinfo=R.Tensor((1, sequence_length, 32, 128), 
dtype="float32"),
+                    tir_vars=R.shape([m]),
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((1, "sequence_length", 32, 128), dtype="float32"),
+            y: R.Tensor((2048, 128), dtype="float32"),
+            len: R.Shape(["m"]),
+        ) -> R.Tensor((1, "sequence_length", 32, 128), dtype="float32"):
+            cls = Before
+            with R.dataflow():
+                gv = cls.fused(x, y, len)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def fused(
+            X_handle: T.handle,
+            Y: T.Buffer((T.int64(2048), T.int64(128)), "float32"),
+            rotary_handle: T.handle,
+            m: T.int64,
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+
+            sequence_length = T.int64()
+
+            X = T.match_buffer(
+                X_handle, [T.int64(1), sequence_length, T.int64(32), 
T.int64(128)], "float32"
+            )
+            rotary = T.match_buffer(
+                rotary_handle, [T.int64(1), sequence_length, T.int64(32), 
T.int64(128)], "float32"
+            )
+
+            T_add = T.alloc_buffer((T.int64(1), sequence_length, T.int64(32), 
T.int64(128)))
+            for ax0, ax1, ax2, ax3 in T.grid(
+                T.int64(1), sequence_length, T.int64(32), T.int64(128)
+            ):
+                with T.block("T_add"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = (
+                        X[v_ax0, v_ax1, v_ax2, v_ax3] + X[v_ax0, v_ax1, v_ax2, 
v_ax3]
+                    )
+            for i0, i1, i2, i3 in T.grid(T.int64(1), sequence_length, 
T.int64(32), T.int64(128)):
+                with T.block("rotary"):
+                    v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                    rotary[v0, v1, v2, v3] = Y[m + v1 - T.int64(1), v3] * 
T_add[v0, v1, v2, v3]
+
+        @R.function
+        def main(
+            x: R.Tensor((1, "sequence_length", 32, 128), dtype="float32"),
+            y: R.Tensor((2048, 128), dtype="float32"),
+            len: R.Shape(["m"]),
+        ) -> R.Tensor((1, "sequence_length", 32, 128), dtype="float32"):
+            sequence_length = T.int64()
+            m = T.int64()
+            cls = Expected
+            with R.dataflow():
+                gv = R.call_tir(
+                    cls.fused,
+                    (x, y),
+                    out_sinfo=R.Tensor([1, sequence_length, 32, 128], 
"float32"),
+                    tir_vars=R.shape([m]),
+                )
+                R.output(gv)
+            return gv
+
+    _check(Before, Expected)
+
+
+def test_symbolic_var_called_with_static_shape():
+    """A dynamic PrimFunc may be called with a static shape"""
+
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def sum_1d(
+            X_handle: T.handle,
+            Y: T.Buffer([T.int64(1)], "float32"),
+        ):
+            num_elements = T.int64()
+
+            X = T.match_buffer(X_handle, [num_elements], "float32")
+
+            for i in range(num_elements):
+                with T.block("sum"):
+                    vi = T.axis.remap("R", [i])
+                    with T.init():
+                        Y[0] = 0.0
+                    Y[0] = Y[0] + X[vi]
+
+        @R.function(private=True)
+        def fused(
+            x: R.Tensor([64], dtype="float32"),
+        ) -> R.Tensor([1], dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Before
+            with R.dataflow():
+                gv = R.call_tir(
+                    cls.sum_1d,
+                    [x],
+                    out_sinfo=R.Tensor([1], dtype="float32"),
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor([64], dtype="float32"),
+        ) -> R.Tensor([1], dtype="float32"):
+            cls = Before
+            with R.dataflow():
+                gv = cls.fused(x)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def fused(
+            X: T.Buffer([T.int64(64)], "float32"),
+            Y: T.Buffer([T.int64(1)], "float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+
+            for i in range(T.int64(64)):
+                with T.block("sum"):
+                    vi = T.axis.remap("R", [i])
+                    with T.init():
+                        Y[0] = 0.0
+                    Y[0] = Y[0] + X[vi]
+
+        @R.function
+        def main(
+            x: R.Tensor([64], dtype="float32"),
+        ) -> R.Tensor([1], dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv = R.call_tir(cls.fused, (x,), out_sinfo=R.Tensor((1,), 
dtype="float32"))
+                R.output(gv)
+            return gv
+
+    _check(Before, Expected)
+
+
+def test_symbolic_var_called_with_multiple_static_shapes():
+    """A dynamic PrimFunc may be called with different shapes each time"""
+
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def sum_1d(
+            X_handle: T.handle,
+            Sum: T.Buffer([T.int64(1)], "float32"),
+        ):
+            num_elements = T.int64()
+
+            X = T.match_buffer(X_handle, [num_elements], "float32")
+
+            for i in range(num_elements):
+                with T.block("sum"):
+                    vi = T.axis.remap("R", [i])
+                    with T.init():
+                        Sum[0] = 0.0
+                    Sum[0] = Sum[0] + X[vi]
+
+        @T.prim_func(private=True)
+        def sum_scalar(
+            X: T.Buffer([T.int64(1)], "float32"),
+            Y: T.Buffer([T.int64(1)], "float32"),
+            Sum: T.Buffer([T.int64(1)], "float32"),
+        ):
+            for i in range(T.int64(1)):
+                with T.block("Out"):
+                    vi = T.axis.remap("S", [i])
+                    Sum[vi] = X[vi] + Y[vi]
+
+        @R.function(private=True)
+        def fused(
+            x: R.Tensor([64], dtype="float32"),
+            y: R.Tensor([16], dtype="float32"),
+        ) -> R.Tensor([1], dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Before
+            with R.dataflow():
+                x_sum = R.call_tir(
+                    cls.sum_1d,
+                    [x],
+                    out_sinfo=R.Tensor([1], dtype="float32"),
+                )
+                y_sum = R.call_tir(
+                    cls.sum_1d,
+                    [y],
+                    out_sinfo=R.Tensor([1], dtype="float32"),
+                )
+                gv = R.call_tir(
+                    cls.sum_scalar,
+                    [x_sum, y_sum],
+                    out_sinfo=R.Tensor([1], dtype="float32"),
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor([64], dtype="float32"),
+            y: R.Tensor([16], dtype="float32"),
+        ) -> R.Tensor([1], dtype="float32"):
+            cls = Before
+            with R.dataflow():
+                gv = cls.fused(x, y)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def fused(
+            X: T.Buffer([T.int64(64)], "float32"),
+            Y: T.Buffer([T.int64(16)], "float32"),
+            Out: T.Buffer([T.int64(1)], "float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+
+            XSum = T.alloc_buffer([T.int64(1)], "float32")
+            YSum = T.alloc_buffer([T.int64(1)], "float32")
+
+            for i in range(T.int64(64)):
+                with T.block("XSum"):
+                    vi = T.axis.remap("R", [i])
+                    with T.init():
+                        XSum[0] = 0.0
+                    XSum[0] = XSum[0] + X[vi]
+
+            for i in range(T.int64(16)):
+                with T.block("YSum"):
+                    vi = T.axis.remap("R", [i])
+                    with T.init():
+                        YSum[0] = 0.0
+                    YSum[0] = YSum[0] + Y[vi]
+
+            for i in range(T.int64(1)):
+                with T.block("Out"):
+                    vi = T.axis.remap("S", [i])
+                    Out[vi] = XSum[vi] + YSum[vi]
+
+        @R.function
+        def main(
+            x: R.Tensor([64], dtype="float32"),
+            y: R.Tensor([16], dtype="float32"),
+        ) -> R.Tensor([1], dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv = R.call_tir(cls.fused, (x, y), out_sinfo=R.Tensor((1,), 
dtype="float32"))
+                R.output(gv)
+            return gv
+
+    _check(Before, Expected)
+
+
+def test_symbolic_var_called_with_static_argument():
+    """A dynamic PrimFunc may accept a static argument
+
+    The `tir_vars` parameter in `R.call_tir` contains definitions for
+    all TIR variables explicitly listed in the function signature, and
+    contains the TIR expression to be passed as the argument for for
+    each parameter.
+
+    This test is identical to the earlier test named
+    "test_symbolic_var_called_with_static_shape", except for the
+    explicit parameter in `sum_1d`.
+    """
+
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def sum_1d(
+            X_handle: T.handle,
+            Y: T.Buffer([T.int64(1)], "float32"),
+            num_elements: T.int64,
+        ):
+
+            X = T.match_buffer(X_handle, [num_elements], "float32")
+
+            for i in range(num_elements):
+                with T.block("sum"):
+                    vi = T.axis.remap("R", [i])
+                    with T.init():
+                        Y[0] = 0.0
+                    Y[0] = Y[0] + X[vi]
+
+        @R.function(private=True)
+        def fused(
+            x: R.Tensor([64], dtype="float32"),
+        ) -> R.Tensor([1], dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Before
+            with R.dataflow():
+                gv = R.call_tir(
+                    cls.sum_1d,
+                    [x],
+                    out_sinfo=R.Tensor([1], dtype="float32"),
+                    tir_vars=R.shape([64]),
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor([64], dtype="float32"),
+        ) -> R.Tensor([1], dtype="float32"):
+            cls = Before
+            with R.dataflow():
+                gv = cls.fused(x)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def fused(
+            X: T.Buffer([T.int64(64)], "float32"),
+            Y: T.Buffer([T.int64(1)], "float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+
+            for i in range(T.int64(64)):
+                with T.block("sum"):
+                    vi = T.axis.remap("R", [i])
+                    with T.init():
+                        Y[0] = 0.0
+                    Y[0] = Y[0] + X[vi]
+
+        @R.function
+        def main(
+            x: R.Tensor([64], dtype="float32"),
+        ) -> R.Tensor([1], dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv = R.call_tir(cls.fused, (x,), out_sinfo=R.Tensor((1,), 
dtype="float32"))
+                R.output(gv)
+            return gv
+
+    _check(Before, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to