This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7a83a75 [MXNET-555] Add subgraph storage type inference to CachedOp
(#11306)
7a83a75 is described below
commit 7a83a7589e46cae1ec98d43c71d001797ad57b08
Author: Haibin Lin <[email protected]>
AuthorDate: Wed Jun 20 17:31:35 2018 -0700
[MXNET-555] Add subgraph storage type inference to CachedOp (#11306)
* copy paste
* pass unit test
* remove lock
* save all inputs and outputs
* add one more test
* update test
* update backward stype inference
* + fwd inference
---
src/imperative/cached_op.cc | 147 +++++++++++++++++++++++++++++-------
src/imperative/cached_op.h | 24 ++++--
src/imperative/imperative_utils.h | 1 -
src/operator/operator_common.h | 4 +-
tests/python/unittest/test_gluon.py | 57 ++++++++++++++
5 files changed, 197 insertions(+), 36 deletions(-)
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index c0e5e83..5a3d44c 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -22,6 +22,7 @@
#include "./cached_op.h"
#include "../executor/exec_pass.h"
#include "../profiler/profiler.h"
+#include "../operator/operator_common.h"
namespace mxnet {
@@ -95,7 +96,6 @@ CachedOp::CachedOp(
using namespace imperative;
static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"),
Op::Get("_zeros")};
static const auto _copy = Op::Get("_copy");
-
config_.Init(flags);
if (config_.static_shape) {
@@ -204,26 +204,17 @@ CachedOp::CachedOp(
size_t num_forward_outputs = num_outputs();
for (uint32_t i = 0; i < ograd_entries_.size(); ++i) {
if (!idx.exist(ograd_entries_[i].node.get())) continue;
- auto eid = idx.entry_id(ograd_entries_[i]);
- if (ref_count[eid] > 0) {
- bwd_ograd_dep_.push_back(i);
- }
+ bwd_ograd_dep_.push_back(i);
}
save_inputs_.resize(num_forward_inputs, false);
for (uint32_t i = 0; i < num_forward_inputs; ++i) {
- auto eid = idx.entry_id(idx.input_nodes()[i], 0);
- if (ref_count[eid] > 0) {
- save_inputs_[i] = true;
- bwd_in_dep_.push_back(i);
- }
+ save_inputs_[i] = true;
+ bwd_in_dep_.push_back(i);
}
save_outputs_.resize(idx.outputs().size(), false);
for (uint32_t i = 0; i < num_forward_outputs; ++i) {
- auto eid = idx.entry_id(idx.outputs()[i]);
- if (ref_count[eid] > 0) {
- save_outputs_[i] = true;
- bwd_out_dep_.push_back(i);
- }
+ save_outputs_[i] = true;
+ bwd_out_dep_.push_back(i);
}
}
}
@@ -233,7 +224,7 @@ CachedOp::~CachedOp() {
std::vector<nnvm::NodeEntry> CachedOp::Gradient(
const nnvm::NodePtr& node,
- const std::vector<nnvm::NodeEntry>& ograds) {
+ const std::vector<nnvm::NodeEntry>& ograds) const {
using namespace nnvm;
static const auto _backward_CachedOp = Op::Get("_backward_CachedOp");
static const auto _NoGrad = Op::Get("_NoGradient");
@@ -328,6 +319,27 @@ bool CachedOp::SetForwardGraph(
return false;
}
+// Utility function to set backward input eids
+void SetBackwardInputEid(const std::vector<uint32_t>& bwd_in_dep,
+ const std::vector<uint32_t>& bwd_out_dep,
+ const std::vector<uint32_t>& bwd_ograd_dep,
+ const std::vector<nnvm::NodeEntry>& ograd_entries,
+ const nnvm::IndexedGraph& idx,
+ std::vector<uint32_t> *bwd_input_eid) {
+ for (const auto& i : bwd_ograd_dep) {
+ auto eid = idx.entry_id(ograd_entries[i]);
+ bwd_input_eid->push_back(eid);
+ }
+ for (const auto& i : bwd_in_dep) {
+ auto eid = idx.entry_id(idx.input_nodes()[i], 0);
+ bwd_input_eid->push_back(eid);
+ }
+ for (const auto& i : bwd_out_dep) {
+ auto eid = idx.entry_id(idx.outputs()[i]);
+ bwd_input_eid->push_back(eid);
+ }
+}
+
bool CachedOp::SetBackwardGraph(
GraphInfo* info,
const std::vector<OpReqType>& reqs,
@@ -356,18 +368,8 @@ bool CachedOp::SetBackwardGraph(
if (info->bwd_input_eid.size() != inputs.size()) {
info->bwd_input_eid.clear();
- for (const auto& i : bwd_ograd_dep_) {
- auto eid = idx.entry_id(ograd_entries_[i]);
- info->bwd_input_eid.push_back(eid);
- }
- for (const auto& i : bwd_in_dep_) {
- auto eid = idx.entry_id(idx.input_nodes()[i], 0);
- info->bwd_input_eid.push_back(eid);
- }
- for (const auto& i : bwd_out_dep_) {
- auto eid = idx.entry_id(idx.outputs()[i]);
- info->bwd_input_eid.push_back(eid);
- }
+ SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_,
+ ograd_entries_, idx, &info->bwd_input_eid);
CHECK_EQ(inputs.size(), info->bwd_input_eid.size());
}
@@ -1019,6 +1021,79 @@ void CachedOp::Backward(
Engine::Get()->set_bulk_size(prev_bulk_size);
}
+bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ using namespace imperative;
+ nnvm::Graph g(fwd_graph_);
+ const auto& idx = g.indexed_graph();
+ const auto &outputs = idx.outputs();
+
+ // Prepare stypes and contexts based on inputs
+ StorageTypeVector storage_type_inputs;
+ storage_type_inputs.reserve(in_attrs->size());
+ for (size_t i = 0; i < in_attrs->size(); ++i) {
+ storage_type_inputs.emplace_back(in_attrs->at(i));
+ }
+ exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
+
+ // Forward graph storage type inference
+ CheckAndInferStorageType(&g, std::move(dev_masks),
std::move(storage_type_inputs), true);
+ // Retrieve result and set outputs
+ const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
+ for (size_t i = 0; i < out_attrs->size(); i++) {
+ const auto eid = idx.entry_id(outputs[i]);
+ STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
+ }
+ DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+ return true;
+}
+
+bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
+ using namespace imperative;
+ nnvm::Graph g(full_graph_);
+ const auto& idx = g.indexed_graph();
+ const auto &outputs = idx.outputs();
+ const size_t num_forward_outputs = fwd_graph_.outputs.size();
+ CHECK_EQ(outputs.size(), num_forward_outputs + out_attrs->size());
+
+ // Construct bwd_input_eid
+ std::vector<uint32_t> bwd_input_eid;
+ SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_,
+ ograd_entries_, idx, &bwd_input_eid);
+ CHECK_EQ(in_attrs->size(), bwd_input_eid.size());
+
+ // Prepare stypes and contexts based on inputs
+ StorageTypeVector stypes(idx.num_node_entries(), -1);
+ for (size_t i = 0; i < in_attrs->size(); ++i) {
+ stypes[bwd_input_eid[i]] = in_attrs->at(i);
+ }
+ // Some out_attr is known ahead of time (e.g. the grad stype is given by
users).
+ // Prepare these to before invoking infer storage on the subgraph
+ for (size_t i = 0; i < out_attrs->size(); i++) {
+ const auto eid = idx.entry_id(outputs[i + num_forward_outputs]);
+ stypes[eid] = out_attrs->at(i);
+ }
+ exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
+
+ // Full graph storage type inference
+ CheckAndInferStorageType(&g, std::move(dev_masks), std::move(stypes), false);
+ // Retrieve result and set outputs
+ const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
+ for (size_t i = 0; i < out_attrs->size(); i++) {
+ const auto eid = idx.entry_id(outputs[i + num_forward_outputs]);
+ STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
+ }
+ DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+ return true;
+}
+
NNVM_REGISTER_OP(_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs) {
@@ -1029,6 +1104,14 @@ NNVM_REGISTER_OP(_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->num_outputs();
})
+.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs&
attrs,
+ const int dev_mask,
+ DispatchMode*
dispatch_mode,
+ std::vector<int>
*in_attrs,
+ std::vector<int>
*out_attrs) {
+ const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+ return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs,
out_attrs);
+ })
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed);
@@ -1044,6 +1127,14 @@ NNVM_REGISTER_OP(_backward_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->num_inputs() - op->mutable_input_nodes().size();
})
+.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs&
attrs,
+ const int dev_mask,
+ DispatchMode*
dispatch_mode,
+ std::vector<int>
*in_attrs,
+ std::vector<int>
*out_attrs) {
+ const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
+ return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs,
out_attrs);
+ })
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<bool>("TIsBackward", true);
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index 60a40c5..6b94c67 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -71,13 +71,13 @@ class CachedOp {
const nnvm::Symbol& sym,
const std::vector<std::pair<std::string, std::string> >& flags);
~CachedOp();
- uint32_t num_inputs() {
+ uint32_t num_inputs() const {
return fwd_graph_.indexed_graph().input_nodes().size();
}
- uint32_t num_outputs() {
+ uint32_t num_outputs() const {
return fwd_graph_.outputs.size();
}
- uint32_t num_backward_inputs() {
+ uint32_t num_backward_inputs() const {
return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
}
std::vector<bool>& save_inputs() {
@@ -86,12 +86,12 @@ class CachedOp {
std::vector<bool>& save_outputs() {
return save_outputs_;
}
- const std::unordered_set<uint32_t>& mutable_input_nodes() {
+ const std::unordered_set<uint32_t>& mutable_input_nodes() const {
return fwd_graph_.indexed_graph().mutable_input_nodes();
}
std::vector<nnvm::NodeEntry> Gradient(
const nnvm::NodePtr& node,
- const std::vector<nnvm::NodeEntry>& ograds);
+ const std::vector<nnvm::NodeEntry>& ograds) const;
void Forward(
const std::shared_ptr<CachedOp>& op_ptr,
const std::vector<NDArray*>& inputs,
@@ -102,6 +102,20 @@ class CachedOp {
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
+ // forward storage type inference
+ bool ForwardStorageType(
+ const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs);
+ // backward storage type inference
+ bool BackwardStorageType(
+ const nnvm::NodeAttrs& attrs,
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs);
private:
struct GraphInfo;
diff --git a/src/imperative/imperative_utils.h
b/src/imperative/imperative_utils.h
index 726531d..faff5f1 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -668,7 +668,6 @@ inline bool CheckAndInferStorageType(nnvm::Graph* p_g,
exec::DevMaskVector&& dev
g.attrs["storage_type"] =
std::make_shared<dmlc::any>(std::move(storage_types));
g = exec::InferStorageType(std::move(g));
}
-
CHECK_EQ(g.GetAttr<size_t>("storage_type_num_unknown_nodes"), 0U);
return false;
}
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index 0a9cd08..02130eb 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -256,7 +256,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const
DispatchMode& x) {
*/
#define STORAGE_TYPE_ASSIGN_CHECK(type_array, index, type) \
{ \
- if (!type_assign(&(type_array)[index], type)) { \
+ if (!::mxnet::op::type_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Storage type inconsistent, Provided = " \
<< common::stype_string((type_array)[index]) << ',' \
@@ -274,7 +274,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const
DispatchMode& x) {
*/
#define DISPATCH_MODE_ASSIGN_CHECK(type_array, index, type) \
{ \
- if (!dispatch_mode_assign(&(type_array)[index], type)) { \
+ if (!::mxnet::op::dispatch_mode_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Dispatch mode inconsistent, Provided = " \
<< common::dispatch_mode_string((type_array)[index]) << ',' \
diff --git a/tests/python/unittest/test_gluon.py
b/tests/python/unittest/test_gluon.py
index 6fafb36..cd3cc68 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -1293,6 +1293,63 @@ def test_legacy_save_params():
model.load_params('test.params', ctx=mx.cpu())
+@with_seed()
+def test_sparse_hybrid_block_grad():
+ class Embedding(mx.gluon.HybridBlock):
+ def __init__(self, num_tokens, embedding_size):
+ super(Embedding, self).__init__()
+ self.num_tokens = num_tokens
+
+ with self.name_scope():
+ self.embedding = mx.gluon.nn.Embedding(
+ num_tokens, embedding_size, sparse_grad=True)
+
+ def hybrid_forward(self, F, words):
+ emb = self.embedding(words)
+ return emb + F.ones_like(emb)
+
+ embedding = Embedding(20, 3)
+ embedding.initialize()
+ embedding.hybridize()
+
+ with mx.autograd.record():
+ emb0 = embedding(mx.nd.arange(10)).sum()
+ emb1 = embedding(mx.nd.arange(10)).sum()
+ loss = emb0 + emb1
+ loss.backward()
+ grad = embedding.embedding.weight.grad().asnumpy()
+ assert (grad[:10] == 2).all()
+ assert (grad[10:] == 0).all()
+
+@with_seed()
+def test_sparse_hybrid_block():
+ class Linear(mx.gluon.HybridBlock):
+ def __init__(self, units):
+ super(Linear, self).__init__()
+ with self.name_scope():
+ self.w = self.params.get('w', shape=(units, units))
+
+ def hybrid_forward(self, F, x, w):
+ return F.dot(x, w)
+
+ class SparseBlock(mx.gluon.HybridBlock):
+ def __init__(self, units):
+ super(SparseBlock, self).__init__()
+ with self.name_scope():
+ self.net = Linear(units)
+
+ def hybrid_forward(self, F, x):
+ return self.net(x) * x
+
+ block = SparseBlock(2)
+ block.initialize()
+ block.hybridize()
+ x = mx.nd.ones((2,2)).tostype('csr')
+ with mx.autograd.record():
+ z = block(x) + block(x)
+ z.backward()
+ assert (block.net.w.grad().asnumpy() == 4).all()
+
def test_hybrid_static_memory_recording():
net = gluon.model_zoo.vision.get_resnet(
1, 18, pretrained=True, ctx=mx.context.current_context())