mbs-octoml commented on a change in pull request #8597:
URL: https://github.com/apache/tvm/pull/8597#discussion_r684529339



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -408,27 +462,135 @@ class LowerTensorExpr : public ExprMutator {
     }
 
     tir_call_attrs->metadata.Set("relay_attrs", func->attrs);
+    tir_call_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
+
+    if (IsDynamic(func->ret_type)) {
+      // Also lower the dynamic shape function.
+      // Shape function keys use the underlying primitive function as their 
'function',
+      // but the generic 'cpu' target as the target since all shape functions 
run
+      // on the host cpu irrespective of where the primitive runs.
+      // TODO(mbs): Cleanup target handling.
+      Target shape_target("llvm");
+      CCacheKey shape_key(func, shape_target);
+      CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key);
+      // Capture the shape function's global var and parameters 'states' in 
call
+      // annotations so calling convention can be recovered.
+      // TODO(mbs): Capture all this as part of a 'call into TIR' construct 
once available.
+      // The way the shape function calling convention is derived and passed 
to call sites
+      // via the 'parameter states' could be improved.
+      tir_call_attrs->metadata.Set("prim_shape_fn_var", 
lowered_shape_func->prim_fn_var);

Review comment:
       Yeah, I think this should be a struct inside the 'call-tir' op.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to