Ubospica opened a new pull request, #15231:
URL: https://github.com/apache/tvm/pull/15231
This PR supports registering te gradient functions. The Gradient pass will
call te gradient functions when meet call_tir nodes in forward code.
Current workflow is as follows:
```
# register te gradient function
@register_te_gradient("f_mul_grad")
def f_mul_grad(output_grad: te.Tensor, src1: te.Tensor, src2: te.Tensor, k:
int):
# returns a list of te tensors, representing gradients w.r.t. src1, src2
# k is a constant parameter
...
# irmodule definition
@I.ir_module
class Module:
@T.prim_func
def f_mul(A, B, result):
...
@R.function
def main(a, b):
cls = Module
with R.dataflow():
lv = R.call_tir(cls.f_mul, (a, b), te_grad_name="f_mul_grad",
te_grad_kwargs={"k": 1})
gv = R.output(lv)
return gv
```
It's worth to note this PR defines an attribute to the call_tir node:
```
struct CallTIRAttrs : public tvm::AttrsNode<CallTIRAttrs> {
Optional<String> te_grad_name;
Map<String, ObjectRef> te_grad_kwargs;
TVM_DECLARE_ATTRS(CallTIRAttrs, "relax.attrs.CallTIRAttrs") {
TVM_ATTR_FIELD(te_grad_name)
.describe("The name of the te gradient function associated with this
call_tir node.");
TVM_ATTR_FIELD(te_grad_kwargs)
.describe("The keyword arguments passed to the te gradient
function.");
}
}; // struct CallTIRAttrs
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]