Code walkthrough - KV caching in nanochat
We will refer the code in nanochat repository. I recommend exploring the code yourself. I will try to highlight the relevant sections only.
The KVCache class
The KV cache is a 6-dimensional tensor that stores keys and values for all transformer layers:
# nanochat/engine.py:83-93
class KVCache:
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
# Shape: (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
# ^ 0=keys, 1=values
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
self.kv_cache = None # Lazy init - allocated on first insert
self.pos = 0 # Current timestep
num_layers- each transformer layer has its own K/V2- one slot for keys, one for valuesbatch_size,num_heads,seq_len,head_dim- the standard attention dimensions
The cache is lazily initialized - memory not allocated until the first forward pass, when we know the actual dtype and device.
Inserting K/V into the cache
nanochat includes the main code for concatenating kv cache and memory allocation in the engine.py file inside the KVCacheClass#insert_kv
This is called by each attention layer during the forward pass:
# nanochat/engine.py:129-154
def insert_kv(self, layer_idx, k, v):
# Lazy initialize
if self.kv_cache is None:
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
B, H, T_add, D = k.size()
t0, t1 = self.pos, self.pos + T_add
# Insert k, v into the cache at current position
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k # Keys
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v # Values
key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
# Only increment pos after the LAST layer processes
# to ensure same cache during a forward pass
if layer_idx == self.kv_cache.size(0) - 1:
self.pos = t1
return key_view, value_view
During decode, k/v is passed just for just the new token. That is concatenated as you can observe that we insert new K/V at position t0:t1 and we return everything from 0 to t1.
dynamic cache allocation
nanochat has a C++ vector like dynamic growth policy. It grows dynamically if generation exceeds the pre-allocated seq_len:
# Also in insert_kv - dynamic cache growth:
if t1 > self.kv_cache.size(4): # need more space?
t_needed = t1 + 1024 # add 1024 token buffer
t_needed = (t_needed + 1023) & ~1023 # round up to nearest 1024
additional_shape = list(self.kv_cache.shape)
additional_shape[4] = t_needed - self.kv_cache.size(4)
additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)
self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4)
How attention uses the cache
In CausalSelfAttention.forward, the cache is used like this:
# nanochat/gpt.py:66-84
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
# Project to get Q, K, V for current input
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# ... rotary embeddings and normalization ...
# insert new k/v, get cached + new k/v
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys in total (cached + current)
During decode: q has shape (B, H, 1, D) (single new token), but after insert_kv, k and v have shape (B, H, full_seq_len, D).
The three attention cases
Nanochat attention code handles three distinct cases based on the relationship between query length (Tq) and key length (Tk). Third one is kinda out of scope of my blog post.
# nanochat/gpt.py:86-105
# Case 1: Training or full prefill (Tq == Tk)
if kv_cache is None or Tq == Tk:
# Standard causal attention - each token attends to itself and all previous tokens
# classic "lower triangular" attention mask
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
# Case 2: Single-token decode (Tq == 1) - THE HOT PATH
elif Tq == 1:
# Single query attending to all cached keys
# No causal mask needed! The lone query is the last position,
# so it's allowed to see everything before it anyway
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
# Case 3: Chunked decode (1 < Tq < Tk)
else:
# Multiple new queries attending to cached prefix + each other
# Need a custom mask: full attention to prefix, causal within chunk
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device)
prefix_len = Tk - Tq
# All queries can see the entire cached prefix
attn_mask[:, :prefix_len] = True
# Within the new chunk, apply causal masking
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
Some visualizations done with help of Claude Opus 4.5 -
Case 1: Prefill (Tq=5, Tk=5) - Standard causal mask
k0 k1 k2 k3 k4
q0 [ 1 0 0 0 0 ]
q1 [ 1 1 0 0 0 ]
q2 [ 1 1 1 0 0 ]
q3 [ 1 1 1 1 0 ]
q4 [ 1 1 1 1 1 ]
Case 2: Single-token decode (Tq=1, Tk=6) - No mask needed
k0 k1 k2 k3 k4 k5
q5 [ 1 1 1 1 1 1 ] <- single query sees everything
Case 3: Chunked decode (Tq=3, Tk=8) - Hybrid mask
k0 k1 k2 k3 k4 | k5 k6 k7
<-- prefix (5) --> | <-- chunk (3) -->
q5 [ 1 1 1 1 1 | 1 0 0 ]
q6 [ 1 1 1 1 1 | 1 1 0 ]
q7 [ 1 1 1 1 1 | 1 1 1 ]
^
prefix_len boundary
Case 2 is the decode hot path - every generated token goes through this. The key optimization: since there's only one query and it's at the last position, we skip mask computation entirely.
The prefill → decode flow
In Engine.generate, the two-phase process:
# nanochat/engine.py:210-232
# PHASE 1: Prefill with batch size 1
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens), # exact prompt length
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill) # fills the cache
# PHASE 2: Create decode cache, copy prefill data
kv_cache_decode = KVCache(
batch_size=num_samples, # can expand for parallel generation
seq_len=kv_length_hint, # longer to accommodate generated tokens
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill) # copy over the computed K/V
del kv_cache_prefill # free memory
The prefill() method copies the computed K/V tensors from the prefill cache to the decode cache. It also allows batch expansion - you can prefill once with batch=1, then expand to generate multiple completions in parallel.