This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 704d218  move exec.reshape to backend (#10882)
704d218 is described below

commit 704d218bfdbe5778173930bd39ed0ee4d730cfe5
Author: Ziyue Huang <[email protected]>
AuthorDate: Thu May 24 08:51:58 2018 +0800

    move exec.reshape to backend (#10882)
    
    * move exec.reshape to backend
    
    * fix lint
    
    * fix lint
    
    * fix Symbol._get_ndarray_inputs
    
    * update
    
    * update
    
    * move Reshape as a member function of Executor
    
    * address comments
---
 include/mxnet/c_api.h                  |  41 +++++++++++
 include/mxnet/executor.h               |  23 ++++++
 python/mxnet/executor.py               | 129 ++++++++++++++++++---------------
 python/mxnet/symbol/symbol.py          |   7 +-
 src/c_api/c_api_executor.cc            |  87 ++++++++++++++++++++++
 src/executor/graph_executor.cc         | 111 ++++++++++++++++++++++++++++
 src/executor/graph_executor.h          |  10 +++
 tests/python/unittest/test_executor.py |   8 ++
 8 files changed, 355 insertions(+), 61 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 06e39bf..be47c3c 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1654,6 +1654,47 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle 
symbol_handle,
                                    NDArrayHandle** aux_states,
                                    ExecutorHandle shared_exec_handle,
                                    ExecutorHandle* out);
+
+/*!
+ * \brief Return a new executor with the same symbol and shared memory,
+ * but different input/output shapes.
+ *
+ * \param partial_shaping Whether to allow changing the shape of unspecified 
arguments.
+ * \param allow_up_sizing Whether to allow allocating new ndarrays that's 
larger than the original.
+ * \param dev_type device type of default context
+ * \param dev_id device id of default context
+ * \param num_map_keys size of group2ctx map
+ * \param map_keys keys of group2ctx map
+ * \param map_dev_types device type of group2ctx map
+ * \param map_dev_ids device id of group2ctx map
+ * \param num_in_args length of in_args
+ * \param in_args in args array
+ * \param arg_grads arg grads handle array
+ * \param num_aux_states length of auxiliary states
+ * \param aux_states auxiliary states array
+ * \param shared_exec input executor handle for memory sharing
+ * \param out output executor handle
+ * \return a new executor
+ */
+MXNET_DLL int MXExecutorReshape(int partial_shaping,
+                                int allow_up_sizing,
+                                int dev_type,
+                                int dev_id,
+                                mx_uint num_map_keys,
+                                const char** map_keys,
+                                const int* map_dev_types,
+                                const int* map_dev_ids,
+                                const mx_uint num_provided_arg_shapes,
+                                const char** provided_arg_shape_names,
+                                const mx_uint* provided_arg_shape_data,
+                                const mx_uint* provided_arg_shape_idx,
+                                mx_uint* num_in_args,
+                                NDArrayHandle** in_args,
+                                NDArrayHandle** arg_grads,
+                                mx_uint* num_aux_states,
+                                NDArrayHandle** aux_states,
+                                ExecutorHandle shared_exec,
+                                ExecutorHandle *out);
 /*!
  * \brief set a call back to notify the completion of operation
  */
diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h
index d749100..842653f 100644
--- a/include/mxnet/executor.h
+++ b/include/mxnet/executor.h
@@ -104,6 +104,29 @@ class Executor {
    */
   virtual const std::unordered_map<std::string, NDArray>& aux_state_map() 
const = 0;
   /*!
+   * \brief Return a new executor with the same symbol and shared memory,
+   *  but different input/output shapes.
+   *
+   * \param partial_shaping Whether to allow changing the shape of unspecified 
arguments.
+   * \param allow_up_sizing Whether to allow allocating new ndarrays that's 
larger than the original.
+   * \param default_ctx the default context of binding.
+   * \param ctx_map Context mapping group to context.
+   * \param provided_arg_shapes New shape for arguments.
+   * \param in_args the NDArray that stores the input arguments.
+   * \param arg_grads NDArray that is used to store the gradient output of the 
input arguments.
+   * \param aux_states NDArray that is used as internal states.
+   * \return a new executor.
+   */
+  virtual Executor* Reshape(const bool partial_shaping,
+                            const bool allow_up_sizing,
+                            const Context& default_ctx,
+                            const std::map<std::string, Context>& ctx_map,
+                            const std::unordered_map<std::string, TShape>&
+                              provided_arg_shapes,
+                            std::vector<NDArray>* in_args,
+                            std::vector<NDArray>* arg_grads,
+                            std::vector<NDArray>* aux_states) = 0;
+  /*!
    * \brief Create an operator by bind symbol with context and arguments.
    *  If user do not want to compute the gradients of i-th argument, 
grad_req_type[i] can be kNullOp.
    *
diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py
index 579e6d3..c0272c5 100644
--- a/python/mxnet/executor.py
+++ b/python/mxnet/executor.py
@@ -20,15 +20,15 @@
 """Symbolic Executor component of MXNet."""
 from __future__ import absolute_import
 
+from array import array as py_array
 import ctypes
 import copy
 import numpy as np
 from .base import _LIB
-from .base import mx_uint, NDArrayHandle, ExecutorHandle
-from .base import check_call, c_handle_array, py_str
+from .base import mx_uint, NDArrayHandle, ExecutorHandle, py_str
+from .base import check_call, c_handle_array, c_array_buf, c_str_array
 from .ndarray import NDArray
 from .ndarray import _ndarray_cls
-from . import ndarray as nd
 
 # those functions are not used here, we just import them to keep backward 
compatibility
 # in case the end user calls them, as they originally lives here
@@ -399,62 +399,73 @@ class Executor(object):
         >>> texec.reshape(allow_up_sizing=True, **new_shape)
         """
         # pylint: disable=too-many-branches
-        arg_shapes, _, aux_shapes = self._symbol.infer_shape(**kwargs)
-        if arg_shapes is None:
-            raise ValueError("Insufficient argument shapes provided.")
-
-        new_arg_dict = {}
-        new_grad_dict = {}
-        for i, name in enumerate(self._symbol.list_arguments()):
-            new_shape = arg_shapes[i]
-            arr = self.arg_arrays[i]
-            darr = None if self.grad_arrays is None else self.grad_arrays[i]
-            if partial_shaping or name in kwargs or new_shape == arr.shape:
-                if np.prod(new_shape) > np.prod(arr.shape):
-                    assert allow_up_sizing, "New shape of arg:%s larger than 
original. "%name + \
-                        "First making a big executor and then down sizing it " 
+ \
-                        "is more efficient than the reverse." + \
-                        "If you really want to up size, set 
allow_up_sizing=True " + \
-                        "to enable allocation of new arrays."
-                    new_arg_dict[name] = nd.empty(new_shape, ctx=arr.context, 
dtype=arr.dtype)
-                    if darr is not None:
-                        new_grad_dict[name] = nd.empty(new_shape, 
ctx=darr.context, dtype=arr.dtype)
-                else:
-                    new_arg_dict[name] = arr.reshape(new_shape)
-                    if darr is not None:
-                        new_grad_dict[name] = darr.reshape(new_shape)
-            else:
-                raise AssertionError("Shape of unspecified array arg:%s 
changed. "%name + \
-                    "This can cause the new executor to not share parameters " 
+ \
-                    "with the old one. Please check for error in network." +\
-                    "If this is intended, set partial_shaping=True to suppress 
this warning.")
-
-        new_aux_dict = {}
-        for name, new_shape, arr in zip(self._symbol.list_auxiliary_states(),
-                                        aux_shapes, self.aux_arrays):
-            if partial_shaping or new_shape == arr.shape:
-                if np.prod(new_shape) > np.prod(arr.shape):
-                    assert allow_up_sizing, "New shape of arg:%s larger than 
original. "%name + \
-                        "First making a big executor and then down sizing it " 
+ \
-                        "is more efficient than the reverse." + \
-                        "If you really want to up size, set 
allow_up_sizing=True " + \
-                        "to enable allocation of new arrays."
-                    new_aux_dict[name] = nd.empty(new_shape, ctx=arr.context, 
dtype=arr.dtype)
-                else:
-                    new_aux_dict[name] = arr.reshape(new_shape)
-            else:
-                raise AssertionError("Shape of unspecified array aux:%s 
changed. "%name + \
-                    "This can cause the new executor to not share parameters " 
+ \
-                    "with the old one. Please check for error in network." +\
-                    "If this is intended, set partial_shaping=True to suppress 
this warning.")
-
-        return self._symbol.bind(self._ctx,
-                                 args=new_arg_dict,
-                                 args_grad=new_grad_dict,
-                                 grad_req=self._grad_req,
-                                 aux_states=new_aux_dict,
-                                 group2ctx=self._group2ctx,
-                                 shared_exec=self)
+        provided_arg_shape_data = []  # shape data
+        # argument shape index in sdata,
+        # e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first 
arg
+        provided_arg_shape_idx = [0]
+        provided_arg_shape_names = []  # provided argument names
+        for k, v in kwargs.items():
+            if isinstance(v, tuple):
+                provided_arg_shape_names.append(k)
+                provided_arg_shape_data.extend(v)
+                provided_arg_shape_idx.append(len(provided_arg_shape_data))
+
+        ctx_map_keys = []
+        ctx_map_dev_types = []
+        ctx_map_dev_ids = []
+
+        if self._group2ctx:
+            for key, val in self._group2ctx.items():
+                ctx_map_keys.append(key)
+                ctx_map_dev_types.append(val.device_typeid)
+                ctx_map_dev_ids.append(val.device_id)
+
+        handle = ExecutorHandle()
+        shared_handle = self.handle
+
+        num_in_args = ctypes.c_uint()
+        in_arg_handles = ctypes.POINTER(NDArrayHandle)()
+        arg_grad_handles = ctypes.POINTER(NDArrayHandle)()
+        num_aux_states = ctypes.c_uint()
+        aux_state_handles = ctypes.POINTER(NDArrayHandle)()
+
+        check_call(_LIB.MXExecutorReshape(ctypes.c_int(int(partial_shaping)),
+                                          ctypes.c_int(int(allow_up_sizing)),
+                                          
ctypes.c_int(self._ctx.device_typeid),
+                                          ctypes.c_int(self._ctx.device_id),
+                                          mx_uint(len(ctx_map_keys)),
+                                          c_str_array(ctx_map_keys),
+                                          c_array_buf(ctypes.c_int,
+                                                      py_array('i', 
ctx_map_dev_types)),
+                                          c_array_buf(ctypes.c_int,
+                                                      py_array('i', 
ctx_map_dev_ids)),
+                                          
mx_uint(len(provided_arg_shape_names)),
+                                          
c_str_array(provided_arg_shape_names),
+                                          c_array_buf(mx_uint,
+                                                      py_array('I', 
provided_arg_shape_data)),
+                                          c_array_buf(mx_uint,
+                                                      py_array('I', 
provided_arg_shape_idx)),
+                                          ctypes.byref(num_in_args),
+                                          ctypes.byref(in_arg_handles),
+                                          ctypes.byref(arg_grad_handles),
+                                          ctypes.byref(num_aux_states),
+                                          ctypes.byref(aux_state_handles),
+                                          shared_handle,
+                                          ctypes.byref(handle)))
+
+        arg_arrays = [_ndarray_cls(NDArrayHandle(in_arg_handles[i]))
+                      for i in range(num_in_args.value)]
+        grad_arrays = [_ndarray_cls(NDArrayHandle(arg_grad_handles[i]))
+                       if arg_grad_handles[i] is not None
+                       else None for i in range(num_in_args.value)]
+        aux_arrays = [_ndarray_cls(NDArrayHandle(aux_state_handles[i]))
+                      for i in range(num_aux_states.value)]
+
+        executor = Executor(handle, self._symbol, self._ctx, self._grad_req, 
self._group2ctx)
+        executor.arg_arrays = arg_arrays
+        executor.grad_arrays = grad_arrays
+        executor.aux_arrays = aux_arrays
+        return executor
 
     def debug_str(self):
         """Get a debug string about internal execution plan.
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index b113ddc..fc1a71c 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -1259,9 +1259,12 @@ class Symbol(SymbolBase):
             if len(args) != len(arg_names):
                 raise ValueError('Length of %s does not match the number of 
arguments' % arg_key)
             for narr in args:
-                if not isinstance(narr, NDArray):
+                if narr is None and allow_missing:
+                    arg_handles.append(None)
+                elif not isinstance(narr, NDArray):
                     raise TypeError('Only accept list of NDArrays or dict of 
str to NDArray')
-                arg_handles.append(narr.handle)
+                else:
+                    arg_handles.append(narr.handle)
             arg_arrays = args
         elif isinstance(args, dict):
             for name in arg_names:
diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc
index 40df491..09bc239 100644
--- a/src/c_api/c_api_executor.cc
+++ b/src/c_api/c_api_executor.cc
@@ -510,6 +510,93 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
   API_END();
 }
 
+int MXExecutorReshape(int partial_shaping,
+                      int allow_up_sizing,
+                      int dev_type,
+                      int dev_id,
+                      mx_uint num_map_keys,
+                      const char** map_keys,
+                      const int* map_dev_types,
+                      const int* map_dev_ids,
+                      const mx_uint num_provided_arg_shapes,
+                      const char** provided_arg_shape_names,
+                      const mx_uint* provided_arg_shape_data,
+                      const mx_uint* provided_arg_shape_idx,
+                      mx_uint* num_in_args,
+                      NDArrayHandle** in_args,
+                      NDArrayHandle** arg_grads,
+                      mx_uint* num_aux_states,
+                      NDArrayHandle** aux_states,
+                      ExecutorHandle shared_exec,
+                      ExecutorHandle *out) {
+  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  API_BEGIN();
+  // create shape map for in_args and aux_states
+  std::unordered_map<std::string, TShape> kwargs(num_provided_arg_shapes);
+  for (mx_uint i = 0; i < num_provided_arg_shapes; ++i) {
+    auto p = kwargs.emplace(provided_arg_shape_names[i],
+        TShape(provided_arg_shape_data+provided_arg_shape_idx[i],
+          provided_arg_shape_data+provided_arg_shape_idx[i+1]));
+    CHECK(p.second) << "Duplicate shapes are provided for argument "
+      << provided_arg_shape_names[i] << " in reshape of executor";
+  }
+
+  Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 
dev_id);
+  std::map<std::string, Context> ctx_map;
+  for (mx_uint i = 0; i < num_map_keys; ++i) {
+    ctx_map[std::string(map_keys[i])] = Context::Create(
+        static_cast<Context::DeviceType>(map_dev_types[i]), map_dev_ids[i]);
+  }
+  std::vector<NDArray> in_arg_vec;
+  std::vector<NDArray> arg_grad_vec;
+  std::vector<NDArray> aux_state_vec;
+
+  Executor* exec = static_cast<Executor*>(shared_exec);
+  *out = exec->Reshape(partial_shaping, allow_up_sizing, ctx, ctx_map, kwargs,
+                       &in_arg_vec, &arg_grad_vec, &aux_state_vec);
+
+  ret->ret_handles.clear();
+  
ret->ret_handles.reserve(in_arg_vec.size()+arg_grad_vec.size()+aux_state_vec.size());
+
+  size_t nd_idx = 0;
+  for (const auto& nd : in_arg_vec) {
+    if (nd.is_none()) {
+      LOG(FATAL) << "Input argument NDArray cannot be un-allocated";
+    }
+    ret->ret_handles.push_back(new NDArray(nd));
+  }
+  if (in_arg_vec.size() > 0) {
+    *num_in_args = in_arg_vec.size();
+    *in_args = &(ret->ret_handles[nd_idx]);
+    nd_idx = ret->ret_handles.size();
+  }
+
+  for (const auto& nd : arg_grad_vec) {
+    if (nd.is_none()) {
+      ret->ret_handles.push_back(nullptr);
+    } else {
+      ret->ret_handles.push_back(new NDArray(nd));
+    }
+  }
+  if (arg_grad_vec.size() > 0) {
+    *arg_grads = &(ret->ret_handles[nd_idx]);
+    nd_idx = ret->ret_handles.size();
+  }
+
+  for (const auto& nd : aux_state_vec) {
+    if (nd.is_none()) {
+      LOG(FATAL) << "Auxiliary argument NDArray cannot be un-allocated";
+    }
+    ret->ret_handles.push_back(new NDArray(nd));
+  }
+  if (aux_state_vec.size() > 0) {
+    *num_aux_states = aux_state_vec.size();
+    *aux_states = &(ret->ret_handles[nd_idx]);
+    nd_idx = ret->ret_handles.size();
+  }
+  API_END_HANDLE_ERROR(delete out);
+}
+
 int MXExecutorSetMonitorCallback(ExecutorHandle handle,
                                  ExecutorMonitorCallback callback,
                                  void* callback_handle) {
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 7a15f6c..e28867d 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1044,6 +1044,117 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
 }
 
 /*!
+ * \brief Return a new executor with the same symbol and shared memory,
+ * but different input/output shapes.
+ * For runtime reshaping, variable length sequences, etc.
+ * The returned executor shares state with the current one,
+ * and cannot be used in parallel with it.
+ */
+Executor* GraphExecutor::Reshape(const bool partial_shaping,
+                                 const bool allow_up_sizing,
+                                 const Context& default_ctx,
+                                 const std::map<std::string, Context>& ctx_map,
+                                 const std::unordered_map<std::string, TShape>&
+                                   provided_arg_shapes,
+                                 std::vector<NDArray>* in_args,
+                                 std::vector<NDArray>* arg_grads,
+                                 std::vector<NDArray>* aux_states) {
+  nnvm::Graph g;
+  g.outputs = std::vector<nnvm::NodeEntry>(graph_.outputs.begin(),
+    graph_.outputs.begin() + num_forward_outputs_);
+  nnvm::Symbol symbol;
+  symbol.outputs = g.outputs;
+  const nnvm::IndexedGraph& idx = g.indexed_graph();
+  nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape());
+  for (size_t i = 0; i < num_forward_inputs_; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const std::string& name = idx[nid].source->attrs.name;
+    auto it = provided_arg_shapes.find(name);
+    if (provided_arg_shapes.end() != it) {
+      arg_shapes[i] = it->second;
+    }
+  }
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+  const nnvm::ShapeVector& shape_vec = g.GetAttr<nnvm::ShapeVector>("shape");
+  std::vector<OpReqType> grad_req_types;
+  size_t grad_top = 0;
+  const size_t num_args = in_arg_map_.size();
+  const size_t num_aux = aux_state_map_.size();
+  in_args->reserve(num_args);
+  grad_req_types.reserve(num_args);
+  arg_grads->reserve(num_args);
+  aux_states->reserve(num_aux);
+  for (uint32_t nid : idx.input_nodes()) {
+    std::string name = idx[nid].source->attrs.name;
+    const TShape& new_shape = shape_vec[idx.entry_id(nid, 0)];
+    if (idx.mutable_input_nodes().count(nid) == 0) {
+      NDArray& arr = in_arg_map_.at(name);
+      auto it = arg_grad_map_.find(name);
+      if (partial_shaping || provided_arg_shapes.count(name) || new_shape == 
arr.shape()) {
+        if (new_shape.Size() > arr.shape().Size()) {
+          CHECK(allow_up_sizing) << "New shape of arg: " << name << " is 
larger than original."
+            << "First making a big executor and then down sizing it "
+            << "is more efficient than the reverse."
+            << "If you really want to up size, set allow_up_sizing=True "
+            << "to enable allocation of new arrays.";
+          in_args->emplace_back(new_shape, arr.ctx(), false, arr.dtype());
+          if (it != arg_grad_map_.end()) {
+            NDArray& darr = it->second;
+            arg_grads->emplace_back(new_shape, darr.ctx(), false, 
darr.dtype());
+            grad_req_types.push_back(grad_store_.at(grad_top++).first);
+          } else {
+            arg_grads->emplace_back();
+            grad_req_types.push_back(kNullOp);
+          }
+        } else {
+          in_args->push_back(arr.Reshape(new_shape));
+          if (it != arg_grad_map_.end()) {
+            NDArray& darr = it->second;
+            arg_grads->push_back(darr.Reshape(new_shape));
+            grad_req_types.push_back(grad_store_.at(grad_top++).first);
+          } else {
+            arg_grads->emplace_back();
+            grad_req_types.push_back(kNullOp);
+          }
+        }
+      } else {
+        LOG(FATAL) << "Shape of unspecifie arg: " << name << " changed. "
+          << "This can cause the new executor to not share parameters "
+          << "with the old one. Please check for error in network."
+          << "If this is intended, set partial_shaping=True to suppress this 
warning.";
+      }
+    } else {
+      NDArray& arr = aux_state_map_.at(name);
+      if (partial_shaping || new_shape == arr.shape()) {
+        if (new_shape.Size() > arr.shape().Size()) {
+          CHECK(allow_up_sizing) << "New shape of arg: " << name << " is 
larger than original."
+            << "First making a big executor and then down sizing it "
+            << "is more efficient than the reverse."
+            << "If you really want to up size, set allow_up_sizing=True "
+            << "to enable allocation of new arrays.";
+          aux_states->emplace_back(new_shape, arr.ctx(), false, arr.dtype());
+        } else {
+          aux_states->push_back(arr.Reshape(new_shape));
+        }
+      } else {
+        LOG(FATAL) << "Shape of unspecifie arg: " << name << " changed. "
+          << "This can cause the new executor to not share parameters "
+          << "with the old one. Please check for error in network."
+          << "If this is intended, set partial_shaping=True to suppress this 
warning.";
+      }
+    }
+  }
+  auto exec = new GraphExecutor();
+  exec->Init(symbol, default_ctx, ctx_map,
+             *in_args, *arg_grads, grad_req_types, *aux_states,
+             this);
+  return exec;
+}
+/*!
  * \brief This function is triggered by both simple_bind
  * and bind flows.
  * Setup backward graph, create device and context
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index bcde41d..24f9889 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -107,6 +107,16 @@ class GraphExecutor : public Executor {
             const nnvm::NodeEntryMap<NDArray>& feed_dict
               = nnvm::NodeEntryMap<NDArray>());
 
+  Executor* Reshape(const bool partial_shaping,
+                    const bool allow_up_sizing,
+                    const Context& default_ctx,
+                    const std::map<std::string, Context>& ctx_map,
+                    const std::unordered_map<std::string, TShape>&
+                      provided_arg_shapes,
+                    std::vector<NDArray>* in_args,
+                    std::vector<NDArray>* arg_grads,
+                    std::vector<NDArray>* aux_states) override;
+
  protected:
   friend class mxnet::Imperative;
   // Information about operational node
diff --git a/tests/python/unittest/test_executor.py 
b/tests/python/unittest/test_executor.py
index 45b9a09..05e71b4 100644
--- a/tests/python/unittest/test_executor.py
+++ b/tests/python/unittest/test_executor.py
@@ -160,6 +160,14 @@ def test_reshape():
     exe.forward(is_train=False)
     assert np.all(exe.outputs[0].asnumpy() == 4)
 
+    # test sharing ndarray depending on new_shape
+    new_exe = exe.reshape(allow_up_sizing=True, x=(6,4))
+    # data ndarray is not shared between exe and new_exe
+    new_exe.arg_arrays[0][:] = 0
+    assert np.all(exe.arg_arrays[0].asnumpy() == 1)
+    # weight ndarray is shared between exe and new_exe
+    assert np.all(new_exe.arg_arrays[1].asnumpy() == 1)
+
 
 if __name__ == "__main__":
     import nose

-- 
To stop receiving notification emails like this one, please contact
[email protected].

Reply via email to