Practical Notes on Training NanoChat

A practical walkthrough for building a tiny ChatGPT

ChatGPT is now part of everyday life, and Karpathy’s NanoGPT [Github] offers a minimal codebase for learning how a small GPT is trained. After studying NanoChat, I gathered the key techniques that helped me. These brief notes are for myself and anyone seeking a simple starting point to build a tiny ChatGPT.

nanogpt

Inside the Lightweight GPT Architecture: gpt.py

MQA + RoPE NanoChat employs multi-query attention (MQA) [Link] to significantly reduce KV cache memory usage during inference. It further incorporates rotary positional embeddings (RoPE) to enhance long-context generalization and enable extrapolation to unseen sequence lengths.

@dataclass
class GPTConfig:
    ...
    # n_kv_head == n_head: MHA
    # n_kv_head <  n_head: MQA
    n_head: int = 6 # number of query heads
    n_kv_head: int = 6 # number of key/value heads (MQA)

def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4  # multihead attention
    d = x.shape[3] // 2
    x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
    y1 = x1 * cos + x2 * sin # rotate pairs of dims
    y2 = x1 * (-sin) + x2 * cos
    out = torch.cat([y1, y2], 3) # re-assemble
    out = out.to(x.dtype) # ensure input/output dtypes match
    return out

class CausalSelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        ...
        self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
    
    def forward(self, x, cos_sin, kv_cache):
        B, T, C = x.size()

        # Project the input to get queries, keys, and values
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)

        # Apply Rotary Embeddings to queries and keys to get relative positional encoding
        cos, sin = cos_sin
        q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
        q, k = norm(q), norm(k) # QK norm
        # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 
        if kv_cache is not None:
            k, v = kv_cache.insert_kv(self.layer_idx, k, v)
        # set causal=False during inference
        y = F.scaled_dot_product_attention(q, k, v, is_causal=is_train, enable_gqa=enable_gqa) 

AdamW + Muon NanoChat adopts a hybrid optimization strategy, using AdamW for the lightweight parameters in the embedding and output layers, and the SVD-inspired Muon optimizer for smoother and more stable training of the Transformer blocks.

