yuxiangw opened a new issue #12307: ’ndarray.batch_dot‘  does not handle 
general dot product, e.g., those with tensors
URL: https://github.com/apache/incubator-mxnet/issues/12307
 
 
   A short example below:
   
   ```
   from mxnet import nd
   batchsz = 2
   W = nd.random_normal(shape=(batchsz,3,4,5))
   U = nd.random_normal(shape=(batchsz,5))
   Y = nd.batch_dot(W,U)
   ```
   And in fact, doing matrix vector product will fail too
   ```
   W = nd.random_normal(shape=(batchsz,4,5))
   U = nd.random_normal(shape=(batchsz,5))
   Y = nd.batch_dot(W,U)
   ```
   
   Doing the following would work:
   ```
   W = nd.random_normal(shape=(batchsz,4,5))
   U = nd.random_normal(shape=(batchsz,5,1))
   Y = nd.batch_dot(W,U)
   ```
   There really shouldn't be any reason why this only handles matrix matrix 
multiplication in batch but not that of tensors and vectors.
   
   ------------------------------
   Error message:
   ```
   ---------------------------------------------------------------------------
   MXNetError                                Traceback (most recent call last)
   <ipython-input-7-654c124195f4> in <module>()
         3 W = nd.random_normal(shape=(batchsz,3,4,5))
         4 U = nd.random_normal(shape=(batchsz,5))
   ----> 5 Y = nd.batch_dot(W,U)
   
   /usr/local/lib/python3.6/site-packages/mxnet/ndarray/register.py in 
batch_dot(lhs, rhs, transpose_a, transpose_b, out, name, **kwargs)
   
   /usr/local/lib/python3.6/site-packages/mxnet/_ctypes/ndarray.py in 
_imperative_invoke(handle, ndargs, keys, vals, out)
        90         c_str_array(keys),
        91         c_str_array([str(s) for s in vals]),
   ---> 92         ctypes.byref(out_stypes)))
        93 
        94     if original_output is not None:
   
   /usr/local/lib/python3.6/site-packages/mxnet/base.py in check_call(ret)
       147     """
       148     if ret != 0:
   --> 149         raise MXNetError(py_str(_LIB.MXGetLastError()))
       150 
       151 
   
   MXNetError: [00:15:08] src/operator/tensor/./dot-inl.h:1265: batch_dot 
currently only support 3D*3D array[2,3,4,5] v.s. [2,5]
   
   Stack trace returned 9 entries:
   [bt] (0) 0   libmxnet.so                         0x000000010e4494b4 
libmxnet.so + 21684
   [bt] (1) 1   libmxnet.so                         0x000000010e44926f 
libmxnet.so + 21103
   [bt] (2) 2   libmxnet.so                         0x000000010e448ee9 
libmxnet.so + 20201
   [bt] (3) 3   libmxnet.so                         0x000000010e90f7db 
libmxnet.so + 5027803
   [bt] (4) 4   libmxnet.so                         0x000000010f5547fa 
MXNDListFree + 431770
   [bt] (5) 5   libmxnet.so                         0x000000010f553459 
MXNDListFree + 426745
   [bt] (6) 6   libmxnet.so                         0x000000010f4c478a 
MXCustomFunctionRecord + 20250
   [bt] (7) 7   libmxnet.so                         0x000000010f4c5990 
MXImperativeInvokeEx + 176
   [bt] (8) 8   _ctypes.cpython-36m-darwin.so       0x000000010d39f427 
ffi_call_unix64 + 79
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to