eric-haibin-lin closed pull request #10882: move exec.reshape to backend
URL: https://github.com/apache/incubator-mxnet/pull/10882
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 9ac90d68c67..940c962ddba 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1647,6 +1647,47 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle
symbol_handle,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);
+
+/*!
+ * \brief Return a new executor with the same symbol and shared memory,
+ * but different input/output shapes.
+ *
+ * \param partial_shaping Whether to allow changing the shape of unspecified
arguments.
+ * \param allow_up_sizing Whether to allow allocating new ndarrays that's
larger than the original.
+ * \param dev_type device type of default context
+ * \param dev_id device id of default context
+ * \param num_map_keys size of group2ctx map
+ * \param map_keys keys of group2ctx map
+ * \param map_dev_types device type of group2ctx map
+ * \param map_dev_ids device id of group2ctx map
+ * \param num_in_args length of in_args
+ * \param in_args in args array
+ * \param arg_grads arg grads handle array
+ * \param num_aux_states length of auxiliary states
+ * \param aux_states auxiliary states array
+ * \param shared_exec input executor handle for memory sharing
+ * \param out output executor handle
+ * \return a new executor
+ */
+MXNET_DLL int MXExecutorReshape(int partial_shaping,
+ int allow_up_sizing,
+ int dev_type,
+ int dev_id,
+ mx_uint num_map_keys,
+ const char** map_keys,
+ const int* map_dev_types,
+ const int* map_dev_ids,
+ const mx_uint num_provided_arg_shapes,
+ const char** provided_arg_shape_names,
+ const mx_uint* provided_arg_shape_data,
+ const mx_uint* provided_arg_shape_idx,
+ mx_uint* num_in_args,
+ NDArrayHandle** in_args,
+ NDArrayHandle** arg_grads,
+ mx_uint* num_aux_states,
+ NDArrayHandle** aux_states,
+ ExecutorHandle shared_exec,
+ ExecutorHandle *out);
/*!
* \brief set a call back to notify the completion of operation
*/
diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h
index d749100f5de..842653f8653 100644
--- a/include/mxnet/executor.h
+++ b/include/mxnet/executor.h
@@ -103,6 +103,29 @@ class Executor {
* \return aux state map in the executor.
*/
virtual const std::unordered_map<std::string, NDArray>& aux_state_map()
const = 0;
+ /*!
+ * \brief Return a new executor with the same symbol and shared memory,
+ * but different input/output shapes.
+ *
+ * \param partial_shaping Whether to allow changing the shape of unspecified
arguments.
+ * \param allow_up_sizing Whether to allow allocating new ndarrays that's
larger than the original.
+ * \param default_ctx the default context of binding.
+ * \param ctx_map Context mapping group to context.
+ * \param provided_arg_shapes New shape for arguments.
+ * \param in_args the NDArray that stores the input arguments.
+ * \param arg_grads NDArray that is used to store the gradient output of the
input arguments.
+ * \param aux_states NDArray that is used as internal states.
+ * \return a new executor.
+ */
+ virtual Executor* Reshape(const bool partial_shaping,
+ const bool allow_up_sizing,
+ const Context& default_ctx,
+ const std::map<std::string, Context>& ctx_map,
+ const std::unordered_map<std::string, TShape>&
+ provided_arg_shapes,
+ std::vector<NDArray>* in_args,
+ std::vector<NDArray>* arg_grads,
+ std::vector<NDArray>* aux_states) = 0;
/*!
* \brief Create an operator by bind symbol with context and arguments.
* If user do not want to compute the gradients of i-th argument,
grad_req_type[i] can be kNullOp.
diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py
index 579e6d3e35b..c0272c5bb43 100644
--- a/python/mxnet/executor.py
+++ b/python/mxnet/executor.py
@@ -20,15 +20,15 @@
"""Symbolic Executor component of MXNet."""
from __future__ import absolute_import
+from array import array as py_array
import ctypes
import copy
import numpy as np
from .base import _LIB
-from .base import mx_uint, NDArrayHandle, ExecutorHandle
-from .base import check_call, c_handle_array, py_str
+from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str
+from .base import check_call, c_handle_array, c_array_buf, c_str_array
from .ndarray import NDArray
from .ndarray import _ndarray_cls
-from . import ndarray as nd
# those functions are not used here, we just import them to keep backward
compatibility
# in case the end user calls them, as they originally lives here
@@ -399,62 +399,73 @@ def reshape(self, partial_shaping=False,
allow_up_sizing=False, **kwargs):
>>> texec.reshape(allow_up_sizing=True, **new_shape)
"""
# pylint: disable=too-many-branches
- arg_shapes, _, aux_shapes = self._symbol.infer_shape(**kwargs)
- if arg_shapes is None:
- raise ValueError("Insufficient argument shapes provided.")
-
- new_arg_dict = {}
- new_grad_dict = {}
- for i, name in enumerate(self._symbol.list_arguments()):
- new_shape = arg_shapes[i]
- arr = self.arg_arrays[i]
- darr = None if self.grad_arrays is None else self.grad_arrays[i]
- if partial_shaping or name in kwargs or new_shape == arr.shape:
- if np.prod(new_shape) > np.prod(arr.shape):
- assert allow_up_sizing, "New shape of arg:%s larger than
original. "%name + \
- "First making a big executor and then down sizing it "
+ \
- "is more efficient than the reverse." + \
- "If you really want to up size, set
allow_up_sizing=True " + \
- "to enable allocation of new arrays."
- new_arg_dict[name] = nd.empty(new_shape, ctx=arr.context,
dtype=arr.dtype)
- if darr is not None:
- new_grad_dict[name] = nd.empty(new_shape,
ctx=darr.context, dtype=arr.dtype)
- else:
- new_arg_dict[name] = arr.reshape(new_shape)
- if darr is not None:
- new_grad_dict[name] = darr.reshape(new_shape)
- else:
- raise AssertionError("Shape of unspecified array arg:%s
changed. "%name + \
- "This can cause the new executor to not share parameters "
+ \
- "with the old one. Please check for error in network." +\
- "If this is intended, set partial_shaping=True to suppress
this warning.")
-
- new_aux_dict = {}
- for name, new_shape, arr in zip(self._symbol.list_auxiliary_states(),
- aux_shapes, self.aux_arrays):
- if partial_shaping or new_shape == arr.shape:
- if np.prod(new_shape) > np.prod(arr.shape):
- assert allow_up_sizing, "New shape of arg:%s larger than
original. "%name + \
- "First making a big executor and then down sizing it "
+ \
- "is more efficient than the reverse." + \
- "If you really want to up size, set
allow_up_sizing=True " + \
- "to enable allocation of new arrays."
- new_aux_dict[name] = nd.empty(new_shape, ctx=arr.context,
dtype=arr.dtype)
- else:
- new_aux_dict[name] = arr.reshape(new_shape)
- else:
- raise AssertionError("Shape of unspecified array aux:%s
changed. "%name + \
- "This can cause the new executor to not share parameters "
+ \
- "with the old one. Please check for error in network." +\
- "If this is intended, set partial_shaping=True to suppress
this warning.")
-
- return self._symbol.bind(self._ctx,
- args=new_arg_dict,
- args_grad=new_grad_dict,
- grad_req=self._grad_req,
- aux_states=new_aux_dict,
- group2ctx=self._group2ctx,
- shared_exec=self)
+ provided_arg_shape_data = [] # shape data
+ # argument shape index in sdata,
+ # e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first
arg
+ provided_arg_shape_idx = [0]
+ provided_arg_shape_names = [] # provided argument names
+ for k, v in kwargs.items():
+ if isinstance(v, tuple):
+ provided_arg_shape_names.append(k)
+ provided_arg_shape_data.extend(v)
+ provided_arg_shape_idx.append(len(provided_arg_shape_data))
+
+ ctx_map_keys = []
+ ctx_map_dev_types = []
+ ctx_map_dev_ids = []
+
+ if self._group2ctx:
+ for key, val in self._group2ctx.items():
+ ctx_map_keys.append(key)
+ ctx_map_dev_types.append(val.device_typeid)
+ ctx_map_dev_ids.append(val.device_id)
+
+ handle = ExecutorHandle()
+ shared_handle = self.handle
+
+ num_in_args = ctypes.c_uint()
+ in_arg_handles = ctypes.POINTER(NDArrayHandle)()
+ arg_grad_handles = ctypes.POINTER(NDArrayHandle)()
+ num_aux_states = ctypes.c_uint()
+ aux_state_handles = ctypes.POINTER(NDArrayHandle)()
+
+ check_call(_LIB.MXExecutorReshape(ctypes.c_int(int(partial_shaping)),
+ ctypes.c_int(int(allow_up_sizing)),
+
ctypes.c_int(self._ctx.device_typeid),
+ ctypes.c_int(self._ctx.device_id),
+ mx_uint(len(ctx_map_keys)),
+ c_str_array(ctx_map_keys),
+ c_array_buf(ctypes.c_int,
+ py_array('i',
ctx_map_dev_types)),
+ c_array_buf(ctypes.c_int,
+ py_array('i',
ctx_map_dev_ids)),
+
mx_uint(len(provided_arg_shape_names)),
+
c_str_array(provided_arg_shape_names),
+ c_array_buf(mx_uint,
+ py_array('I',
provided_arg_shape_data)),
+ c_array_buf(mx_uint,
+ py_array('I',
provided_arg_shape_idx)),
+ ctypes.byref(num_in_args),
+ ctypes.byref(in_arg_handles),
+ ctypes.byref(arg_grad_handles),
+ ctypes.byref(num_aux_states),
+ ctypes.byref(aux_state_handles),
+ shared_handle,
+ ctypes.byref(handle)))
+
+ arg_arrays = [_ndarray_cls(NDArrayHandle(in_arg_handles[i]))
+ for i in range(num_in_args.value)]
+ grad_arrays = [_ndarray_cls(NDArrayHandle(arg_grad_handles[i]))
+ if arg_grad_handles[i] is not None
+ else None for i in range(num_in_args.value)]
+ aux_arrays = [_ndarray_cls(NDArrayHandle(aux_state_handles[i]))
+ for i in range(num_aux_states.value)]
+
+ executor = Executor(handle, self._symbol, self._ctx, self._grad_req,
self._group2ctx)
+ executor.arg_arrays = arg_arrays
+ executor.grad_arrays = grad_arrays
+ executor.aux_arrays = aux_arrays
+ return executor
def debug_str(self):
"""Get a debug string about internal execution plan.
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 1ab7cf87bf5..732c1a31f6a 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -1259,9 +1259,12 @@ def _get_ndarray_inputs(arg_key, args, arg_names,
allow_missing):
if len(args) != len(arg_names):
raise ValueError('Length of %s does not match the number of
arguments' % arg_key)
for narr in args:
- if not isinstance(narr, NDArray):
+ if narr is None and allow_missing:
+ arg_handles.append(None)
+ elif not isinstance(narr, NDArray):
raise TypeError('Only accept list of NDArrays or dict of
str to NDArray')
- arg_handles.append(narr.handle)
+ else:
+ arg_handles.append(narr.handle)
arg_arrays = args
elif isinstance(args, dict):
for name in arg_names:
diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc
index 40df49144fa..09bc23934e5 100644
--- a/src/c_api/c_api_executor.cc
+++ b/src/c_api/c_api_executor.cc
@@ -510,6 +510,93 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
API_END();
}
+int MXExecutorReshape(int partial_shaping,
+ int allow_up_sizing,
+ int dev_type,
+ int dev_id,
+ mx_uint num_map_keys,
+ const char** map_keys,
+ const int* map_dev_types,
+ const int* map_dev_ids,
+ const mx_uint num_provided_arg_shapes,
+ const char** provided_arg_shape_names,
+ const mx_uint* provided_arg_shape_data,
+ const mx_uint* provided_arg_shape_idx,
+ mx_uint* num_in_args,
+ NDArrayHandle** in_args,
+ NDArrayHandle** arg_grads,
+ mx_uint* num_aux_states,
+ NDArrayHandle** aux_states,
+ ExecutorHandle shared_exec,
+ ExecutorHandle *out) {
+ MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+ API_BEGIN();
+ // create shape map for in_args and aux_states
+ std::unordered_map<std::string, TShape> kwargs(num_provided_arg_shapes);
+ for (mx_uint i = 0; i < num_provided_arg_shapes; ++i) {
+ auto p = kwargs.emplace(provided_arg_shape_names[i],
+ TShape(provided_arg_shape_data+provided_arg_shape_idx[i],
+ provided_arg_shape_data+provided_arg_shape_idx[i+1]));
+ CHECK(p.second) << "Duplicate shapes are provided for argument "
+ << provided_arg_shape_names[i] << " in reshape of executor";
+ }
+
+ Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type),
dev_id);
+ std::map<std::string, Context> ctx_map;
+ for (mx_uint i = 0; i < num_map_keys; ++i) {
+ ctx_map[std::string(map_keys[i])] = Context::Create(
+ static_cast<Context::DeviceType>(map_dev_types[i]), map_dev_ids[i]);
+ }
+ std::vector<NDArray> in_arg_vec;
+ std::vector<NDArray> arg_grad_vec;
+ std::vector<NDArray> aux_state_vec;
+
+ Executor* exec = static_cast<Executor*>(shared_exec);
+ *out = exec->Reshape(partial_shaping, allow_up_sizing, ctx, ctx_map, kwargs,
+ &in_arg_vec, &arg_grad_vec, &aux_state_vec);
+
+ ret->ret_handles.clear();
+
ret->ret_handles.reserve(in_arg_vec.size()+arg_grad_vec.size()+aux_state_vec.size());
+
+ size_t nd_idx = 0;
+ for (const auto& nd : in_arg_vec) {
+ if (nd.is_none()) {
+ LOG(FATAL) << "Input argument NDArray cannot be un-allocated";
+ }
+ ret->ret_handles.push_back(new NDArray(nd));
+ }
+ if (in_arg_vec.size() > 0) {
+ *num_in_args = in_arg_vec.size();
+ *in_args = &(ret->ret_handles[nd_idx]);
+ nd_idx = ret->ret_handles.size();
+ }
+
+ for (const auto& nd : arg_grad_vec) {
+ if (nd.is_none()) {
+ ret->ret_handles.push_back(nullptr);
+ } else {
+ ret->ret_handles.push_back(new NDArray(nd));
+ }
+ }
+ if (arg_grad_vec.size() > 0) {
+ *arg_grads = &(ret->ret_handles[nd_idx]);
+ nd_idx = ret->ret_handles.size();
+ }
+
+ for (const auto& nd : aux_state_vec) {
+ if (nd.is_none()) {
+ LOG(FATAL) << "Auxiliary argument NDArray cannot be un-allocated";
+ }
+ ret->ret_handles.push_back(new NDArray(nd));
+ }
+ if (aux_state_vec.size() > 0) {
+ *num_aux_states = aux_state_vec.size();
+ *aux_states = &(ret->ret_handles[nd_idx]);
+ nd_idx = ret->ret_handles.size();
+ }
+ API_END_HANDLE_ERROR(delete out);
+}
+
int MXExecutorSetMonitorCallback(ExecutorHandle handle,
ExecutorMonitorCallback callback,
void* callback_handle) {
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 7a15f6c931c..e28867d5488 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1043,6 +1043,117 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
FinishInitGraph(symbol, g, shared_exec, feed_dict);
}
+/*!
+ * \brief Return a new executor with the same symbol and shared memory,
+ * but different input/output shapes.
+ * For runtime reshaping, variable length sequences, etc.
+ * The returned executor shares state with the current one,
+ * and cannot be used in parallel with it.
+ */
+Executor* GraphExecutor::Reshape(const bool partial_shaping,
+ const bool allow_up_sizing,
+ const Context& default_ctx,
+ const std::map<std::string, Context>& ctx_map,
+ const std::unordered_map<std::string, TShape>&
+ provided_arg_shapes,
+ std::vector<NDArray>* in_args,
+ std::vector<NDArray>* arg_grads,
+ std::vector<NDArray>* aux_states) {
+ nnvm::Graph g;
+ g.outputs = std::vector<nnvm::NodeEntry>(graph_.outputs.begin(),
+ graph_.outputs.begin() + num_forward_outputs_);
+ nnvm::Symbol symbol;
+ symbol.outputs = g.outputs;
+ const nnvm::IndexedGraph& idx = g.indexed_graph();
+ nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape());
+ for (size_t i = 0; i < num_forward_inputs_; ++i) {
+ const uint32_t nid = idx.input_nodes().at(i);
+ const std::string& name = idx[nid].source->attrs.name;
+ auto it = provided_arg_shapes.find(name);
+ if (provided_arg_shapes.end() != it) {
+ arg_shapes[i] = it->second;
+ }
+ }
+ g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+ if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+ HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
+ g.GetAttr<nnvm::ShapeVector>("shape"));
+ }
+ const nnvm::ShapeVector& shape_vec = g.GetAttr<nnvm::ShapeVector>("shape");
+ std::vector<OpReqType> grad_req_types;
+ size_t grad_top = 0;
+ const size_t num_args = in_arg_map_.size();
+ const size_t num_aux = aux_state_map_.size();
+ in_args->reserve(num_args);
+ grad_req_types.reserve(num_args);
+ arg_grads->reserve(num_args);
+ aux_states->reserve(num_aux);
+ for (uint32_t nid : idx.input_nodes()) {
+ std::string name = idx[nid].source->attrs.name;
+ const TShape& new_shape = shape_vec[idx.entry_id(nid, 0)];
+ if (idx.mutable_input_nodes().count(nid) == 0) {
+ NDArray& arr = in_arg_map_.at(name);
+ auto it = arg_grad_map_.find(name);
+ if (partial_shaping || provided_arg_shapes.count(name) || new_shape ==
arr.shape()) {
+ if (new_shape.Size() > arr.shape().Size()) {
+ CHECK(allow_up_sizing) << "New shape of arg: " << name << " is
larger than original."
+ << "First making a big executor and then down sizing it "
+ << "is more efficient than the reverse."
+ << "If you really want to up size, set allow_up_sizing=True "
+ << "to enable allocation of new arrays.";
+ in_args->emplace_back(new_shape, arr.ctx(), false, arr.dtype());
+ if (it != arg_grad_map_.end()) {
+ NDArray& darr = it->second;
+ arg_grads->emplace_back(new_shape, darr.ctx(), false,
darr.dtype());
+ grad_req_types.push_back(grad_store_.at(grad_top++).first);
+ } else {
+ arg_grads->emplace_back();
+ grad_req_types.push_back(kNullOp);
+ }
+ } else {
+ in_args->push_back(arr.Reshape(new_shape));
+ if (it != arg_grad_map_.end()) {
+ NDArray& darr = it->second;
+ arg_grads->push_back(darr.Reshape(new_shape));
+ grad_req_types.push_back(grad_store_.at(grad_top++).first);
+ } else {
+ arg_grads->emplace_back();
+ grad_req_types.push_back(kNullOp);
+ }
+ }
+ } else {
+ LOG(FATAL) << "Shape of unspecifie arg: " << name << " changed. "
+ << "This can cause the new executor to not share parameters "
+ << "with the old one. Please check for error in network."
+ << "If this is intended, set partial_shaping=True to suppress this
warning.";
+ }
+ } else {
+ NDArray& arr = aux_state_map_.at(name);
+ if (partial_shaping || new_shape == arr.shape()) {
+ if (new_shape.Size() > arr.shape().Size()) {
+ CHECK(allow_up_sizing) << "New shape of arg: " << name << " is
larger than original."
+ << "First making a big executor and then down sizing it "
+ << "is more efficient than the reverse."
+ << "If you really want to up size, set allow_up_sizing=True "
+ << "to enable allocation of new arrays.";
+ aux_states->emplace_back(new_shape, arr.ctx(), false, arr.dtype());
+ } else {
+ aux_states->push_back(arr.Reshape(new_shape));
+ }
+ } else {
+ LOG(FATAL) << "Shape of unspecifie arg: " << name << " changed. "
+ << "This can cause the new executor to not share parameters "
+ << "with the old one. Please check for error in network."
+ << "If this is intended, set partial_shaping=True to suppress this
warning.";
+ }
+ }
+ }
+ auto exec = new GraphExecutor();
+ exec->Init(symbol, default_ctx, ctx_map,
+ *in_args, *arg_grads, grad_req_types, *aux_states,
+ this);
+ return exec;
+}
/*!
* \brief This function is triggered by both simple_bind
* and bind flows.
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index bcde41d508e..24f98894912 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -107,6 +107,16 @@ class GraphExecutor : public Executor {
const nnvm::NodeEntryMap<NDArray>& feed_dict
= nnvm::NodeEntryMap<NDArray>());
+ Executor* Reshape(const bool partial_shaping,
+ const bool allow_up_sizing,
+ const Context& default_ctx,
+ const std::map<std::string, Context>& ctx_map,
+ const std::unordered_map<std::string, TShape>&
+ provided_arg_shapes,
+ std::vector<NDArray>* in_args,
+ std::vector<NDArray>* arg_grads,
+ std::vector<NDArray>* aux_states) override;
+
protected:
friend class mxnet::Imperative;
// Information about operational node
diff --git a/tests/python/unittest/test_executor.py
b/tests/python/unittest/test_executor.py
index 45b9a099223..05e71b426eb 100644
--- a/tests/python/unittest/test_executor.py
+++ b/tests/python/unittest/test_executor.py
@@ -160,6 +160,14 @@ def test_reshape():
exe.forward(is_train=False)
assert np.all(exe.outputs[0].asnumpy() == 4)
+ # test sharing ndarray depending on new_shape
+ new_exe = exe.reshape(allow_up_sizing=True, x=(6,4))
+ # data ndarray is not shared between exe and new_exe
+ new_exe.arg_arrays[0][:] = 0
+ assert np.all(exe.arg_arrays[0].asnumpy() == 1)
+ # weight ndarray is shared between exe and new_exe
+ assert np.all(new_exe.arg_arrays[1].asnumpy() == 1)
+
if __name__ == "__main__":
import nose
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services