class Block(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.attn = CausalSelfAttention(config, layer_idx)
        self.mlp = MLP(config)

    def forward(self, x, cos_sin, kv_cache):
        x = x + self.attn(norm(x), cos_sin, kv_cache)
        x = x + self.mlp(norm(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        ...
        self.transformer = nn.ModuleDict({
            "wte": nn.Embedding(config.vocab_size, config.n_embd),
            "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
        })
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # pay attention, output is vocab_size
    
    def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
        ...
        matrix_params = list(self.transformer.h.parameters())
        embedding_params = list(self.transformer.wte.parameters())
        lm_head_params = list(self.lm_head.parameters())
        # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
        dmodel_lr_scale = (model_dim / 768) ** -0.5
        adam_groups = [
            dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
            dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
        ]
        ...
        adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
        ...
        muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
        optimizers = [adamw_optimizer, muon_optimizer]
        ...
        return optimizers

Running the Chat Engine: engine.py

The Engine class handles autoregressive text generation with efficient KV-cache reuse and multi-sample decoding. It also supports tool-augmented generation — detecting <|python_start|> and <|python_end|> code blocks, safely evaluating expressions via a built-in calculator, and injecting the computed results back into the output stream.

class RowState:
    # Per-row state tracking during generation
    def __init__(self, current_tokens=None):
        ...
        # Why do we need forced_tokens: allow to inject non-model outputs into the token stream 
        # For a prompt: <|python_start|>2 + 3 * (4 - 1)<|python_end|>
        # The engine detects this Python block, evaluates it → 11,
        # and encodes that result as tokens:
        self.forced_tokens = deque() # Queue of tokens to force inject 
        ...
        self.python_expr_tokens = [] # Tokens of the current python expression

class Engine:
    @torch.inference_mode() # disables gradient tracking and reduces memory + compute overhead
    def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
        ...
        # Get the special tokens we need to coordinate the tool use state machine
        get_special = lambda s: self.tokenizer.encode_special(s)
        python_start = get_special("<|python_start|>") # <|.|> ensures just one token
        python_end = get_special("<|python_end|>")
        output_start = get_special("<|output_start|>")
        output_end = get_special("<|output_end|>")
        assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
        bos = self.tokenizer.get_bos_token_id() # if sampled, ends row

        # 1) Run a batch 1 prefill of the prompt tokens
        m = self.model.config
        kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
        kv_cache_prefill = KVCache( 
            batch_size=1,
            seq_len=len(tokens),
            **kv_model_kwargs,
        ) # only need one KV cache since there is only one prompt
        ids = torch.tensor([tokens], dtype=torch.long, device=device)
        logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
        logits = logits[:, -1, :]
        next_ids = sample_next_token(logits, rng, temperature, top_k)  # (B, 1)
        sampled_tokens = next_ids[:, 0].tolist() # tokens if batch > 1 otherwise token

        # 2) Replicate the KV cache for each sample/row
        kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
        kv_cache_decode = KVCache(
            batch_size=num_samples,
            seq_len=kv_length_hint, # use max_tokens instead of len(tokens)+1 to avoid re-initialize everytime
            **kv_model_kwargs,
        )
        kv_cache_decode.prefill(kv_cache_prefill)
        del kv_cache_prefill # no need to keep this memory around

        # 3) Initialize states for each sample
        row_states = [RowState(tokens.copy()) for _ in range(num_samples)] # different memory addresses

        # 4) Main generation loop
        num_generated = 0
        first_iteration = True
        while True:
            # Stop condition: max tokens reached or all rows are completed
            ...

            # Get sampled tokens - either from prefill or from forward pass
            if not first_iteration:
                ...
                # Forward the model and get the next token for each row
                logits = self.model.forward(ids, kv_cache=kv_cache_decode)  # (B, T, vocab_size)
                logits = logits[:, -1, :]  # (B, vocab_size) at last time step
                next_ids = sample_next_token(logits, rng, temperature, top_k)  # (B, 1)
                sampled_tokens = next_ids[:, 0].tolist()

            # Process each row: choose the next token, update state, optional tool use
            token_column = [] # contains the next token id along each row
            for i, state in enumerate(row_states):
                # Select the next token in this row
                is_forced = len(state.forced_tokens) > 0 # tokens waiting to be forced in deque?
                ...
                next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
                token_column.append(next_token)
                # Update the state of this row to include the next token
                state.current_tokens.append(next_token)
                # On <|assistant_end|> or <|bos|>, mark the row as completed
                if next_token == assistant_end or next_token == bos:
                    state.completed = True
                # Handle tool logic
                if next_token == python_start:
                    state.in_python_block = True
                    state.python_expr_tokens = []
                elif next_token == python_end and state.in_python_block:
                    state.in_python_block = False
                    if state.python_expr_tokens:
                        expr = self.tokenizer.decode(state.python_expr_tokens)
                        result = use_calculator(expr) # avoid __import__('os').system('rm -rf /') ...
                        if result is not None:
                            result_tokens = self.tokenizer.encode(str(result))
                            state.forced_tokens.append(output_start)
                            state.forced_tokens.extend(result_tokens)
                            state.forced_tokens.append(output_end)
                    state.python_expr_tokens = []
                elif state.in_python_block:
                    state.python_expr_tokens.append(next_token)
            ...
            ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)

Reducing the Tokenizer Space: tokenizer.py

English has ~170k common words and adding variants, punctuation, slang, numbers, and code tokens leads to millions of unique words, which ends up with a huge embedding matrix. BPE tokenizers reuse subwords (like comput, ation) to cover the space compactly — typically 30k–100k tokens total.

Here’s a quick view of how BPE learns. Take the tiny dataset hello hello and world hello. The tokenizer begins at the character level: h e l l o, etc. It then repeatedly merges the most frequent adjacent characters—first l l → ll (3 times), then e ll → ell, then h ell → hell, and so on. After a few rounds, hello becomes a single token. The same process turns world into a token. In the end, BPE discovers hello and world as its core learned units.

SPLIT_PATTERN = (
    r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| """
    r"""?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
)
...
@classmethod
def train_from_iterator(cls, text_iterator, vocab_size):
    # train from an iterator of text
    tokenizer = HFTokenizer(BPE(
        byte_fallback=True, # needed!
        unk_token=None,
        fuse_unk=False,
    ))
    tokenizer.normalizer = None
    gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
    tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
        pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
        pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
    ])
    # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
    tokenizer.decoder = decoders.ByteLevel()
    ...
    trainer = BpeTrainer(
        vocab_size=vocab_size,
        show_progress=True,
        min_frequency=0, # no minimum frequency
        initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
        special_tokens=SPECIAL_TOKENS,
    )
    # Kick off the training
    tokenizer.train_from_iterator(text_iterator, trainer)
    return cls(tokenizer)

