ZiyueHuang edited a comment on pull request #19387:
URL: https://github.com/apache/incubator-mxnet/pull/19387#issuecomment-714034023


   I've considered to apply a similar workflow with CNN (im2col --> gemm --> 
col2im), however, I realize that this will cost `num_head_units` times (which 
is 64 in BERT) more memory, since we need a packing tensor of shape roughly 
like (..., seq_length, w, num_head_units).
   
   We can consider using the existing GEMM optimization techniques (like 
blocking, etc.). Also, we can optimize the for-loop in the kernel into a 
parallel reduction implementation, which I didn't implement currently because 
this for-loop only has 64 iterations for BERT and I think it is not significant.
   
   TVM is a good option, and I think it may not be difficult for someone who is 
familiar with TVM to transfer this kernel into TVM expression, since the 
computation flow is the same.
   
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to