Notes on Training NanoChat
A practical walkthrough for building a tiny ChatGPT
ChatGPT is now part of everyday life, and Karpathy’s NanoGPT (Karpathy, Oct 2025) offers a minimal codebase for learning how a small GPT is trained. After studying NanoChat, I gathered the key techniques that helped me. These brief notes are for myself and anyone seeking a simple starting point to build a tiny ChatGPT.
nanogpt: Model, Engine, Dataloader, Tokenizer
gpt.py: Inside the Lightweight GPT Architecture
MQA + RoPE NanoChat employs multi-query attention (MQA) [Link] to significantly reduce KV cache memory usage during inference. It further incorporates rotary positional embeddings (RoPE) to enhance long-context generalization and enable extrapolation to unseen sequence lengths.
@dataclass
class GPTConfig:
...
# n_kv_head == n_head: MHA
# n_kv_head < n_head: MQA
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (MQA)
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
y2 = x1 * (-sin) + x2 * cos
out = torch.cat([y1, y2], 3) # re-assemble
out = out.to(x.dtype) # ensure input/output dtypes match
return out
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
...
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
q, k = norm(q), norm(k) # QK norm
# make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
# set causal=False during inference
y = F.scaled_dot_product_attention(q, k, v, is_causal=is_train, enable_gqa=enable_gqa)
AdamW + Muon NanoChat adopts a hybrid optimization strategy, using AdamW for the lightweight parameters in the embedding and output layers, and the SVD-inspired Muon optimizer for smoother and more stable training of the Transformer blocks.
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache)
x = x + self.mlp(norm(x))
return x
class GPT(nn.Module):
def __init__(self, config):
...
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # pay attention, output is vocab_size
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
...
matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
]
...
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
...
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
optimizers = [adamw_optimizer, muon_optimizer]
...
return optimizers
engine.py: Running the Chat Engine
The Engine class handles autoregressive text generation with efficient KV-cache reuse and multi-sample decoding. It also supports tool-augmented generation — detecting <|python_start|> and <|python_end|> code blocks, safely evaluating expressions via a built-in calculator, and injecting the computed results back into the output stream.
class RowState:
# Per-row state tracking during generation
def __init__(self, current_tokens=None):
...
# Why do we need forced_tokens: allow to inject non-model outputs into the token stream
# For a prompt: <|python_start|>2 + 3 * (4 - 1)<|python_end|>
# The engine detects this Python block, evaluates it → 11,
# and encodes that result as tokens:
self.forced_tokens = deque() # Queue of tokens to force inject
...
self.python_expr_tokens = [] # Tokens of the current python expression
class Engine:
@torch.inference_mode() # disables gradient tracking and reduces memory + compute overhead
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
...
# Get the special tokens we need to coordinate the tool use state machine
get_special = lambda s: self.tokenizer.encode_special(s)
python_start = get_special("<|python_start|>") # <|.|> ensures just one token
python_end = get_special("<|python_end|>")
output_start = get_special("<|output_start|>")
output_end = get_special("<|output_end|>")
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
# 1) Run a batch 1 prefill of the prompt tokens
m = self.model.config
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens),
**kv_model_kwargs,
) # only need one KV cache since there is only one prompt
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :]
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist() # tokens if batch > 1 otherwise token
# 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
kv_cache_decode = KVCache(
batch_size=num_samples,
seq_len=kv_length_hint, # use max_tokens instead of len(tokens)+1 to avoid re-initialize everytime
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
del kv_cache_prefill # no need to keep this memory around
# 3) Initialize states for each sample
row_states = [RowState(tokens.copy()) for _ in range(num_samples)] # different memory addresses
# 4) Main generation loop
num_generated = 0
first_iteration = True
while True:
# Stop condition: max tokens reached or all rows are completed
...
# Get sampled tokens - either from prefill or from forward pass
if not first_iteration:
...
# Forward the model and get the next token for each row
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size) at last time step
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# Process each row: choose the next token, update state, optional tool use
token_column = [] # contains the next token id along each row
for i, state in enumerate(row_states):
# Select the next token in this row
is_forced = len(state.forced_tokens) > 0 # tokens waiting to be forced in deque?
...
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
token_column.append(next_token)
# Update the state of this row to include the next token
state.current_tokens.append(next_token)
# On <|assistant_end|> or <|bos|>, mark the row as completed
if next_token == assistant_end or next_token == bos:
state.completed = True
# Handle tool logic
if next_token == python_start:
state.in_python_block = True
state.python_expr_tokens = []
elif next_token == python_end and state.in_python_block:
state.in_python_block = False
if state.python_expr_tokens:
expr = self.tokenizer.decode(state.python_expr_tokens)
result = use_calculator(expr) # avoid __import__('os').system('rm -rf /') ...
if result is not None:
result_tokens = self.tokenizer.encode(str(result))
state.forced_tokens.append(output_start)
state.forced_tokens.extend(result_tokens)
state.forced_tokens.append(output_end)
state.python_expr_tokens = []
elif state.in_python_block:
state.python_expr_tokens.append(next_token)
...
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)
tokenizer.py: Reducing the Tokenizer Space
English has ~170k common words and adding variants, punctuation, slang, numbers, and code tokens leads to millions of unique words, which ends up with a huge embedding matrix. BPE tokenizers reuse subwords (like comput, ation) to cover the space compactly — typically 30k–100k tokens total.
Here’s a quick view of how BPE learns. Take the tiny dataset hello hello and world hello. The tokenizer begins at the character level: h e l l o, etc. It then repeatedly merges the most frequent adjacent characters—first l l → ll (3 times), then e ll → ell, then h ell → hell, and so on. After a few rounds, hello becomes a single token. The same process turns world into a token. In the end, BPE discovers hello and world as its core learned units.
Here is an OpenAI link that shows how tokenizer works.
SPLIT_PATTERN = (
r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| """
r"""?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
)
...
@classmethod
def train_from_iterator(cls, text_iterator, vocab_size):
# train from an iterator of text
tokenizer = HFTokenizer(BPE(
byte_fallback=True, # needed!
unk_token=None,
fuse_unk=False,
))
tokenizer.normalizer = None
gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
])
# Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
tokenizer.decoder = decoders.ByteLevel()
...
trainer = BpeTrainer(
vocab_size=vocab_size,
show_progress=True,
min_frequency=0, # no minimum frequency
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
special_tokens=SPECIAL_TOKENS,
)
# Kick off the training
tokenizer.train_from_iterator(text_iterator, trainer)
return cls(tokenizer)
dataloader.py: From Parquet Files to Tokenizer Batches
# 1) Parquet files (indexed by pq_idx)
# pq_idx = 0 → part-00000.parquet
# pq_idx = 1 → part-00001.parquet
# pq_idx = 2 → part-00002.parquet ← example below
# 2) Zoom into part-00002.parquet (pq_idx = 2)
# ┌── part-00002.parquet ────────────────────────────────┐
# │ RowGroup 0 (rg_idx = 0) → text: [t0, ..., t9] │
# │ RowGroup 1 (rg_idx = 1) → text: [t10, ..., t19] │
# │ RowGroup 2 (rg_idx = 2) → text: [t20, ..., t29] │
# │ RowGroup 3 (rg_idx = 3) → text: [t30, ..., t39] │
# └──────────────────────────────────────────────────────┘
# (Each RowGroup is just a list of text lines.)
# 3) Distributed training (world_size = 2)
# GPU0 (rank = 0) → picks RowGroups: 0, 2, ...
# GPU1 (rank = 1) → picks RowGroups: 1, 3, ...
# Example for GPU0:
# rg_idx = 0
# batch = [t0, t1, t2, ..., t9] # 10 text lines
# 4) Split one RowGroup into small doc_batches (for tokenizer)
# Suppose tokenizer_batch_size = 4.
# Then batch (10 lines) becomes:
# batch[0:4] → doc_batch_0
# batch[4:8] → doc_batch_1
# batch[8:10] → doc_batch_2
# Each doc_batch is sent to:
# tokenizer.encode(doc_batch)
def tokenizing_distributed_data_loader_with_state(
B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None
):
...
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
def document_batches():
parquet_paths = list_parquet_files()
...
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
while True: # iterate infinitely (multi-epoch)
while pq_idx < len(parquet_paths): # iterate over all parquet files
filepath = parquet_paths[pq_idx]
pf = pq.ParquetFile(filepath)
# Start from resume point if resuming on same file, otherwise from DDP rank
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
if resume_rg_idx is not None:
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank
resume_rg_idx = None # set to None as we only want to do this a single time
else:
rg_idx = ddp_rank
while rg_idx < pf.num_row_groups:
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
rg_idx += ddp_world_size # advance to the next row group (in DDP)
pq_idx += 1 # advance to the next parquet file
This code streams raw text from Parquet files, tokenizes it into one long continuous token sequence. They take a chunk of tokens, use scratch[:-1] as the input, and use scratch[1:] (the same tokens shifted one step) as the labels.
def tokenizing_distributed_data_loader_with_state():
...
batches = document_batches()
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
...
token_buffer = deque() # we stream tokens on the right and pop from the left
while True:
# Accumulate enough tokens for one iteration before yielding.
while len(token_buffer) < needed_tokens:
doc_batch, (pq_idx, rg_idx) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
token_buffer.extend(tokens)
# Move tokens from the deque into the scratch buffer
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
...
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations)
# Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1]
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # for approximately resume training
yield inputs, targets, state_dict
execution.py: Sandboxed Code Execution
def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
"""Execute code in a subprocess with safety guards. Results are written to result_dict."""
...
def execute_code(
code: str,
timeout: float = 5.0, # 5 seconds default
maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
) -> ExecutionResult:
"""
Execute Python code in a sandboxed environment.
Example:
>>> result = execute_code("print('hello world')")
>>> result.success
True
>>> result.stdout
'hello world\\n'
"""
manager = multiprocessing.Manager()
result_dict = manager.dict()
p = multiprocessing.Process(
target=_unsafe_execute,
args=(code, timeout, maximum_memory_bytes, result_dict)
)
p.start()
p.join(timeout=timeout + 1)
if p.is_alive():
p.kill()
return ExecutionResult(
success=False,
stdout="",
stderr="",
error="Execution timed out (process killed)",
timeout=True,
memory_exceeded=False,
)
...
return ExecutionResult(
success=result_dict["success"],
stdout=result_dict["stdout"],
stderr=result_dict["stderr"],
error=result_dict["error"],
timeout=result_dict["timeout"],
memory_exceeded=result_dict["memory_exceeded"],
)
core_eval.py: Task Evaluation Logic
# 1) For each data example (item)
# Optionally sample few-shot examples (excluding this item).
# 2) Render prompts (depends on task_type)
# - MC:
# [Context] + Delimiter + Choice A
# [Context] + Delimiter + Choice B
# ...
# - Schema:
# Context_1 + Delimiter + Continuation
# Context_2 + Delimiter + Continuation
# ...
# - LM:
# prompt_without = Context + Delimiter
# prompt_with = Context + Delimiter + Continuation
# 3) Tokenize prompts and locate the continuation span
# - MC: continuation = tokens after common prefix
# - Schema: continuation = tokens in common suffix
# - LM: continuation = tokens added by prompt_with
# Produce:
# tokens, start_idx[i], end_idx[i]
# 4) Batch + pad sequences
# input_ids = stack_sequences(tokens)
# If model has max_seq_len → crop from left and shift indices.
# 5) Forward model
# losses, predictions = forward_model(model, input_ids)
# losses: per-token cross entropy
# predictions: argmax tokens
# 6) Decide correctness
# - LM:
# Compare predicted continuation tokens with ground truth.
# - MC / Schema:
# For each option:
# score_i = mean(losses[i, si-1 : ei-1])
# Pick lowest score; compare with gold.
# 7) Distributed evaluation
# Each rank handles a stride of examples.
# dist.all_reduce to combine results.
# Final accuracy = mean(correct).
loss_eval.py: Byte-Normalized Loss
The loss computes bits-per-byte (BPB) by summing token NLLs and normalizing by the total raw-text bytes represented by valid targets. It excludes special tokens and ignore_index entries to keep the metric tokenizer-independent.
scripts: Training, Mid-train, SFT, RL, and Eval Pipeline
base_train.py: Base Pretrain Script
Budget Planning
batch_per_gpu = 64 sequences
seq_len = 2048
world_size = 8 GPUs
grad_accum_steps = 1
By Chinchilla scaling laws (Team, 2023), to train a 1B model, we can choose
Model: 1B params
Data: ~20B tokens # scaling law
Compute: ~10^20 FLOPs # scaling law - figure
world_size = 8 GPUs
batch_per_gpu = 64 sequences
seq_len = 2048
grad_accum_steps = 1
Then tokens per training step:
\[\begin{align} &\text{Training compute per token: } \mathrm{FLOPs/token} \approx 6 \times 10^9 \notag \\ &\text{Tokens per step: } 8 \times 64 \times 2048 \approx 10^6 \notag \\ &\text{Compute per step: FLOPs/token} \times \text{Tokens per step} \approx 6 \times 10^{15} \notag \\ &\text{8 H100 FLOPs per step: } 8 \times 10^{15} \notag \\ &\text{MFU: } \frac{6 \times 10^{15}}{8 \times 10^{15}} \approx 75\%. \notag \end{align}\]The num_iterations can be set to
\[\begin{align} \text{steps}=\frac{20,000,000,000}{1,048,576}≈19,100. \notag \end{align}\]A more comprehensive cost evaluation can be seen in [Link].
The peak performance of NVIDIA GPU is shown in the datasheet of [Link].
FLOPS are not runtime (CS 336, Tatsu)
| Class | FLOP | Runtime |
|---|---|---|
| Tensor contraction | 99.5% | 61.0% |
| Stat. normalization | 0.17% | 25.5% |
| Element-wise | 0.03% | 13.5% |
if num_iterations > 0:
print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif target_flops > 0:
# calculate the number of iterations from the target flops
num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif target_param_data_ratio > 0:
# calculate the number of iterations from the target param data ratio
target_tokens = target_param_data_ratio * num_params
num_iterations = target_tokens // total_batch_size
print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
Hyperparameters
Learning rate schedule
1. gently warm up the learning rate
2. keep it flat at its maximum for most of training
3. linearly decay it to a small final value for stable convergence.
Momentum schedule Muon optimizer: 0.85 to stabilize early training
def get_muon_momentum(it):
frac = min(it / 300, 1)
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
The main training loop runs the full training process:
-
1. Gradient Accumulation
-
2. Automatic Mixed Precision (AMP)
Use torch.autocast with bfloat16 to speed up training and reduce memory usage.
-
3. Learning-Rate Scheduling + Muon Momentum Warmup
-
4. Gradient Clipping
-
5. Next-Batch Prefetching
CPU loads the next batch while GPU computes the current one:
with autocast_ctx:
loss = model(x, y)
loss.backward()
x, y, dataloader_state_dict = next(train_loader)
- 6. Distributed Synchronization (DDP all_reduce + synchronize)
Because mid-training uses a mixture task with variable-length conversations, each GPU may reach the end of its portion of the dataset at different times. Without \(\text{dist.all_reduce(last_step_tensor, MAX)}\), one GPU may stop (because it hit the dataset end) while others continue training — causing the entire distributed job to deadlock. This reduction broadcasts the earliest stop signal so all GPUs exit the loop together.
-
7. torch.compile
-
8. Evaluation Hooks
1. evaluates validation bpb;
2. computes the CORE metric;
3. samples model outputs for eye-ball sanity check
4. save checkpoints.
- 9. EMA
It keeps a smoothed training loss using an exponential moving average (EMA):
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() - 10. Checkpointing with Full Training State
1. model weights
2. optimizer states
3. data loader state
4. loop state (loss, time, etc.) to fully resume training.
After every step, it synchronize() and logs key stats (loss, throughput, MFU, FLOPs, etc.) and continues until all planned iterations are completed.
mid_train.py: Skill-Oriented LM Refinement (Next-Token Style)
train_dataset = TaskMixture([
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows
Techniques used in the mid-training loop:
-
1. Gradient Accumulation
-
2. Automatic Mixed Precision (AMP)
Use torch.autocast with bfloat16 to speed up training and reduce memory usage.
- 3. Learning-Rate Scheduling + Muon Momentum Warmup
Flat for first 80% of progress, then linear decay to 0
init_lr_frac = 1.0
-
4. Gradient Clipping
-
5. Next-Batch Prefetching
CPU loads the next batch while GPU computes the current one:
with autocast_ctx:
loss = model(x, y)
loss.backward()
# Prefetch the next batch while GPU is still busy
x, y = next(train_loader)
- 6. Distributed Synchronization (DDP all_reduce + synchronize)
Because mid-training uses a mixture task with variable-length conversations, each GPU may reach the end of its portion of the dataset at different times. Without \(\text{dist.all_reduce(last_step_tensor, MAX)}\), one GPU may stop (because it hit the dataset end) while others continue training — causing the entire distributed job to deadlock. This reduction broadcasts the earliest stop signal so all GPUs exit the loop together.
-
7. torch.compile
-
8. Multi-Task Training Mixture
-
9. EMA It keeps a smoothed training loss using an exponential moving average (EMA):
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item()
chat_sft.py: Supervised Fine-Tuning (SFT) for Chat
Pretrain/mid-train data:
Input: "The capital of France is Pa"
Target: "r"
SFT data
Input: "<user> What is the capital of France?"
Output: "<assistant> The capital of France is Paris."
Target tokens = assistant’s answer only.
Data mixture is smaller and more “exam/chat” flavored:
train_ds = TaskMixture([
ARC("ARC-Easy", "train"),
ARC("ARC-Challenge", "train"),
GSM8K("main", "train"),
SmolTalk("train", stop=10_000),
CustomJSON(identity_conversations),
SimpleSpelling(size=300),
SpellingBee(size=300),
])
Techniques used in the SFT loop:
-
1. Gradient Accumulation
-
2. Automatic Mixed Precision (AMP)
-
3. Learning-Rate Scheduling + Small SFT LR
init_lr_frac = 0.02. Then linearly decays to 0 over all SFT steps.
- 4. Assistant-Only Supervision (Masked Targets)
Only assistant tokens contribute to the loss; other tokens are masked to -1:
pad_token_id = tokenizer.encode_special("<|assistant_end|>")
ids, mask = tokenizer.render_conversation(doc)
...
row_targets = ids_tensor[1:]
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
row_targets[mask_tensor == 0] = -1 # ignore non-assistant positions
- 5. Variable-Length Conversation Batching with Padding
ncols = max(len(ids) for ids, mask in batch) - 1
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
targets = torch.full((nrows, ncols), -1, dtype=torch.long)
# fill row-by-row, shorter rows remain padded
inputs[i, :n-1] = ids_tensor[:-1]
targets[i, :n-1] = row_targets
- 6. Distributed Training + Aggregation (DDP all_reduce)
num_tokens += (train_targets >= 0).sum()
if ddp:
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM)
-
7. Periodic Validation on Chat Data
-
8. Inline Task Metrics (MMLU, ARC) via Engine
chat_rl.py: Reinforcement learning on GSM8K via “GRPO” (REINFORCE)
The core logic implements a lightweight GRPO objective with no trust region and no PPO ratio/clip.
\[\begin{align} \mathrm{\max_{\theta} \;\; \mathbb{E}_{\mathbf{a} \sim \pi_\theta}\Bigg[ \sum_{t=1}^{T} \mathbf{1}_{\text{mask}_t=1} A_t \,\log \pi_\theta(a_t \mid s_t) \Bigg]},\notag \end{align}\]where \(\mathrm{A_t = R - \mu}\), $\mathrm{a_t}$ is the generated token at step $\mathrm{t}$, $\mathrm{s_t}$ is its preceding prefix, and $\mathrm{A_t}$ denotes the token-level advantage, we still need mask since we only evaluate the generated answers.
def get_input_target_reward_advatnage_batch():
...
for sampling_step in range(num_sampling_steps):
...
generated_token_sequences.extend(generated_token_sequences_batch)
masks.extend(masks_batch)
# Calculate the rewards for each sample
rewards = []
for sample_tokens in generated_token_sequences:
generated_tokens = sample_tokens[prefix_length:]
generated_text = tokenizer.decode(generated_tokens)
reward = train_task.reward(conversation, generated_text)
rewards.append(reward)
# Pad the sequences so that their lengths (in time) match
padded_generated_token_sequences = ...
padded_masks = ...
# Stack up the sequences and masks into PyTorch tensors
ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
# Generate autoregressive inputs and targets to the Transformer
inputs = ids[:, :-1]
targets = ids[:, 1:].clone() # clone to avoid in-place modification:
targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
rewards = torch.tensor(rewards, dtype=torch.float, device=device)
# Calculate the advantages by simply subtracting the mean (instead of z-score (x-mu)/sigma)
mu = rewards.mean()
advantages = rewards - mu
# yield inputs/targets as (B, T) of ids and rewards as (B,) of floats
yield generated_token_sequences, inputs, targets, rewards, advantages
Evaluate a model on GSM8K by computing pass@k style stats:
def run_gsm8k_eval(..., num_samples=1, max_completion_tokens=256, temperature=0.0, top_k=50):
for idx in range(ddp_rank, max_examples, ddp_world_size):
conversation = task[idx]
'''
U: What is 2+3?
A: 2+3=5.
U: What is 2+7?
'''
tokens = tokenizer.render_for_completion(conversation)
'''
<user> What is 2+3?
<assistant> 2+3=5.
<user> What is 2+7?
<assistant> ← model will generate here
'''
prefix_length = len(tokens)
...
generated_token_sequences, masks = engine.generate_batch(
tokens,
num_samples=num_samples,
max_tokens=max_completion_tokens,
temperature=temperature,
top_k=top_k
)
# Check each sample for correctness
outcomes = []
for sample_tokens in generated_token_sequences:
generated_tokens = sample_tokens[prefix_length:]
generated_text = tokenizer.decode(generated_tokens)
is_correct = task.evaluate(conversation, generated_text)
outcomes.append({"is_correct": is_correct})
# A bit bloated because I wanted to do more complex logging at one point.
record = {"idx": idx, "outcomes": outcomes}
yield record
The core training loop follows
examples_per_rank = examples_per_step // ddp_world_size # per GPU
batch_iterator = get_batch()
for step in range(num_steps):
...
# Forward/Backward on rollouts over multiple examples in the dataset
for example_step in range(examples_per_rank):
sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
model.train() # ensure the model is in train mode
...
for pass_idx in range(num_passes):
# Pluck out the batch for this pass
b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
inputs = inputs_all[b0:b1]
targets = targets_all[b0:b1]
rewards = rewards_all[b0:b1]
advantages = advantages_all[b0:b1]
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
with autocast_ctx:
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
# Calculate the PG objective. ignore_index=-1 ensures that invalid tokens have loss 0.
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
# normalize by the number of valid tokens, number of passes, and examples_per_rank
num_valid = (targets >= 0).sum().clamp(min=1)
pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
loss = -pg_obj
loss.backward()
...
for opt in optimizers: # then step the optimizers
opt.step()
- 1. Rollout with Multiple Samples per Prompt
Generate several completions per input using \(\text{engine.generate_batch(num_samples=device_batch_size, top_k=top_k,..)}\)
- 2. Create a reproducible per-rollout RNG seed by hashing
seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
- 2. Automatic Mixed Precision (AMP)
with autocast_ctx:
...
- 3. Masked LM-Style Targets with Ignore Index
Prompt tokens and padding are masked out.
targets[mask_ids[:, 1:] == 0] = -1
- 4. Distributed Aggregation of Metrics (DDP all_reduce)
dist.all_reduce(mean_reward_tensor, AVG)
-
5. Variable-Length Sequence Padding
-
6. Evaluator for Pass@k Metrics
Comparison: Pretrain vs Mid-Train vs SFT vs RL
| Category | Pretrain | Mid-Train | SFT | RL (GRPO) |
|---|---|---|---|---|
| Script | base_train.py |
mid_train.py |
chat_sft.py |
chat_rl.py |
| Data | Massive corpus |
SmolTalk, GSM8K, ... |
CustomJSON, GSM8K, ... |
Prompts + rollouts |
| Rows | 20B tokens for 1B model |
848K |
23K |
✗ |
| Target | Next tokens (MT) |
MT |
Masked MT |
Generated answer tokens |
| LR | warm up + flat (most) + linear decay to small |
Flat (most) + linear decay to 0 |
Start from small |
Start from mid-small |
| Masking | ✗ | ✗ | Mask non-assistant |
Mask prompt & padding |
| Sampling | ✗ | ✗ | ✗ | ✓ |
| Rewards | ✗ | ✗ | ✗ | ✓ |
| Sequence | Fixed |
Fixed |
Variable |
Variable rollouts |
| Areas | General Knowledge |
Math, QA, skills |
Instruction following, chat |
Correctness, logic |
Future Works
NanoGPT does not include the following components:
- LoRA
- MoE
- FlashAttention
Citation
@misc{deng2025nanochat,
title ={{Practical Notes on Training NanoChat}},
author ={Wei Deng},
journal ={waynedw.github.io},
year ={2025},
howpublished = {\url{https://weideng.org/posts/nanochat}}
}
- Karpathy, A. (Oct 2025Oct 2025). nanochat. Https://Github.com/Karpathy/Nanochat.
- Team, D. M. (2023). Training Compute-Optimal Large Language Models. Advances in Neural Information Processing Systems (NeurIPS).