khaotik opened a new pull request, #21044:
URL: https://github.com/apache/incubator-mxnet/pull/21044
## Description ##
Fixes a bug where SymbolBlock does not copy user-given parameter attributes
such as `lr_mult`. Simple reproducer:
```python
#!/usr/bin/env python
import numpy as np
import mxnet as mx
import mxnet.symbol as mxs
DTYPE = np.float32
LR_MULT = 0.555
WD_MULT = 0.444
s_x = mxs.var('x', shape=(1,256,), dtype=DTYPE)
s_w = mxs.var('W', shape=(256,192), lr_mult=LR_MULT, wd_mult=WD_MULT,
dtype=DTYPE)
s_b = mxs.var('b', shape=(1,192,), dtype=DTYPE, init=mx.init.Zero())
s_y = mxs.linalg.gemm(s_x, s_w, s_b)
fn = mx.gluon.SymbolBlock([s_y], [s_x])
fn.initialize()
fn.forward(mx.nd.random_uniform(-1., 1., shape=(1,256), dtype=DTYPE))
param_di = fn.collect_params()
assert param_di['W'].lr_mult == LR_MULT
assert param_di['W'].wd_mult == WD_MULT
assert not param_di['b'].data().asnumpy().any()
```
## Checklist ##
### Essentials ###
- [x] PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL],
[FEATURE], [DOC], etc)
- [ ] Changes are complete (i.e. I finished coding on this PR)
- [x] All changes have test coverage
- [x] Code is well-documented
### Changes ###
- [x] SymbolBlock now copies user-given `wd_mult` `lr_mult` `init`
attributes from symbol.
- [ ] `stype` attribute
## Comments ##
- `stype` attribute is still not copied. There's no
`MXSymbolInferStorageType` C API. I think a complete fix would require some C
API mods.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]