yuxiangw opened a new issue #12307: ’ndarray.batch_dot‘ does not handle general dot product, e.g., those with tensors URL: https://github.com/apache/incubator-mxnet/issues/12307 A short example below: ``` from mxnet import nd batchsz = 2 W = nd.random_normal(shape=(batchsz,3,4,5)) U = nd.random_normal(shape=(batchsz,5)) Y = nd.batch_dot(W,U) ``` And in fact, doing matrix vector product will fail too ``` W = nd.random_normal(shape=(batchsz,4,5)) U = nd.random_normal(shape=(batchsz,5)) Y = nd.batch_dot(W,U) ``` Doing the following would work: ``` W = nd.random_normal(shape=(batchsz,4,5)) U = nd.random_normal(shape=(batchsz,5,1)) Y = nd.batch_dot(W,U) ``` There really shouldn't be any reason why this only handles matrix matrix multiplication in batch but not that of tensors and vectors. ------------------------------ Error message: ``` --------------------------------------------------------------------------- MXNetError Traceback (most recent call last) <ipython-input-7-654c124195f4> in <module>() 3 W = nd.random_normal(shape=(batchsz,3,4,5)) 4 U = nd.random_normal(shape=(batchsz,5)) ----> 5 Y = nd.batch_dot(W,U) /usr/local/lib/python3.6/site-packages/mxnet/ndarray/register.py in batch_dot(lhs, rhs, transpose_a, transpose_b, out, name, **kwargs) /usr/local/lib/python3.6/site-packages/mxnet/_ctypes/ndarray.py in _imperative_invoke(handle, ndargs, keys, vals, out) 90 c_str_array(keys), 91 c_str_array([str(s) for s in vals]), ---> 92 ctypes.byref(out_stypes))) 93 94 if original_output is not None: /usr/local/lib/python3.6/site-packages/mxnet/base.py in check_call(ret) 147 """ 148 if ret != 0: --> 149 raise MXNetError(py_str(_LIB.MXGetLastError())) 150 151 MXNetError: [00:15:08] src/operator/tensor/./dot-inl.h:1265: batch_dot currently only support 3D*3D array[2,3,4,5] v.s. [2,5] Stack trace returned 9 entries: [bt] (0) 0 libmxnet.so 0x000000010e4494b4 libmxnet.so + 21684 [bt] (1) 1 libmxnet.so 0x000000010e44926f libmxnet.so + 21103 [bt] (2) 2 libmxnet.so 0x000000010e448ee9 libmxnet.so + 20201 [bt] (3) 3 libmxnet.so 0x000000010e90f7db libmxnet.so + 5027803 [bt] (4) 4 libmxnet.so 0x000000010f5547fa MXNDListFree + 431770 [bt] (5) 5 libmxnet.so 0x000000010f553459 MXNDListFree + 426745 [bt] (6) 6 libmxnet.so 0x000000010f4c478a MXCustomFunctionRecord + 20250 [bt] (7) 7 libmxnet.so 0x000000010f4c5990 MXImperativeInvokeEx + 176 [bt] (8) 8 _ctypes.cpython-36m-darwin.so 0x000000010d39f427 ffi_call_unix64 + 79 ```
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
