sxjscience edited a comment on issue #18043:
URL: 
https://github.com/apache/incubator-mxnet/issues/18043#issuecomment-616339427


   Here are the einsum workloads that are related to attention:
   
   Consider the attention cell with query, key, value -> out
   We denote the batch_size as `B`, num_heads as `K`, the query_length as 
`L_q`, the mem_length as `L_m`, key dimension as `C_k`, value dimension as 
`C_v`. In the numpy version of GluonNLP, we will support different layouts for 
attention cell:
   
   - layout = 'NKT'
      - query.shape = `(B, K, L_q, C_k)`
      - key.shape = `(B, K, L_m, C_k)`
      - valule.shape = `(B, K, L_m, C_v)`
      - out.shape = `(B, L_q, K * C_v)`
   
      We need the following einsums in the implementation:
      - 'bnic,bnjc->bnij'
      - 'bnic,bnjc->binc'
   
   - layout = 'NTK'
      - query.shape = `(B, L_q, K, C_k)`
      - key.shape = `(B, L_m, K, C_k)`
      - value.shape = `(B, L_m, K, C_v)`
      - out.shape = `(B, L_q, K * C_v)`
   
      We need the following einsums:
      - `'binc,bjnc->bnij'`
      - `'bnij,bjnc->binc'`
   
   - layout = 'TNK'
      - query.shape = `(L_q, B, K, C_k)`
      - key.shape = `(L_m, B, K, C_k)`
      - value.shape = `(L_m, B, K, C_v)`
      - out.shape = `(L_q, B, K * C_v)`
   
      We need the following einsums:
      - `'ibnc,jbnc->bnij'`
      - `'bnij,jbnc->ibnc'`
   
   Actually, `out = np.einsum('ibnc,jbnc->bnij', A, B)` can be implemented via 
a single `cublasGemmStridedBatched` call. Consider the (i, j)th element in the 
output, we have
   
   ```
   out[i, j, :, :] = A[:, i, j, :] x B[:, i, j, :].T
   ```
   
   This can be implemented via a single GEMM, thus, the following calculation 
can be implemented via a single batched GEMM with specific parameters.
   ```
   for i in 1 -> B
      for j in 1 -> K
         out[i, j, :, :] = A[:, i, j, :] x B[:, i, j, :].T
   ```
    This is actually the technique used in `interleaved_matmul`. Thus, we 
should be able to get rid of the `interleaved_matmul` when we have accelerated 
the einsum.
   
   @ptrendx @eric-haibin-lin 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to