The reason my barebones attention got a different answer than the
paper's chunked attention was that I hadn't included the division by
the square root of the feature count, that I had intended to return to
but had not done. When included, the outputs are the same, and the
script is attached, unsure why.

Next I'm comparing the output of huggingface's PerceiverSelfAttention
class with my script and the chunked attention. The output is
different, maybe due to an additional post processing step? It also
includes the square root denominator.
import jax, jax.numpy as jnp

def attention(query, key, value, precision=jax.lax.Precision.HIGHEST):
    """Memory-efficient multi-head dot product attention."""
    num_q, num_heads, q_features = query.shape

    # query is queries, heads, features
    # key is keyvalues, heads, features
    # value is keyvalues, heads, features

    # 1. weights = dot(query, key)
    attn_weights = jnp.einsum('qhf,khf->qhk', query / jnp.sqrt(key.shape[-1]), key, precision=precision)
    # weights shape is now [queries, keys, heads]

    # 2. softmax of the weights across the features
    # softmax can be calculated as exp(a) / sum(exp(a))
    #   where a has its maximum values subtracted.
    #   the example code uses the max across keys, maybe since
    #   each query and head are separate.
    max_weights = jnp.max(attn_weights, axis=-1, keepdims=True)
    exp_weights = jnp.exp(attn_weights - max_weights)
    exp_weights /= jnp.sum(exp_weights, axis=-1, keepdims=True)


    # 3. dot of the weights with the values, losing the keyvalue dim,
    #    a separate result for each query
    attn_out = jnp.einsum('qhk,khf->qhf', exp_weights, value, precision=precision)

    return attn_out

if __name__ == '__main__':
    queries, keys, values = jax.random.normal(jax.random.PRNGKey(0), (3, 64, 8, 16))
    out = attention(queries, keys, values)
    print(out)

Reply via email to