ChaiBapchya closed pull request #12677: [MXNET-995] Constant Initializer for ND
Array
URL: https://github.com/apache/incubator-mxnet/pull/12677
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index de2ad692adf..2e39dc6f988 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -869,7 +869,7 @@ def _sync_copyfrom(self, source_array):
source_array = np.ascontiguousarray(source_array, dtype=self.dtype)
if source_array.shape != self.shape:
raise ValueError('Shape inconsistent: expected %s vs got %s'%(
- str(self.shape), str(source_array.shape)))
+ str(source_array.shape), str(self.shape)))
check_call(_LIB.MXNDArraySyncCopyFromCPU(
self.handle,
source_array.ctypes.data_as(ctypes.c_void_p),
@@ -2479,6 +2479,8 @@ def array(source_array, ctx=None, dtype=None):
if isinstance(source_array, NDArray):
dtype = source_array.dtype if dtype is None else dtype
else:
+ if isinstance(source_array, (float, int)):
+ source_array = [float(source_array)]
dtype = mx_real_t if dtype is None else dtype
if not isinstance(source_array, np.ndarray):
try:
diff --git a/tests/python/unittest/test_ndarray.py
b/tests/python/unittest/test_ndarray.py
index a1c178f8234..2e3c17141d7 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -1471,6 +1471,24 @@ def test_dlpack():
mx.test_utils.assert_almost_equal(a_np, d_np)
mx.test_utils.assert_almost_equal(a_np, e_np)
+@with_seed()
+def test_ndarray_constant_init():
+ # a=mx.nd.array([1])
+ a=mx.nd.array(9)
+ assert(isinstance(a,mx.nd.NDArray))
+
+@with_seed()
+def test_symbol_constant_init():
+ # a=mx.nd.array([1])
+ a = mx.sym.Variable('a')
+ b = mx.sym.Variable('b')
+ c = a * b
+ d = c.eval(a=mx.nd.array(2),b=mx.nd.array(3))
+ assert(isinstance(d[0],mx.nd.NDArray))
+ e = c.bind(ctx=mx.cpu(),args=[mx.nd.array(2),mx.nd.array(3)])
+ f = e.forward()
+ assert(isinstance(f[0],mx.nd.NDArray))
+
if __name__ == '__main__':
import nose
nose.runmodule()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services