Practical Notes on Training NanoChat
A practical walkthrough for building a tiny ChatGPT
ChatGPT is now part of everyday life, and Karpathy’s NanoGPT [Github] offers a minimal codebase for learning how a small GPT is trained. After studying NanoChat, I gathered the key techniques that helped me. These brief notes are for myself and anyone seeking a simple starting point to build a tiny ChatGPT.
nanogpt
Inside the Lightweight GPT Architecture: gpt.py
MQA + RoPE NanoChat employs multi-query attention (MQA) [Link] to significantly reduce KV cache memory usage during inference. It further incorporates rotary positional embeddings (RoPE) to enhance long-context generalization and enable extrapolation to unseen sequence lengths.
@dataclass
class GPTConfig:
...
# n_kv_head == n_head: MHA
# n_kv_head < n_head: MQA
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (MQA)
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
y2 = x1 * (-sin) + x2 * cos
out = torch.cat([y1, y2], 3) # re-assemble
out = out.to(x.dtype) # ensure input/output dtypes match
return out
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
...
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = norm(q), norm(k) # QK norm
# make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
# set causal=False during inference
y = F.scaled_dot_product_attention(q, k, v, is_causal=is_train, enable_gqa=enable_gqa)
AdamW + Muon NanoChat adopts a hybrid optimization strategy, using AdamW for the lightweight parameters in the embedding and output layers, and the SVD-inspired Muon optimizer for smoother and more stable training of the Transformer blocks.
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache)
x = x + self.mlp(norm(x))
return x
class GPT(nn.Module):
def __init__(self, config):
...
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # pay attention, output is vocab_size
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
...
matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
]
...
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
...
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
optimizers = [adamw_optimizer, muon_optimizer]
...
return optimizers
Running the Chat Engine: engine.py
The Engine class handles autoregressive text generation with efficient KV-cache reuse and multi-sample decoding. It also supports tool-augmented generation — detecting <|python_start|> and <|python_end|> code blocks, safely evaluating expressions via a built-in calculator, and injecting the computed results back into the output stream.
class RowState:
# Per-row state tracking during generation
def __init__(self, current_tokens=None):
...
# Why do we need forced_tokens: allow to inject non-model outputs into the token stream
# For a prompt: <|python_start|>2 + 3 * (4 - 1)<|python_end|>
# The engine detects this Python block, evaluates it → 11,
# and encodes that result as tokens:
self.forced_tokens = deque() # Queue of tokens to force inject
...
self.python_expr_tokens = [] # Tokens of the current python expression
class Engine:
@torch.inference_mode() # disables gradient tracking and reduces memory + compute overhead
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
...
# Get the special tokens we need to coordinate the tool use state machine
get_special = lambda s: self.tokenizer.encode_special(s)
python_start = get_special("<|python_start|>") # <|.|> ensures just one token
python_end = get_special("<|python_end|>")
output_start = get_special("<|output_start|>")
output_end = get_special("<|output_end|>")
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
# 1) Run a batch 1 prefill of the prompt tokens
m = self.model.config
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens),
**kv_model_kwargs,
) # only need one KV cache since there is only one prompt
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :]
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist() # tokens if batch > 1 otherwise token
# 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
kv_cache_decode = KVCache(
batch_size=num_samples,
seq_len=kv_length_hint, # use max_tokens instead of len(tokens)+1 to avoid re-initialize everytime
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
del kv_cache_prefill # no need to keep this memory around
# 3) Initialize states for each sample
row_states = [RowState(tokens.copy()) for _ in range(num_samples)] # different memory addresses
# 4) Main generation loop
num_generated = 0
first_iteration = True
while True:
# Stop condition: max tokens reached or all rows are completed
...
# Get sampled tokens - either from prefill or from forward pass
if not first_iteration:
...
# Forward the model and get the next token for each row
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size) at last time step
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# Process each row: choose the next token, update state, optional tool use
token_column = [] # contains the next token id along each row
for i, state in enumerate(row_states):
# Select the next token in this row
is_forced = len(state.forced_tokens) > 0 # tokens waiting to be forced in deque?
...
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
token_column.append(next_token)
# Update the state of this row to include the next token
state.current_tokens.append(next_token)
# On <|assistant_end|> or <|bos|>, mark the row as completed
if next_token == assistant_end or next_token == bos:
state.completed = True
# Handle tool logic
if next_token == python_start:
state.in_python_block = True
state.python_expr_tokens = []
elif next_token == python_end and state.in_python_block:
state.in_python_block = False
if state.python_expr_tokens:
expr = self.tokenizer.decode(state.python_expr_tokens)
result = use_calculator(expr) # avoid __import__('os').system('rm -rf /') ...
if result is not None:
result_tokens = self.tokenizer.encode(str(result))
state.forced_tokens.append(output_start)
state.forced_tokens.extend(result_tokens)
state.forced_tokens.append(output_end)
state.python_expr_tokens = []
elif state.in_python_block:
state.python_expr_tokens.append(next_token)
...
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
Reducing the Tokenizer Space: tokenizer.py
English has ~170k common words and adding variants, punctuation, slang, numbers, and code tokens leads to millions of unique words, which ends up with a huge embedding matrix. BPE tokenizers reuse subwords (like comput, ation) to cover the space compactly — typically 30k–100k tokens total.
Here’s a quick view of how BPE learns. Take the tiny dataset hello hello and world hello. The tokenizer begins at the character level: h e l l o, etc. It then repeatedly merges the most frequent adjacent characters—first l l → ll (3 times), then e ll → ell, then h ell → hell, and so on. After a few rounds, hello becomes a single token. The same process turns world into a token. In the end, BPE discovers hello and world as its core learned units.
SPLIT_PATTERN = (
r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| """
r"""?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
)
...
@classmethod
def train_from_iterator(cls, text_iterator, vocab_size):
# train from an iterator of text
tokenizer = HFTokenizer(BPE(
byte_fallback=True, # needed!
unk_token=None,
fuse_unk=False,
))
tokenizer.normalizer = None
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
])
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
tokenizer.decoder = decoders.ByteLevel()
...
trainer = BpeTrainer(
vocab_size=vocab_size,
show_progress=True,
min_frequency=0, # no minimum frequency
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
special_tokens=SPECIAL_TOKENS,
)
# Kick off the training
tokenizer.train_from_iterator(text_iterator, trainer)
return cls(tokenizer)
DataLoader
# 1) Parquet files (indexed by pq_idx)
# pq_idx = 0 → part-00000.parquet
# pq_idx = 1 → part-00001.parquet
# pq_idx = 2 → part-00002.parquet ← example below
# 2) Zoom into part-00002.parquet (pq_idx = 2)
# ┌── part-00002.parquet ────────────────────────────────┐
# │ RowGroup 0 (rg_idx = 0) → text: [t0, ..., t9] │
# │ RowGroup 1 (rg_idx = 1) → text: [t10, ..., t19] │
# │ RowGroup 2 (rg_idx = 2) → text: [t20, ..., t29] │
# │ RowGroup 3 (rg_idx = 3) → text: [t30, ..., t39] │
# └──────────────────────────────────────────────────────┘
# (Each RowGroup is just a list of text lines.)
# 3) Distributed training (world_size = 2)
# GPU0 (rank = 0) → picks RowGroups: 0, 2, ...
# GPU1 (rank = 1) → picks RowGroups: 1, 3, ...
# Example for GPU0:
# rg_idx = 0
# batch = [t0, t1, t2, ..., t9] # 10 text lines
# 4) Split one RowGroup into small doc_batches (for tokenizer)
# Suppose tokenizer_batch_size = 4.
# Then batch (10 lines) becomes:
# batch[0:4] → doc_batch_0
# batch[4:8] → doc_batch_1
# batch[8:10] → doc_batch_2
# Each doc_batch is sent to:
# tokenizer.encode(doc_batch)
def tokenizing_distributed_data_loader_with_state(
B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None
):
...
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
def document_batches():
parquet_paths = list_parquet_files()
...
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
while True: # iterate infinitely (multi-epoch)
while pq_idx < len(parquet_paths): # iterate over all parquet files
filepath = parquet_paths[pq_idx]
pf = pq.ParquetFile(filepath)
# Start from resume point if resuming on same file, otherwise from DDP rank
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
if resume_rg_idx is not None:
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank
resume_rg_idx = None # set to None as we only want to do this a single time
else:
rg_idx = ddp_rank
while rg_idx < pf.num_row_groups:
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
rg_idx += ddp_world_size # advance to the next row group (in DDP)
pq_idx += 1 # advance to the next parquet file
This code streams raw text from Parquet files, tokenizes it into one long continuous token sequence. They take a chunk of tokens, use scratch[:-1] as the input, and use scratch[1:] (the same tokens shifted one step) as the labels.
def tokenizing_distributed_data_loader_with_state():
...
batches = document_batches()
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
...
token_buffer = deque() # we stream tokens on the right and pop from the left
while True:
# Accumulate enough tokens for one iteration before yielding.
while len(token_buffer) < needed_tokens:
doc_batch, (pq_idx, rg_idx) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
token_buffer.extend(tokens)
# Move tokens from the deque into the scratch buffer
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
...
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations)
# Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1]
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # in case we wish to approximately resume training
yield inputs, targets, state_dict
TBD