DataLoader

# 1) Parquet files (indexed by pq_idx)
#     pq_idx = 0  →  part-00000.parquet
#     pq_idx = 1  →  part-00001.parquet
#     pq_idx = 2  →  part-00002.parquet   ← example below

# 2) Zoom into part-00002.parquet (pq_idx = 2)
#     ┌── part-00002.parquet ────────────────────────────────┐
#     │ RowGroup 0 (rg_idx = 0) → text: [t0,  ..., t9]       │
#     │ RowGroup 1 (rg_idx = 1) → text: [t10, ..., t19]      │
#     │ RowGroup 2 (rg_idx = 2) → text: [t20, ..., t29]      │
#     │ RowGroup 3 (rg_idx = 3) → text: [t30, ..., t39]      │
#     └──────────────────────────────────────────────────────┘
#     (Each RowGroup is just a list of text lines.)

# 3) Distributed training (world_size = 2)
#     GPU0 (rank = 0) → picks RowGroups: 0, 2, ...
#     GPU1 (rank = 1) → picks RowGroups: 1, 3, ...

#     Example for GPU0:
#         rg_idx = 0
#         batch = [t0, t1, t2, ..., t9]       # 10 text lines


# 4) Split one RowGroup into small doc_batches (for tokenizer)
#     Suppose tokenizer_batch_size = 4.
#     Then batch (10 lines) becomes:
#         batch[0:4]  → doc_batch_0
#         batch[4:8]  → doc_batch_1
#         batch[8:10] → doc_batch_2
#     Each doc_batch is sent to:
#         tokenizer.encode(doc_batch)

def tokenizing_distributed_data_loader_with_state(
    B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None
):
    ...
    ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
    def document_batches():
        parquet_paths = list_parquet_files()
        ...
        resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
        resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
        pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
        while True: # iterate infinitely (multi-epoch)
            while pq_idx < len(parquet_paths): # iterate over all parquet files
                filepath = parquet_paths[pq_idx]
                pf = pq.ParquetFile(filepath)
                # Start from resume point if resuming on same file, otherwise from DDP rank
                # I know this state resumption is a little bit tricky and a little bit hacky... sigh.
                if resume_rg_idx is not None:
                    base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
                    base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
                    rg_idx = base_idx * ddp_world_size + ddp_rank
                    resume_rg_idx = None # set to None as we only want to do this a single time
                else:
                    rg_idx = ddp_rank
                while rg_idx < pf.num_row_groups:
                    rg = pf.read_row_group(rg_idx)
                    batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
                    # the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
                    for i in range(0, len(batch), tokenizer_batch_size):
                        yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
                    rg_idx += ddp_world_size # advance to the next row group (in DDP)
                pq_idx += 1 # advance to the next parquet file

This code streams raw text from Parquet files, tokenizes it into one long continuous token sequence. They take a chunk of tokens, use scratch[:-1] as the input, and use scratch[1:] (the same tokens shifted one step) as the labels.

def tokenizing_distributed_data_loader_with_state():
    ...
    batches = document_batches()
    needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
    ...
    token_buffer = deque() # we stream tokens on the right and pop from the left
    while True:
        # Accumulate enough tokens for one iteration before yielding.
        while len(token_buffer) < needed_tokens:
            doc_batch, (pq_idx, rg_idx) = next(batches)
            token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
            for tokens in token_lists:
                token_buffer.extend(tokens)
        # Move tokens from the deque into the scratch buffer
        tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
        ...
        scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) 
        # Create the inputs/targets as 1D tensors
        inputs_cpu = scratch[:-1]
        targets_cpu = scratch[1:]
        # Reshape to 2D and move to GPU async
        inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
        targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
        state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # in case we wish to approximately resume training
        yield inputs, targets, state_dict

TBD