This is an automated email from the ASF dual-hosted git repository.
marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new a7ea976 Static alloc for hybridblock (#11320)
a7ea976 is described below
commit a7ea97651edcb87668541efe36b35a87b7b2b50e
Author: Eric Junyuan Xie <[email protected]>
AuthorDate: Wed Jun 20 14:08:24 2018 -0700
Static alloc for hybridblock (#11320)
* Revert "Revert "Static alloc for hybridblock (#11313)" (#11318)"
This reverts commit 9b8eb567273d9521c639ea783e1117fe09cb3d64.
* fix
---
include/mxnet/c_api.h | 5 -
include/mxnet/imperative.h | 89 ----
include/mxnet/ndarray.h | 8 +
include/mxnet/op_attr_types.h | 33 +-
python/mxnet/_ctypes/ndarray.py | 16 +-
python/mxnet/gluon/block.py | 70 +--
src/c_api/c_api_ndarray.cc | 26 +-
src/engine/threaded_engine.cc | 3 +-
src/executor/attach_op_execs_pass.cc | 165 +++----
src/executor/attach_op_resource_pass.cc | 16 +-
src/executor/exec_pass.h | 28 +-
src/executor/graph_executor.cc | 2 +-
src/imperative/cached_op.cc | 754 ++++++++++++++++++++++++++------
src/imperative/cached_op.h | 174 ++++++++
src/imperative/imperative.cc | 90 +---
src/imperative/imperative_utils.cc | 120 +++++
src/imperative/imperative_utils.h | 256 +++++++++--
tests/python/unittest/test_gluon.py | 67 ++-
18 files changed, 1399 insertions(+), 523 deletions(-)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 55c26bc..4dd858a 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -987,11 +987,6 @@ MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
int num_flags,
const char** keys,
const char** vals,
- int num_inputs,
- const char** input_names,
- int num_params,
- const char** param_names,
- NDArrayHandle* params,
CachedOpHandle *out);
/*!
* \brief free cached operator
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 758ce85..7ea60df 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -35,23 +35,6 @@
#include "./ndarray.h"
namespace mxnet {
-/*! \brief CachedOp Parameters */
-struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
- uint32_t inline_limit;
- uint32_t forward_bulk_size;
- uint32_t backward_bulk_size;
- DMLC_DECLARE_PARAMETER(CachedOpConfig) {
- DMLC_DECLARE_FIELD(inline_limit)
- .set_default(2)
- .describe("Maximum number of operators that can be inlined.");
- DMLC_DECLARE_FIELD(forward_bulk_size)
- .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
- .describe("Segment size of bulk execution during forward pass.");
- DMLC_DECLARE_FIELD(backward_bulk_size)
- .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
- .describe("Segment size of bulk execution during backward pass.");
- }
-};
/*! \brief runtime functions for NDArray */
class Imperative {
public:
@@ -94,67 +77,6 @@ class Imperative {
&& info.out_grads.size() == 1;
}
};
- class CachedOp {
- public:
- CachedOp(
- const nnvm::Symbol& sym,
- const std::vector<std::pair<std::string, std::string> >& flags,
- const std::vector<std::string> arg_names,
- const std::unordered_map<std::string, std::vector<NDArray> >& params);
- uint32_t num_inputs() {
- return fwd_graph_.indexed_graph().input_nodes().size();
- }
- uint32_t num_outputs() {
- return fwd_graph_.outputs.size();
- }
- uint32_t num_backward_inputs() {
- return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
- }
- std::vector<bool>& save_inputs() {
- return save_inputs_;
- }
- std::vector<bool>& save_outputs() {
- return save_outputs_;
- }
- const std::unordered_set<uint32_t>& mutable_input_nodes() {
- return fwd_graph_.indexed_graph().mutable_input_nodes();
- }
- nnvm::Graph GetForwardGraph(const bool recording,
- const std::vector<NDArray*>& inputs);
- nnvm::Graph GetBackwardGraph(const OpStatePtr& state,
- const std::vector<OpReqType>& reqs,
- const std::vector<NDArray*>& inputs);
- std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
- const std::vector<nnvm::NodeEntry>&
ograds);
- void Forward(const std::shared_ptr<CachedOp>& op_ptr,
- const std::vector<NDArray*>& args,
- const std::vector<NDArray*>& outputs);
- void Backward(const bool retain_graph,
- const OpStatePtr& state,
- const std::vector<NDArray*>& inputs,
- const std::vector<OpReqType>& reqs,
- const std::vector<NDArray*>& outputs);
-
- private:
- struct CachedOpState {
- std::vector<NDArray> buff;
- std::vector<OpStatePtr> states;
- };
- std::mutex mutex_;
- CachedOpConfig config_;
- nnvm::Graph fwd_graph_;
- nnvm::Graph grad_graph_;
- nnvm::Graph full_graph_;
- std::unordered_map<Context, std::vector<NDArray> > params_;
- bool inlining_;
- std::vector<nnvm::NodeEntry> ograd_entries_;
- std::vector<bool> curr_grad_req_;
- std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
- std::vector<uint32_t> fwd_args_idx_;
- std::vector<uint32_t> fwd_params_idx_;
- std::vector<uint32_t> bwd_input_eid_;
- std::vector<bool> save_inputs_, save_outputs_;
- };
/*! \brief whether operator recording is on. */
bool is_training() const {
return is_train_;
@@ -222,15 +144,6 @@ class Imperative {
uint32_t num_inputs, uint32_t num_outputs,
std::vector<bool> *p_save_inputs,
std::vector<bool> *p_save_outputs);
- void RunGraph(
- const bool retain_graph,
- const nnvm::IndexedGraph& idx,
- const std::vector<NDArray*> arrays,
- size_t node_start, size_t node_end,
- std::vector<OpReqType>&& array_reqs,
- std::vector<uint32_t>&& ref_count,
- std::vector<OpStatePtr> *p_states,
- const DispatchModeVector& dispatch_modes);
/*! \brief indicate whether is training. */
#if DMLC_CXX11_THREAD_LOCAL
static thread_local bool is_train_;
@@ -247,7 +160,5 @@ class Imperative {
int backward_bulk_size_{0};
};
-using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;
-
} // namespace mxnet
#endif // MXNET_IMPERATIVE_H_
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index e243eb7..ae96fd8 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -155,6 +155,14 @@ class NDArray {
return byte_offset_ > 0 || shape() != ptr_->storage_shape;
}
+ /* \brief Check whether the two arrays are the same array */
+ inline bool IsSame(const NDArray& other) {
+ return ptr_ == other.ptr_ &&
+ shape_ == other.shape_ &&
+ byte_offset_ == other.byte_offset_ &&
+ dtype_ == other.dtype_;
+ }
+
/*!
* \return the shape of current NDArray.
*/
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index 3969d84..f4694ef 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -126,25 +126,36 @@ class OpStatePtr {
template<typename T, typename... Args>
static OpStatePtr Create(Args&&... args) {
OpStatePtr ret;
- ret.ptr_ = std::make_shared<OpState>();
- ret.ptr_->var_ = Engine::Get()->NewVariable();
- ret.ptr_->state_.construct<T>(std::forward<Args>(args)...);
+ auto state = new T(std::forward<Args>(args)...);
+ auto var = Engine::Get()->NewVariable();
+ ret.ptr_.reset(
+ new OpState(var, state),
+ [](OpState* p) {
+ Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(),
p->var);
+ delete reinterpret_cast<T*>(p->state);
+ delete p;
+ });
return ret;
}
/* \brief Get engine variable associated with this state */
engine::VarHandle get_var() const {
- return ptr_->var_;
+ return ptr_->var;
}
/* \brief Get state of type T */
template<typename T>
T& get_state() const {
- return dmlc::get<T>(ptr_->state_);
+ return *reinterpret_cast<T*>(ptr_->state);
}
/* \brief clear state */
void reset() {
ptr_.reset();
}
+ /* \brief checks whether the managed object is managed only by the current
+ OpStatePtr instance */
+ bool unique() const {
+ return ptr_.unique();
+ }
/* \brief Whether state is empty */
explicit operator bool() const {
return ptr_ ? true : false;
@@ -153,16 +164,12 @@ class OpStatePtr {
private:
/* \brief state structure */
struct OpState {
- OpState() {}
+ engine::VarHandle var;
+ void* state;
+
+ OpState(engine::VarHandle var_, void* state_) : var(var_), state(state_) {}
OpState(const OpState& other) = delete;
OpState& operator=(const OpState& other) = delete;
-
- ~OpState() {
- Engine::Get()->DeleteVariable([](RunContext s) {}, Context::CPU(), var_);
- }
-
- engine::VarHandle var_;
- dmlc::any state_;
};
/* \brief shared pointer to state */
std::shared_ptr<OpState> ptr_;
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index d2cae0c..f324545 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -105,28 +105,14 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle"]
- def __init__(self, sym, flags=(), inputs=None, params=None):
+ def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()
- param_names = []
- param_arrays = []
- if inputs is None:
- assert params is None, "When inputs is None params must also be
None."
- inputs = sym.list_inputs()
- elif params is not None:
- for name, arrs in params.items():
- param_arrays.extend(arrs)
- param_names.extend([name] * len(arrs))
check_call(_LIB.MXCreateCachedOpEx(
sym.handle,
len(flags),
c_str_array([key for key, _ in flags]),
c_str_array([str(val) for _, val in flags]),
- len(inputs),
- c_str_array(inputs),
- len(param_names),
- c_str_array(param_names),
- c_handle_array(param_arrays),
ctypes.byref(self.handle)))
def __del__(self):
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 3b97c05..0845669 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -502,8 +502,12 @@ class Block(object):
----------
active : bool, default True
Whether to turn hybrid on or off.
- **kwargs : string
- Additional flags for hybridized operator.
+ static_alloc : bool, default False
+ Statically allocate memory to improve speed. Memory usage may
increase.
+ static_shape : bool, default False
+ Optimize for invariant input shapes between iterations. Must also
+ set static_alloc to True. Change of input shapes is still allowed
+ but slower.
"""
for cld in self._children.values():
cld.hybridize(active, **kwargs)
@@ -696,7 +700,7 @@ class HybridBlock(Block):
self._out_format = None
self._in_format = None
self._active = False
- self._flags = {}
+ self._flags = []
def __setattr__(self, name, value):
"""Registers parameters."""
@@ -723,39 +727,43 @@ class HybridBlock(Block):
return self._cached_graph
def _build_cache(self, *args):
- inputs, out = self._get_graph(*args)
- input_names = [i.name for i in inputs]
-
+ data, out = self._get_graph(*args)
+ data_names = {data.name : i for i, data in enumerate(data)}
params = self.collect_params()
+ input_names = out.list_inputs()
+
param_names = set(params.keys())
- expected_names = set(out.list_inputs())
+ expected_names = set(input_names)
for name in expected_names:
- assert name in param_names or name in input_names, \
+ assert name in param_names or name in data_names, \
"Unknown input to HybridBlock: %s"%name
- used_input_names = [i for i in input_names if i in expected_names]
- if len(used_input_names) != len(input_names):
- unused = ', '.join(['%d-th'%i for i, name in enumerate(input_names)
+ used_data_names = [i for i in data_names if i in expected_names]
+ if len(used_data_names) != len(data_names):
+ unused = ', '.join(['%d-th'%i for name, i in data_names.items()
if name not in expected_names])
warnings.warn("The %s input to HybridBlock is not used by any "
"computation. Is this intended?"%unused,
stacklevel=4)
- used_param_names = set(i for i in param_names if i in expected_names)
+ used_param_names = [i for i in param_names if i in expected_names]
if len(used_param_names) != len(param_names):
- unused = ', '.join(list(param_names - used_param_names))
+ unused = ', '.join(list(param_names - set(used_param_names)))
warnings.warn("Parameter %s is not used by any computation. "
"Is this intended?"%unused, stacklevel=4)
- used_params = {k: params[k] for k in used_param_names}
- try:
- param_dict = {k: v.list_data() for k, v in used_params.items()}
- except DeferredInitializationError:
- self._deferred_infer_shape(*args)
- for i in used_params.values():
- i._finish_deferred_init()
- param_dict = {k: v.list_data() for k, v in used_params.items()}
-
- self._cached_op = ndarray.CachedOp(out, self._flags, input_names,
param_dict)
+ data_indices = []
+ param_indices = []
+ self._cached_op_args = []
+ for i, name in enumerate(input_names):
+ if name in data_names:
+ data_indices.append(i)
+ self._cached_op_args.append((True, data_names[name]))
+ else:
+ param_indices.append(i)
+ self._cached_op_args.append((False, params[name]))
+ flags = [('data_indices', data_indices), ('param_indices',
param_indices)] + \
+ self._flags
+ self._cached_op = ndarray.CachedOp(out, flags)
def _deferred_infer_shape(self, *args):
try:
@@ -771,7 +779,19 @@ class HybridBlock(Block):
args, fmt = _flatten(args, "input")
assert fmt == self._in_format, "Invalid input format"
- out = self._cached_op(*args)
+ try:
+ cargs = [args[i] if is_arg else i.data()
+ for is_arg, i in self._cached_op_args]
+ except DeferredInitializationError:
+ self._deferred_infer_shape(*args)
+ cargs = []
+ for is_arg, i in self._cached_op_args:
+ if is_arg:
+ cargs.append(args[i])
+ else:
+ i._finish_deferred_init()
+ cargs.append(i.data())
+ out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
return _regroup(out, self._out_format)[0]
@@ -792,7 +812,7 @@ class HybridBlock(Block):
def hybridize(self, active=True, **kwargs):
self._active = active
- self._flags = kwargs.items()
+ self._flags = list(kwargs.items())
self._clear_cached_op()
if active and self._forward_hooks or self._forward_pre_hooks:
warnings.warn('"{}" is being hybridized while still having forward
hook/pre-hook. '
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 9aabe04..34bd4b2 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -36,6 +36,7 @@
#include "../common/utils.h"
#include "../common/exec_utils.h"
#include "../imperative/imperative_utils.h"
+#include "../imperative/cached_op.h"
using namespace mxnet;
@@ -160,12 +161,8 @@ int MXCreateCachedOp(SymbolHandle handle,
std::vector<std::string> input_names;
input_names.reserve(inputs.size());
for (const auto& i : inputs) input_names.push_back(i->attrs.name);
- *out = new std::shared_ptr<Imperative::CachedOp>(
- new Imperative::CachedOp(
- *sym,
- std::vector<std::pair<std::string, std::string> >(),
- input_names,
- std::unordered_map<std::string, std::vector<NDArray> >()));
+ *out = new CachedOpPtr(new CachedOp(
+ *sym, std::vector<std::pair<std::string, std::string> >()));
API_END();
}
@@ -173,11 +170,6 @@ int MXCreateCachedOpEx(SymbolHandle handle,
int num_flags,
const char** keys,
const char** vals,
- int num_args,
- const char** arg_names,
- int num_params,
- const char** param_names,
- NDArrayHandle* params,
CachedOpHandle *out) {
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle);
@@ -186,17 +178,7 @@ int MXCreateCachedOpEx(SymbolHandle handle,
for (int i = 0; i < num_flags; ++i) {
flags.push_back({keys[i], vals[i]});
}
- std::vector<std::string> args;
- for (int i = 0; i < num_args; ++i) {
- args.push_back(arg_names[i]);
- }
- std::unordered_map<std::string, std::vector<NDArray> > param_dict;
- for (int i = 0; i < num_params; ++i) {
- param_dict[param_names[i]].emplace_back(
- *reinterpret_cast<NDArray*>(params[i]));
- }
- *out = new std::shared_ptr<Imperative::CachedOp>(
- new Imperative::CachedOp(*sym, flags, args, param_dict));
+ *out = new CachedOpPtr(new CachedOp(*sym, flags));
API_END();
}
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index dc0436e..e70cc19 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -278,6 +278,8 @@ void ThreadedEngine::DeleteOperator(OprHandle op) {
}
void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool
profiling) {
+ BulkFlush();
+
ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op);
OprBlock* opr_block = OprBlock::New();
opr_block->opr = threaded_opr;
@@ -323,7 +325,6 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
<< device_count_;
}
#endif
- BulkFlush();
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars,
prop, opr_name, wait);
opr->temporary = true;
const bool profiling =
profiler_->IsProfiling(profiler::Profiler::kImperative);
diff --git a/src/executor/attach_op_execs_pass.cc
b/src/executor/attach_op_execs_pass.cc
index 697e486..72919d9 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -134,6 +134,10 @@ class StatefulComputeExecutor : public
StorageFallbackOpExecutor {
return state_.get_var();
}
+ OpStatePtr state() const override {
+ return state_;
+ }
+
explicit StatefulComputeExecutor(const OpStatePtr& state,
const FStatefulCompute& fcompute,
ExecType exec_type,
@@ -142,7 +146,6 @@ class StatefulComputeExecutor : public
StorageFallbackOpExecutor {
state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
private:
- friend Graph AttachOpExecs(Graph g);
OpStatePtr state_;
FStatefulCompute fcompute_;
ExecType exec_type_;
@@ -170,13 +173,16 @@ class StatefulComputeExExecutor : public OpExecutor {
return state_.get_var();
}
+ OpStatePtr state() const override {
+ return state_;
+ }
+
explicit StatefulComputeExExecutor(const OpStatePtr& state,
const FStatefulComputeEx& fcompute,
ExecType exec_type)
: state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
private:
- friend Graph AttachOpExecs(Graph g);
OpStatePtr state_;
FStatefulComputeEx fcompute_;
ExecType exec_type_;
@@ -241,16 +247,15 @@ class FComputeExExecutor : public OpExecutor {
ExecType exec_type_;
};
-// pass to attach operator executors
-Graph AttachOpExecs(Graph g) {
+void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
using nnvm::DTypeVector;
using nnvm::ShapeVector;
using nnvm::FMutateInputs;
- auto& fcreate_op_state = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
- auto& fmutate_inputs = nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
- auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
- auto& is_layer_backward = nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
+ static auto& fcreate_op_state =
nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+ static auto& fmutate_inputs =
nnvm::Op::GetAttr<FMutateInputs>("FMutateInputs");
+ static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
+ static auto& is_layer_backward =
nnvm::Op::GetAttr<bool>("TIsLayerOpBackward");
const auto& vdtype = g.GetAttr<DTypeVector>("dtype");
const auto& vshape = g.GetAttr<ShapeVector>("shape");
@@ -259,81 +264,87 @@ Graph AttachOpExecs(Graph g) {
// get the graph
const auto& idx = g.indexed_graph();
- std::vector<std::shared_ptr<OpExecutor> > ret(idx.num_nodes());
+ OpExecVector& ret = *p_ret;
// initialize the nodes
- for (size_t i = 0; i < idx.num_nodes(); ++i) {
- const auto& inode = idx[i];
- if (inode.source->is_variable()) continue;
- const nnvm::Op *op = inode.source->op();
- ExecType exec_type = ExecType::kSync;
- std::vector<uint32_t> mutate_index;
- if (fmutate_inputs.count(op)) {
- mutate_index = fmutate_inputs[op](inode.source->attrs);
- }
- if (fexec_type.count(op)) {
- exec_type = fexec_type[op](inode.source->attrs);
+ const auto& inode = idx[i];
+ if (inode.source->is_variable()) return;
+ const nnvm::Op *op = inode.source->op();
+ ExecType exec_type = ExecType::kSync;
+ std::vector<uint32_t> mutate_index;
+ if (fmutate_inputs.count(op)) {
+ mutate_index = fmutate_inputs[op](inode.source->attrs);
+ }
+ if (fexec_type.count(op)) {
+ exec_type = fexec_type[op](inode.source->attrs);
+ }
+ CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
+ if (fcreate_op_state.count(op)) {
+ std::vector<TShape> ishape;
+ std::vector<int> itype;
+ for (const auto& e : inode.inputs) {
+ ishape.emplace_back(vshape[idx.entry_id(e)]);
+ itype.emplace_back(vdtype[idx.entry_id(e)]);
}
- CHECK(dispatch_modes[i] != DispatchMode::kUndefined);
- if (fcreate_op_state.count(op)) {
- std::vector<TShape> ishape;
- std::vector<int> itype;
- for (const auto& e : inode.inputs) {
- ishape.emplace_back(vshape[idx.entry_id(e)]);
- itype.emplace_back(vdtype[idx.entry_id(e)]);
- }
- OpStatePtr state = fcreate_op_state[op](
- inode.source->attrs, vctx[i], ishape, itype);
- FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
- op, "FStatefulComputeEx", vctx[i]);
- // FStatefulComputeEx is dispatched only when dispatch_mode is
DispatchMode::kFComputeEx
- if (fcompute_ex != nullptr && dispatch_modes[i] ==
DispatchMode::kFComputeEx) {
- ret[i] = std::make_shared<StatefulComputeExExecutor>(state,
fcompute_ex, exec_type);
- } else {
- FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
- op, "FStatefulCompute", vctx[i]);
- CHECK(fcompute != nullptr)
- << "One of FStatefulCompute and FStatefulComputeEx must be
registered "
- << "for stateful operator " << op->name;
- ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
- exec_type,
mutate_index);
- }
- } else if (is_layer_backward.get(op, false)) {
- CHECK_GE(inode.control_deps.size(), 1);
- uint32_t fwd_id = inode.control_deps[0];
- CHECK(vctx[fwd_id] == vctx[i]);
- CHECK(ret[fwd_id] != nullptr);
- FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
- op, "FStatefulComputeEx", vctx[i]);
- // FStatefulComputeEx is dispatched only when dispatch_mode is
DispatchMode::kFComputeEx
- if (fcompute_ex != nullptr && dispatch_modes[i] ==
DispatchMode::kFComputeEx) {
- ret[i] = std::make_shared<StatefulComputeExExecutor>(
-
dynamic_cast<StatefulComputeExExecutor*>(ret[fwd_id].get())->state_,
- fcompute_ex, exec_type);
- } else {
- FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
- op, "FStatefulCompute", vctx[i]);
- CHECK(fcompute != nullptr)
- << "One of FStatefulCompute and FStatefulComputeEx must be
registered "
- << "for stateful operator " << op->name;
- ret[i] = std::make_shared<StatefulComputeExecutor>(
- dynamic_cast<StatefulComputeExecutor*>(ret[fwd_id].get())->state_,
- fcompute, exec_type, mutate_index);
- }
+ OpStatePtr state = fcreate_op_state[op](
+ inode.source->attrs, vctx[i], ishape, itype);
+ FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
+ op, "FStatefulComputeEx", vctx[i]);
+ // FStatefulComputeEx is dispatched only when dispatch_mode is
DispatchMode::kFComputeEx
+ if (fcompute_ex != nullptr && dispatch_modes[i] ==
DispatchMode::kFComputeEx) {
+ ret[i] = std::make_shared<StatefulComputeExExecutor>(state, fcompute_ex,
exec_type);
} else {
- FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute",
vctx[i]);
- FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx",
vctx[i]);
- if (fcomp_ex != nullptr && dispatch_modes[i] ==
DispatchMode::kFComputeEx) {
- ret[i] = std::make_shared<FComputeExExecutor>(
- inode.source->attrs, fcomp_ex, exec_type);
- } else if (fcompute != nullptr) {
- ret[i] = std::make_shared<FComputeExecutor>(
- inode.source->attrs, fcompute, exec_type, mutate_index);
- } else {
- LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
- }
+ FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
+ op, "FStatefulCompute", vctx[i]);
+ CHECK(fcompute != nullptr)
+ << "One of FStatefulCompute and FStatefulComputeEx must be
registered "
+ << "for stateful operator " << op->name;
+ ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
+ exec_type,
mutate_index);
+ }
+ } else if (is_layer_backward.get(op, false)) {
+ CHECK_GE(inode.control_deps.size(), 1);
+ uint32_t fwd_id = inode.control_deps[0];
+ CHECK(vctx[fwd_id] == vctx[i]);
+ CHECK(ret[fwd_id] != nullptr);
+ FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
+ op, "FStatefulComputeEx", vctx[i]);
+ // FStatefulComputeEx is dispatched only when dispatch_mode is
DispatchMode::kFComputeEx
+ if (fcompute_ex != nullptr && dispatch_modes[i] ==
DispatchMode::kFComputeEx) {
+ ret[i] = std::make_shared<StatefulComputeExExecutor>(
+ ret[fwd_id].get()->state(), fcompute_ex, exec_type);
+ } else {
+ FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
+ op, "FStatefulCompute", vctx[i]);
+ CHECK(fcompute != nullptr)
+ << "One of FStatefulCompute and FStatefulComputeEx must be
registered "
+ << "for stateful operator " << op->name;
+ ret[i] = std::make_shared<StatefulComputeExecutor>(
+ ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index);
}
+ } else {
+ FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
+ FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx",
vctx[i]);
+ if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx)
{
+ ret[i] = std::make_shared<FComputeExExecutor>(
+ inode.source->attrs, fcomp_ex, exec_type);
+ } else if (fcompute != nullptr) {
+ ret[i] = std::make_shared<FComputeExecutor>(
+ inode.source->attrs, fcompute, exec_type, mutate_index);
+ } else {
+ LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
+ }
+ }
+}
+
+
+// pass to attach operator executors
+Graph AttachOpExecs(Graph g) {
+ const auto& idx = g.indexed_graph();
+ OpExecVector ret(idx.num_nodes());
+ for (size_t i = 0; i < idx.num_nodes(); ++i) {
+ CreateOpExecs(g, &ret, i);
}
g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
return g;
diff --git a/src/executor/attach_op_resource_pass.cc
b/src/executor/attach_op_resource_pass.cc
index 6818662..56122cd 100644
--- a/src/executor/attach_op_resource_pass.cc
+++ b/src/executor/attach_op_resource_pass.cc
@@ -30,12 +30,15 @@
namespace mxnet {
namespace exec {
-Graph AttachOpResources(Graph g) {
+void AttachOpResources(
+ const Graph& g,
+ const OpExecVector& op_execs,
+ size_t start_nid,
+ size_t end_nid) {
static auto& fresource =
nnvm::Op::GetAttr<FResourceRequest>("FResourceRequest");
static auto& fresource_ex =
nnvm::Op::GetAttr<FResourceRequestEx>("FResourceRequestEx");
- auto& op_execs = nnvm::get<OpExecVector>(*g.attrs.at("op_execs"));
const auto& vctx = g.GetAttr<ContextVector>("context");
const auto& vdispatch = g.GetAttr<DispatchModeVector>("dispatch_mode");
const auto& dev_masks = g.GetAttr<DevMaskVector>("dev_mask");
@@ -43,7 +46,7 @@ Graph AttachOpResources(Graph g) {
// Use global resource pool for each executor for now.
std::map<Context, Resource> cached_temp;
// Resource allocation
- for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
+ for (uint32_t nid = start_nid; nid < end_nid; ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
const Context &ctx = vctx[nid];
@@ -84,7 +87,12 @@ Graph AttachOpResources(Graph g) {
requested.push_back(ResourceManager::Get()->Request(ctx,
ResourceRequest::kTempSpace));
}
}
- return g;
}
+
+void AttachOpResources(const Graph& g) {
+ const auto& op_execs = g.GetAttr<OpExecVector>("op_execs");
+ AttachOpResources(g, op_execs, 0, g.indexed_graph().num_nodes());
+}
+
} // namespace exec
} // namespace mxnet
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index 99b1b16..26a2491 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -82,6 +82,10 @@ class OpExecutor {
virtual engine::VarHandle var() const {
return nullptr;
}
+ /*! \return return operator state */
+ virtual OpStatePtr state() const {
+ return OpStatePtr();
+ }
};
/*!
@@ -103,6 +107,14 @@ using ContextVector = std::vector<Context>;
using DevMaskVector = std::vector<int>;
/*!
+ * \brief create OpExecutor for a node in graph
+ *
+ * \param g input graph
+ * \param p_ret OpExecVector for input and output
+ * \param i the id of the node
+ */
+void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i);
+/*!
* \brief Attach OpExecutor to the graph attributes.
*
* \param g input graph
@@ -115,12 +127,20 @@ Graph AttachOpExecs(Graph g);
* \brief Attach Resource to the OpExecVector of the graph.
*
* \param g input graph need to contain op_exec attribute.
+ */
+void AttachOpResources(const Graph& g);
+/*!
+ * \brief Attach Resource to the OpExecVector
*
- * \return graph with new attribute "op_exec" of type OpExecVector
- * The fields on the OpExecVector are not yet been setup.
+ * \param g input graph
+ * \param op_execs OpExecutor vector
+ * \param start_nid starting node id
+ * \param end_nid end node id
*/
-Graph AttachOpResources(Graph g);
-
+void AttachOpResources(const Graph& g,
+ const OpExecVector& op_execs,
+ size_t start_nid,
+ size_t end_nid);
/*!
* \brief Discover chance of inplace addto operators.
* i.e. z = plus(z, source_op), and encourage it to become z += source_op.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index e28867d..831b5f9 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -912,7 +912,7 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
}
g = AttachOpExecs(g);
- g = AttachOpResources(g);
+ AttachOpResources(g);
graph_ = std::move(g);
if (shared_exec != nullptr) {
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 140b5a5..c0e5e83 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -19,16 +19,78 @@
#include <unordered_set>
#include <iostream>
#include "./imperative_utils.h"
+#include "./cached_op.h"
+#include "../executor/exec_pass.h"
+#include "../profiler/profiler.h"
+
namespace mxnet {
DMLC_REGISTER_PARAMETER(CachedOpConfig);
-Imperative::CachedOp::CachedOp(
+struct CachedOp::GraphInfo {
+ nnvm::Graph fwd_graph;
+ nnvm::Graph full_graph;
+ std::vector<OpReqType> bwd_output_reqs;
+ std::vector<uint32_t> bwd_input_eid;
+};
+
+struct CachedOp::DynamicRuntime {
+ GraphInfo info;
+ std::vector<NDArray> buff;
+ std::vector<OpStatePtr> op_states;
+};
+
+struct CachedOp::CachedOpState {
+ CachedOpState(const Context& context_,
+ const nnvm::Graph& fwd_graph_,
+ const nnvm::Graph& full_graph_) {
+ context = context_;
+ info.fwd_graph = fwd_graph_;
+ info.full_graph = full_graph_;
+
+ size_t max_nodes = info.full_graph.indexed_graph().num_nodes();
+ size_t max_entries = info.full_graph.indexed_graph().num_node_entries();
+ info.fwd_graph.attrs["context"] = std::make_shared<dmlc::any>(
+ std::vector<Context>(info.fwd_graph.indexed_graph().num_nodes(),
context));
+ info.full_graph.attrs["context"] = std::make_shared<dmlc::any>(
+ std::vector<Context>(max_nodes, context));
+
+ buff.resize(max_entries);
+ arrays.resize(max_entries);
+ array_reqs.resize(max_entries);
+ dynamic_entries.resize(max_entries, false);
+ op_states.resize(max_nodes);
+ execs.resize(max_nodes);
+ opr_segs.resize(max_nodes);
+ }
+
+ std::mutex mutex;
+ Context context;
+ GraphInfo info;
+
+ bool recording = false;
+ bool fwd_alloc = false;
+ bool bwd_alloc = false;
+ bool fwd_exec_init = false;
+ bool bwd_exec_init = false;
+
+ std::vector<NDArray> buff;
+ std::vector<NDArray*> arrays;
+ std::vector<OpReqType> array_reqs;
+
+ std::vector<OpStatePtr> op_states;
+ std::vector<std::shared_ptr<exec::OpExecutor> > execs;
+ std::vector<imperative::EngineOprSeg> opr_segs;
+
+ std::vector<bool> dynamic_entries;
+ std::multimap<size_t, NDArray> fwd_reuse_pool;
+ std::multimap<size_t, NDArray> bwd_reuse_pool;
+};
+
+CachedOp::CachedOp(
const nnvm::Symbol& sym,
- const std::vector<std::pair<std::string, std::string> >& flags,
- const std::vector<std::string> arg_names,
- const std::unordered_map<std::string, std::vector<NDArray> >& params) {
+ const std::vector<std::pair<std::string, std::string> >& flags) {
using namespace nnvm;
using namespace imperative;
static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"),
Op::Get("_zeros")};
@@ -36,6 +98,10 @@ Imperative::CachedOp::CachedOp(
config_.Init(flags);
+ if (config_.static_shape) {
+ CHECK(config_.static_alloc) << "static_alloc must be True when
static_shape is True";
+ }
+
// construct forward graph
{
NodeEntryMap<int> dedup_out;
@@ -68,34 +134,22 @@ Imperative::CachedOp::CachedOp(
fwd_graph_.attrs["forward_ref_count"] =
std::make_shared<dmlc::any>(std::move(ref_count));
- inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <=
config_.inline_limit;
+ inlining_ = !config_.static_alloc &&
+ (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit;
}
// Set params
{
const auto& idx = fwd_graph_.indexed_graph();
- std::unordered_map<std::string, size_t> arg_name_to_id;
- for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
- const auto& name = idx[idx.input_nodes()[i]].source->attrs.name;
- auto iter = params.find(name);
- if (iter == params.end()) {
- arg_name_to_id[name] = i;
- continue;
- }
- fwd_params_idx_.push_back(i);
- for (const auto& param : iter->second) {
- params_[param.ctx()].emplace_back(param);
+ if (config_.data_indices.ndim() || config_.param_indices.ndim()) {
+ CHECK_EQ(config_.data_indices.ndim() + config_.param_indices.ndim(),
+ idx.input_nodes().size());
+ } else {
+ std::vector<uint32_t> tmp;
+ for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
+ tmp.push_back(i);
}
- }
-
- CHECK_EQ(arg_name_to_id.size(), arg_names.size())
- << "CachedOp expects " << arg_name_to_id.size()
- << " inputs, given " << arg_names.size();
-
- for (const auto& name : arg_names) {
- auto iter = arg_name_to_id.find(name);
- CHECK(iter != arg_name_to_id.end()) << "Unexpected input name " << name;
- fwd_args_idx_.push_back(iter->second);
+ config_.data_indices.assign(tmp.begin(), tmp.end());
}
}
@@ -107,9 +161,14 @@ Imperative::CachedOp::CachedOp(
}
std::vector<NodeEntry> xs;
- std::vector<NodePtr> args = sym.ListInputs(Symbol::kReadOnlyArgs);
- xs.reserve(args.size());
- for (const auto& i : args) xs.emplace_back(NodeEntry{i, 0, 0});
+ const auto& idx = fwd_graph_.indexed_graph();
+ for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
+ auto nid = idx.input_nodes()[i];
+ if (idx.mutable_input_nodes().count(nid)) continue;
+ fwd_input_to_grad_output_[i] = xs.size();
+ xs.emplace_back(NodeEntry{idx[nid].weak_ref.lock(), 0, 0});
+ }
+
CHECK_GT(xs.size(), 0)
<< "There are no inputs in computation graph that require gradients.";
@@ -125,7 +184,7 @@ Imperative::CachedOp::CachedOp(
size_t num_forward_entries = fwd_graph_.indexed_graph().num_node_entries();
full_graph_.outputs = fwd_graph_.outputs;
- curr_grad_req_ = std::vector<bool>(grad_graph_.outputs.size(), true);
+ bwd_output_reqs_ = std::vector<OpReqType>(grad_graph_.outputs.size(),
kWriteTo);
for (const auto& i : grad_graph_.outputs)
full_graph_.outputs.emplace_back(i);
const auto& idx = full_graph_.indexed_graph();
@@ -169,7 +228,10 @@ Imperative::CachedOp::CachedOp(
}
}
-std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
+CachedOp::~CachedOp() {
+}
+
+std::vector<nnvm::NodeEntry> CachedOp::Gradient(
const nnvm::NodePtr& node,
const std::vector<nnvm::NodeEntry>& ograds) {
using namespace nnvm;
@@ -206,13 +268,15 @@ std::vector<nnvm::NodeEntry>
Imperative::CachedOp::Gradient(
return ret;
}
-nnvm::Graph Imperative::CachedOp::GetForwardGraph(
- const bool recording, const std::vector<NDArray*>& inputs) {
+
+bool CachedOp::SetForwardGraph(
+ GraphInfo* info,
+ const bool recording,
+ const std::vector<NDArray*>& inputs) {
using namespace nnvm;
using namespace imperative;
- std::lock_guard<std::mutex> lock(mutex_);
CHECK_EQ(inputs.size(), num_inputs());
- nnvm::Graph& g = fwd_graph_;
+ nnvm::Graph& g = info->fwd_graph;
ShapeVector shape_inputs;
DTypeVector dtype_inputs;
@@ -237,18 +301,22 @@ nnvm::Graph Imperative::CachedOp::GetForwardGraph(
g.attrs.erase("forward_mem_plan");
g.attrs.erase("full_mem_plan");
} else if (g.attrs.count(recording ? "full_mem_plan" : "forward_mem_plan")) {
- return g;
+ return true;
}
const auto& idx = g.indexed_graph();
StorageVector storage(idx.num_node_entries(), exec::kBadStorageID);
- for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] =
exec::kExternalStorageID;
const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
CHECK_EQ(stypes.size(), storage.size());
for (size_t i = 0; i < stypes.size(); i++) {
- if (stypes[i] != kDefaultStorage)
- storage[i] = exec::kDynamicStorageID;
+ if (stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID;
+ }
+ for (const auto i : idx.input_nodes()) {
+ storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
+ }
+ for (size_t i = 0; i < idx.outputs().size(); ++i) {
+ storage[idx.entry_id(idx.outputs()[i])] = exec::kExternalStorageID;
}
auto mem_plan = PlanMemory(
@@ -257,51 +325,50 @@ nnvm::Graph Imperative::CachedOp::GetForwardGraph(
g.attrs[recording ? "full_mem_plan" : "forward_mem_plan"] =
std::make_shared<dmlc::any>(std::move(mem_plan));
- return g;
+ return false;
}
-nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
- const OpStatePtr& op_state,
+bool CachedOp::SetBackwardGraph(
+ GraphInfo* info,
const std::vector<OpReqType>& reqs,
- const std::vector<NDArray*>& inputs) {
+ const std::vector<NDArray*>& inputs,
+ bool detect_inplace_addto) {
using namespace nnvm;
using namespace imperative;
std::lock_guard<std::mutex> lock(mutex_);
- nnvm::Graph& g = full_graph_;
- auto& state = op_state.get_state<CachedOpState>();
- bool req_match = true;
- for (size_t i = 0; i < reqs.size(); ++i) {
- if (curr_grad_req_[i] != (reqs[i] != kNullOp)) {
- curr_grad_req_[i] = reqs[i] != kNullOp;
- req_match = false;
- }
- }
- if (!req_match) {
+ Context default_ctx = inputs[0]->ctx();
+ nnvm::Graph& g = info->full_graph;
+
+ if (info->bwd_output_reqs != reqs) {
+ info->bwd_output_reqs = reqs;
+ info->bwd_input_eid.clear();
g = nnvm::Graph();
g.outputs = fwd_graph_.outputs;
for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
- if (curr_grad_req_[i]) g.outputs.emplace_back(grad_graph_.outputs[i]);
+ if (info->bwd_output_reqs[i] == kNullOp) continue;
+ g.outputs.emplace_back(grad_graph_.outputs[i]);
}
- bwd_input_eid_.clear();
+ g.attrs["context"] = std::make_shared<dmlc::any>(
+ std::vector<Context>(g.indexed_graph().num_nodes(), default_ctx));
}
const auto& idx = g.indexed_graph();
- if (bwd_input_eid_.size() != inputs.size()) {
- bwd_input_eid_.clear();
+ if (info->bwd_input_eid.size() != inputs.size()) {
+ info->bwd_input_eid.clear();
for (const auto& i : bwd_ograd_dep_) {
auto eid = idx.entry_id(ograd_entries_[i]);
- bwd_input_eid_.push_back(eid);
+ info->bwd_input_eid.push_back(eid);
}
for (const auto& i : bwd_in_dep_) {
auto eid = idx.entry_id(idx.input_nodes()[i], 0);
- bwd_input_eid_.push_back(eid);
+ info->bwd_input_eid.push_back(eid);
}
for (const auto& i : bwd_out_dep_) {
auto eid = idx.entry_id(idx.outputs()[i]);
- bwd_input_eid_.push_back(eid);
+ info->bwd_input_eid.push_back(eid);
}
- CHECK_EQ(inputs.size(), bwd_input_eid_.size());
+ CHECK_EQ(inputs.size(), info->bwd_input_eid.size());
}
size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
@@ -312,25 +379,22 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) {
for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)];
}
- for (size_t i = 0; i < inputs.size(); ++i) ++ref_count[bwd_input_eid_[i]];
+ for (size_t i = 0; i < inputs.size(); ++i)
++ref_count[info->bwd_input_eid[i]];
for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)];
g.attrs["backward_ref_count"] =
std::make_shared<dmlc::any>(std::move(ref_count));
}
- ShapeVector shapes(idx.num_node_entries(), TShape());
- DTypeVector dtypes(idx.num_node_entries(), -1);
- StorageTypeVector stypes(idx.num_node_entries(), -1);
-
- for (size_t i = 0; i < num_forward_entries; ++i) {
- shapes[i] = state.buff[i].shape();
- dtypes[i] = state.buff[i].dtype();
- stypes[i] = state.buff[i].storage_type();
- }
+ auto shapes = info->fwd_graph.GetAttr<ShapeVector>("shape");
+ shapes.resize(idx.num_node_entries(), TShape());
+ auto dtypes = info->fwd_graph.GetAttr<DTypeVector>("dtype");
+ dtypes.resize(idx.num_node_entries(), -1);
+ auto stypes = info->fwd_graph.GetAttr<StorageTypeVector>("storage_type");
+ stypes.resize(idx.num_node_entries(), -1);
for (size_t i = 0; i < inputs.size(); ++i) {
- shapes[bwd_input_eid_[i]] = inputs[i]->shape();
- dtypes[bwd_input_eid_[i]] = inputs[i]->dtype();
- stypes[bwd_input_eid_[i]] = inputs[i]->storage_type();
+ shapes[info->bwd_input_eid[i]] = inputs[i]->shape();
+ dtypes[info->bwd_input_eid[i]] = inputs[i]->dtype();
+ stypes[info->bwd_input_eid[i]] = inputs[i]->storage_type();
}
std::pair<uint32_t, uint32_t> node_range, entry_range;
@@ -342,79 +406,353 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
node_range, entry_range);
match &= CheckAndInferType(&g, std::move(dtypes), false,
node_range, entry_range);
- exec::DevMaskVector dev_mask(idx.num_nodes(), inputs[0]->ctx().dev_mask());
+ exec::DevMaskVector dev_mask(idx.num_nodes(), default_ctx.dev_mask());
match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(stypes),
false, node_range, entry_range);
if (!match) {
g.attrs.erase("backward_mem_plan");
} else if (g.attrs.count("backward_mem_plan")) {
- return g;
+ return true;
}
StorageVector storage(idx.num_node_entries(), exec::kBadStorageID);
+ const auto& bwd_stypes = g.GetAttr<StorageTypeVector>("storage_type");
+ for (size_t i = 0; i < bwd_stypes.size(); i++) {
+ if (bwd_stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID;
+ }
for (size_t i = 0; i < num_forward_entries; ++i) storage[i] =
exec::kExternalStorageID;
for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] =
exec::kExternalStorageID;
for (const auto i : idx.outputs()) storage[idx.entry_id(i)] =
exec::kExternalStorageID;
- for (size_t i = 0; i < stypes.size(); i++) {
- if (stypes[i] != kDefaultStorage)
- storage[i] = exec::kDynamicStorageID;
- }
auto mem_plan = PlanMemory(
&g, std::move(storage), g.GetAttr<std::vector<uint32_t>
>("backward_ref_count"),
- {num_forward_nodes, idx.num_nodes()}, {num_forward_entries,
idx.num_node_entries()});
+ {num_forward_nodes, idx.num_nodes()},
+ {num_forward_entries, idx.num_node_entries()},
+ detect_inplace_addto);
g.attrs["backward_mem_plan"] =
std::make_shared<dmlc::any>(std::move(mem_plan));
- return g;
+ return false;
}
-void Imperative::CachedOp::Forward(
- const std::shared_ptr<CachedOp>& op_ptr,
- const std::vector<NDArray*>& args,
- const std::vector<NDArray*>& outputs) {
+OpStatePtr CachedOp::GetCachedOpState(
+ const Context& ctx) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ for (const auto& i : cached_op_states_[ctx]) {
+ // only create one state per device when not using static memory
+ if (!config_.static_alloc || i.unique()) {
+ return i;
+ }
+ }
+ auto state_ptr = OpStatePtr::Create<CachedOpState>(ctx, fwd_graph_,
full_graph_);
+
+ cached_op_states_[ctx].push_back(state_ptr);
+ return state_ptr;
+}
+
+void CachedOp::StaticAllocMemory(
+ const OpStatePtr& state_ptr,
+ bool recording,
+ bool keep_fwd) {
using namespace nnvm;
using namespace imperative;
- static const auto cached_op = nnvm::Op::Get("_CachedOp");
- CHECK_EQ(args.size(), fwd_args_idx_.size())
- << "CachedOp requires " << fwd_args_idx_.size()
- << " inputs but got " << args.size();
+ auto& state = state_ptr.get_state<CachedOpState>();
+ const auto& default_ctx = state.context;
+ nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
+ const auto& idx = g.indexed_graph();
+ const auto& vstorage_inplace = g.GetAttr<std::vector<int>
>("storage_inplace_index");
+ const auto& mem_plan = g.GetAttr<MemoryPlanVector>(
+ keep_fwd ? "backward_mem_plan" : (recording ? "full_mem_plan" :
"forward_mem_plan"));
+ std::vector<int> addto_entry;
+ if (g.attrs.count("addto_entry")) {
+ addto_entry = g.GetAttr<std::vector<int> >("addto_entry");
+ }
+ size_t start_eid =
+ keep_fwd ? state.info.fwd_graph.indexed_graph().num_node_entries() : 0;
+ size_t end_eid = idx.num_node_entries();
+
+ if (!keep_fwd) state.fwd_alloc = false;
+ state.bwd_alloc = false;
+ for (size_t i = start_eid; i < state.buff.size(); ++i) {
+ state.buff[i] = NDArray();
+ state.arrays[i] = &state.buff[i];
+ state.array_reqs[i] = kNullOp;
+ state.dynamic_entries[i] = false;
+ }
+
+ for (auto i : idx.input_nodes()) {
+ auto eid = idx.entry_id(i, 0);
+ if (eid >= start_eid) state.dynamic_entries[eid] = true;
+ }
+ for (auto i : idx.outputs()) {
+ auto eid = idx.entry_id(i);
+ if (eid >= start_eid) state.dynamic_entries[eid] = true;
+ }
- Context default_ctx = args[0]->ctx();
+ for (size_t i = start_eid; i < end_eid; ++i) {
+ if (addto_entry.size() && addto_entry[i]) {
+ state.array_reqs[i] = kAddTo;
+ } else if (vstorage_inplace[i] >= 0) {
+ state.array_reqs[i] = kWriteInplace;
+ } else if (vstorage_inplace[i] == -2) {
+ // -2 indicate that the entry is never referenced.
+ state.array_reqs[i] = kNullOp;
+ } else {
+ state.array_reqs[i] = kWriteTo;
+ }
+ }
+ auto& reuse_pool = keep_fwd ? state.bwd_reuse_pool : state.fwd_reuse_pool;
+ reuse_pool = imperative::AllocateMemory(
+ g, idx, default_ctx, start_eid, end_eid, mem_plan,
+ state.arrays, &state.array_reqs, std::move(reuse_pool));
- std::vector<NDArray*> inputs(num_inputs());
- for (index_t i = 0; i < fwd_args_idx_.size(); ++i) {
- inputs[fwd_args_idx_[i]] = args[i];
+ state.recording = recording;
+ if (keep_fwd) {
+ state.bwd_alloc = true;
+ } else {
+ state.fwd_alloc = true;
}
- if (fwd_params_idx_.size()) {
- CHECK(params_.find(default_ctx) != params_.end())
- << "CachedOp is not initialized on context " << default_ctx;
+}
+
+void CachedOp::StaticInitExec(
+ const OpStatePtr& state_ptr,
+ bool recording,
+ bool keep_fwd) {
+ using namespace nnvm;
+ using namespace imperative;
+
+ auto& state = state_ptr.get_state<CachedOpState>();
+ const auto& default_ctx = state.context;
+ nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph;
+ const auto& idx = g.indexed_graph();
+ std::vector<int> skip_plus_node;
+ if (g.attrs.count("skip_plus_node")) {
+ skip_plus_node = g.GetAttr<std::vector<int> >("skip_plus_node");
+ }
+ size_t start_nid =
+ keep_fwd ? state.info.fwd_graph.indexed_graph().num_nodes() : 0;
+ size_t end_nid = idx.num_nodes();
+
+ if (!keep_fwd) state.fwd_exec_init = false;
+ state.bwd_exec_init = false;
- for (size_t i = 0; i < fwd_params_idx_.size(); ++i) {
- inputs[fwd_params_idx_[i]] = ¶ms_[default_ctx][i];
+ for (size_t i = start_nid; i < state.execs.size(); ++i) {
+ state.execs[i].reset();
+ state.opr_segs[i] = EngineOprSeg();
+ }
+
+ if (!config_.static_shape) {
+ for (size_t i = start_nid; i < end_nid; ++i) {
+ state.opr_segs[i].next_nid = i + 1;
+ state.opr_segs[i].skip = skip_plus_node.size() && skip_plus_node[i];
+ }
+ } else {
+ for (size_t i = start_nid; i < end_nid; ++i) {
+ exec::CreateOpExecs(g, &state.execs, i);
}
+ exec::AttachOpResources(g, state.execs, start_nid, end_nid);
+
+ for (size_t i = start_nid; i < end_nid; ++i) {
+ bool skip = idx[i].source->is_variable();
+ for (size_t j = 0; !skip && j < idx[i].inputs.size(); ++j) {
+ skip = state.dynamic_entries[idx.entry_id(idx[i].inputs[j])];
+ }
+ for (size_t j = 0; !skip && j < idx[i].source->num_outputs(); ++j) {
+ skip = state.dynamic_entries[idx.entry_id(i, j)];
+ }
+ if (skip) continue;
+ SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs);
+ }
+
+ size_t bulk_size = idx.num_nodes();
+ std::unordered_set<uint32_t> excludes;
+ if (recording || keep_fwd) {
+ bulk_size = keep_fwd ? config_.backward_bulk_size :
config_.forward_bulk_size;
+ for (const auto& i : idx.outputs()) excludes.insert(idx.entry_id(i));
+ for (const auto& i : idx.input_nodes()) excludes.insert(idx.entry_id(i,
0));
+ }
+
+ CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size,
excludes,
+ state.execs, skip_plus_node, &state.opr_segs);
}
- // Initialize
+ if (keep_fwd) {
+ state.bwd_exec_init = true;
+ } else {
+ state.fwd_exec_init = true;
+ }
+}
+
+void CachedOp::StaticRunOps(
+ const Context& default_ctx,
+ const nnvm::Graph& g,
+ const OpStatePtr& state_ptr,
+ size_t start_nid,
+ size_t end_nid) {
+ static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+ static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
+
+ bool profiling = profiler::Profiler::Get()->GetState() ==
profiler::Profiler::kRunning;
+ bool is_training = Imperative::Get()->is_training();
+ auto& state = state_ptr.get_state<CachedOpState>();
+ const auto& idx = g.indexed_graph();
+ const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
+ const auto& op_execs = state.execs;
+
+ std::vector<NDArray*> ndinputs, ndoutputs;
+ nnvm::ShapeVector arg_shapes;
+ nnvm::DTypeVector arg_dtypes;
+ std::vector<OpReqType> req;
+
+ for (size_t i = start_nid; config_.static_shape && i < end_nid; ++i) {
+ if (op_execs[i]) op_execs[i]->op_ctx.is_train = is_training;
+ }
+
+ for (size_t i = start_nid; i < end_nid; i = state.opr_segs[i].next_nid) {
+ const auto& opr_seg = state.opr_segs[i];
+ if (opr_seg.skip) continue;
+ if (opr_seg.opr != nullptr) {
+ Engine::Get()->Push(opr_seg.opr.get(), default_ctx, 0, profiling);
+ } else {
+ const nnvm::IndexedGraph::Node& node = idx[i];
+ if (node.source->is_variable()) continue;
+ auto num_outputs = node.source->num_outputs();
+ ndinputs.clear();
+ ndinputs.reserve(node.inputs.size());
+ for (const auto& j : node.inputs) {
+ ndinputs.emplace_back(state.arrays[idx.entry_id(j)]);
+ CHECK(!ndinputs.back()->is_none());
+ }
+ ndoutputs.clear();
+ ndoutputs.reserve(num_outputs);
+ req.clear();
+ req.reserve(num_outputs);
+ for (size_t j = 0; j < num_outputs; ++j) {
+ size_t eid = idx.entry_id(i, j);
+ ndoutputs.emplace_back(state.arrays[eid]);
+ req.push_back(state.array_reqs[eid]);
+ CHECK(req.back() == kNullOp || !ndoutputs.back()->is_none());
+ }
+ const DispatchMode dispatch_mode = dispatch_modes[i];
+ if (createop.count(node.source->op())) {
+ arg_shapes.clear();
+ arg_dtypes.clear();
+ arg_shapes.reserve(ndinputs.size());
+ arg_dtypes.reserve(ndinputs.size());
+ for (size_t i = 0; i < ndinputs.size(); ++i) {
+ arg_shapes.emplace_back(ndinputs[i]->shape());
+ arg_dtypes.emplace_back(ndinputs[i]->dtype());
+ }
+ state.op_states[i] = createop[node.source->op()](
+ node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
+ Imperative::Get()->InvokeOp(
+ default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
+ dispatch_mode, state.op_states[i]);
+ } else if (is_layer_backward.get(node.source->op(), false)) {
+ nnvm::Node* fwd_node = node.source->control_deps[0].get();
+ auto fwd_node_id = idx.node_id(fwd_node);
+ Imperative::Get()->InvokeOp(
+ default_ctx, node.source->attrs, ndinputs, ndoutputs,
+ req, dispatch_mode, state.op_states[fwd_node_id]);
+ } else {
+ Imperative::Get()->InvokeOp(
+ default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
+ dispatch_mode);
+ }
+ }
+ }
+}
+
+OpStatePtr CachedOp::StaticForward(
+ const Context& default_ctx,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<NDArray*>& outputs) {
+ using namespace nnvm;
+ using namespace imperative;
+
bool recording = Imperative::Get()->is_recording();
- nnvm::Graph g = GetForwardGraph(recording, inputs);
+ auto state_ptr = GetCachedOpState(default_ctx);
+ auto& state = state_ptr.get_state<CachedOpState>();
+ std::lock_guard<std::mutex> lock(state.mutex);
+
+ bool match = SetForwardGraph(&state.info, recording, inputs);
+ match = match && state.recording == recording;
+
+ nnvm::Graph& g = state.info.fwd_graph;
const auto& idx = g.indexed_graph();
- size_t num_inputs = idx.input_nodes().size();
+ if (!state.fwd_alloc || !match) {
+ StaticAllocMemory(state_ptr, recording, false);
+ }
- for (size_t i = 0; i < inputs.size(); ++i) {
- CHECK_EQ(inputs[i]->ctx(), default_ctx)
- << "CachedOp requires all inputs to live on the same context. But "
- << idx[idx.input_nodes()[0]].source->attrs.name << " is on " <<
default_ctx
- << " while " << idx[idx.input_nodes()[i]].source->attrs.name << " is
on "
- << inputs[i]->ctx();
+ if (config_.static_shape) {
+ for (auto i : config_.param_indices) {
+ auto nid = idx.input_nodes()[i];
+ if (!state.arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[i])) {
+ match = false;
+ auto ptr = &state.buff[idx.entry_id(nid, 0)];
+ CHECK_EQ(state.arrays[idx.entry_id(nid, 0)], ptr);
+ *state.arrays[idx.entry_id(nid, 0)] = *inputs[i];
+ state.dynamic_entries[idx.entry_id(nid, 0)] = false;
+ }
+ }
+ for (auto i : config_.data_indices) {
+ auto eid = idx.entry_id(idx.input_nodes()[i], 0);
+ state.arrays[eid] = inputs[i];
+ }
+ } else {
+ for (size_t i = 0; i < num_inputs(); ++i) {
+ auto nid = idx.input_nodes()[i];
+ state.arrays[idx.entry_id(nid, 0)] = inputs[i];
+ }
}
- auto op_state_ptr = OpStatePtr::Create<CachedOpState>();
- auto& cached_op_state = op_state_ptr.get_state<CachedOpState>();
- auto& buff = cached_op_state.buff;
- auto& states = cached_op_state.states;
+ if (!state.fwd_exec_init || !match) {
+ StaticInitExec(state_ptr, recording, false);
+ }
+
+ const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
+ const auto& shapes = g.GetAttr<ShapeVector>("shape");
+ const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ auto eid = idx.entry_id(idx.outputs()[i]);
+ state.arrays[eid] = outputs[i];
+ if (!outputs[i]->is_none()) continue;
+ *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
+ shapes[eid], default_ctx, true, dtypes[eid]);
+ }
+
+ StaticRunOps(default_ctx, g, state_ptr, 0, idx.num_nodes());
+
+ return recording ? state_ptr : OpStatePtr();
+}
+
+
+OpStatePtr CachedOp::DynamicForward(
+ const Context& default_ctx,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<NDArray*>& outputs) {
+ using namespace nnvm;
+ using namespace imperative;
+
+ // Initialize
+ bool recording = Imperative::Get()->is_recording();
+ auto op_state = OpStatePtr::Create<DynamicRuntime>();
+ auto& runtime = op_state.get_state<DynamicRuntime>();
+ {
+ auto state_ptr = GetCachedOpState(default_ctx);
+ auto& state = state_ptr.get_state<CachedOpState>();
+ std::lock_guard<std::mutex> lock(state.mutex);
+ SetForwardGraph(&state.info, recording, inputs);
+ runtime.info.fwd_graph = state.info.fwd_graph;
+ }
+ nnvm::Graph& g = runtime.info.fwd_graph;
+ const auto& idx = g.indexed_graph();
+ size_t num_inputs = idx.input_nodes().size();
+ auto& buff = runtime.buff;
+ auto& states = runtime.op_states;
// Allocate entries
states.resize(idx.num_nodes());
@@ -446,57 +784,98 @@ void Imperative::CachedOp::Forward(
AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
mem_plan, arrays, &array_reqs);
+ const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
+ const auto& shapes = g.GetAttr<ShapeVector>("shape");
+ const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ auto eid = idx.entry_id(idx.outputs()[i]);
+ arrays[eid] = outputs[i];
+ if (!outputs[i]->is_none()) continue;
+ *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
+ shapes[eid], default_ctx, true, dtypes[eid]);
+ }
+
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
if (recording && !inlining_) Imperative::Get()->set_is_recording(false);
- int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
- Imperative::Get()->RunGraph(
- false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
- std::move(ref_count), &states, dispatch_modes);
+ RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
+ std::move(ref_count), &states, dispatch_modes);
- Engine::Get()->set_bulk_size(prev_bulk_size);
Imperative::Get()->set_is_recording(recording);
- for (size_t i = 0; i < idx.num_node_entries(); ++i) {
- if (arrays[i] == &buff[i]) continue;
- buff[i].shape_ = arrays[i]->shape_;
- buff[i].dtype_ = arrays[i]->dtype_;
- buff[i].storage_type_ = arrays[i]->storage_type_;
+ return op_state;
+}
+
+void CachedOp::Forward(
+ const std::shared_ptr<CachedOp>& op_ptr,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<NDArray*>& outputs) {
+ static const auto cached_op = nnvm::Op::Get("_CachedOp");
+
+ CHECK_EQ(inputs.size(), num_inputs());
+
+ Context default_ctx = inputs[0]->ctx();
+
+ const auto& idx = fwd_graph_.indexed_graph();
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ CHECK_EQ(inputs[i]->ctx(), default_ctx)
+ << "CachedOp requires all inputs to live on the same context. But "
+ << idx[idx.input_nodes()[0]].source->attrs.name
+ << " is on " << default_ctx << " while "
+ << idx[idx.input_nodes()[i]].source->attrs.name
+ << " is on " << inputs[i]->ctx();
}
- if (recording && !inlining_) {
+ int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
+
+ OpStatePtr op_state;
+ if (config_.static_alloc) {
+ op_state = StaticForward(default_ctx, inputs, outputs);
+ } else {
+ op_state = DynamicForward(default_ctx, inputs, outputs);
+ }
+
+ Engine::Get()->set_bulk_size(prev_bulk_size);
+
+ if (Imperative::Get()->is_recording() && !inlining_) {
nnvm::NodeAttrs attrs;
attrs.op = cached_op;
attrs.name = "_cachedop";
attrs.parsed = op_ptr;
Imperative::Get()->RecordOp(
- std::move(attrs), inputs, outputs, op_state_ptr,
+ std::move(attrs), inputs, outputs, op_state,
&save_inputs(), &save_outputs());
}
}
-void Imperative::CachedOp::Backward(
+void CachedOp::DynamicBackward(
const bool retain_graph,
- const OpStatePtr& state,
+ const OpStatePtr& op_state,
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs) {
using namespace nnvm;
using namespace imperative;
- CHECK(!Imperative::Get()->is_recording())
- << "CachedOp does not support higher order gradients. "
- << "If you want to do backward with create_graph=True please "
- << "do not use hybridize.";
// Initialize
- nnvm::Graph g = GetBackwardGraph(state, reqs, inputs);
+ Context default_ctx = outputs[0]->ctx();
+ auto& runtime = op_state.get_state<DynamicRuntime>();
+ {
+ auto state_ptr = GetCachedOpState(default_ctx);
+ auto& state = state_ptr.get_state<CachedOpState>();
+ std::lock_guard<std::mutex> lock(state.mutex);
+ state.info.fwd_graph = runtime.info.fwd_graph;
+ SetBackwardGraph(&state.info, reqs, inputs);
+ runtime.info.full_graph = state.info.full_graph;
+ runtime.info.bwd_input_eid = state.info.bwd_input_eid;
+ }
+ nnvm::Graph& g = runtime.info.full_graph;
const auto& idx = g.indexed_graph();
-
- auto& cached_op_state = state.get_state<CachedOpState>();
- auto& buff = cached_op_state.buff;
- auto& states = cached_op_state.states;
+ auto& buff = runtime.buff;
+ auto& states = runtime.op_states;
size_t num_forward_outputs = fwd_graph_.outputs.size();
size_t num_forward_nodes = fwd_graph_.indexed_graph().num_nodes();
@@ -506,7 +885,7 @@ void Imperative::CachedOp::Backward(
arrays.reserve(buff.size());
for (size_t i = 0; i < buff.size(); ++i) arrays.push_back(&buff[i]);
for (size_t i = 0; i < inputs.size(); ++i) {
- arrays[bwd_input_eid_[i]] = inputs[i];
+ arrays[runtime.info.bwd_input_eid[i]] = inputs[i];
}
for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) {
if (reqs[i] == kNullOp) continue;
@@ -530,20 +909,14 @@ void Imperative::CachedOp::Backward(
if (ref_count[i] == 0) array_reqs[i] = kNullOp;
}
- Context default_ctx = outputs[0]->ctx();
const auto& mem_plan = g.GetAttr<MemoryPlanVector >("backward_mem_plan");
AllocateMemory(g, idx, default_ctx, num_forward_entries,
idx.num_node_entries(),
mem_plan, arrays, &array_reqs);
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
- int prev_bulk_size =
Engine::Get()->set_bulk_size(config_.backward_bulk_size);
-
- Imperative::Get()->RunGraph(
- retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
- std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
-
- Engine::Get()->set_bulk_size(prev_bulk_size);
+ RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
+ std::move(array_reqs), std::move(ref_count), &states,
dispatch_modes);
if (retain_graph) {
buff.resize(num_forward_entries);
@@ -553,6 +926,99 @@ void Imperative::CachedOp::Backward(
}
}
+void CachedOp::StaticBackward(
+ const bool retain_graph,
+ const OpStatePtr& state_ptr,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<OpReqType>& reqs,
+ const std::vector<NDArray*>& outputs) {
+ using namespace nnvm;
+ using namespace imperative;
+
+ Context default_ctx = outputs[0]->ctx();
+
+ auto& state = state_ptr.get_state<CachedOpState>();
+ std::lock_guard<std::mutex> lock(state.mutex);
+
+ bool match = SetBackwardGraph(&state.info, reqs, inputs, true);
+
+ nnvm::Graph& g = state.info.full_graph;
+ const auto& idx = g.indexed_graph();
+ auto num_forward_nodes = state.info.fwd_graph.indexed_graph().num_nodes();
+
+ if (!state.bwd_alloc || !match) {
+ StaticAllocMemory(state_ptr, true, true);
+ }
+
+ if (config_.static_shape) {
+ for (auto i : config_.param_indices) {
+ const auto iter = fwd_input_to_grad_output_.find(i);
+ if (iter == fwd_input_to_grad_output_.end()) continue;
+ auto entry = grad_graph_.outputs[iter->second];
+ if (!idx.exist(entry.node.get())) continue;
+ auto eid = idx.entry_id(entry);
+ if (!state.arrays[eid]->IsSame(*outputs[iter->second]) ||
+ !(state.array_reqs[eid] == reqs[iter->second])) {
+ match = false;
+ state.array_reqs[eid] = reqs[iter->second];
+ *state.arrays[eid] = *outputs[iter->second];
+ state.dynamic_entries[eid] = false;
+ }
+ }
+ for (auto i : config_.data_indices) {
+ const auto iter = fwd_input_to_grad_output_.find(i);
+ if (iter == fwd_input_to_grad_output_.end()) continue;
+ auto entry = grad_graph_.outputs[iter->second];
+ if (!idx.exist(entry.node.get())) continue;
+ auto eid = idx.entry_id(entry);
+ state.array_reqs[eid] = reqs[iter->second];
+ state.arrays[eid] = outputs[iter->second];
+ }
+ } else {
+ for (size_t i = 0; i < grad_graph_.outputs.size(); ++i) {
+ auto entry = grad_graph_.outputs[i];
+ if (!idx.exist(entry.node.get())) continue;
+ auto eid = idx.entry_id(entry);
+ state.array_reqs[eid] = reqs[i];
+ state.arrays[eid] = outputs[i];
+ }
+ }
+
+ if (!state.bwd_exec_init || !match) {
+ StaticInitExec(state_ptr, true, true);
+ }
+
+ for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
+ auto eid = state.info.bwd_input_eid[i];
+ if (state.dynamic_entries[eid]) state.arrays[eid] = inputs[i];
+ }
+
+ StaticRunOps(default_ctx, g, state_ptr, num_forward_nodes, idx.num_nodes());
+}
+
+void CachedOp::Backward(
+ const bool retain_graph,
+ const OpStatePtr& state,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<OpReqType>& reqs,
+ const std::vector<NDArray*>& outputs) {
+ using namespace imperative;
+ CHECK(!Imperative::Get()->is_recording())
+ << "CachedOp does not support higher order gradients. "
+ << "If you want to do backward with create_graph=True please "
+ << "do not use hybridize.";
+
+ int prev_bulk_size =
Engine::Get()->set_bulk_size(config_.backward_bulk_size);
+
+ if (config_.static_alloc) {
+ StaticBackward(retain_graph, state, inputs, reqs, outputs);
+ } else {
+ DynamicBackward(retain_graph, state, inputs, reqs, outputs);
+ }
+
+ Engine::Get()->set_bulk_size(prev_bulk_size);
+}
+
NNVM_REGISTER_OP(_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs) {
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
new file mode 100644
index 0000000..60a40c5
--- /dev/null
+++ b/src/imperative/cached_op.h
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_IMPERATIVE_CACHED_OP_H_
+#define MXNET_IMPERATIVE_CACHED_OP_H_
+
+#include <mxnet/imperative.h>
+#include <vector>
+#include <atomic>
+#include <utility>
+#include <string>
+#include <unordered_map>
+
+namespace mxnet {
+/*! \brief CachedOp Parameters */
+struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
+ uint32_t inline_limit;
+ uint32_t forward_bulk_size;
+ uint32_t backward_bulk_size;
+ bool static_alloc;
+ bool static_shape;
+ nnvm::Tuple<uint32_t> data_indices;
+ nnvm::Tuple<uint32_t> param_indices;
+ DMLC_DECLARE_PARAMETER(CachedOpConfig) {
+ DMLC_DECLARE_FIELD(static_alloc)
+ .set_default(false)
+ .describe("Statically allocate memory to improve speed. "
+ "Memory usage may increase.");
+ DMLC_DECLARE_FIELD(static_shape)
+ .set_default(false)
+ .describe("Optimize for invariant input shapes between iterations. "
+ "Must also set static_alloc to True. "
+ "Change of input shapes is still allowed but slower.");
+ DMLC_DECLARE_FIELD(inline_limit)
+ .set_default(2)
+ .describe("Maximum number of operators that can be inlined.");
+ DMLC_DECLARE_FIELD(forward_bulk_size)
+ .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+ .describe("Segment size of bulk execution during forward pass.");
+ DMLC_DECLARE_FIELD(backward_bulk_size)
+ .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+ .describe("Segment size of bulk execution during backward pass.");
+ DMLC_DECLARE_FIELD(data_indices)
+ .set_default(nnvm::Tuple<uint32_t>())
+ .describe("Position of argument variables.");
+ DMLC_DECLARE_FIELD(param_indices)
+ .set_default(nnvm::Tuple<uint32_t>())
+ .describe("Position of parameters.");
+ }
+};
+
+class CachedOp {
+ public:
+ CachedOp(
+ const nnvm::Symbol& sym,
+ const std::vector<std::pair<std::string, std::string> >& flags);
+ ~CachedOp();
+ uint32_t num_inputs() {
+ return fwd_graph_.indexed_graph().input_nodes().size();
+ }
+ uint32_t num_outputs() {
+ return fwd_graph_.outputs.size();
+ }
+ uint32_t num_backward_inputs() {
+ return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
+ }
+ std::vector<bool>& save_inputs() {
+ return save_inputs_;
+ }
+ std::vector<bool>& save_outputs() {
+ return save_outputs_;
+ }
+ const std::unordered_set<uint32_t>& mutable_input_nodes() {
+ return fwd_graph_.indexed_graph().mutable_input_nodes();
+ }
+ std::vector<nnvm::NodeEntry> Gradient(
+ const nnvm::NodePtr& node,
+ const std::vector<nnvm::NodeEntry>& ograds);
+ void Forward(
+ const std::shared_ptr<CachedOp>& op_ptr,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<NDArray*>& outputs);
+ void Backward(
+ const bool retain_graph,
+ const OpStatePtr& state,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<OpReqType>& reqs,
+ const std::vector<NDArray*>& outputs);
+
+ private:
+ struct GraphInfo;
+ struct DynamicRuntime;
+ struct CachedOpState;
+
+ OpStatePtr GetCachedOpState(const Context& ctx);
+ bool SetForwardGraph(
+ GraphInfo* info,
+ const bool recording,
+ const std::vector<NDArray*>& inputs);
+ bool SetBackwardGraph(
+ GraphInfo* info,
+ const std::vector<OpReqType>& reqs,
+ const std::vector<NDArray*>& inputs,
+ bool detect_inplace_addto = false);
+ OpStatePtr DynamicForward(
+ const Context& default_ctx,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<NDArray*>& outputs);
+ void DynamicBackward(
+ const bool retain_graph,
+ const OpStatePtr& op_state,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<OpReqType>& reqs,
+ const std::vector<NDArray*>& outputs);
+ void StaticAllocMemory(
+ const OpStatePtr& state_ptr,
+ bool recording,
+ bool keep_fwd);
+ void StaticInitExec(
+ const OpStatePtr& state_ptr,
+ bool recording,
+ bool keep_fwd);
+ void StaticRunOps(
+ const Context& default_ctx,
+ const nnvm::Graph& g,
+ const OpStatePtr& state_ptr,
+ size_t start_nid,
+ size_t end_nid);
+ OpStatePtr StaticForward(
+ const Context& default_ctx,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<NDArray*>& outputs);
+ void StaticBackward(
+ const bool retain_graph,
+ const OpStatePtr& state_ptr,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<OpReqType>& reqs,
+ const std::vector<NDArray*>& outputs);
+
+ CachedOpConfig config_;
+ nnvm::Graph fwd_graph_;
+ nnvm::Graph grad_graph_;
+ nnvm::Graph full_graph_;
+ bool inlining_;
+ std::vector<nnvm::NodeEntry> ograd_entries_;
+ std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
+ std::unordered_map<uint32_t, uint32_t> fwd_input_to_grad_output_;
+ std::vector<bool> save_inputs_, save_outputs_;
+ std::vector<OpReqType> bwd_output_reqs_;
+
+ std::mutex mutex_;
+ std::unordered_map<Context, std::vector<OpStatePtr> > cached_op_states_;
+};
+
+using CachedOpPtr = std::shared_ptr<CachedOp>;
+
+} // namespace mxnet
+#endif // MXNET_IMPERATIVE_CACHED_OP_H_
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index 7caf305..e165425 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -19,6 +19,7 @@
#include <unordered_set>
#include <iostream>
#include "./imperative_utils.h"
+#include "./cached_op.h"
namespace mxnet {
#if DMLC_CXX11_THREAD_LOCAL
@@ -266,95 +267,6 @@ void Imperative::RecordOp(
}
}
-void Imperative::RunGraph(
- const bool retain_graph,
- const nnvm::IndexedGraph& idx,
- const std::vector<NDArray*> arrays,
- size_t node_start, size_t node_end,
- std::vector<OpReqType>&& array_reqs,
- std::vector<uint32_t>&& ref_count,
- std::vector<OpStatePtr> *p_states,
- const DispatchModeVector &dispatch_modes) {
- using namespace nnvm;
- using namespace imperative;
- static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
- static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
- static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
-
- std::vector<OpStatePtr>& states = *p_states;
- bool recording = is_recording();
-
- std::vector<NDArray*> ndinputs, ndoutputs;
- ShapeVector arg_shapes;
- DTypeVector arg_dtypes;
- std::vector<OpReqType> req;
-
- for (size_t i = node_start; i < node_end; ++i) {
- const nnvm::IndexedGraph::Node& node = idx[i];
- if (node.source->op() == nullptr) continue;
- auto num_outputs = node.source->num_outputs();
- ndinputs.clear();
- ndinputs.reserve(node.inputs.size());
- for (const auto& j : node.inputs) {
- ndinputs.emplace_back(arrays[idx.entry_id(j)]);
- CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name
<< " " << j.index;
- }
- ndoutputs.clear();
- ndoutputs.reserve(num_outputs);
- req.clear();
- req.reserve(num_outputs);
- for (size_t j = 0; j < num_outputs; ++j) {
- size_t eid = idx.entry_id(i, j);
- ndoutputs.emplace_back(arrays[eid]);
- req.push_back(array_reqs[eid]);
- CHECK(!ndoutputs.back()->is_none());
- }
- const Context& ctx = ndoutputs[0]->ctx();
- const DispatchMode dispatch_mode = dispatch_modes[i];
- if (node.source->op() == bwd_cached_op) {
- const auto& cached_op =
dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
- nnvm::Node* fwd_node = node.source->control_deps[0].get();
- auto fwd_node_id = idx.node_id(fwd_node);
- cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req,
ndoutputs);
- } else if (createop.count(node.source->op())) {
- arg_shapes.clear();
- arg_dtypes.clear();
- arg_shapes.reserve(ndinputs.size());
- arg_dtypes.reserve(ndinputs.size());
- for (size_t i = 0; i < ndinputs.size(); ++i) {
- arg_shapes.emplace_back(ndinputs[i]->shape());
- arg_dtypes.emplace_back(ndinputs[i]->dtype());
- }
- states[i] = createop[node.source->op()](
- node.source->attrs, ctx, arg_shapes, arg_dtypes);
- InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req,
dispatch_mode, states[i]);
- if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs,
ndoutputs, states[i]);
- } else if (is_layer_backward.get(node.source->op(), false)) {
- nnvm::Node* fwd_node = node.source->control_deps[0].get();
- auto fwd_node_id = idx.node_id(fwd_node);
- InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
- req, dispatch_mode, states[fwd_node_id]);
- if (recording) {
- RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs,
states[fwd_node_id]);
- }
- } else {
- InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req,
dispatch_mode);
- if (recording) RecordOp(NodeAttrs(node.source->attrs), ndinputs,
ndoutputs);
- }
-
- for (const auto& j : node.inputs) {
- size_t eid = idx.entry_id(j);
- --ref_count[eid];
- if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
- }
- for (size_t j = 0; j < ndoutputs.size(); ++j) {
- size_t eid = idx.entry_id(i, j);
- if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
- }
- }
-}
-
-
std::vector<NDArray*> Imperative::Backward(
const std::vector<NDArray*>& outputs,
const std::vector<NDArray*>& ograds,
diff --git a/src/imperative/imperative_utils.cc
b/src/imperative/imperative_utils.cc
new file mode 100644
index 0000000..464aefc
--- /dev/null
+++ b/src/imperative/imperative_utils.cc
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./imperative_utils.h"
+#include "./cached_op.h"
+
+namespace mxnet {
+namespace imperative {
+void RunGraph(
+ const bool retain_graph,
+ const nnvm::IndexedGraph& idx,
+ const std::vector<NDArray*> arrays,
+ size_t node_start, size_t node_end,
+ std::vector<OpReqType>&& array_reqs,
+ std::vector<uint32_t>&& ref_count,
+ std::vector<OpStatePtr> *p_states,
+ const DispatchModeVector &dispatch_modes) {
+ using namespace nnvm;
+ using namespace imperative;
+ static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+ static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
+ static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
+
+ const auto imp = Imperative::Get();
+
+ std::vector<OpStatePtr>& states = *p_states;
+ bool recording = imp->is_recording();
+
+ std::vector<NDArray*> ndinputs, ndoutputs;
+ ShapeVector arg_shapes;
+ DTypeVector arg_dtypes;
+ std::vector<OpReqType> req;
+
+ for (size_t i = node_start; i < node_end; ++i) {
+ const nnvm::IndexedGraph::Node& node = idx[i];
+ if (node.source->op() == nullptr) continue;
+ auto num_outputs = node.source->num_outputs();
+ ndinputs.clear();
+ ndinputs.reserve(node.inputs.size());
+ for (const auto& j : node.inputs) {
+ ndinputs.emplace_back(arrays[idx.entry_id(j)]);
+ CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name
<< " " << j.index;
+ }
+ ndoutputs.clear();
+ ndoutputs.reserve(num_outputs);
+ req.clear();
+ req.reserve(num_outputs);
+ for (size_t j = 0; j < num_outputs; ++j) {
+ size_t eid = idx.entry_id(i, j);
+ ndoutputs.emplace_back(arrays[eid]);
+ req.push_back(array_reqs[eid]);
+ CHECK(array_reqs[eid] == kNullOp || !ndoutputs.back()->is_none());
+ }
+ const Context& ctx = ndoutputs[0]->ctx();
+ const DispatchMode dispatch_mode = dispatch_modes[i];
+ if (node.source->op() == bwd_cached_op) {
+ const auto& cached_op =
dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
+ nnvm::Node* fwd_node = node.source->control_deps[0].get();
+ auto fwd_node_id = idx.node_id(fwd_node);
+ cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req,
ndoutputs);
+ } else if (createop.count(node.source->op())) {
+ arg_shapes.clear();
+ arg_dtypes.clear();
+ arg_shapes.reserve(ndinputs.size());
+ arg_dtypes.reserve(ndinputs.size());
+ for (size_t i = 0; i < ndinputs.size(); ++i) {
+ arg_shapes.emplace_back(ndinputs[i]->shape());
+ arg_dtypes.emplace_back(ndinputs[i]->dtype());
+ }
+ states[i] = createop[node.source->op()](
+ node.source->attrs, ctx, arg_shapes, arg_dtypes);
+ imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req,
dispatch_mode, states[i]);
+ if (recording) {
+ imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs,
states[i]);
+ }
+ } else if (is_layer_backward.get(node.source->op(), false)) {
+ nnvm::Node* fwd_node = node.source->control_deps[0].get();
+ auto fwd_node_id = idx.node_id(fwd_node);
+ imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
+ req, dispatch_mode, states[fwd_node_id]);
+ if (recording) {
+ imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs,
states[fwd_node_id]);
+ }
+ } else {
+ imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req,
dispatch_mode);
+ if (recording) {
+ imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs);
+ }
+ }
+
+ for (const auto& j : node.inputs) {
+ size_t eid = idx.entry_id(j);
+ --ref_count[eid];
+ if (ref_count[eid] == 0) *arrays[eid] = NDArray();
+ }
+ for (size_t j = 0; j < ndoutputs.size(); ++j) {
+ size_t eid = idx.entry_id(i, j);
+ if (ref_count[eid] == 0) *arrays[eid] = NDArray();
+ }
+ }
+}
+
+} // namespace imperative
+} // namespace mxnet
diff --git a/src/imperative/imperative_utils.h
b/src/imperative/imperative_utils.h
index 06b7e05..726531d 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -23,6 +23,7 @@
#include <utility>
#include <algorithm>
#include <vector>
+#include <map>
#include <string>
#include "../executor/graph_executor.h"
#include "../executor/exec_pass.h"
@@ -38,11 +39,24 @@ namespace mxnet {
namespace imperative {
struct MemoryPlanInfo {
- uint32_t sid;
+ int storage_id;
+ uint32_t root;
size_t size;
bool inplace;
};
+struct EngineOprDeleter {
+ void operator()(engine::Opr* handle) {
+ Engine::Get()->DeleteOperator(handle);
+ }
+};
+
+struct EngineOprSeg {
+ bool skip;
+ size_t next_nid;
+ std::unique_ptr<engine::Opr, EngineOprDeleter> opr;
+};
+
using MemoryPlanVector = std::vector<MemoryPlanInfo>;
inline Context GetContext(const nnvm::NodeAttrs& attrs,
@@ -715,10 +729,12 @@ inline std::vector<Context> PlaceDevice(const
nnvm::IndexedGraph& idx) {
inline MemoryPlanVector PlanMemory(
- nnvm::Graph* p_g, nnvm::StorageVector&& storage,
+ nnvm::Graph* p_g,
+ nnvm::StorageVector&& storage,
const std::vector<uint32_t>& ref_count,
const std::pair<uint32_t, uint32_t>& node_range = {0, 0},
- const std::pair<uint32_t, uint32_t>& entry_range = {0, 0}) {
+ const std::pair<uint32_t, uint32_t>& entry_range = {0, 0},
+ bool detect_inplace_addto = false) {
using namespace nnvm;
nnvm::Graph& g = *p_g;
const auto& idx = g.indexed_graph();
@@ -728,31 +744,31 @@ inline MemoryPlanVector PlanMemory(
g.attrs["ref_count"] = std::make_shared<dmlc::any>(ref_count);
g.attrs["storage"] = std::make_shared<dmlc::any>(std::move(storage));
g = nnvm::ApplyPass(g, "PlanMemory");
+ if (detect_inplace_addto) g = exec::DetectInplaceAddTo(g);
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<ShapeVector>("shape");
- const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
- auto storage_ids = g.MoveCopyAttr<StorageVector>("storage_id");
- auto storage_inplace = g.MoveCopyAttr<std::vector<int>
>("storage_inplace_index");
+ const auto& storage_inplace = g.GetAttr<std::vector<int>
>("storage_inplace_index");
+ const auto& storage_ids = g.GetAttr<StorageVector>("storage_id");
uint32_t entry_start = entry_range.first;
uint32_t entry_end =
entry_range.second > entry_start ? entry_range.second :
idx.num_node_entries();
MemoryPlanVector mem_plan(idx.num_node_entries());
- std::unordered_map<int, uint32_t> sid_to_loc;
+ std::unordered_map<int, uint32_t> sid_to_root;
for (uint32_t i = entry_start; i < entry_end; ++i) {
- if (stypes[i] != kDefaultStorage) continue;
if (storage_ids[i] < 0) {
- mem_plan[i] = {i, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(),
false};
- } else if (!sid_to_loc.count(storage_ids[i])) {
+ mem_plan[i] = {storage_ids[i], i, 0, false};
+ } else if (!sid_to_root.count(storage_ids[i])) {
CHECK_LT(storage_inplace[i], 0);
- sid_to_loc[storage_ids[i]] = i;
- mem_plan[i].sid = i;
- mem_plan[i].size = mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size();
+ sid_to_root[storage_ids[i]] = i;
+ mem_plan[i] = {storage_ids[i], i,
+ mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(),
+ false};
} else {
- uint32_t loc = sid_to_loc[storage_ids[i]];
- mem_plan[i] = {loc, 0, storage_inplace[i] >= 0};
- mem_plan[loc].size = std::max(mem_plan[loc].size,
+ uint32_t root = sid_to_root[storage_ids[i]];
+ mem_plan[i] = {storage_ids[i], root, 0, storage_inplace[i] >= 0};
+ mem_plan[root].size = std::max(mem_plan[root].size,
mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size());
}
}
@@ -761,39 +777,213 @@ inline MemoryPlanVector PlanMemory(
}
-inline void AllocateMemory(const nnvm::Graph& g,
- const nnvm::IndexedGraph& idx,
- const Context& default_ctx,
- const uint32_t entry_start, const uint32_t entry_end,
- const MemoryPlanVector& mem_plan,
- const std::vector<NDArray*>& arrays,
- std::vector<OpReqType> *array_reqs) {
+inline std::multimap<size_t, NDArray> AllocateMemory(
+ const nnvm::Graph& g,
+ const nnvm::IndexedGraph& idx,
+ const Context& default_ctx,
+ const uint32_t entry_start, const uint32_t entry_end,
+ const MemoryPlanVector& mem_plan,
+ const std::vector<NDArray*>& arrays,
+ std::vector<OpReqType> *array_reqs,
+ std::multimap<size_t, NDArray>&& pool = std::multimap<size_t, NDArray>()) {
using namespace nnvm;
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<ShapeVector>("shape");
const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+ std::multimap<size_t, NDArray> new_pool;
+
for (uint32_t i = entry_start; i < entry_end; ++i) {
- if (!arrays[i]->is_none()) continue;
- if (stypes[i] == kDefaultStorage) {
- if (mem_plan[i].sid == i) {
- CHECK_GT(mem_plan[i].size, 0);
+ if (mem_plan[i].storage_id == exec::kExternalStorageID) continue;
+ CHECK(arrays[i]->is_none());
+ if (mem_plan[i].storage_id == exec::kDynamicStorageID) {
+ *arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
+ shapes[i], default_ctx, true, dtypes[i]);
+ continue;
+ }
+ CHECK_EQ(stypes[i], kDefaultStorage);
+ if (mem_plan[i].root == i) {
+ CHECK_GT(mem_plan[i].size, 0);
+ auto iter = pool.lower_bound(mem_plan[i].size);
+ if (iter != pool.end()) {
+ *arrays[i] = iter->second.AsArray(shapes[i], dtypes[i]);
+ new_pool.insert(*iter);
+ pool.erase(iter);
+ } else {
NDArray buff(TShape({static_cast<nnvm::dim_t>(mem_plan[i].size)}),
default_ctx, true, mshadow::kUint8);
*arrays[i] = buff.AsArray(shapes[i], dtypes[i]);
+ new_pool.insert({mem_plan[i].size, buff});
+ }
+ } else {
+ CHECK_GE(mem_plan[mem_plan[i].root].storage_id, 0);
+ *arrays[i] = arrays[mem_plan[i].root]->AsArray(shapes[i], dtypes[i]);
+ if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
+ array_reqs->at(i) = kWriteInplace;
+ }
+ }
+ }
+
+ return new_pool;
+}
+
+inline void SetupOpExec(
+ const nnvm::Graph& g,
+ size_t nid,
+ const std::shared_ptr<exec::OpExecutor>& exec,
+ const std::vector<NDArray*> arrays,
+ const std::vector<OpReqType> array_reqs) {
+ const auto& idx = g.indexed_graph();
+ const auto& inode = idx[nid];
+ CHECK_EQ(exec->in_array.size(), 0U);
+ CHECK_EQ(exec->out_array.size(), 0U);
+ for (const auto& e : inode.inputs) {
+ CHECK(!arrays[idx.entry_id(e)]->is_none()) << inode.source->attrs.name;
+ exec->in_array.push_back(*arrays[idx.entry_id(e)]);
+ }
+ for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
+ uint32_t eid = idx.entry_id(nid, index);
+ CHECK(!arrays[eid]->is_none()) << inode.source->attrs.name;
+ exec->out_array.push_back(*arrays[eid]);
+ exec->req.push_back(array_reqs[eid]);
+ }
+
+ exec->Setup();
+}
+
+inline Engine::OprHandle CreateEngineOp(
+ const Context& default_ctx,
+ const std::vector<std::shared_ptr<exec::OpExecutor> >& execs) {
+ CHECK_GT(execs.size(), 0);
+ std::vector<Engine::VarHandle> use_vars, mutate_vars;
+
+ for (const auto& exec : execs) {
+ CHECK_GT(exec->out_array.size(), 0);
+ CHECK(execs.size() == 1 || exec->exec_type() == ExecType::kSync);
+
+ // the variables
+ for (const auto& nd : exec->in_array) {
+ use_vars.push_back(nd.var());
+ }
+ for (auto& r : exec->op_ctx.requested) {
+ mutate_vars.push_back(r.var);
+ }
+ for (auto& nd : exec->out_array) {
+ mutate_vars.push_back(nd.var());
+ }
+ if (exec->var() != nullptr) {
+ mutate_vars.push_back(exec->var());
+ }
+ }
+
+ // dedup vars
+ Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars);
+ bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask;
+ bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() ==
ExecType::kAsync;
+
+ auto exec_fun = [execs, is_async, is_gpu] (
+ RunContext ctx, Engine::CallbackOnComplete on_complete) {
+ if (is_async) {
+ execs[0]->op_ctx.async_on_complete = on_complete;
+ }
+ for (const auto& exec : execs) exec->Run(ctx, is_gpu);
+ // call on complete only if it is async op
+ if (!is_async) {
+ if (is_gpu) {
+ #if MXNET_USE_CUDA
+ // Wait GPU kernel to finish.
+ ctx.get_stream<gpu>()->Wait();
+ #else
+ LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+ #endif
+ }
+ on_complete();
+ }
+ };
+
+ return Engine::Get()->NewOperator(
+ exec_fun, use_vars, mutate_vars, FnProperty::kNormal);
+}
+
+inline void CreateEngineOpSeg(
+ const nnvm::IndexedGraph& idx,
+ const Context default_ctx,
+ const size_t start_nid,
+ const size_t end_nid,
+ const size_t bulk_size,
+ const std::unordered_set<uint32_t>& excludes,
+ const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
+ const std::vector<int> skip_plus_node,
+ std::vector<EngineOprSeg> *opr_segs) {
+ size_t seg_start = start_nid;
+ std::vector<std::shared_ptr<exec::OpExecutor> > seg_execs;
+ for (size_t nid = start_nid; nid < end_nid; ++nid) {
+ const auto& node = idx[nid];
+ if (node.source->is_variable()) continue;
+ if (skip_plus_node.size() && skip_plus_node[nid]) continue;
+ auto& exec = execs[nid];
+ bool is_async = exec->exec_type() != ExecType::kSync;
+ bool valid = exec->out_array.size() > 0;
+
+ // Stop at async nodes and invalid node (due to input/output is not
allocated)
+ bool stop = is_async || !valid || seg_execs.size() >= bulk_size;
+ for (size_t i = 0; i < node.inputs.size() && !stop; ++i) {
+ if (excludes.count(idx.entry_id(node.inputs[i]))) stop = true;
+ }
+ auto num_outputs = node.source->num_outputs();
+ for (size_t i = 0; i < num_outputs && !stop; ++i) {
+ if (excludes.count(idx.entry_id(nid, i))) stop = true;
+ }
+
+ // Create opr segment for previous nodes.
+ if (stop && nid > seg_start) {
+ auto& seg = (*opr_segs)[seg_start];
+ if (seg_execs.size()) {
+ seg = EngineOprSeg{false, nid};
+ seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
} else {
- *arrays[i] = arrays[mem_plan[i].sid]->AsArray(shapes[i], dtypes[i]);
- if (mem_plan[i].inplace && array_reqs->at(i) == kWriteTo) {
- array_reqs->at(i) = kWriteInplace;
- }
+ seg = EngineOprSeg{true, nid, nullptr};
}
+ seg_start = nid;
+ seg_execs.clear();
+ }
+
+ seg_execs.push_back(exec);
+
+ auto& seg = (*opr_segs)[nid];
+ if (is_async) {
+ seg = EngineOprSeg{false, nid + 1};
+ seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
+ seg_execs.clear();
+ seg_start = nid + 1;
+ } else if (!valid) {
+ seg = EngineOprSeg{false, nid + 1, nullptr};
+ seg_execs.clear();
+ seg_start = nid + 1;
+ }
+ }
+ // The last segment
+ if (end_nid > seg_start) {
+ auto& seg = (*opr_segs)[seg_start];
+ if (seg_execs.size()) {
+ seg = EngineOprSeg{false, end_nid};
+ seg.opr.reset(CreateEngineOp(default_ctx, seg_execs));
} else {
- *arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
- shapes[i], default_ctx, true, dtypes[i]);
+ seg = EngineOprSeg{true, end_nid, nullptr};
}
}
}
+
+void RunGraph(const bool retain_graph,
+ const nnvm::IndexedGraph& idx,
+ const std::vector<NDArray*> arrays,
+ size_t node_start, size_t node_end,
+ std::vector<OpReqType>&& array_reqs,
+ std::vector<uint32_t>&& ref_count,
+ std::vector<OpStatePtr> *p_states,
+ const DispatchModeVector &dispatch_modes);
+
} // namespace imperative
} // namespace mxnet
diff --git a/tests/python/unittest/test_gluon.py
b/tests/python/unittest/test_gluon.py
index e540657..6fafb36 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -22,6 +22,7 @@ from mxnet.test_utils import assert_almost_equal
from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
from common import setup_module, with_seed, assertRaises, teardown
import numpy as np
+from numpy.testing import assert_array_equal
from nose.tools import raises, assert_raises
from copy import deepcopy
import warnings
@@ -1132,7 +1133,6 @@ def test_hybrid_multi_context():
net.hybridize()
net(mx.nd.zeros((1, 3, 32, 32), ctx=mx.cpu(0))).asnumpy()
-
@with_seed()
def test_zero_grad():
data = mx.nd.random.uniform(shape=(3,3))
@@ -1145,6 +1145,60 @@ def test_zero_grad():
grad = net.collect_params()['test_zero_grad_weight'].grad()
assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)
+def check_hybrid_static_memory(**kwargs):
+ x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
+ x.attach_grad()
+
+ net1 = gluon.model_zoo.vision.get_resnet(
+ 1, 18, pretrained=True, prefix='net_',
ctx=mx.context.current_context())
+ net2 = gluon.model_zoo.vision.get_resnet(
+ 1, 18, pretrained=True, prefix='net_',
ctx=mx.context.current_context())
+ net2.hybridize(**kwargs)
+ net1(x)
+ net2(x)
+
+ def test(net, x):
+ with mx.autograd.record():
+ y = net(x) + net(x)
+ y.backward()
+
+ grads = {k: v.grad() for k, v in net.collect_params().items() if
v.grad_req != 'null'}
+
+ return y, grads
+
+ y1, grads1 = test(net1, x)
+ y2, grads2 = test(net2, x)
+
+ assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
+ for key in grads1:
+ assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(),
rtol=1e-3, atol=1e-5)
+
+def test_hybrid_static_memory():
+ check_hybrid_static_memory()
+ check_hybrid_static_memory(static_alloc=True)
+ check_hybrid_static_memory(static_alloc=True, static_shape=True)
+
+def check_hybrid_static_memory_switching(**kwargs):
+ net = gluon.model_zoo.vision.get_resnet(
+ 1, 18, pretrained=True, ctx=mx.context.current_context())
+ net.hybridize(**kwargs)
+
+ x = mx.nd.random.uniform(shape=(4, 3, 32, 32))
+ net(x)
+ with mx.autograd.record():
+ y = net(x)
+ y.backward()
+ x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
+ net(x)
+ with mx.autograd.record():
+ y = net(x)
+ y.backward()
+ mx.nd.waitall()
+
+def test_hybrid_static_memory_switching():
+ check_hybrid_static_memory_switching()
+ check_hybrid_static_memory_switching(static_alloc=True)
+ check_hybrid_static_memory_switching(static_alloc=True, static_shape=True)
@with_seed()
def test_hook():
@@ -1239,6 +1293,17 @@ def test_legacy_save_params():
model.load_params('test.params', ctx=mx.cpu())
+def test_hybrid_static_memory_recording():
+ net = gluon.model_zoo.vision.get_resnet(
+ 1, 18, pretrained=True, ctx=mx.context.current_context())
+ net.hybridize(static_alloc=True)
+
+ x = mx.nd.random.uniform(shape=(1, 3, 32, 32))
+ with mx.autograd.record(True):
+ net(x)
+ net(x)
+
+
if __name__ == '__main__':
import nose
nose.runmodule()