Position Encodings in Attention: From Absolute to Rotary
Why We Need Position Encodings
The attention operation is permutation equivariant - it doesn't care about the order of the input tokens.
Let's see an example of this:
def basic_attention(query, key, value):
# Compute the attention scores
attention_scores = torch.matmul(query, key.transpose(-2, -1))
attention_scores = attention_scores / torch.sqrt(torch.tensor(query.size(-1), dtype=torch.float32))
attention_weights = torch.softmax(attention_scores, dim=-1)
output = torch.matmul(attention_weights, value)
return output
Now, say we had two sentences:
# These two sequences clearly mean different things
seq1 = ["I", "bought", "an", "apple", "watch"]
seq2 = ["watch", "I", "bought", "an", "apple"]
and we grabbed embeddings for each sentence.
embeddings1 = get_embeddings(seq1)
embeddings2 = get_embeddings(seq2)
Let's say we want to compute the attention scores between the word apple and the rest of the sequence.
apple_emb1 = embeddings1[3]
apple_emb2 = embeddings2[4]
# Compute the attention scores using apple as the query embedding
atten1 = basic_attention(apple_emb1, embeddings1, embeddings1)
atten2 = basic_attention(apple_emb2, embeddings2, embeddings2)
Finally, if we sum each attention vector, we get the same result for both sentences.
atten1.sum() == atten2.sum()
This is a problem because we want our model to understand the nuances of the position of words and how they change the meaning of a sentence. To overcome this issue, we need to add positional encodings to our embeddings.
Absolute Position Encoding in Attention
Absolute position encodings were first introduced in the original Transformer paper ("Attention Is All You Need"). They work by adding fixed positional information to each token embedding before it enters the attention layers.
The most common implementation uses sinusoidal functions:
class AbsolutePositionAttention(nn.Module):
def __init__(self, d_model, n_heads, max_seq_len=1000):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
# Position encoding with configurable max length
self.max_seq_len = max_seq_len
self.pos_encoding = nn.Parameter(
torch.zeros(max_seq_len, d_model),
requires_grad=False # Fixed encodings, no learning
)
self._init_pos_encoding()
# Attention projections
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
def _init_pos_encoding(self):
position = torch.arange(self.max_seq_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2) *
(-math.log(10000.0) / self.d_model)
)
# Create sinusoidal pattern
self.pos_encoding[:, 0::2] = torch.sin(position * div_term)
self.pos_encoding[:, 1::2] = torch.cos(position * div_term)
# Add batch dimension for broadcasting
self.pos_encoding = self.pos_encoding.unsqueeze(0)
def forward(self, x, mask=None):
B, L, D = x.shape
# Ensure sequence length doesn't exceed maximum
assert L <= self.max_seq_len, f"Sequence length {L} exceeds maximum {self.max_seq_len}"
# Add position encodings
x = x + self.pos_encoding[:, :L]
# Project to Q, K, V
q = self.q_proj(x).view(B, L, self.n_heads, self.d_head)
k = self.k_proj(x).view(B, L, self.n_heads, self.d_head)
v = self.v_proj(x).view(B, L, self.n_heads, self.d_head)
# Attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_head)
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
return out.reshape(B, L, D)
Why Sinusoidal?
The sinusoidal position encoding has several important properties:
- Unique Position Representation: Each position gets a unique encoding vector
- Fixed Distance: The relative position between any two tokens can be expressed as a linear function of their encodings
- Pattern Continuity: The sinusoidal pattern continues smoothly across positions, making it theoretically possible to extend to longer sequences
However, there's an important caveat about extrapolation:
class AbsolutePositionAttention(nn.Module):
def __init__(self, d_model, n_heads, max_seq_len=1000):
# ... initialization code ...
# Once initialized, we're locked to this maximum length
self.pos_encoding = nn.Parameter(
torch.zeros(max_seq_len, d_model),
requires_grad=False
)
While the sinusoidal pattern itself could mathematically extend indefinitely, the implementation requires choosing a fixed maximum length at initialization. To handle longer sequences, you would need to either:
- Reinitialize the model with a larger
max_seq_len
- Implement dynamic position encoding computation (at the cost of efficiency)
def compute_position_encoding(position):
# Compute on the fly instead of storing
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe = torch.zeros(d_model)
pe[0::2] = torch.sin(position * div_term)
pe[1::2] = torch.cos(position * div_term)
return pe
This limitation is one of the reasons why newer position encoding methods like RoPE (Rotary Position Embedding) have gained popularity, as they can more naturally handle variable sequence lengths.
Implementation Details
Here's how the position encodings are initialized:
def _init_pos_encoding(self):
position = torch.arange(self.max_seq_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2) *
(-math.log(10000.0) / self.d_model)
)
# Create sinusoidal pattern
self.pos_encoding[:, 0::2] = torch.sin(position * div_term)
self.pos_encoding[:, 1::2] = torch.cos(position * div_term)
# Add batch dimension for broadcasting
self.pos_encoding = self.pos_encoding.unsqueeze(0)
Limitations
- Fixed Maximum Length: The model has a hard limit on sequence length, requiring careful initialization:
# Must be set during initialization
max_seq_len = 1000 # Common default
pos_encoding = nn.Parameter(torch.zeros(max_seq_len, d_model))
- Position Information Decay: As signals pass through multiple attention layers, the absolute position information can become diluted. This is particularly problematic in deep networks:
# Position information weakens through layers
layer1_out = attention1(x + pos_encoding)
layer2_out = attention2(layer1_out) # Original position signal is weaker
layer3_out = attention3(layer2_out) # Even weaker...
- No Explicit Relative Positioning: The model must learn to compute relative positions indirectly. Consider these two sequences:
seq1 = ["The", "cat", "sat"] # Position encodings: [0, 1, 2]
seq2 = ["Yesterday", "the", "cat", "sat"] # Position encodings: [0, 1, 2, 3]
# "cat" and "sat" have different absolute positions in seq1 vs seq2
# but their relative position (adjacent) is the same
- Memory Inefficiency: Requires storing the full position encoding matrix, which scales linearly with maximum sequence length:
# Memory usage grows with sequence length
memory_size = max_seq_len * d_model * sizeof(float) # e.g., 1000 * 512 * 4 bytes
- Discrete Positions: Cannot easily handle continuous positions or irregular spacing, making it unsuitable for tasks like:
- Time series with irregular sampling
- Spatial positions in 3D space
- Document layouts with hierarchical structure
Transitioning to Relative Position
The limitations of absolute position encodings point to a fundamental issue: they focus on where tokens are in a sequence rather than how tokens relate to each other. Consider these common scenarios:
# Scenario 1: Same meaning, different positions
seq1 = ["The", "cat", "sat"]
seq2 = ["Yesterday", "the", "cat", "sat"]
# Scenario 2: Translation
en = ["I", "love", "cats"] # Positions: 0, 1, 2
fr = ["J'", "aime", "les", "chats"] # Positions: 0, 1, 2, 3
# Scenario 3: Long-range dependencies
text = ["Although", ...<20 words>..., "therefore"] # Related words far apart
In all these cases, absolute positions fail to capture what we really care about: the relationships between tokens. This insight led to the development of relative position encodings, which offer a few key advantages like:
-
Dynamic Relationships: Instead of asking "what position is this token?", relative position asks "how far is this token from others?" This better captures linguistic structure.
-
Length Generalization: Models can handle sequences longer than those seen during training because relationships are position-agnostic:
# The relationship between "cat" and "sat" is always "next to"
short = ["The", "cat", "sat"]
long = ["Yesterday", "when", "I", "was", "walking", "the", "cat", "sat"]
In the next section, we'll explore how relative position encodings implement these ideas, addressing the limitations we've discussed while introducing some powerful new capabilities...
Understanding Relative Position Encodings
Relative position encodings shifts how we think about position in sequences. Instead of asking "what position is this token at?", they ask "how far is this token from other tokens?"
Let's see why this matters with some examples:
# Example 1: Translation between languages
en = ["The", "cat", "sat"] # Positions: 0, 1, 2
fr = ["Le", "chat", "s'assit"] # Positions: 0, 1, 2
# The relationship between words remains the same despite different absolute positions
# "cat" is one token away from "sat" in English
# "chat" is one token away from "s'assit" in French
# Example 2: Same phrase in different contexts
text1 = ["The", "big", "dog", "chased", "the", "cat"]
text2 = ["Yesterday,", "the", "big", "dog", "chased", "the", "cat"]
# The relationship between "dog" and "chased" remains the same
# regardless of where the phrase appears in the sentence
Implementation Deep Dive
Let's break down the key compoents of relative position attention:
def _get_relative_positions(self, length):
# Create a matrix of relative distances between all positions
range_vec = torch.arange(length)
range_mat = range_vec.unsqueeze(0) - range_vec.unsqueeze(1)
# Clamp to our maximum relative distance
range_mat = torch.clamp(
range_mat,
-self.max_relative_position,
self.max_relative_position
)
return range_mat + self.max_relative_position
This creates a matrix where each entry (i,j) represents how far position i is from position j. For example, with length=4:
# Example relative position matrix (before clamping)
# [[ 0, -1, -2, -3],
# [ 1, 0, -1, -2],
# [ 2, 1, 0, -1],
# [ 3, 2, 1, 0]]
The attention computation then combines content-based and position-based scores:
# Content-based attention (standard dot product)
content_scores = torch.matmul(q, k.transpose(-2, -1))
# Position-based attention
rel_pos_embeddings = self.rel_embeddings[rel_pos_indices]
relative_scores = torch.matmul(
q.permute(0, 2, 1, 3),
rel_pos_embeddings.transpose(-2, -1)
)
# Combine both signals
scores = (content_scores + relative_scores) / math.sqrt(self.d_head)
Benefits for LLM Training
- Better Generalization
# Model can learn patterns independent of absolute position
pattern1 = ["despite", "X", "however", "Y"] # At start of text
pattern2 = ["Z", "despite", "X", "however", "Y"] # In middle of text
# Same relative relationships preserved
- Handling Variable-Length Input
# No fixed maximum sequence length
# Relationships are defined by relative distance, not absolute position
max_relative_position = 32 # Only care about distances up to ±32 tokens
- Translation Invariance
# Same patterns can be recognized regardless of position
text1 = ["The", "cat", "sat"]
text2 = ["Yesterday,", "the", "cat", "sat"]
# Model learns "sat" follows "cat" regardless of starting position
Limitations
Relative position encodings, while theoretically powerful, come with several significant practical limitations.
They are slow!
- Computational Overhead
- Computing the relative position matrix requires O(n²) operations for sequence length n
- Each attention layer needs to perform these calculations independently
# For each attention layer:
rel_pos = torch.arange(length).unsqueeze(0) - torch.arange(length).unsqueeze(1)
# Shape: [length, length] - Quadratic memory growth
- KV Cache Complications
- The KV cache optimization becomes less effective because relative positions change as the sequence grows
- Each new token requires recomputing positions relative to all previous tokens
# With absolute positions, cached positions stay fixed
cached_k = previous_k_with_absolute_positions
new_k = compute_new_k_with_absolute_positions
# With relative positions, must update for each new token
for i in range(prev_length, new_length):
relative_pos = compute_relative_positions(i, range(new_length))
k_with_relative = update_cached_k_with_new_positions(cached_k, relative_pos)
- Memory Intensity
- Storing relative position embeddings for all possible token pairs requires significant memory
- Memory usage grows quadratically with sequence length
# Memory required for relative position embeddings
memory_size = max_seq_len * max_seq_len * d_model * sizeof(float)
# Example: 1000 tokens * 1000 tokens * 512 dims * 4 bytes = 2GB
- Training Instability
- The large number of learned relative position embeddings can lead to optimization challenges
- Models may struggle to learn consistent position representations across different sequence lengths
- Implementation Complexity
- Requires careful attention to numerical stability and edge cases
- More complex to implement efficiently compared to absolute or rotary encodings
# Need to handle edge cases like padding and masking
rel_pos_masked = torch.masked_fill(
rel_pos,
mask.unsqueeze(1).expand(-1, length, -1),
0
)
These limitations help explain why many modern architectures have moved towards alternatives like RoPE, which take advantage of both absolute and relative position encodings.
Rotary Position Encodings (RoPE)
| we made it!
Rotary Position Encoding (RoPE) is an advancement in position encoding by using rotation matrices in complex space. Instead of adding position information or learning relative embeddings, RoPE applies a rotation to the query and key vectors that elegantly preserves their relative positions through the attention operation.
How RoPE Works
The core idea is to encode position by rotating vectors in complex space. For each position m
, RoPE applies a rotation matrix R(m)
to the query and key vectors:
def _get_rotation_matrix(self, seq_len):
# Generate frequency bands for different dimensions
theta = 1.0 / (self.base ** (torch.arange(0, self.d_head, 2).float() / self.d_head))
# Position indices
pos = torch.arange(seq_len).float()
# Compute rotation frequencies
freqs = torch.outer(pos, theta)
freqs_cos = torch.cos(freqs) # Real component
freqs_sin = torch.sin(freqs) # Imaginary component
return freqs_cos, freqs_sin
The rotation is applied through complex multiplication, which is implemented by splitting vectors in half and applying trigonometric functions:
def _apply_rotary_pos_emb(self, x, freqs_cos, freqs_sin):
# Complex multiplication: (a + bi)(cos θ + i sin θ)
return (x * freqs_cos) + (self._rotate_half(x) * freqs_sin)
RoPE's Benefits
- Relative Position Preservation: The dot product between two rotated vectors naturally encodes their relative position:
# For positions m and n:
# <R(m)q, R(n)k> = <q, k> * f(m-n)
# The attention score depends only on relative position (m-n)
- Efficient Computation: Unlike relative position encodings, RoPE doesn't require explicit pairwise computations:
# Apply rotation once per position
q_rot = self._apply_rotary_pos_emb(q, freqs_cos, freqs_sin)
k_rot = self._apply_rotary_pos_emb(k, freqs_cos, freqs_sin)
# Standard attention computation
scores = torch.matmul(q_rot, k_rot.transpose(-2, -1))
-
Better Length Extrapolation: RoPE's mathematical properties allow it to generalize better to sequences longer than those seen during training.
-
KV Cache Friendly: Since rotations are position-dependent but content-independent, they work well with KV caching in inference:
# Can cache k_rot, v directly
cached_k_rot = self._apply_rotary_pos_emb(k, freqs_cos[:cached_len], freqs_sin[:cached_len])
This combination of theoretical elegance and practical benefits has made RoPE increasingly popular in modern transformer architectures.
Recap and Conclusion
Let's compare the three main approaches to position encoding in attention mechanisms:
Method | Key Idea | Advantages | Limitations |
---|---|---|---|
Absolute |
|
|
|
Relative |
|
|
|
RoPE |
|
|
|
The evolution from absolute to relative to rotary position encodings reflects our growing understanding of how transformers process sequential information:
-
Absolute encodings were a first attempt to inject position information, but their rigid nature limited model capabilities.
-
Relative encodings recognized that relationships between tokens matter more than absolute positions, but came with significant computational costs.
-
RoPE emerged as an elegant solution that combines the best of both worlds: the efficiency of absolute encodings with the relationship-preserving properties of relative encodings.
RoPE has become the dominant choice in modern LLMs for several compelling reasons:
- Inference Efficiency: Its compatibility with KV caching makes it particularly well-suited for deployment in production systems.
- Length Generalization: Models can handle sequences longer than their training length more reliably.
- Mathematical Elegance: The rotation-based approach provides a theoretically sound way to encode position information that naturally preserves relative positions through attention operations.
As we continue to scale language models and push for better performance, RoPE's balanced approach to handling positional information remains a crucial component in state-of-the-art architectures.