MoisesHer opened a new issue #19019:
URL: https://github.com/apache/incubator-mxnet/issues/19019
Automatic mixed precision (AMP) in MXNet is not reusing weights when these
are applied recursively, instead, it creates a new copy each time, i.e. when a
network iterates over the same layer, with same weights, AMP creates a new copy
of the weights at each iteration.
This lead to a heavier memory consumption when using AMP, producing out of
memory on GPU if the number of iterations is large enough.
### Error Message
```
Traceback (most recent call last):
File "decoder.py", line 114, in <module>
model.forward_backward(batch, ctx)
File "decoder.py", line 71, in forward_backward
output = self._myRNN(batch, ctx)
File "/opt/mxnet/python/mxnet/gluon/block.py", line 693, in __call__
out = self.forward(*args)
File "decoder.py", line 52, in forward
mx.nd.waitall()
File "/opt/mxnet/python/mxnet/ndarray/ndarray.py", line 200, in waitall
check_call(_LIB.MXNDArrayWaitAll())
File "/opt/mxnet/python/mxnet/base.py", line 255, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [17:47:42]
../src/storage/./pooled_storage_manager.h:188: Memory allocation failed out of
memory
Stack trace:
[bt] (0)
/usr/local/lib/libmxnet.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x50)
[0x7f96ada930b0]
[bt] (1)
/usr/local/lib/libmxnet.so(mxnet::storage::PooledStorageManager<mxnet::storage::RoundMultiple,
mxnet::storage::UnorderedMapContainer>::Alloc(mxnet::Storage::Handle*)+0x26e)
[0x7f96b06c0dde]
[bt] (2)
/usr/local/lib/libmxnet.so(mxnet::storage::StorageImpl::Alloc(mxnet::Storage::Handle*)+0x65)
[0x7f96b06b91a5]
[bt] (3) /usr/local/lib/libmxnet.so(mxnet::NDArray::CheckAndAlloc()
const+0x1cf) [0x7f96adaab5af]
[bt] (4) /usr/local/lib/libmxnet.so(+0x10e4708) [0x7f96adc9b708]
[bt] (5)
/usr/local/lib/libmxnet.so(mxnet::imperative::PushFCompute(std::function<void
(nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob,
std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType,
std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob,
std::allocator<mxnet::TBlob> > const&)> const&, nnvm::Op const*,
nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*,
std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*,
std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource,
std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*,
std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*,
std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int,
std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType,
std::allocator<mxnet::OpReqType> >
const&)::{lambda(mxnet::RunContext)#1}::operator()(mxnet::RunContext) const+0x1
74) [0x7f96adcb0434]
[bt] (6) /usr/local/lib/libmxnet.so(std::_Function_handler<void
(mxnet::RunContext), mxnet::imperative::PushFCompute(std::function<void
(nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::TBlob,
std::allocator<mxnet::TBlob> > const&, std::vector<mxnet::OpReqType,
std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::TBlob,
std::allocator<mxnet::TBlob> > const&)> const&, nnvm::Op const*,
nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*,
std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*,
std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource,
std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*,
std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*,
std::allocator<mxnet::NDArray*> > const&, std::vector<unsigned int,
std::allocator<unsigned int> > const&, std::vector<mxnet::OpReqType,
std::allocator<mxnet::OpReqType> > const&)::{lambda(mxnet::RunCon
text)#1}>::_M_invoke(std::_Any_data const&, mxnet::RunContext&&)+0x17)
[0x7f96adcb0907]
[bt] (7) /usr/local/lib/libmxnet.so(+0xff47f5) [0x7f96adbab7f5]
[bt] (8)
/usr/local/lib/libmxnet.so(mxnet::engine::ThreadedEngine::ExecuteOprBlock(mxnet::RunContext,
mxnet::engine::OprBlock*)+0x43d) [0x7f96adbb8a0d]
```
## To Reproduce
(If you developed your own code, please provide a short script that
reproduces the error. For existing examples, please provide link.)
```
import mxnet as mx
import numpy as np
from mxnet import gluon
from mxnet import autograd
from mxnet.gluon import HybridBlock, Block, nn, rnn, loss
from mxnet import symbol, ndarray
from mxnet.ndarray import NDArray
from mxnet.contrib import amp
import pynvml as nvml
import sys
batch_size = 32
dim_cell = 1024
recursive_steps = 300
nvml.nvmlInit()
handle = nvml.nvmlDeviceGetHandleByIndex(0)
class Cell(HybridBlock):
def __init__(self, **kwargs):
super(Cell, self).__init__(**kwargs)
self.nlayers = 8
self.cell_dim = dim_cell
with self.name_scope():
self._cell = rnn.HybridSequentialRNNCell()
for i in range(self.nlayers):
self._cell.add(rnn.LSTMCell(hidden_size=self.cell_dim))
def hybrid_forward(self, F, lstm_inpt, lstm_state):
lstm_out, lstm_state = self._cell(lstm_inpt, lstm_state)
return lstm_out, lstm_state
def begin_state(self, batch_size, ctx):
lstm_state = self._cell.begin_state(
batch_size=batch_size, func=mx.nd.zeros, ctx=ctx)
return (lstm_state)
class MyRNN(Block):
def __init__(self, **kwargs):
super(MyRNN, self).__init__(**kwargs)
with self.name_scope():
self._cell = Cell()
def forward(self, inp, ctx):
lstm_state = self._cell.begin_state(batch_size, ctx=ctx)
_steps = recursive_steps
prev_out = inp
for i in range(_steps):
print('---- step:', i)
new_out, new_state = self._cell.forward(prev_out, lstm_state)
prev_out = new_out
lstm_state = new_state
mx.nd.waitall()
info = nvml.nvmlDeviceGetMemoryInfo(handle)
free_mem = info.free
free_mem = free_mem / (1024*1024)
print(f"Free_mem(MB): {free_mem}")
sys.stdout.flush()
return new_out
class Model(Block):
def __init__(self,
**kwargs):
super(Model, self).__init__(**kwargs)
with self.name_scope():
self._myRNN = MyRNN()
self._l1_loss = loss.L1Loss()
def forward_backward(self, batch, ctx):
with autograd.record():
output = self._myRNN(batch, ctx)
actual = output * 0.02
loss = self._l1_loss(output, actual)
loss.backward()
if __name__ == "__main__":
AMP = True
ctx = mx.gpu()
model = Model()
model.initialize(mx.init.Normal(sigma=1.), ctx=ctx)
print(model.collect_params())
lr = 0.001
beta1 = 0.5
wd = 0.1
lr_decay_step = 1000
lr_decay_factor = 0.9
lr_stop_factor = 1e-9
clip_gradient = 1.0
lr_scheduler = mx.lr_scheduler.FactorScheduler(
step=lr_decay_step,
factor=lr_decay_factor,
stop_factor_lr=lr_stop_factor)
optimizer = mx.optimizer.Adam(learning_rate=lr,
clip_gradient=clip_gradient,
beta1=beta1,
lr_scheduler=lr_scheduler,
wd=wd)
trainer = gluon.Trainer(model.collect_params(),
optimizer, None)
if AMP:
amp.init()
amp.init_trainer(trainer)
max_steps = 1
step = 0
while step < max_steps:
# batch with random data
batch = mx.nd.uniform(shape=(batch_size, dim_cell), ctx=ctx)
batch.attach_grad()
model.forward_backward(batch, ctx)
trainer.step(batch_size)
step += 1
print('DONE')
```
### Steps to reproduce
(Paste the commands you ran that produced the error.)
1. python <script_name.py>
## Environment
Ubuntu OS, [MXNet 1.5 / MXNet 1.6 / MXNet 1.7] (similar issue in all
versions)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]