This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9f4d0fa5d3 [Unity][Transform] Allow static Relax arguments to dynamic
PrimFunc (#15883)
9f4d0fa5d3 is described below
commit 9f4d0fa5d37dfed3193dae29177c033435ef7130
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Oct 9 08:14:25 2023 -0500
[Unity][Transform] Allow static Relax arguments to dynamic PrimFunc (#15883)
* [Unity][Transform] Allow static Relax arguments to dynamic PrimFunc
Prior to this commit, the `relax.transform.FuseTIR` transform required
that the shapes arguments passed into a `PrimFunc` be structurally
equivalent to the shapes of the parameters, and that any replacement
of symbolic `tir.Var` be with a symbolic `tir.Var` in the fused
function.
This commit updates the `SymbolicMatcher` to instead extract a
`Map<tir::Var, PrimExpr>`. As a result, a Relax tensor with
statically-known shape can be passed into a TIR PrimFunc with dynamic
shape. The resulting fused TIR function is in terms of the
statically-known shape, and no longer contains the symbolic variable.
---
src/relax/transform/fuse_tir.cc | 84 +++---
tests/python/relax/test_transform_fuse_tir.py | 397 ++++++++++++++++++++++++++
2 files changed, 440 insertions(+), 41 deletions(-)
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index 98fce9215f..2fb3f1d8ce 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -39,31 +39,37 @@ namespace tir {
*/
class SymbolicMatcher : ExprFunctor<bool(const PrimExpr& n, const PrimExpr&
other)> {
public:
- explicit SymbolicMatcher(Map<tir::Var, tir::Var>* var_remap) :
var_remap_(var_remap) {}
+ explicit SymbolicMatcher(Map<tir::Var, PrimExpr>* var_remap) :
var_remap_(var_remap) {}
- void Match(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
- CHECK_EQ(lhs.size(), rhs.size());
- for (size_t i = 0; i < lhs.size(); ++i) {
- Match(lhs[i], rhs[i]);
+ void Match(const Array<PrimExpr>& params, const Array<PrimExpr>& args) {
+ CHECK_EQ(params.size(), args.size());
+ for (size_t i = 0; i < params.size(); ++i) {
+ Match(params[i], args[i]);
}
}
- void Match(const PrimExpr& lhs, const PrimExpr& rhs) {
- if (!VisitExpr(lhs, rhs)) {
- LOG(FATAL) << "Failed to match PrimExpr " << lhs << " with " << rhs;
+ void Match(const PrimExpr& param, const PrimExpr& arg) {
+ if (!VisitExpr(param, arg)) {
+ LOG(FATAL) << "Failed to match PrimExpr " << param << " with " << arg;
}
}
private:
- bool VisitExpr(const PrimExpr& n, const PrimExpr& other) {
- bool matched = n.same_as(other) || ((n->type_index() ==
other->type_index()) &&
- n.dtype().code() ==
other.dtype().code());
- return matched && ExprFunctor::VisitExpr(n, other);
+ bool VisitExpr(const PrimExpr& node, const PrimExpr& other) {
+ if (node.same_as(other)) {
+ return true;
+ } else if (node.dtype().code() != other.dtype().code()) {
+ return false;
+ } else {
+ return ExprFunctor::VisitExpr(node, other);
+ }
}
#define TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OpName) \
bool VisitExpr_(const OpName* op, const PrimExpr& other) { \
const auto* rhs = other.as<OpName>(); \
- ICHECK(rhs); \
+ if (!rhs) { \
+ return false; \
+ } \
return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \
}
@@ -87,34 +93,35 @@ class SymbolicMatcher : ExprFunctor<bool(const PrimExpr& n,
const PrimExpr& othe
bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) {
const auto* rhs = other.as<IntImmNode>();
- return op->value == rhs->value;
+ return rhs && (op->value == rhs->value);
}
bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) {
const auto* rhs = other.as<FloatImmNode>();
- return op->value == rhs->value;
+ return rhs && (op->value == rhs->value);
}
bool VisitExpr_(const CastNode* op, const PrimExpr& other) {
const auto* rhs = other.as<CastNode>();
- return VisitExpr(op->value, rhs->value);
+ return rhs && VisitExpr(op->value, rhs->value);
}
- bool VisitExpr_(const VarNode* op, const PrimExpr& other) {
- const auto* rhs = other.as<VarNode>();
+ bool VisitExpr_(const VarNode* op, const PrimExpr& rhs) {
auto lhs = GetRef<Var>(op);
- if (lhs.same_as(other)) return true;
- if (op->dtype.code() != rhs->dtype.code()) return false;
- auto it = var_remap_->find(lhs);
- if (it == var_remap_->end()) {
- var_remap_->Set(lhs, GetRef<Var>(rhs));
+
+ if (lhs.same_as(rhs)) {
return true;
+ } else if (op->dtype.code() != rhs->dtype.code()) {
+ return false;
+ } else if (auto it = var_remap_->find(lhs); it != var_remap_->end()) {
+ return VisitExpr((*it).second, rhs);
} else {
- return (*it).second.same_as(other);
+ var_remap_->Set(lhs, rhs);
+ return true;
}
}
- Map<tir::Var, tir::Var>* var_remap_;
+ Map<tir::Var, PrimExpr>* var_remap_;
};
/*!
@@ -123,7 +130,7 @@ class SymbolicMatcher : ExprFunctor<bool(const PrimExpr& n,
const PrimExpr& othe
class FuseTIRBufferSubstitutor : private StmtExprMutator {
public:
explicit FuseTIRBufferSubstitutor(const Map<Buffer, Buffer>& buffer_map,
- const Map<Var, Var>& var_map) {
+ const Map<Var, PrimExpr>& var_map) {
buffer_remap_ = buffer_map;
var_remap_ = var_map;
for (const auto& [src, tgt] : buffer_map) {
@@ -156,8 +163,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator {
private:
PrimExpr VisitExpr_(const VarNode* _op) final {
- auto it = var_remap_.find(GetRef<Var>(_op));
- if (it != var_remap_.end()) {
+ if (auto it = var_remap_.find(GetRef<Var>(_op)); it != var_remap_.end()) {
return (*it).second;
} else {
return GetRef<PrimExpr>(_op);
@@ -246,7 +252,7 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator {
/*! \brief Mapping from src buffer to tgt buffer. */
Map<tir::Buffer, tir::Buffer> buffer_remap_;
/*! \brief Mapping from src tir var to tgt var. */
- Map<tir::Var, tir::Var> var_remap_;
+ Map<tir::Var, PrimExpr> var_remap_;
Array<tir::BufferRegion> UnionAccessRegion(const Array<BufferRegion>&
regions) const {
// For now we only allow Buffer access the same elements.
@@ -474,6 +480,7 @@ class FusedTIRConstructor : public ExprVisitor {
// Step 5. Map input arguments to buffer
MapInputBuffer(prim_func, call->args[1]);
const Array<Array<PrimExpr>>& output_buffer_shapes =
GetCallTIROutputShapes(call);
+
AllocateIntermediateBuffer(GetRef<Expr>(call), prim_func,
output_buffer_shapes);
// Step 6. Update tir_vars
@@ -481,17 +488,12 @@ class FusedTIRConstructor : public ExprVisitor {
ICHECK(call->args.size() == 3);
const Expr& tir_vars = call->args[2];
if (const auto* shape_expr = tir_vars.as<ShapeExprNode>()) {
- const Array<tir::Var> vars = shape_expr->values.Map([](const PrimExpr&
expr) {
- if (!expr->IsInstance<tir::VarNode>()) {
- LOG(FATAL) << "Expected a single var, but got: " << expr;
- }
- return Downcast<tir::Var>(expr);
- });
+ const auto& args = shape_expr->values;
size_t num_params = prim_func->params.size();
- ICHECK_GE(num_params, vars.size());
- for (size_t i = 0; i < vars.size(); ++i) {
- const tir::Var& param = prim_func->params[num_params - vars.size() +
i];
- func_info_.symbolic_var_matcher.Match(param, vars[i]);
+ ICHECK_GE(num_params, args.size());
+ for (size_t i = 0; i < args.size(); ++i) {
+ const tir::Var& param = prim_func->params[num_params - args.size() +
i];
+ func_info_.symbolic_var_matcher.Match(param, args[i]);
}
} else {
LOG(FATAL) << "TIR vars should be a shape expr, but got: " <<
tir_vars->GetTypeKey();
@@ -805,8 +807,8 @@ class FusedTIRConstructor : public ExprVisitor {
* function
*/
Map<tir::Buffer, tir::Buffer> buffer_subst_map;
- /*! \brief The map from symbolic var to its corresponding var in the fused
function */
- Map<tir::Var, tir::Var> symbolic_var_remap;
+ /*! \brief The map from symbolic var to its value in the fused function */
+ Map<tir::Var, PrimExpr> symbolic_var_remap;
/*! \brief The `buffer_map` in the fused function*/
Map<tir::Var, tir::Buffer> buffer_map;
/*! \brief The output buffers in the function buffer_map*/
diff --git a/tests/python/relax/test_transform_fuse_tir.py
b/tests/python/relax/test_transform_fuse_tir.py
index 6932b1c89d..556b673e61 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -1206,5 +1206,402 @@ def test_extern_func():
_check(mod, mod)
+def test_symbolic_var_in_buffer_shape():
+ """A PrimFunc may have dynamic buffer shapes
+
+ Symbolic variables in a PrimFunc may be present in the buffer
+ shape without a corresponding parameter. These symbolic variables
+ are inferred from the buffer's shape. (Or, at runtime, they are
+ typically determined from the DLTensor's known shape.)
+ """
+
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def foo(
+ X_handle: T.handle,
+ Y: T.Buffer((T.int64(2048), T.int64(128)), "float32"),
+ rotary_handle: T.handle,
+ m: T.int64,
+ ):
+ sequence_length = T.int64()
+
+ X = T.match_buffer(
+ X_handle, [T.int64(1), sequence_length, T.int64(32),
T.int64(128)], "float32"
+ )
+ rotary = T.match_buffer(
+ rotary_handle, [T.int64(1), sequence_length, T.int64(32),
T.int64(128)], "float32"
+ )
+
+ for i0, i1, i2, i3 in T.grid(T.int64(1), sequence_length,
T.int64(32), T.int64(128)):
+ with T.block("rotary"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ rotary[v0, v1, v2, v3] = Y[m + v1 - 1, v3] * X[v0, v1, v2,
v3]
+
+ @R.function
+ def fused(
+ x: R.Tensor((1, "sequence_length", 32, 128), dtype="float32"),
+ y: R.Tensor((2048, 128), dtype="float32"),
+ len: R.Shape(["m"]),
+ ) -> R.Tensor((1, "sequence_length", 32, 128), dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ sequence_length = T.int64()
+ m = T.int64()
+ cls = Before
+ with R.dataflow():
+ lv1 = R.emit_te(topi.add, x, x)
+ gv = R.call_tir(
+ cls.foo,
+ [lv1, y],
+ out_sinfo=R.Tensor((1, sequence_length, 32, 128),
dtype="float32"),
+ tir_vars=R.shape([m]),
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor((1, "sequence_length", 32, 128), dtype="float32"),
+ y: R.Tensor((2048, 128), dtype="float32"),
+ len: R.Shape(["m"]),
+ ) -> R.Tensor((1, "sequence_length", 32, 128), dtype="float32"):
+ cls = Before
+ with R.dataflow():
+ gv = cls.fused(x, y, len)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def fused(
+ X_handle: T.handle,
+ Y: T.Buffer((T.int64(2048), T.int64(128)), "float32"),
+ rotary_handle: T.handle,
+ m: T.int64,
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+
+ sequence_length = T.int64()
+
+ X = T.match_buffer(
+ X_handle, [T.int64(1), sequence_length, T.int64(32),
T.int64(128)], "float32"
+ )
+ rotary = T.match_buffer(
+ rotary_handle, [T.int64(1), sequence_length, T.int64(32),
T.int64(128)], "float32"
+ )
+
+ T_add = T.alloc_buffer((T.int64(1), sequence_length, T.int64(32),
T.int64(128)))
+ for ax0, ax1, ax2, ax3 in T.grid(
+ T.int64(1), sequence_length, T.int64(32), T.int64(128)
+ ):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T_add[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ X[v_ax0, v_ax1, v_ax2, v_ax3] + X[v_ax0, v_ax1, v_ax2,
v_ax3]
+ )
+ for i0, i1, i2, i3 in T.grid(T.int64(1), sequence_length,
T.int64(32), T.int64(128)):
+ with T.block("rotary"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ rotary[v0, v1, v2, v3] = Y[m + v1 - T.int64(1), v3] *
T_add[v0, v1, v2, v3]
+
+ @R.function
+ def main(
+ x: R.Tensor((1, "sequence_length", 32, 128), dtype="float32"),
+ y: R.Tensor((2048, 128), dtype="float32"),
+ len: R.Shape(["m"]),
+ ) -> R.Tensor((1, "sequence_length", 32, 128), dtype="float32"):
+ sequence_length = T.int64()
+ m = T.int64()
+ cls = Expected
+ with R.dataflow():
+ gv = R.call_tir(
+ cls.fused,
+ (x, y),
+ out_sinfo=R.Tensor([1, sequence_length, 32, 128],
"float32"),
+ tir_vars=R.shape([m]),
+ )
+ R.output(gv)
+ return gv
+
+ _check(Before, Expected)
+
+
+def test_symbolic_var_called_with_static_shape():
+ """A dynamic PrimFunc may be called with a static shape"""
+
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def sum_1d(
+ X_handle: T.handle,
+ Y: T.Buffer([T.int64(1)], "float32"),
+ ):
+ num_elements = T.int64()
+
+ X = T.match_buffer(X_handle, [num_elements], "float32")
+
+ for i in range(num_elements):
+ with T.block("sum"):
+ vi = T.axis.remap("R", [i])
+ with T.init():
+ Y[0] = 0.0
+ Y[0] = Y[0] + X[vi]
+
+ @R.function(private=True)
+ def fused(
+ x: R.Tensor([64], dtype="float32"),
+ ) -> R.Tensor([1], dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ cls = Before
+ with R.dataflow():
+ gv = R.call_tir(
+ cls.sum_1d,
+ [x],
+ out_sinfo=R.Tensor([1], dtype="float32"),
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor([64], dtype="float32"),
+ ) -> R.Tensor([1], dtype="float32"):
+ cls = Before
+ with R.dataflow():
+ gv = cls.fused(x)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def fused(
+ X: T.Buffer([T.int64(64)], "float32"),
+ Y: T.Buffer([T.int64(1)], "float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+
+ for i in range(T.int64(64)):
+ with T.block("sum"):
+ vi = T.axis.remap("R", [i])
+ with T.init():
+ Y[0] = 0.0
+ Y[0] = Y[0] + X[vi]
+
+ @R.function
+ def main(
+ x: R.Tensor([64], dtype="float32"),
+ ) -> R.Tensor([1], dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ gv = R.call_tir(cls.fused, (x,), out_sinfo=R.Tensor((1,),
dtype="float32"))
+ R.output(gv)
+ return gv
+
+ _check(Before, Expected)
+
+
+def test_symbolic_var_called_with_multiple_static_shapes():
+ """A dynamic PrimFunc may be called with different shapes each time"""
+
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def sum_1d(
+ X_handle: T.handle,
+ Sum: T.Buffer([T.int64(1)], "float32"),
+ ):
+ num_elements = T.int64()
+
+ X = T.match_buffer(X_handle, [num_elements], "float32")
+
+ for i in range(num_elements):
+ with T.block("sum"):
+ vi = T.axis.remap("R", [i])
+ with T.init():
+ Sum[0] = 0.0
+ Sum[0] = Sum[0] + X[vi]
+
+ @T.prim_func(private=True)
+ def sum_scalar(
+ X: T.Buffer([T.int64(1)], "float32"),
+ Y: T.Buffer([T.int64(1)], "float32"),
+ Sum: T.Buffer([T.int64(1)], "float32"),
+ ):
+ for i in range(T.int64(1)):
+ with T.block("Out"):
+ vi = T.axis.remap("S", [i])
+ Sum[vi] = X[vi] + Y[vi]
+
+ @R.function(private=True)
+ def fused(
+ x: R.Tensor([64], dtype="float32"),
+ y: R.Tensor([16], dtype="float32"),
+ ) -> R.Tensor([1], dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ cls = Before
+ with R.dataflow():
+ x_sum = R.call_tir(
+ cls.sum_1d,
+ [x],
+ out_sinfo=R.Tensor([1], dtype="float32"),
+ )
+ y_sum = R.call_tir(
+ cls.sum_1d,
+ [y],
+ out_sinfo=R.Tensor([1], dtype="float32"),
+ )
+ gv = R.call_tir(
+ cls.sum_scalar,
+ [x_sum, y_sum],
+ out_sinfo=R.Tensor([1], dtype="float32"),
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor([64], dtype="float32"),
+ y: R.Tensor([16], dtype="float32"),
+ ) -> R.Tensor([1], dtype="float32"):
+ cls = Before
+ with R.dataflow():
+ gv = cls.fused(x, y)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def fused(
+ X: T.Buffer([T.int64(64)], "float32"),
+ Y: T.Buffer([T.int64(16)], "float32"),
+ Out: T.Buffer([T.int64(1)], "float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+
+ XSum = T.alloc_buffer([T.int64(1)], "float32")
+ YSum = T.alloc_buffer([T.int64(1)], "float32")
+
+ for i in range(T.int64(64)):
+ with T.block("XSum"):
+ vi = T.axis.remap("R", [i])
+ with T.init():
+ XSum[0] = 0.0
+ XSum[0] = XSum[0] + X[vi]
+
+ for i in range(T.int64(16)):
+ with T.block("YSum"):
+ vi = T.axis.remap("R", [i])
+ with T.init():
+ YSum[0] = 0.0
+ YSum[0] = YSum[0] + Y[vi]
+
+ for i in range(T.int64(1)):
+ with T.block("Out"):
+ vi = T.axis.remap("S", [i])
+ Out[vi] = XSum[vi] + YSum[vi]
+
+ @R.function
+ def main(
+ x: R.Tensor([64], dtype="float32"),
+ y: R.Tensor([16], dtype="float32"),
+ ) -> R.Tensor([1], dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ gv = R.call_tir(cls.fused, (x, y), out_sinfo=R.Tensor((1,),
dtype="float32"))
+ R.output(gv)
+ return gv
+
+ _check(Before, Expected)
+
+
+def test_symbolic_var_called_with_static_argument():
+ """A dynamic PrimFunc may accept a static argument
+
+ The `tir_vars` parameter in `R.call_tir` contains definitions for
+ all TIR variables explicitly listed in the function signature, and
+ contains the TIR expression to be passed as the argument for for
+ each parameter.
+
+ This test is identical to the earlier test named
+ "test_symbolic_var_called_with_static_shape", except for the
+ explicit parameter in `sum_1d`.
+ """
+
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def sum_1d(
+ X_handle: T.handle,
+ Y: T.Buffer([T.int64(1)], "float32"),
+ num_elements: T.int64,
+ ):
+
+ X = T.match_buffer(X_handle, [num_elements], "float32")
+
+ for i in range(num_elements):
+ with T.block("sum"):
+ vi = T.axis.remap("R", [i])
+ with T.init():
+ Y[0] = 0.0
+ Y[0] = Y[0] + X[vi]
+
+ @R.function(private=True)
+ def fused(
+ x: R.Tensor([64], dtype="float32"),
+ ) -> R.Tensor([1], dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ cls = Before
+ with R.dataflow():
+ gv = R.call_tir(
+ cls.sum_1d,
+ [x],
+ out_sinfo=R.Tensor([1], dtype="float32"),
+ tir_vars=R.shape([64]),
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor([64], dtype="float32"),
+ ) -> R.Tensor([1], dtype="float32"):
+ cls = Before
+ with R.dataflow():
+ gv = cls.fused(x)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def fused(
+ X: T.Buffer([T.int64(64)], "float32"),
+ Y: T.Buffer([T.int64(1)], "float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+
+ for i in range(T.int64(64)):
+ with T.block("sum"):
+ vi = T.axis.remap("R", [i])
+ with T.init():
+ Y[0] = 0.0
+ Y[0] = Y[0] + X[vi]
+
+ @R.function
+ def main(
+ x: R.Tensor([64], dtype="float32"),
+ ) -> R.Tensor([1], dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ gv = R.call_tir(cls.fused, (x,), out_sinfo=R.Tensor((1,),
dtype="float32"))
+ R.output(gv)
+ return gv
+
+ _check(Before, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()