Lunderberg commented on code in PR #17075:
URL: https://github.com/apache/tvm/pull/17075#discussion_r1633589835
##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
namespace relax {
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>&
inplace_indices,
+ int num_inputs) {
+ Array<Integer> ret;
+ int last_idx = num_inputs;
+ for (auto idx : inplace_indices) {
+ int i = idx.IntValue();
+ if (i >= 0) {
+ ret.push_back(Integer(i));
+ } else {
+ ret.push_back(Integer(last_idx));
+ last_idx++;
+ }
+ }
+
+ return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+ void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool
in_place = false) {
+ GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+ tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+ const auto& buffer_map = prim_func_->buffer_map;
+ const auto& tir_args = prim_func_->params;
+
+ const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+ Array<Expr> relax_results;
+ if (lhs_var->IsInstance<TupleNode>()) {
+ relax_results = Downcast<Tuple>(lhs_var)->fields;
+ } else {
+ CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be
either tuple or var";
+ relax_results = {Downcast<Var>(lhs_var)};
+ }
+
+ size_t num_inputs = relax_args.size();
+ size_t num_outputs = relax_results.size();
+
+ Array<Integer> output_idxs;
+ if (in_place) {
+ const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+ CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+ output_idxs = GetInplaceOutputIndices(attrs->inplace_indices,
num_inputs);
+ } else {
+ for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+ output_idxs.push_back(i);
+ }
+ }
+ for (size_t i = 0; i < tir_args.size(); ++i) {
+ const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
+ if (i < num_inputs) {
+ const auto& relax_var = Downcast<Var>(relax_args[i]);
Review Comment:
This `Downcast<Var>` is not guaranteed to work. While the normalizer will
pull most `relax.Var` instances out to their own variable binding, `R.const`
arguments may still appear inline.
##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
namespace relax {
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>&
inplace_indices,
+ int num_inputs) {
+ Array<Integer> ret;
+ int last_idx = num_inputs;
+ for (auto idx : inplace_indices) {
+ int i = idx.IntValue();
+ if (i >= 0) {
+ ret.push_back(Integer(i));
+ } else {
+ ret.push_back(Integer(last_idx));
+ last_idx++;
+ }
+ }
+
+ return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+ void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool
in_place = false) {
+ GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+ tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+ const auto& buffer_map = prim_func_->buffer_map;
+ const auto& tir_args = prim_func_->params;
+
+ const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+ Array<Expr> relax_results;
+ if (lhs_var->IsInstance<TupleNode>()) {
+ relax_results = Downcast<Tuple>(lhs_var)->fields;
+ } else {
+ CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be
either tuple or var";
+ relax_results = {Downcast<Var>(lhs_var)};
+ }
+
+ size_t num_inputs = relax_args.size();
+ size_t num_outputs = relax_results.size();
+
+ Array<Integer> output_idxs;
+ if (in_place) {
+ const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+ CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+ output_idxs = GetInplaceOutputIndices(attrs->inplace_indices,
num_inputs);
+ } else {
+ for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+ output_idxs.push_back(i);
+ }
+ }
+ for (size_t i = 0; i < tir_args.size(); ++i) {
+ const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
+ if (i < num_inputs) {
+ const auto& relax_var = Downcast<Var>(relax_args[i]);
+ relax_to_tir_var_map_.Set(relax_var, buffer_map[tir_var]);
+ }
+ if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i); it
!= output_idxs.end()) {
+ int result_idx = it - output_idxs.begin();
+ const auto& inplace_out_var = Downcast<Var>(relax_results[result_idx]);
+ relax_to_tir_var_map_.Set(inplace_out_var, buffer_map[tir_var]);
+ }
+ }
+ }
+
+ public:
+ explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {}
+ static Map<Var, tir::Buffer> Collect(const IRModule& mod, const Function&
func) {
+ RelaxToTIRVarMapCollector visitor(mod);
+ visitor(func->body);
+ return visitor.relax_to_tir_var_map_;
+ }
+ void VisitBinding_(const VarBindingNode* binding) final {
+ const auto& lhs_var = binding->var;
Review Comment:
As written, this would only collect the mapping for relax variables whose
binding occurs in the outermost `SeqExpr`. Nested `SeqExpr` may occur if
`binding->value` is a `relax::If` node, where each branch then contains a
`SeqExpr`.
To resolve this, I'd recommend either adding
`ExprVisitor::VisitBinding_(binding);` or `VisitExpr(binding->value)` to this
method.
##########
tests/python/relax/test_transform_fuse_tir.py:
##########
@@ -2314,5 +2314,88 @@ def take(
_check(Before, Before)
+def test_fuse_with_axis_separators():
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def add(a: T.handle, b: T.handle, c: T.handle):
+ A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32",
axis_separators=[1])
+ B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32",
axis_separators=[1])
+ C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32",
axis_separators=[1])
+
+ for iters in T.grid(T.int64(16), T.int64(32)):
+ with T.block("compute"):
+ i, j = T.axis.remap("SS", iters)
+ C[i, j] = A[i, j] + B[i, j]
+
+ @R.function(private=True)
+ def fused_function(
+ x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+ ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ cls = Before
+ with R.dataflow():
+ w = R.call_tir(
+ cls.add, [x, y], out_sinfo=R.Tensor([T.int64(16),
T.int64(32)], "float32")
Review Comment:
Can we add a test case for incompatible usage of a single Relax var? As
currently written, we could have a single Relax variable that is used in two
separate `R.call_tir` statements, where the function being called imposes
different restrictions on it. For example, if `x` were used in `cls.add1`,
which requires `axis_separators=[1]`, and `cls.add2`, which requires
`axis_separators=[]`. We should be able to identify this case and raise an
error when it occurs.
(Ideally, that should never happen, but this would be the last point at
which we'd have enough information to catch this failure mode at compile-time.)
##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
namespace relax {
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>&
inplace_indices,
+ int num_inputs) {
+ Array<Integer> ret;
+ int last_idx = num_inputs;
+ for (auto idx : inplace_indices) {
+ int i = idx.IntValue();
+ if (i >= 0) {
+ ret.push_back(Integer(i));
+ } else {
+ ret.push_back(Integer(last_idx));
+ last_idx++;
+ }
+ }
+
+ return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+ void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool
in_place = false) {
+ GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+ tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+ const auto& buffer_map = prim_func_->buffer_map;
+ const auto& tir_args = prim_func_->params;
+
+ const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+ Array<Expr> relax_results;
+ if (lhs_var->IsInstance<TupleNode>()) {
+ relax_results = Downcast<Tuple>(lhs_var)->fields;
+ } else {
+ CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be
either tuple or var";
+ relax_results = {Downcast<Var>(lhs_var)};
+ }
+
+ size_t num_inputs = relax_args.size();
+ size_t num_outputs = relax_results.size();
+
+ Array<Integer> output_idxs;
+ if (in_place) {
+ const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+ CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+ output_idxs = GetInplaceOutputIndices(attrs->inplace_indices,
num_inputs);
+ } else {
+ for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+ output_idxs.push_back(i);
+ }
+ }
+ for (size_t i = 0; i < tir_args.size(); ++i) {
+ const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
+ if (i < num_inputs) {
+ const auto& relax_var = Downcast<Var>(relax_args[i]);
+ relax_to_tir_var_map_.Set(relax_var, buffer_map[tir_var]);
Review Comment:
The `buffer_map` does not necessarily contain an entry for `tir_var`. For
example, the `relax_var` could have `PrimStructInfo` to pass a primitive scalar
to the TIR funciton. Even if `relax_var` has `TensorStructInfo`, the TIR
function may treat the `DLTensor*` as an opaque pointer, passing it to a
`PackedFunc` without having an entry in the `buffer_map`.
The best way to handle these cases is to wrap this line in a `if(auto
tir_buffer = buffer_map.Get(tir_var))` conditional, and then use
`tir_buffer.value()` inside the conditional instead of `buffer_map[tir_var]`.
##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
namespace relax {
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>&
inplace_indices,
+ int num_inputs) {
+ Array<Integer> ret;
+ int last_idx = num_inputs;
+ for (auto idx : inplace_indices) {
+ int i = idx.IntValue();
+ if (i >= 0) {
+ ret.push_back(Integer(i));
+ } else {
+ ret.push_back(Integer(last_idx));
Review Comment:
Nitpick: This should only apply for `i==-1`. For any other negative value,
we should raise an error.
##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
namespace relax {
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>&
inplace_indices,
+ int num_inputs) {
+ Array<Integer> ret;
+ int last_idx = num_inputs;
+ for (auto idx : inplace_indices) {
+ int i = idx.IntValue();
+ if (i >= 0) {
+ ret.push_back(Integer(i));
+ } else {
+ ret.push_back(Integer(last_idx));
+ last_idx++;
+ }
+ }
+
+ return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+ void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool
in_place = false) {
+ GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+ tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+ const auto& buffer_map = prim_func_->buffer_map;
+ const auto& tir_args = prim_func_->params;
+
+ const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+ Array<Expr> relax_results;
+ if (lhs_var->IsInstance<TupleNode>()) {
+ relax_results = Downcast<Tuple>(lhs_var)->fields;
+ } else {
+ CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be
either tuple or var";
+ relax_results = {Downcast<Var>(lhs_var)};
+ }
+
+ size_t num_inputs = relax_args.size();
+ size_t num_outputs = relax_results.size();
+
+ Array<Integer> output_idxs;
+ if (in_place) {
+ const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+ CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+ output_idxs = GetInplaceOutputIndices(attrs->inplace_indices,
num_inputs);
+ } else {
+ for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+ output_idxs.push_back(i);
+ }
+ }
+ for (size_t i = 0; i < tir_args.size(); ++i) {
+ const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
+ if (i < num_inputs) {
+ const auto& relax_var = Downcast<Var>(relax_args[i]);
+ relax_to_tir_var_map_.Set(relax_var, buffer_map[tir_var]);
+ }
+ if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i); it
!= output_idxs.end()) {
+ int result_idx = it - output_idxs.begin();
+ const auto& inplace_out_var = Downcast<Var>(relax_results[result_idx]);
+ relax_to_tir_var_map_.Set(inplace_out_var, buffer_map[tir_var]);
+ }
+ }
+ }
+
+ public:
+ explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {}
+ static Map<Var, tir::Buffer> Collect(const IRModule& mod, const Function&
func) {
+ RelaxToTIRVarMapCollector visitor(mod);
+ visitor(func->body);
+ return visitor.relax_to_tir_var_map_;
+ }
+ void VisitBinding_(const VarBindingNode* binding) final {
+ const auto& lhs_var = binding->var;
+ const auto& value = binding->value;
+ if (const CallNode* call = value.as<CallNode>()) {
+ static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+ static const Op& call_tir_inplace_op_ =
Op::Get("relax.call_tir_inplace");
+
+ ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
+ << "Only call_tir and call_tir_inplace are supported in primitive
function, but got: "
+ << GetRef<Expr>(call);
+ if (call->op == call_tir_inplace_op_) {
+ CollectVarMapping(call, lhs_var, /*in_place*/ true);
+ } else {
+ CollectVarMapping(call, lhs_var);
+ }
+ }
+ }
+
+ private:
+ /*! \brief The IRModule */
+ const IRModule& mod_;
+ // size_t call_num_inputs_ = -1;
+ Map<Var, tir::Buffer> relax_to_tir_var_map_;
Review Comment:
This data structure assumes that there is a 1:1 mapping from `relax::Var` to
`tir::Buffer` across the entire fused function. This would have incorrect
results for cases where the same tensor is used as multiple arguments (e.g.
`R.add(A, A)`), or where the same tensor is used as an argument to more than
one function (e.g. The tensor `A` corresponds to two different TIR buffers in
the sequence `mean = R.mean(A); norm = R.sqrt(mean); A_norm = R.divide(A,
norm)`).
##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
namespace relax {
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>&
inplace_indices,
+ int num_inputs) {
+ Array<Integer> ret;
+ int last_idx = num_inputs;
+ for (auto idx : inplace_indices) {
+ int i = idx.IntValue();
+ if (i >= 0) {
+ ret.push_back(Integer(i));
+ } else {
+ ret.push_back(Integer(last_idx));
+ last_idx++;
+ }
+ }
+
+ return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+ void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool
in_place = false) {
Review Comment:
Nitpick: For readability, having the public-facing interface at the top of
the class makes it easier to find the entry point. Can the `static Map<Var,
tir::Buffer> Collect` function be moved to the top of
`RelaxToTIRVarMapCollector`?
##########
src/relax/transform/fuse_tir.cc:
##########
@@ -391,10 +484,15 @@ class FusedTIRConstructor : public ExprVisitor {
: mod_(mod), func_name_(func_name) {}
void VisitExpr_(const FunctionNode* func) final {
+ auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect(mod_,
GetRef<Function>(func));
std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
for (const Var& relax_param : func->params) {
size_t size_before = prim_func_params.size();
- CollectPrimFuncParams(relax_param, &prim_func_params);
+ if (relax_to_tir_var_map.count(relax_param)) {
Review Comment:
Nitpick: Instead of having the conditional, both branches could be written
as `CollectPrimFuncParams(relax_param, &prim_func_params,
relax_to_tir_var_map.Get(relax_param))`. The `Map::Get` method returns an
`Optional<tir::Buffer>`, which is `NullOpt` when the key is absent from the map.
##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
namespace relax {
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>&
inplace_indices,
+ int num_inputs) {
+ Array<Integer> ret;
+ int last_idx = num_inputs;
+ for (auto idx : inplace_indices) {
+ int i = idx.IntValue();
+ if (i >= 0) {
+ ret.push_back(Integer(i));
+ } else {
+ ret.push_back(Integer(last_idx));
+ last_idx++;
+ }
+ }
+
+ return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+ void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool
in_place = false) {
+ GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+ tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+ const auto& buffer_map = prim_func_->buffer_map;
+ const auto& tir_args = prim_func_->params;
+
+ const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+ Array<Expr> relax_results;
+ if (lhs_var->IsInstance<TupleNode>()) {
+ relax_results = Downcast<Tuple>(lhs_var)->fields;
+ } else {
+ CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be
either tuple or var";
+ relax_results = {Downcast<Var>(lhs_var)};
+ }
+
+ size_t num_inputs = relax_args.size();
+ size_t num_outputs = relax_results.size();
+
+ Array<Integer> output_idxs;
+ if (in_place) {
+ const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+ CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+ output_idxs = GetInplaceOutputIndices(attrs->inplace_indices,
num_inputs);
+ } else {
+ for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+ output_idxs.push_back(i);
+ }
+ }
+ for (size_t i = 0; i < tir_args.size(); ++i) {
+ const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
Review Comment:
This `Downcast<tir::Var>` is unnecessary, because `prim_func->params` is
already an array of `tir::Var`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]