Extend context windows of transformer models using RoPE, YaRN, ALiBi, and position interpolation techniques. Use when processing long documents (32k-128k+ tokens), extending pre-trained models beyond original context limits, or implementing efficient positional encodings. Covers rotary embeddings, attention biases, interpolation methods, and extrapolation strategies for LLMs.
Use Long Context techniques when you need to:
Key Techniques: RoPE (Rotary Position Embeddings), YaRN, ALiBi (Attention with Linear Biases), Position Interpolation
Papers: RoFormer (arXiv 2104.09864), YaRN (arXiv 2309.00071), ALiBi (arXiv 2108.12409), Position Interpolation (arXiv 2306.15595)
# HuggingFace Transformers (includes RoPE, YaRN support)
pip install transformers torch
# For custom implementations
pip install einops # Tensor operations
pip install rotary-embedding-torch # Standalone RoPE
# Optional: FlashAttention for efficiency
pip install flash-attn --no-build-isolation
import torch
import torch.nn as nn
class RotaryEmbedding(nn.Module):
"""Rotary Position Embeddings (RoPE)."""
def __init__(self, dim, max_seq_len=8192, base=10000):
super().__init__()
# Compute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len = max_seq_len
def forward(self, seq_len, device):
# Position indices
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
# Compute frequencies
freqs = torch.outer(t, self.inv_freq) # (seq_len, dim/2)
# Compute sin and cos
emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, dim)
return emb.cos(), emb.sin()
def rotate_half(x):
"""Rotate half the hidden dimensions."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""Apply rotary embeddings to queries and keys."""
# q, k shape: (batch, heads, seq_len, dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Usage
rope = RotaryEmbedding(dim=64, max_seq_len=8192)
cos, sin = rope(seq_len=2048, device='cuda')
# In attention layer
q_rotated, k_rotated = apply_rotary_pos_emb(query, key, cos, sin)
def get_alibi_slopes(num_heads):
"""Get ALiBi slope values for each attention head."""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(num_heads).is_integer():
return get_slopes_power_of_2(num_heads)
else:
# Closest power of 2
closest_power = 2 ** math.floor(math.log2(num_heads))
slopes = get_slopes_power_of_2(closest_power)
# Add extra slopes
extra = get_slopes_power_of_2(2 * closest_power)
slopes.extend(extra[0::2][:num_heads - closest_power])
return slopes
def create_alibi_bias(seq_len, num_heads):
"""Create ALiBi attention bias."""
# Distance matrix
context_position = torch.arange(seq_len)
memory_position = torch.arange(seq_len)
relative_position = memory_position[None, :] - context_position[:, None]
# Get slopes
slopes = torch.tensor(get_alibi_slopes(num_heads))
# Apply slopes to distances
alibi = slopes[:, None, None] * relative_position[None, :, :]
return alibi # (num_heads, seq_len, seq_len)
# Usage in attention
num_heads = 8
seq_len = 2048
alibi_bias = create_alibi_bias(seq_len, num_heads).to('cuda')
# Add bias to attention scores
# attn_scores shape: (batch, num_heads, seq_len, seq_len)
attn_scores = attn_scores + alibi_bias
attn_weights = torch.softmax(attn_scores, dim=-1)
from transformers import LlamaForCausalLM, LlamaTokenizer
# Original context: 2048 tokens
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# Extend to 32k with position interpolation
# Modify RoPE base frequency
model.config.rope_scaling = {
"type": "linear",
"factor": 16.0 # 2048 * 16 = 32768
}
# Or use dynamic scaling
model.config.rope_scaling = {
"type": "dynamic",
"factor": 16.0
}
# Fine-tune with long documents (minimal steps needed)
# Position interpolation works out-of-the-box after this config change
How it works:
Mathematical formulation:
q_m = (W_q * x_m) * e^(imθ)
k_n = (W_k * x_n) * e^(inθ)
where θ_j = base^(-2j/d) for j ∈ [0, d/2)
Advantages:
Key innovation:
Parameters:
# YaRN configuration
yarn_config = {
"scale": 16, # Extension factor
"original_max_position": 2048, # Base context
"extrapolation_factor": 1.0, # NTK parameter
"attn_factor": 1.0, # Attention scaling
"beta_fast": 32, # High-frequency scale
"beta_slow": 1, # Low-frequency scale
}
Performance:
Core idea:
Formula:
attention_bias[i, j] = -m * |i - j|
where m = slope for each attention head
Advantages:
Technique:
Formula:
# Original: position indices [0, 1, 2, ..., L]
# Extended: position indices [0, 0.5, 1.0, ..., L/2]
# (for 2× extension)
scaled_position[i] = i / extension_factor
Results:
| Method | Max Context | Training Needed | Memory | Extrapolation | Best For |
|---|---|---|---|---|---|
| RoPE | 8k-32k | Full pre-training | Moderate | Good | New models |
| YaRN | 32k-128k | Minimal (10× efficient) | Moderate | Excellent | Extending existing models |
| ALiBi | Unlimited | Full pre-training | Low (-11%) | Excellent | Training from scratch |
| Position Interpolation | 32k+ | Minimal (1k steps) | Moderate | Poor (by design) | Quick extension |
from transformers import AutoModelForCausalLM, AutoConfig
# RoPE with YaRN scaling
config = AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1")
config.rope_scaling = {
"type": "yarn",
"factor": 8.0,
"original_max_position_embeddings": 8192,
"attention_factor": 1.0
}
model = AutoModelForCausalLM.from_config(config)
# Position interpolation (simpler)
config.rope_scaling = {
"type": "linear",
"factor": 4.0
}
# Dynamic scaling (adjusts based on input length)
config.rope_scaling = {
"type": "dynamic",
"factor": 8.0
}
class LongContextAttention(nn.Module):
"""Multi-head attention with RoPE."""
def __init__(self, hidden_size, num_heads, max_seq_len=32768):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
# Q, K, V projections
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.o_proj = nn.Linear(hidden_size, hidden_size)
# RoPE
self.rotary_emb = RotaryEmbedding(
dim=self.head_dim,
max_seq_len=max_seq_len
)
def forward(self, hidden_states):
batch_size, seq_len, _ = hidden_states.shape
# Project to Q, K, V
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Reshape for multi-head
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Apply RoPE
cos, sin = self.rotary_emb(seq_len, device=hidden_states.device)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# Standard attention
attn_output = F.scaled_dot_product_attention(q, k, v)
# Reshape and project
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, -1)
output = self.o_proj(attn_output)
return output
from transformers import Trainer, TrainingArguments
# Extend model config
model.config.max_position_embeddings = 32768
model.config.rope_scaling = {"type": "linear", "factor": 16.0}
# Training args (minimal steps needed)
training_args = TrainingArguments(
output_dir="./llama-32k",
num_train_epochs=1,
max_steps=1000, # Only 1000 steps!
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
learning_rate=2e-5,
warmup_steps=100,
logging_steps=10,
save_steps=500,
)
# Train on long documents
trainer = Trainer(
model=model,
args=training_args,
train_dataset=long_document_dataset, # 32k token sequences
)
trainer.train()
# Clone YaRN implementation
git clone https://github.com/jquesnelle/yarn
cd yarn
# Fine-tune LLaMA with YaRN
python scripts/train.py \
--model meta-llama/Llama-2-7b-hf \
--scale 16 \
--rope_theta 10000 \
--max_length 32768 \
--batch_size 1 \
--gradient_accumulation 16 \
--steps 400 \
--learning_rate 2e-5
# For NEW models (training from scratch)
use_method = "ALiBi" # Best extrapolation, lowest memory
# For EXTENDING existing RoPE models
use_method = "YaRN" # Most efficient extension (10× less data)
# For QUICK extension with minimal compute
use_method = "Position Interpolation" # 1000 steps
# For MODERATE extension with good efficiency
use_method = "Linear RoPE Scaling" # Built-in, simple
# Conservative (safer, better quality)
scaling_factor = 2.0 # 8k → 16k
# Moderate (good balance)
scaling_factor = 4.0 # 8k → 32k
# Aggressive (requires more fine-tuning)
scaling_factor = 8.0 # 8k → 64k
scaling_factor = 16.0 # 8k → 128k
# Rule: Larger factors need more fine-tuning steps
steps_needed = 100 * scaling_factor # Rough estimate
# ✅ Good: Long documents matching target length
train_data = [
{"text": long_doc_32k_tokens}, # Full 32k
{"text": long_doc_24k_tokens}, # Varied lengths
{"text": long_doc_16k_tokens},
]
# ❌ Bad: Short documents (won't learn long context)
train_data = [
{"text": short_doc_2k_tokens},
]
# Use datasets like:
# - PG-19 (books, long texts)
# - arXiv papers
# - Long-form conversations
# - GitHub repositories (concatenated files)
# ❌ Bad: Applying position interpolation without fine-tuning
model.config.rope_scaling = {"type": "linear", "factor": 16.0}
# Model will perform poorly without fine-tuning!
# ✅ Good: Fine-tune after scaling
model.config.rope_scaling = {"type": "linear", "factor": 16.0}
fine_tune(model, long_documents, steps=1000)
# ❌ Bad: Too aggressive scaling without data
scale_to_1M_tokens() # Won't work without massive fine-tuning
# ✅ Good: Incremental scaling
# 8k → 16k → 32k → 64k (fine-tune at each step)
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load long-context model
model = AutoModelForCausalLM.from_pretrained(
"togethercomputer/LLaMA-2-7B-32K", # 32k context
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("togethercomputer/LLaMA-2-7B-32K")
# Process long document
long_text = "..." * 30000 # 30k tokens
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to('cuda')
# Generate
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Use gradient checkpointing for fine-tuning
model.gradient_checkpointing_enable()
# Use Flash Attention 2
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2", # 2-3× faster
torch_dtype=torch.float16
)
# Use paged attention (vLLM)
from vllm import LLM
llm = LLM(
model="togethercomputer/LLaMA-2-7B-32K",
max_model_len=32768, # 32k context
gpu_memory_utilization=0.9
)
references/rope.md - Detailed RoPE implementation and theoryreferences/extension_methods.md - YaRN, ALiBi, Position Interpolation comparisonsreferences/fine_tuning.md - Complete fine-tuning guide for context extension