[ot][spam][crazy][data] transformer model 'attention' improvement

k gmkarl at gmail.com
Tue Jan 25 11:18:19 PST 2022


30:  chunk_values, chunk_weights, chunk_max = jax.lax.map(
31:    chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))

chunk_values is exp_values, which I think was the local values dotted
with the exponentiated attention weights minus their local max.

It looks like chunk_weights is those exponentiated weights.

And it looks like chunk_max are those local maximum values.

32:
33:  global_max = jnp.max(chunk_max, axis=0, keepdims=True)

[struggling some to continue, it looks like these tensors get
recombined into one vector by jax.lax.map.  chunk_max has its global
max taken along axis 0, which is likely the axis that jax.lax.map adds
when recombining them.  So maybe this would extend the existing
maximum values, to find the maximums among the split keys and values.]

34:  max_diffs = jnp.exp(chunk_max - global_max)

This likely calculates the scale needed for each chunk, that would
change it to be relative to the global max, rather than its local max.
A scale because it's inside an exponent.

35:  chunk_values *= jnp.expand_dims(max_diffs, axis=-1)

jnp.expand_dims appears to simply wrap every value in max_diffs in a
new one-sized dimension, maybe for the multiplication to broadcast
across the feature dimension of the values.

So line 35 multiplies all the output values, by the calculated scale.
This looks like it turns exp(s_i - m_i) into exp(s_i - m_i + m_i -
global_max) i.e. exp(s_i - global_max).

36:  chunk_weights *= max_diffs

This likely performs the same operation.  The chunk_weights are not
yet dotted with the value vectors.

37:
38:  all_values = chunk_values.sum(axis=0)

Here we finally form the summation of the combined values across the
key and value chunks.  They were already dotted across the values
dimension, so this possibly combine the values of adjacent chunks as
if there were one large dot-product taken.

39:  all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)

This appears to perform the same for the weights, which I think were
dotted with the original value vectors to make chunk_values.
Additionally a dimension is added to wrap each value.

40:  return all_values / all_weights

Here is where the softmax operation must finally complete.  all_values
and all_weights have been arithmetically shifted inside exp()
operations, to be relative to their maxima.  When the division is
performed, the shift is analogous to a scaling value applied to both
the numerator and denominator, and the result is the same, but with
much higher precision due to less extreme values.  I think!

Whoohoo!

Let's quickly glance at where that value comes out after
_query_chunk_attention returns it.  On line 52, it's returned through
chunk_scanner, and then must be stacked with other query chunks back
on line 55:

54:  _, res = jax.lax.scan(
55:    chunk_scanner, init=0, xs=None, length=math.ceil(num_q /
query_chunk_size))
56:  return res.reshape(num_q, num_heads, value.shape[-1])

The final call to reshape likely squashes that stack into a single
[queries, heads, features] output tensor.

Maybe rather than connecting more verbiage from the paper it would
make sense to try to run this now, and compare an implementation with
and without chunking, to verify it's correct.


More information about the cypherpunks mailing list