andrei5055 opened a new pull request #18570:
URL: https://github.com/apache/incubator-mxnet/pull/18570
## Description ##
During training, MxNet reinitiates data stored in the instances of `NDArray`
class by making a lot of `NDArray` constructor's calls and assigning the newly
constructed objects to the existing ones.
First of all, it is not very effective, because
- (a) usually we call some `NDArray` constructor AND NDArray` copy
constructor
- (b) after each such call the NDArray destructor will be called.
Second, lots of such calls are done "under the hood" and are not effective
at all. For instance, the usage of
```cpp
std::vector<NDArray> inputs, outputs
```
in lambda expressions generates more `NDArray` constructor calls than it's
actually necessary:
- 3 times more for `PushFCompute(...)` and
- 1.5 time more for `PushFComputeEx(...)`, `PushOperator(...)`;
Third, it's hard to debug.
To fix this problem we propose to use
- the newly implemented `NDArray::Init(...), NDArray::ReInit(...)` methods
- `std::vector<NDArray *> inputs, outputs;` instead of `std::vector<NDArray>
inputs, outputs;`
Our experiments show that these changes reduce the number of `NDArray`
constructor calls by approx. 80%:
## Checklist ##
### Essentials ###
Please feel free to remove inapplicable items for your PR.
- [x] Changes are complete (i.e. I finished coding on this PR)
- [x] All changes have test coverage:
- [x] Code is well-documented:
- [x] To the best of my knowledge, examples are either not affected by this
change, or have been fixed to be compatible with this change
### Changes ###
- [x] The usage of
```
std::vector<NDArray> inputs, outputs
```
in `PushFCompute(...)` , `PushFComputeEx(...)`, `PushOperator(...)`
was replaced by
```cpp
std::vector<NDArray> inputs, outputs
```
- [x] The usage of macros, similar to the following ones, which are defined
in different ways for fifferent values of `MXNET_USE_MKLDNN` simplifies the
code
```cpp
#if MXNET_USE_MKLDNN == 1
#define INVALIDATE_OUTPUTS(outputs, req) InvalidateOutputs(&outputs, req)
#define INVALIDATE_OUTPUTS_COND(cond, outputs, req) if (cond)
INVALIDATE_OUTPUTS(outputs, req)
// add for mkldnn OP + no mkldnn OP
#define CREATE_DEFAULT_INPUTS(cond, attrs, func_call) \
if (cond) { \
const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN"); \
if (!is_mkldnn.get(attrs.op, false)) func_call; \
}
#else
#define INVALIDATE_OUTPUTS(outputs, ...) // empty macros
#define INVALIDATE_OUTPUTS_COND(outputs, ...) // empty macro
#define CREATE_DEFAULT_INPUTS(input, ...) // empty macro
#endif
```
- [x] Implementation of constructor:
```cpp
NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape,
Context ctx,
bool delay_alloc = true, int dtype = mshadow::default_type_flag,
std::vector<int> aux_types = {}, mxnet::ShapeVector aux_shapes =
{},
mxnet::TShape storage_shape = mxnet::TShape(mshadow::Shape1(0)));
```
was moved into newly implemented:
```cpp
void NDArray::ReInit(const NDArrayStorageType stype, const mxnet::TShape
&shape,
Context ctx, int dtype, bool delay_alloc, const std::vector<int>
*pAux_types,
const mxnet::ShapeVector *pAux_shapes, const mxnet::TShape *pStorage_shapes)
```
## Comments ##
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]