insop opened a new pull request #7230:
URL: https://github.com/apache/tvm/pull/7230


   Adding _npi_advanced_indexing_multiple as discussed in 
https://github.com/apache/tvm/issues/7186
   
   ####  Need to find a proper ` mx.sym.np` and wanted ask for the reviewers 
help to find it, so that test case can be valid
   
   @sxjscience , @junrushao1994 
   
   
   #### Details:
   _npi_advanced_indexing_multiple, BART model. This is triggered when we call 
a[idx1, idx2]. Also see the MXNet-side implementation.
   
   
   - Example from 
(BART)[https://raw.githubusercontent.com/dmlc/gluon-nlp/master/src/gluonnlp/models/bart.py]
   ```
               batch_indices = mx.npx.arange_like(sequence, 
axis=0).astype(mx.np.int32)
               outputs = sequence[batch_indices, valid_length - 1]
   ```
   
   - Standalone example
       ```
       import mxnet as mx 
       from mxnet import use_np 
       from mxnet.gluon import nn  
       from mxnet import np, npx
   
       sequence = np.array([[1, 2,11,12], [3, 4,23,24], [5, 6,35,36]])    
   
       Out[30]: 
       array([[ 1.,  2., 11., 12.],
              [ 3.,  4., 23., 24.],
              [ 5.,  6., 35., 36.]])
   
   
       batch_indices = mx.npx.arange_like(sequence, axis=0).astype(mx.np.int32)
   
       Out[32]: array([0, 1, 2], dtype=int32)
   
       valid_length=2 
   
       outputs = sequence[batch_indices, valid_length - 1]
       Out[35]: array([2., 4., 6.])
   
   
       ```
   
   - Pytorch advanced indxing example
   ```
   import torch
   
       In [44]: a = torch.randn(5, 7, dtype=torch.double)                       
                                                                                
                        
   
   In [45]: a                                                                   
                                                                                
                    
   Out[45]: 
   tensor([[-1.2230,  0.7823,  0.6655, -0.8564, -0.2611, -0.0423, -0.6728],
           [ 1.6607,  0.9779, -0.2754, -0.7090, -0.3243,  2.2017, -1.7534],
           [-1.9319,  0.5544,  2.0244, -0.8144, -0.2657,  0.7849, -0.4825],
           [ 0.0085,  1.0663,  0.1695, -0.3458, -0.4960,  1.2339,  0.6244],
           [ 0.5265, -2.0689, -0.4739,  0.5544,  0.8612,  0.2270, -2.0888]],
          dtype=torch.float64)
   
   In [46]: t = a[(0,1,2,3,4),1]                                                
                                                                                
                    
   
   In [47]: t                                                                   
                                                                                
                    
   Out[47]: tensor([ 0.7823,  0.9779,  0.5544,  1.0663, -2.0689], 
dtype=torch.float64)
   
   In [48]: t = a[(0,1,2,3,3),1]                                                
                                                                                
                    In [49]: t                                                  
                                                                                
                                     
   Out[49]: tensor([0.7823, 0.9779, 0.5544, 1.0663, 1.0663], 
dtype=torch.float64)
   
   ```
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to