attention_u2

Attention layers.

Module Contents

Classes

ScaledDotProductAttentionU2

Calculate the attention weights.

MultiHeadAttentionU2

Multi-head attention consists of four parts:

class attention_u2.ScaledDotProductAttentionU2(unidirectional=False, look_ahead=0)

Bases: tensorflow.keras.layers.Layer

Calculate the attention weights. q, k, v must have matching leading dimensions. k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v. The mask has different shapes depending on its type(padding or look ahead) but it must be broadcastable for addition.

Parameters
  • q – query shape == (…, seq_len_q, depth)

  • k – key shape == (…, seq_len_k, depth)

  • v – value shape == (…, seq_len_v, depth_v)

  • mask – Float tensor with shape broadcastable to (…, seq_len_q, seq_len_k). Defaults to None.

Returns

output, attention_weights

call(q, k, v, mask)

This is where the layer’s logic lives.

class attention_u2.MultiHeadAttentionU2(d_model, num_heads, unidirectional=False, look_ahead=0)

Bases: tensorflow.keras.layers.Layer

Multi-head attention consists of four parts:

  • Linear layers and split into heads.

  • Scaled dot-product attention.

  • Concatenation of heads.

  • Final linear layer.

Each multi-head attention block gets three inputs; Q (query), K (key), V (value). These are put through linear (Dense) layers and split up into multiple heads. The scaled_dot_product_attention defined above is applied to each head (broadcasted for efficiency). An appropriate mask must be used in the attention step. The attention output for each head is then concatenated (using tf.transpose, and tf.reshape) and put through a final Dense layer.

Instead of one single attention head, Q, K, and V are split into multiple heads because it allows the model to jointly attend to information at different positions from different representational spaces. After the split each head has a reduced dimensionality, so the total computation cost is the same as a single head attention with full dimensionality.

split_heads(x, batch_size)

Split the last dimension into (num_heads, depth).

Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)

call(v, k, q, mask)

call function