This is an automated email from the ASF dual-hosted git repository.
sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5bfca2e7a2 [Transform] Modify FuseTIR pass to propagate buffer
attributes (#17075)
5bfca2e7a2 is described below
commit 5bfca2e7a25a357e5b3399ade98461a2678e8fc5
Author: Anirudh Sundar Subramaniam <[email protected]>
AuthorDate: Mon Jun 17 22:01:54 2024 +0530
[Transform] Modify FuseTIR pass to propagate buffer attributes (#17075)
Arguments of a fused TIR PrimFunc generated from a fused relax function do
not retain all the buffer attributes from their original PrimFuncs as the
buffers are created from the StructInfo of the Relax vars. This patch collects
a mapping of relax vars to its corresponding TIR buffers in a fused relax
function and uses that info to propagate its buffer attributes such as
`axis_separators` and `storage_scope`
---
src/relax/transform/fuse_tir.cc | 140 ++++++++++++++++++++++----
tests/python/relax/test_transform_fuse_tir.py | 128 +++++++++++++++++++++++
2 files changed, 248 insertions(+), 20 deletions(-)
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index e712b5022a..b203b322ab 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -362,6 +362,114 @@ class BlockNameDeduplicator : public tir::StmtMutator {
namespace relax {
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>&
inplace_indices,
+ int num_inputs) {
+ Array<Integer> ret;
+ int last_idx = num_inputs;
+ for (auto idx : inplace_indices) {
+ int i = idx.IntValue();
+ if (i >= 0) {
+ ret.push_back(Integer(i));
+ } else {
+ CHECK_EQ(i, -1) << "The only negative index expected in inplace_indices
is -1, but got " << i;
+ ret.push_back(Integer(last_idx));
+ last_idx++;
+ }
+ }
+
+ return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+ public:
+ explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {}
+ static Map<Expr, tir::Buffer> Collect(const IRModule& mod, const Function&
func) {
+ RelaxToTIRVarMapCollector visitor(mod);
+ visitor(func->body);
+ return visitor.relax_to_tir_var_map_;
+ }
+
+ private:
+ void VisitBinding_(const VarBindingNode* binding) final {
+ current_var_ = binding->var;
+ ExprVisitor::VisitBinding_(binding);
+ }
+
+ void VisitExpr_(const CallNode* call) {
+ static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+ static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
+
+ ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
+ << "Only call_tir and call_tir_inplace are supported in primitive
function, but got: "
+ << GetRef<Expr>(call);
+ CollectVarMapping(call, current_var_, call->op == call_tir_inplace_op_);
+ }
+
+ void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool
in_place) {
+ GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+ tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+ const auto& buffer_map = prim_func_->buffer_map;
+ const auto& tir_args = prim_func_->params;
+
+ const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+ Array<Expr> relax_results;
+ if (lhs_var->IsInstance<TupleNode>()) {
+ relax_results = Downcast<Tuple>(lhs_var)->fields;
+ } else {
+ CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be
either tuple or var";
+ relax_results = {Downcast<Var>(lhs_var)};
+ }
+
+ size_t num_inputs = relax_args.size();
+ size_t num_outputs = relax_results.size();
+
+ Array<Integer> output_idxs;
+ if (in_place) {
+ const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+ CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+ output_idxs = GetInplaceOutputIndices(attrs->inplace_indices,
num_inputs);
+ } else {
+ for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+ output_idxs.push_back(i);
+ }
+ }
+
+ // If the `expr` is already seen (present in the map), validate whether
the mapped buffer is
+ // structurally equal to the `new_buf` passed
+ auto ValidateBufferCompatibility = [this](tir::Buffer new_buf, Expr expr) {
+ if (auto it = relax_to_tir_var_map_.find(expr); it !=
relax_to_tir_var_map_.end()) {
+ ICHECK(StructuralEqual()((*it).second, new_buf))
+ << "Inconsistent buffers " << (*it).second << " and " << new_buf
+ << " mapped to the same relax var: " << expr;
+ }
+ };
+ for (size_t i = 0; i < tir_args.size(); ++i) {
+ const auto& tir_var = tir_args[i];
+ if (auto tir_buffer = buffer_map.Get(tir_var)) {
+ if (i < num_inputs) {
+ const auto& relax_var = relax_args[i];
+ ValidateBufferCompatibility(tir_buffer.value(), relax_var);
+ relax_to_tir_var_map_.Set(relax_var, tir_buffer.value());
+ }
+ if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i);
+ it != output_idxs.end()) {
+ int result_idx = it - output_idxs.begin();
+ const auto& relax_var = relax_results[result_idx];
+ ValidateBufferCompatibility(tir_buffer.value(), relax_var);
+ relax_to_tir_var_map_.Set(relax_var, tir_buffer.value());
+ }
+ }
+ }
+ }
+
+ private:
+ /*! \brief The IRModule */
+ const IRModule& mod_;
+ Map<Expr, tir::Buffer> relax_to_tir_var_map_;
+ Var current_var_;
+};
+
class FusedTIRConstructor : public ExprVisitor {
public:
/*!
@@ -391,10 +499,11 @@ class FusedTIRConstructor : public ExprVisitor {
: mod_(mod), func_name_(func_name) {}
void VisitExpr_(const FunctionNode* func) final {
+ auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect(mod_,
GetRef<Function>(func));
std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
for (const Var& relax_param : func->params) {
size_t size_before = prim_func_params.size();
- CollectPrimFuncParams(relax_param, &prim_func_params);
+ CollectPrimFuncParams(relax_param, &prim_func_params,
relax_to_tir_var_map.Get(relax_param));
auto param_buffers = [&]() -> Array<tir::Buffer> {
Array<tir::Buffer> out;
@@ -676,23 +785,6 @@ class FusedTIRConstructor : public ExprVisitor {
MapArgsToBuffer(arg_list, buffer_list);
}
- static Array<Integer> GetInplaceOutputIndices(const Array<Integer>&
inplace_indices,
- int num_inputs) {
- Array<Integer> ret;
- int last_idx = num_inputs;
- for (auto idx : inplace_indices) {
- int i = idx.IntValue();
- if (i >= 0) {
- ret.push_back(Integer(i));
- } else {
- ret.push_back(Integer(last_idx));
- last_idx++;
- }
- }
-
- return ret;
- }
-
static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func,
const Array<Integer>&
output_indices) {
size_t n = func->params.size();
@@ -799,7 +891,8 @@ class FusedTIRConstructor : public ExprVisitor {
* \param out The vector into which to collect the params/buffers
*/
static void CollectPrimFuncParams(const Var& relax_param,
- std::vector<Variant<tir::Var,
tir::Buffer>>* out) {
+ std::vector<Variant<tir::Var,
tir::Buffer>>* out,
+ const tvm::runtime::Optional<tir::Buffer>&
tir_buffer_param) {
auto struct_info = GetStructInfo(relax_param);
CHECK(!struct_info.as<TupleStructInfoNode>())
@@ -814,7 +907,14 @@ class FusedTIRConstructor : public ExprVisitor {
const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a
known shape.";
DataType dtype = tensor->dtype;
- tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype,
name_hint);
+ tir::Buffer buffer;
+ if (tir_buffer_param.defined()) {
+ buffer =
+ tir::decl_buffer(shape_expr->values, dtype, name_hint,
tir_buffer_param.value().scope(),
+ tir_buffer_param.value()->axis_separators);
+ } else {
+ buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint);
+ }
out->push_back(std::move(buffer));
} else if (const auto* prim_value = struct_info.as<PrimStructInfoNode>()) {
diff --git a/tests/python/relax/test_transform_fuse_tir.py
b/tests/python/relax/test_transform_fuse_tir.py
index 90baeaad04..99e7a5d2b7 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+import pytest
+
import tvm
import tvm.testing
from tvm import relax, topi
@@ -2314,5 +2316,131 @@ def test_private_nonprimitive_func():
_check(Before, Before)
+def test_fuse_with_axis_separators():
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def add(a: T.handle, b: T.handle, c: T.handle):
+ A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32",
axis_separators=[1])
+ B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32",
axis_separators=[1])
+ C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32",
axis_separators=[1])
+
+ for iters in T.grid(T.int64(16), T.int64(32)):
+ with T.block("compute"):
+ i, j = T.axis.remap("SS", iters)
+ C[i, j] = A[i, j] + B[i, j]
+
+ @R.function(private=True)
+ def fused_function(
+ x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ cls = Before
+ with R.dataflow():
+ w = R.call_tir(
+ cls.add, [x, y], out_sinfo=R.Tensor([T.int64(16),
T.int64(32)], "float32")
+ )
+ out = R.call_tir(
+ cls.add, [w, z], out_sinfo=R.Tensor([T.int64(16),
T.int64(32)], "float32")
+ )
+ R.output(out)
+ return out
+
+ @R.function
+ def main(
+ x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
+ cls = Before
+ with R.dataflow():
+ gv = cls.fused_function(x, y, z)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def fused_function(x: T.handle, y: T.handle, z: T.handle, c: T.handle):
+ T.func_attr({"tir.noalias": True})
+ X = T.match_buffer(x, [T.int64(16), T.int64(32)], "float32",
axis_separators=[1])
+ Y = T.match_buffer(y, [T.int64(16), T.int64(32)], "float32",
axis_separators=[1])
+ Z = T.match_buffer(z, [T.int64(16), T.int64(32)], "float32",
axis_separators=[1])
+ C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32",
axis_separators=[1])
+ Temp = T.alloc_buffer(X.shape, "float32", axis_separators=[1])
+ for iters in T.grid(*X.shape):
+ with T.block("compute_Y"):
+ i, j = T.axis.remap("SS", iters)
+ Temp[i, j] = X[i, j] + Y[i, j]
+
+ for iters in T.grid(*X.shape):
+ with T.block("compute_Z"):
+ i, j = T.axis.remap("SS", iters)
+ C[i, j] = Temp[i, j] + Z[i, j]
+
+ @R.function
+ def main(
+ x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ gv = R.call_tir(
+ cls.fused_function,
+ [x, y, z],
+ out_sinfo=R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ )
+ R.output(gv)
+ return gv
+
+ _check(Before, Expected)
+
+
+def test_fuse_with_axis_separators_inconsistent_buffer_mapping():
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def mul(a: T.handle, b: T.handle, c: T.handle):
+ A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32",
axis_separators=[1])
+ B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32",
axis_separators=[])
+ C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32",
axis_separators=[1])
+
+ for iters in T.grid(T.int64(16), T.int64(32)):
+ with T.block("compute"):
+ i, j = T.axis.remap("SS", iters)
+ C[i, j] = A[i, j] * B[i, j]
+
+ @R.function(private=True)
+ def fused_function(
+ x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ cls = Before
+ with R.dataflow():
+ out = R.call_tir(
+ cls.mul, [x, x], out_sinfo=R.Tensor([T.int64(16),
T.int64(32)], "float32")
+ )
+ R.output(out)
+ return out
+
+ @R.function
+ def main(
+ x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
+ cls = Before
+ with R.dataflow():
+ gv = cls.fused_function(x)
+ R.output(gv)
+ return gv
+
+ with pytest.raises(
+ tvm.TVMError, match=r"Inconsistent buffers.*and.*mapped to the same
relax var:.*"
+ ):
+ relax.transform.FuseTIR()(Before)
+
+
if __name__ == "__main__":
tvm.testing.main()