The reason my barebones attention got a different answer than the paper's chunked attention was that I hadn't included the division by the square root of the feature count, that I had intended to return to but had not done. When included, the outputs are the same, and the script is attached, unsure why.
Next I'm comparing the output of huggingface's PerceiverSelfAttention class with my script and the chunked attention. The output is different, maybe due to an additional post processing step? It also includes the square root denominator.
import jax, jax.numpy as jnp def attention(query, key, value, precision=jax.lax.Precision.HIGHEST): """Memory-efficient multi-head dot product attention.""" num_q, num_heads, q_features = query.shape # query is queries, heads, features # key is keyvalues, heads, features # value is keyvalues, heads, features # 1. weights = dot(query, key) attn_weights = jnp.einsum('qhf,khf->qhk', query / jnp.sqrt(key.shape[-1]), key, precision=precision) # weights shape is now [queries, keys, heads] # 2. softmax of the weights across the features # softmax can be calculated as exp(a) / sum(exp(a)) # where a has its maximum values subtracted. # the example code uses the max across keys, maybe since # each query and head are separate. max_weights = jnp.max(attn_weights, axis=-1, keepdims=True) exp_weights = jnp.exp(attn_weights - max_weights) exp_weights /= jnp.sum(exp_weights, axis=-1, keepdims=True) # 3. dot of the weights with the values, losing the keyvalue dim, # a separate result for each query attn_out = jnp.einsum('qhk,khf->qhf', exp_weights, value, precision=precision) return attn_out if __name__ == '__main__': queries, keys, values = jax.random.normal(jax.random.PRNGKey(0), (3, 64, 8, 16)) out = attention(queries, keys, values) print(out)