Ubospica opened a new pull request, #15231:
URL: https://github.com/apache/tvm/pull/15231

   This PR supports registering te gradient functions. The Gradient pass will 
call te gradient functions when meet call_tir nodes in forward code.
   
   Current workflow is as follows:
   ```
   # register te gradient function
   @register_te_gradient("f_mul_grad")
   def f_mul_grad(output_grad: te.Tensor, src1: te.Tensor, src2: te.Tensor, k: 
int):
       # returns a list of te tensors, representing gradients w.r.t. src1, src2
       # k is a constant parameter
       ...
   
   # irmodule definition
   @I.ir_module
   class Module:
       @T.prim_func
       def f_mul(A, B, result):
           ...
   
       @R.function
       def main(a, b):
           cls = Module
           with R.dataflow():
               lv = R.call_tir(cls.f_mul, (a, b), te_grad_name="f_mul_grad", 
te_grad_kwargs={"k": 1})
               gv = R.output(lv)
           return gv
   ```
   
   It's worth to note this PR defines an attribute to the call_tir node:
   ```
   struct CallTIRAttrs : public tvm::AttrsNode<CallTIRAttrs> {
     Optional<String> te_grad_name;
     Map<String, ObjectRef> te_grad_kwargs;
   
     TVM_DECLARE_ATTRS(CallTIRAttrs, "relax.attrs.CallTIRAttrs") {
       TVM_ATTR_FIELD(te_grad_name)
           .describe("The name of the te gradient function associated with this 
call_tir node.");
       TVM_ATTR_FIELD(te_grad_kwargs)
           .describe("The keyword arguments passed to the te gradient 
function.");
     }
   };  // struct CallTIRAttrs
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to