SiriusNEO opened a new pull request, #14527:
URL: https://github.com/apache/tvm/pull/14527
### Intro
This PR registers gradient functions for many high-level Relax operators.
Similar with Relay, the gradient function is registered as an attribute
`FPrimalGradient` (OpAttr) of corresponding Relax operators. But the function
signature is different from Relay:
```
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(
const Var& orig_var, const Call& orig_call, const Var& output_grad,
const BlockBuilder& ctx)>;
```
- `orig_call` is the orginal call expr which we want to differentiate.
- `output_grad` is the gradient of RHS.
- `orig_var` is `y`. It is passed to saving some calculations.
- `ctx` is the context which is not used right now. But we believe it is
useful when it comes to dynamic shape cases and when we need to emit some
bindings or do some normalizations.
For some complicate gradient functions, we introduce some high-level
backward operators and put them under the namespace `op.grad.xxx`. All gradient
functions are well tested (numerically). For more details please check Part 2
of [this
document](https://github.com/mlc-ai/mlc-training/blob/main/tutorial/Reverse_Mode_Automatic_Differentiation_in_Relax.md).
### Others
Also this PR fixes two small problems about op:
- `CumsumAttrs` isn't declared in the Python side.
- A small problem in the implementation about legalizing op `variance`.
Co-authored-by: Yixin Dong <[email protected]>
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]