quic-sanirudh commented on code in PR #17075:
URL: https://github.com/apache/tvm/pull/17075#discussion_r1635077626


##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
 
 namespace relax {
 
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& 
inplace_indices,
+                                              int num_inputs) {
+  Array<Integer> ret;
+  int last_idx = num_inputs;
+  for (auto idx : inplace_indices) {
+    int i = idx.IntValue();
+    if (i >= 0) {
+      ret.push_back(Integer(i));
+    } else {
+      ret.push_back(Integer(last_idx));
+      last_idx++;
+    }
+  }
+
+  return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+  void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool 
in_place = false) {
+    GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+    tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+    const auto& buffer_map = prim_func_->buffer_map;
+    const auto& tir_args = prim_func_->params;
+
+    const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+    Array<Expr> relax_results;
+    if (lhs_var->IsInstance<TupleNode>()) {
+      relax_results = Downcast<Tuple>(lhs_var)->fields;
+    } else {
+      CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be 
either tuple or var";
+      relax_results = {Downcast<Var>(lhs_var)};
+    }
+
+    size_t num_inputs = relax_args.size();
+    size_t num_outputs = relax_results.size();
+
+    Array<Integer> output_idxs;
+    if (in_place) {
+      const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+      CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+      output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, 
num_inputs);
+    } else {
+      for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+        output_idxs.push_back(i);
+      }
+    }
+    for (size_t i = 0; i < tir_args.size(); ++i) {
+      const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
+      if (i < num_inputs) {
+        const auto& relax_var = Downcast<Var>(relax_args[i]);
+        relax_to_tir_var_map_.Set(relax_var, buffer_map[tir_var]);
+      }
+      if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i); it 
!= output_idxs.end()) {
+        int result_idx = it - output_idxs.begin();
+        const auto& inplace_out_var = Downcast<Var>(relax_results[result_idx]);
+        relax_to_tir_var_map_.Set(inplace_out_var, buffer_map[tir_var]);
+      }
+    }
+  }
+
+ public:
+  explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {}
+  static Map<Var, tir::Buffer> Collect(const IRModule& mod, const Function& 
func) {
+    RelaxToTIRVarMapCollector visitor(mod);
+    visitor(func->body);
+    return visitor.relax_to_tir_var_map_;
+  }
+  void VisitBinding_(const VarBindingNode* binding) final {
+    const auto& lhs_var = binding->var;

Review Comment:
   I've added the visit, thanks.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to