sankalp's blog

Code walkthrough - KV caching in nanochat

We will refer the code in nanochat repository. I recommend exploring the code yourself. I will try to highlight the relevant sections only.

The KVCache class

The KV cache is a 6-dimensional tensor that stores keys and values for all transformer layers:

# nanochat/engine.py:83-93

class KVCache:
    def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
        # Shape: (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
        #                     ^ 0=keys, 1=values
        self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
        self.kv_cache = None  # Lazy init - allocated on first insert
        self.pos = 0  # Current timestep

The cache is lazily initialized - memory not allocated until the first forward pass, when we know the actual dtype and device.

Inserting K/V into the cache

nanochat includes the main code for concatenating kv cache and memory allocation in the engine.py file inside the KVCacheClass#insert_kv

This is called by each attention layer during the forward pass:

# nanochat/engine.py:129-154

def insert_kv(self, layer_idx, k, v):
    # Lazy initialize 
    if self.kv_cache is None:
        self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)

    B, H, T_add, D = k.size()
    t0, t1 = self.pos, self.pos + T_add

    # Insert k, v into the cache at current position
    self.kv_cache[layer_idx, 0, :, :, t0:t1] = k  # Keys
    self.kv_cache[layer_idx, 1, :, :, t0:t1] = v  # Values

    key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
    value_view = self.kv_cache[layer_idx, 1, :, :, :t1]

    # Only increment pos after the LAST layer processes
    # to ensure same cache during a forward pass
    if layer_idx == self.kv_cache.size(0) - 1:
        self.pos = t1

    return key_view, value_view

During decode, k/v is passed just for just the new token. That is concatenated as you can observe that we insert new K/V at position t0:t1 and we return everything from 0 to t1.

dynamic cache allocation

nanochat has a C++ vector like dynamic growth policy. It grows dynamically if generation exceeds the pre-allocated seq_len:

# Also in insert_kv - dynamic cache growth:

if t1 > self.kv_cache.size(4):  # need more space?
    t_needed = t1 + 1024                    # add 1024 token buffer
    t_needed = (t_needed + 1023) & ~1023    # round up to nearest 1024

    additional_shape = list(self.kv_cache.shape)
    additional_shape[4] = t_needed - self.kv_cache.size(4)
    additional_cache = torch.empty(additional_shape, dtype=k.dtype, device=k.device)

    self.kv_cache = torch.cat([self.kv_cache, additional_cache], dim=4)

How attention uses the cache

In CausalSelfAttention.forward, the cache is used like this:

# nanochat/gpt.py:66-84

def forward(self, x, cos_sin, kv_cache):
    B, T, C = x.size()

    # Project to get Q, K, V for current input
    q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
    k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
    v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)

    # ... rotary embeddings and normalization ...

    # insert new k/v, get cached + new k/v
    if kv_cache is not None:
        k, v = kv_cache.insert_kv(self.layer_idx, k, v)

    Tq = q.size(2)  # number of queries in this forward pass
    Tk = k.size(2)  # number of keys in total (cached + current)

During decode: q has shape (B, H, 1, D) (single new token), but after insert_kv, k and v have shape (B, H, full_seq_len, D).

The three attention cases

Nanochat attention code handles three distinct cases based on the relationship between query length (Tq) and key length (Tk). Third one is kinda out of scope of my blog post.

# nanochat/gpt.py:86-105

# Case 1: Training or full prefill (Tq == Tk)
if kv_cache is None or Tq == Tk:
    # Standard causal attention - each token attends to itself and all previous tokens
    # classic "lower triangular" attention mask
    y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

# Case 2: Single-token decode (Tq == 1) - THE HOT PATH
elif Tq == 1:
    # Single query attending to all cached keys
    # No causal mask needed! The lone query is the last position,
    # so it's allowed to see everything before it anyway
    y = F.scaled_dot_product_attention(q, k, v, is_causal=False)

# Case 3: Chunked decode (1 < Tq < Tk)
else:
    # Multiple new queries attending to cached prefix + each other
    # Need a custom mask: full attention to prefix, causal within chunk
    attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device)
    prefix_len = Tk - Tq

    # All queries can see the entire cached prefix
    attn_mask[:, :prefix_len] = True

    # Within the new chunk, apply causal masking
    attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool))

    y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)

Some visualizations done with help of Claude Opus 4.5 -

Case 1: Prefill (Tq=5, Tk=5) - Standard causal mask

     k0  k1  k2  k3  k4
q0 [  1   0   0   0   0 ]
q1 [  1   1   0   0   0 ]
q2 [  1   1   1   0   0 ]
q3 [  1   1   1   1   0 ]
q4 [  1   1   1   1   1 ]

Case 2: Single-token decode (Tq=1, Tk=6) - No mask needed

     k0  k1  k2  k3  k4  k5
q5 [  1   1   1   1   1   1 ]   <- single query sees everything

Case 3: Chunked decode (Tq=3, Tk=8) - Hybrid mask

     k0  k1  k2  k3  k4 | k5  k6  k7
     <-- prefix (5) --> | <-- chunk (3) -->
q5 [  1   1   1   1   1 |  1   0   0 ]
q6 [  1   1   1   1   1 |  1   1   0 ]
q7 [  1   1   1   1   1 |  1   1   1 ]
                        ^
                   prefix_len boundary

Case 2 is the decode hot path - every generated token goes through this. The key optimization: since there's only one query and it's at the last position, we skip mask computation entirely.

The prefill → decode flow

In Engine.generate, the two-phase process:

# nanochat/engine.py:210-232

# PHASE 1: Prefill with batch size 1
kv_cache_prefill = KVCache(
    batch_size=1,
    seq_len=len(tokens),  # exact prompt length
    **kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)  # fills the cache

# PHASE 2: Create decode cache, copy prefill data
kv_cache_decode = KVCache(
    batch_size=num_samples,  # can expand for parallel generation
    seq_len=kv_length_hint,  # longer to accommodate generated tokens
    **kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)  # copy over the computed K/V
del kv_cache_prefill  # free memory

The prefill() method copies the computed K/V tensors from the prefill cache to the decode cache. It also allows batch expansion - you can prefill once with batch=1, then expand to generate multiple completions in parallel.