This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 2f8c1e8 add inline for cached op & fixed a bug when calling backward
on variable (#8701)
2f8c1e8 is described below
commit 2f8c1e83f94e84a25a48d2cd43136030fb3f2d1e
Author: Eric Junyuan Xie <[email protected]>
AuthorDate: Tue Nov 28 10:44:47 2017 -0800
add inline for cached op & fixed a bug when calling backward on variable
(#8701)
* add inline for cached op & fixed a bug when calling backward on variable
* fix
* Update test_gluon.py
---
include/mxnet/c_api.h | 11 ++++++++--
include/mxnet/imperative.h | 37 +++++++++++++++++++++++++++----
python/mxnet/_ctypes/ndarray.py | 7 ++++--
python/mxnet/gluon/block.py | 20 ++++++++++++-----
python/mxnet/gluon/nn/basic_layers.py | 6 +++--
src/c_api/c_api_ndarray.cc | 31 ++++++++++++++++----------
src/imperative/cached_op.cc | 41 +++++++++++++++++++++++++++++------
src/imperative/imperative.cc | 16 +++++++++-----
tests/python/unittest/test_gluon.py | 27 ++++++++++++++++++++++-
9 files changed, 157 insertions(+), 39 deletions(-)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 9815786..faa4535 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -793,8 +793,15 @@ MXNET_DLL int MXAutogradGetSymbol(NDArrayHandle handle,
SymbolHandle *out);
/*!
* \brief create cached operator
*/
-MXNET_DLL int MXCreateCachedOp(SymbolHandle handle,
- CachedOpHandle *out);
+MXNET_DLL int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out);
+/*!
+ * \brief create cached operator
+ */
+MXNET_DLL int MXCreateCachedOpEx(SymbolHandle handle,
+ int num_params,
+ const char** keys,
+ const char** vals,
+ CachedOpHandle *out);
/*!
* \brief free cached operator
*/
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 88a9f4d..d605e9d 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -28,11 +28,30 @@
#include <nnvm/graph.h>
#include <vector>
#include <atomic>
+#include <utility>
+#include <string>
#include <unordered_map>
#include "./ndarray.h"
namespace mxnet {
+/*! \brief CachedOp Parameters */
+struct CachedOpParam : public dmlc::Parameter<CachedOpParam> {
+ uint32_t inline_limit;
+ uint32_t forward_bulk_size;
+ uint32_t backward_bulk_size;
+ DMLC_DECLARE_PARAMETER(CachedOpParam) {
+ DMLC_DECLARE_FIELD(inline_limit)
+ .set_default(2)
+ .describe("Maximum number of operators that can be inlined.");
+ DMLC_DECLARE_FIELD(forward_bulk_size)
+ .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+ .describe("Segment size of bulk execution during forward pass.");
+ DMLC_DECLARE_FIELD(backward_bulk_size)
+ .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+ .describe("Segment size of bulk execution during backward pass.");
+ }
+};
/*! \brief runtime functions for NDArray */
class Imperative {
public:
@@ -77,7 +96,8 @@ class Imperative {
};
class CachedOp {
public:
- explicit CachedOp(const nnvm::Symbol& sym);
+ CachedOp(const nnvm::Symbol& sym,
+ const std::vector<std::pair<std::string, std::string> >& kwargs);
uint32_t num_inputs() {
return fwd_graph_.indexed_graph().input_nodes().size();
}
@@ -103,8 +123,9 @@ class Imperative {
const std::vector<NDArray*>& inputs);
std::vector<nnvm::NodeEntry> Gradient(const nnvm::NodePtr& node,
const std::vector<nnvm::NodeEntry>&
ograds);
- OpStatePtr Forward(const std::vector<NDArray*>& inputs,
- const std::vector<NDArray*>& outputs);
+ void Forward(const std::shared_ptr<CachedOp>& op_ptr,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<NDArray*>& outputs);
void Backward(const bool retain_graph,
const OpStatePtr& state,
const std::vector<NDArray*>& inputs,
@@ -117,9 +138,11 @@ class Imperative {
std::vector<OpStatePtr> states;
};
std::mutex mutex_;
+ CachedOpParam param_;
nnvm::Graph fwd_graph_;
nnvm::Graph grad_graph_;
nnvm::Graph full_graph_;
+ bool inlining_;
std::vector<nnvm::NodeEntry> ograd_entries_;
std::vector<bool> curr_grad_req_;
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
@@ -182,7 +205,11 @@ class Imperative {
private:
friend class NDArray;
/*! \brief make constructor protected. */
- Imperative() {}
+ Imperative() {
+ if (dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)) {
+ backward_bulk_size_ =
dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
+ }
+ }
/*! \brief find the input/output ndarrays that are needed for backward */
void GetBackwardDependency(
const nnvm::NodePtr& node,
@@ -210,6 +237,8 @@ class Imperative {
std::atomic<uint64_t> node_count_{0};
/*! \brief variable count used for naming */
std::atomic<uint64_t> variable_count_{0};
+ /*! \brief default backward bulk size */
+ int backward_bulk_size_{0};
};
using CachedOpPtr = std::shared_ptr<Imperative::CachedOp>;
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index a0c01a6..20ad2bf 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -105,10 +105,13 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle"]
- def __init__(self, sym):
+ def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()
- check_call(_LIB.MXCreateCachedOp(
+ check_call(_LIB.MXCreateCachedOpEx(
sym.handle,
+ len(flags),
+ c_str_array([key for key, _ in flags]),
+ c_str_array([str(val) for _, val in flags]),
ctypes.byref(self.handle)))
def __del__(self):
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 466f87f..37734ac 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -274,7 +274,7 @@ class Block(object):
"""
self.collect_params().initialize(init, ctx, verbose)
- def hybridize(self, active=True):
+ def hybridize(self, active=True, **kwargs):
"""Activates or deactivates :py:class:`HybridBlock` s recursively. Has
no effect on
non-hybrid children.
@@ -282,9 +282,11 @@ class Block(object):
----------
active : bool, default True
Whether to turn hybrid on or off.
+ **kwargs : string
+ Additional flags for hybridized operator.
"""
for cld in self._children:
- cld.hybridize(active)
+ cld.hybridize(active, **kwargs)
def cast(self, dtype):
"""Cast this Block to use another data type.
@@ -343,6 +345,7 @@ class HybridBlock(Block):
self._out_format = None
self._in_format = None
self._active = False
+ self._flags = {}
def __setattr__(self, name, value):
"""Registers parameters."""
@@ -378,7 +381,7 @@ class HybridBlock(Block):
def _build_cache(self, *args):
inputs, out = self._get_graph(*args)
input_idx = {var.name: i for i, var in enumerate(inputs)}
- self._cached_op = ndarray.CachedOp(out)
+ self._cached_op = ndarray.CachedOp(out, self._flags)
params = dict(self.collect_params().items())
# verify graph inputs
@@ -437,9 +440,11 @@ class HybridBlock(Block):
super(HybridBlock, self).register_child(block)
self._clear_cached_op()
- def hybridize(self, active=True):
+ def hybridize(self, active=True, **kwargs):
self._active = active
- super(HybridBlock, self).hybridize(active)
+ self._flags = kwargs.items()
+ self._clear_cached_op()
+ super(HybridBlock, self).hybridize(active, **kwargs)
def cast(self, dtype):
self._clear_cached_op()
@@ -615,5 +620,10 @@ class SymbolBlock(HybridBlock):
ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0],
args)})
return _regroup(list(ret), self._out_format)[0]
+ def _clear_cached_op(self):
+ tmp = self._cached_graph
+ super(SymbolBlock, self)._clear_cached_op()
+ self._cached_graph = tmp
+
def hybrid_forward(self, F, x, *args, **kwargs):
raise NotImplementedError
diff --git a/python/mxnet/gluon/nn/basic_layers.py
b/python/mxnet/gluon/nn/basic_layers.py
index c0b4b52..ab5d5e1 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -68,7 +68,7 @@ class Sequential(Block):
def __len__(self):
return len(self._children)
- def hybridize(self, active=True):
+ def hybridize(self, active=True, **kwargs):
"""Activates or deactivates `HybridBlock`s recursively. Has no effect
on
non-hybrid children.
@@ -76,11 +76,13 @@ class Sequential(Block):
----------
active : bool, default True
Whether to turn hybrid on or off.
+ **kwargs : string
+ Additional flags for hybridized operator.
"""
if self._children and all(isinstance(c, HybridBlock) for c in
self._children):
warnings.warn('All children of this Sequential layer are
HybridBlocks. Consider ' \
'using HybridSequential for the best performance.')
- super(Sequential, self).hybridize(active)
+ super(Sequential, self).hybridize(active, **kwargs)
class HybridSequential(HybridBlock):
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index 2c4a305..51f30e2 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -157,7 +157,25 @@ int MXCreateCachedOp(SymbolHandle handle,
API_BEGIN();
*out = new std::shared_ptr<Imperative::CachedOp>(
- new Imperative::CachedOp(*sym));
+ new Imperative::CachedOp(
+ *sym, std::vector<std::pair<std::string, std::string> >()));
+ API_END();
+}
+
+int MXCreateCachedOpEx(SymbolHandle handle,
+ int num_params,
+ const char** keys,
+ const char** vals,
+ CachedOpHandle *out) {
+ nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(handle);
+
+ API_BEGIN();
+ std::vector<std::pair<std::string, std::string> > kwargs;
+ for (int i = 0; i < num_params; ++i) {
+ kwargs.push_back({keys[i], vals[i]});
+ }
+ *out = new std::shared_ptr<Imperative::CachedOp>(
+ new Imperative::CachedOp(*sym, kwargs));
API_END();
}
@@ -198,16 +216,7 @@ int MXInvokeCachedOp(CachedOpHandle handle,
}
}
- OpStatePtr state = op->Forward(ndinputs, ndoutputs);
- if (Imperative::Get()->is_recording()) {
- nnvm::NodeAttrs attrs;
- attrs.op = cached_op;
- attrs.name = "_cachedop";
- attrs.parsed = op;
- Imperative::Get()->RecordOp(
- std::move(attrs), ndinputs, ndoutputs, state,
- &op->save_inputs(), &op->save_outputs());
- }
+ op->Forward(op, ndinputs, ndoutputs);
if (*outputs == nullptr) {
ret->ret_handles.clear();
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 5717327..eaa95a5 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -22,12 +22,18 @@
namespace mxnet {
-Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
+DMLC_REGISTER_PARAMETER(CachedOpParam);
+
+Imperative::CachedOp::CachedOp(
+ const nnvm::Symbol& sym,
+ const std::vector<std::pair<std::string, std::string> >& kwargs) {
using namespace nnvm;
using namespace imperative;
static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"),
Op::Get("_zeros")};
static const auto _copy = Op::Get("_copy");
+ param_.Init(kwargs);
+
// construct forward graph
{
NodeEntryMap<int> dedup_out;
@@ -59,6 +65,8 @@ Imperative::CachedOp::CachedOp(const nnvm::Symbol& sym) {
fwd_graph_.attrs["forward_ref_count"] =
std::make_shared<dmlc::any>(std::move(ref_count));
+
+ inlining_ = (idx.num_nodes() - idx.input_nodes().size()) <=
param_.inline_limit;
}
// construct backward graph
@@ -321,13 +329,16 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
return g;
}
-OpStatePtr Imperative::CachedOp::Forward(const std::vector<NDArray*>& inputs,
- const std::vector<NDArray*>& outputs)
{
+void Imperative::CachedOp::Forward(
+ const std::shared_ptr<CachedOp>& op_ptr,
+ const std::vector<NDArray*>& inputs,
+ const std::vector<NDArray*>& outputs) {
using namespace nnvm;
using namespace imperative;
+ static const auto cached_op = nnvm::Op::Get("_CachedOp");
- bool recording = Imperative::Get()->set_is_recording(false);
// Initialize
+ bool recording = Imperative::Get()->is_recording();
nnvm::Graph g = GetForwardGraph(recording, inputs);
const auto& idx = g.indexed_graph();
size_t num_inputs = idx.input_nodes().size();
@@ -381,10 +392,16 @@ OpStatePtr Imperative::CachedOp::Forward(const
std::vector<NDArray*>& inputs,
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
+ if (recording && !inlining_) Imperative::Get()->set_is_recording(false);
+ int prev_bulk_size = Engine::Get()->set_bulk_size(param_.forward_bulk_size);
+
Imperative::Get()->RunGraph(
false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes);
+ Engine::Get()->set_bulk_size(prev_bulk_size);
+ Imperative::Get()->set_is_recording(recording);
+
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (arrays[i] == &buff[i]) continue;
buff[i].shape_ = arrays[i]->shape_;
@@ -392,9 +409,15 @@ OpStatePtr Imperative::CachedOp::Forward(const
std::vector<NDArray*>& inputs,
buff[i].storage_type_ = arrays[i]->storage_type_;
}
- Imperative::Get()->set_is_recording(recording);
-
- return op_state_ptr;
+ if (recording && !inlining_) {
+ nnvm::NodeAttrs attrs;
+ attrs.op = cached_op;
+ attrs.name = "_cachedop";
+ attrs.parsed = op_ptr;
+ Imperative::Get()->RecordOp(
+ std::move(attrs), inputs, outputs, op_state_ptr,
+ &save_inputs(), &save_outputs());
+ }
}
@@ -452,10 +475,14 @@ void Imperative::CachedOp::Backward(
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
+ int prev_bulk_size = Engine::Get()->set_bulk_size(param_.backward_bulk_size);
+
Imperative::Get()->RunGraph(
retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
+ Engine::Get()->set_bulk_size(prev_bulk_size);
+
if (retain_graph) {
buff.resize(num_forward_entries);
} else {
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index 361b971..fbbaf82 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -288,8 +288,6 @@ void Imperative::RunGraph(
DTypeVector arg_dtypes;
std::vector<OpReqType> req;
- int prev_bulk_size = Engine::Get()->set_bulk_size(10);
-
for (size_t i = node_start; i < node_end; ++i) {
const nnvm::IndexedGraph::Node& node = idx[i];
if (node.source->op() == nullptr) continue;
@@ -353,8 +351,6 @@ void Imperative::RunGraph(
if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
}
}
-
- Engine::Get()->set_bulk_size(prev_bulk_size);
}
@@ -367,6 +363,7 @@ std::vector<NDArray*> Imperative::Backward(
using namespace nnvm;
using namespace imperative;
static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"),
Op::Get("_zeros")};
+ static const Op* copy_op = Op::Get("_copy");
// Construct forward graph
Graph graph;
@@ -439,7 +436,14 @@ std::vector<NDArray*> Imperative::Backward(
zero_ops, "_copy");
CHECK_EQ(g_graph.outputs.size(), xs.size());
for (const auto &e : g_graph.outputs) {
- graph.outputs.push_back(e);
+ if (e.node->op() == nullptr) {
+ auto node = Node::Create();
+ node->attrs.op = copy_op;
+ node->inputs.push_back(e);
+ graph.outputs.push_back(NodeEntry{node, 0, 0});
+ } else {
+ graph.outputs.push_back(e);
+ }
}
const auto& idx = graph.indexed_graph();
// get number of nodes used in forward pass
@@ -575,10 +579,12 @@ std::vector<NDArray*> Imperative::Backward(
bool prev_recording = set_is_recording(create_graph);
bool prev_training = set_is_training(is_train);
+ int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_);
RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states,
dispatch_modes);
+ Engine::Get()->set_bulk_size(prev_bulk_size);
set_is_recording(prev_recording);
set_is_training(prev_training);
diff --git a/tests/python/unittest/test_gluon.py
b/tests/python/unittest/test_gluon.py
index df9f78e..c619056 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -23,6 +23,7 @@ import numpy as np
from nose.tools import raises
from copy import deepcopy
import warnings
+import json
def test_parameter():
@@ -256,7 +257,6 @@ def test_deconv():
# # check_layer_forward(layer, (1, 10, 10, 10, 4))
-
def test_pool():
layers1d = [
nn.MaxPool1D(),
@@ -611,6 +611,31 @@ def test_fill_shape_load():
assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1]
+def test_inline():
+ net = mx.gluon.nn.HybridSequential()
+ with net.name_scope():
+ net.add(mx.gluon.nn.Dense(10))
+ net.add(mx.gluon.nn.Dense(10))
+ net.add(mx.gluon.nn.Dense(10))
+
+ net.initialize()
+ net.hybridize(inline_limit=3)
+ with mx.autograd.record():
+ y = net(mx.nd.zeros((1,10)))
+
+ len_1 = len(json.loads(mx.autograd.get_symbol(y).tojson())['nodes'])
+ y.backward()
+
+ net.hybridize(inline_limit=0)
+ with mx.autograd.record():
+ y = net(mx.nd.zeros((1,10)))
+
+ len_2 = len(json.loads(mx.autograd.get_symbol(y).tojson())['nodes'])
+ y.backward()
+
+ assert len_1 == len_2 + 2
+
+
if __name__ == '__main__':
import nose
nose.runmodule()
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].