Notes on Training NanoChat

A practical walkthrough for building a tiny ChatGPT

ChatGPT is now part of everyday life, and Karpathy’s NanoGPT (Karpathy, Oct 2025) offers a minimal codebase for learning how a small GPT is trained. After studying NanoChat, I gathered the key techniques that helped me. These brief notes are for myself and anyone seeking a simple starting point to build a tiny ChatGPT.

nanogpt: Model, Engine, Dataloader, Tokenizer

gpt.py: Inside the Lightweight GPT Architecture

MQA + RoPE NanoChat employs multi-query attention (MQA) [Link] to significantly reduce KV cache memory usage during inference. It further incorporates rotary positional embeddings (RoPE) to enhance long-context generalization and enable extrapolation to unseen sequence lengths.

@dataclass
class GPTConfig:
    ...
    # n_kv_head == n_head: MHA
    # n_kv_head <  n_head: MQA
    n_head: int = 6 # number of query heads
    n_kv_head: int = 6 # number of key/value heads (MQA)

def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4  # multihead attention
    d = x.shape[3] // 2
    x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
    y1 = x1 * cos + x2 * sin # rotate pairs of dims
    y2 = x1 * (-sin) + x2 * cos
    out = torch.cat([y1, y2], 3) # re-assemble
    out = out.to(x.dtype) # ensure input/output dtypes match
    return out

class CausalSelfAttention(nn.Module):
    def __init__(self, config, layer_idx):
        ...
        self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
    
    def forward(self, x, cos_sin, kv_cache):
        B, T, C = x.size()

        # Project the input to get queries, keys, and values
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)

        # Apply Rotary Embeddings to queries and keys to get relative positional encoding
        cos, sin = cos_sin
        q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) # QK rotary embedding
        q, k = norm(q), norm(k) # QK norm
        # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 
        if kv_cache is not None:
            k, v = kv_cache.insert_kv(self.layer_idx, k, v)
        # set causal=False during inference
        y = F.scaled_dot_product_attention(q, k, v, is_causal=is_train, enable_gqa=enable_gqa) 

AdamW + Muon NanoChat adopts a hybrid optimization strategy, using AdamW for the lightweight parameters in the embedding and output layers, and the SVD-inspired Muon optimizer for smoother and more stable training of the Transformer blocks.

class Block(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        self.attn = CausalSelfAttention(config, layer_idx)
        self.mlp = MLP(config)

    def forward(self, x, cos_sin, kv_cache):
        x = x + self.attn(norm(x), cos_sin, kv_cache)
        x = x + self.mlp(norm(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        ...
        self.transformer = nn.ModuleDict({
            "wte": nn.Embedding(config.vocab_size, config.n_embd),
            "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
        })
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # pay attention, output is vocab_size
    
    def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
        ...
        matrix_params = list(self.transformer.h.parameters())
        embedding_params = list(self.transformer.wte.parameters())
        lm_head_params = list(self.lm_head.parameters())
        # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
        dmodel_lr_scale = (model_dim / 768) ** -0.5
        adam_groups = [
            dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
            dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
        ]
        ...
        adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
        ...
        muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
        optimizers = [adamw_optimizer, muon_optimizer]
        ...
        return optimizers

engine.py: Running the Chat Engine

The Engine class handles autoregressive text generation with efficient KV-cache reuse and multi-sample decoding. It also supports tool-augmented generation — detecting <|python_start|> and <|python_end|> code blocks, safely evaluating expressions via a built-in calculator, and injecting the computed results back into the output stream.

class RowState:
    # Per-row state tracking during generation
    def __init__(self, current_tokens=None):
        ...
        # Why do we need forced_tokens: allow to inject non-model outputs into the token stream 
        # For a prompt: <|python_start|>2 + 3 * (4 - 1)<|python_end|>
        # The engine detects this Python block, evaluates it → 11,
        # and encodes that result as tokens:
        self.forced_tokens = deque() # Queue of tokens to force inject 
        ...
        self.python_expr_tokens = [] # Tokens of the current python expression

class Engine:
    @torch.inference_mode() # disables gradient tracking and reduces memory + compute overhead
    def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
        ...
        # Get the special tokens we need to coordinate the tool use state machine
        get_special = lambda s: self.tokenizer.encode_special(s)
        python_start = get_special("<|python_start|>") # <|.|> ensures just one token
        python_end = get_special("<|python_end|>")
        output_start = get_special("<|output_start|>")
        output_end = get_special("<|output_end|>")
        assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
        bos = self.tokenizer.get_bos_token_id() # if sampled, ends row

        # 1) Run a batch 1 prefill of the prompt tokens
        m = self.model.config
        kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
        kv_cache_prefill = KVCache( 
            batch_size=1,
            seq_len=len(tokens),
            **kv_model_kwargs,
        ) # only need one KV cache since there is only one prompt
        ids = torch.tensor([tokens], dtype=torch.long, device=device)
        logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
        logits = logits[:, -1, :]
        next_ids = sample_next_token(logits, rng, temperature, top_k)  # (B, 1)
        sampled_tokens = next_ids[:, 0].tolist() # tokens if batch > 1 otherwise token

        # 2) Replicate the KV cache for each sample/row
        kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
        kv_cache_decode = KVCache(
            batch_size=num_samples,
            seq_len=kv_length_hint, # use max_tokens instead of len(tokens)+1 to avoid re-initialize everytime
            **kv_model_kwargs,
        )
        kv_cache_decode.prefill(kv_cache_prefill)
        del kv_cache_prefill # no need to keep this memory around

        # 3) Initialize states for each sample
        row_states = [RowState(tokens.copy()) for _ in range(num_samples)] # different memory addresses

        # 4) Main generation loop
        num_generated = 0
        first_iteration = True
        while True:
            # Stop condition: max tokens reached or all rows are completed
            ...

            # Get sampled tokens - either from prefill or from forward pass
            if not first_iteration:
                ...
                # Forward the model and get the next token for each row
                logits = self.model.forward(ids, kv_cache=kv_cache_decode)  # (B, T, vocab_size)
                logits = logits[:, -1, :]  # (B, vocab_size) at last time step
                next_ids = sample_next_token(logits, rng, temperature, top_k)  # (B, 1)
                sampled_tokens = next_ids[:, 0].tolist()

            # Process each row: choose the next token, update state, optional tool use
            token_column = [] # contains the next token id along each row
            for i, state in enumerate(row_states):
                # Select the next token in this row
                is_forced = len(state.forced_tokens) > 0 # tokens waiting to be forced in deque?
                ...
                next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
                token_column.append(next_token)
                # Update the state of this row to include the next token
                state.current_tokens.append(next_token)
                # On <|assistant_end|> or <|bos|>, mark the row as completed
                if next_token == assistant_end or next_token == bos:
                    state.completed = True
                # Handle tool logic
                if next_token == python_start:
                    state.in_python_block = True
                    state.python_expr_tokens = []
                elif next_token == python_end and state.in_python_block:
                    state.in_python_block = False
                    if state.python_expr_tokens:
                        expr = self.tokenizer.decode(state.python_expr_tokens)
                        result = use_calculator(expr) # avoid __import__('os').system('rm -rf /') ...
                        if result is not None:
                            result_tokens = self.tokenizer.encode(str(result))
                            state.forced_tokens.append(output_start)
                            state.forced_tokens.extend(result_tokens)
                            state.forced_tokens.append(output_end)
                    state.python_expr_tokens = []
                elif state.in_python_block:
                    state.python_expr_tokens.append(next_token)
            ...
            ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)

tokenizer.py: Reducing the Tokenizer Space

English has ~170k common words and adding variants, punctuation, slang, numbers, and code tokens leads to millions of unique words, which ends up with a huge embedding matrix. BPE tokenizers reuse subwords (like comput, ation) to cover the space compactly — typically 30k–100k tokens total.

Here’s a quick view of how BPE learns. Take the tiny dataset hello hello and world hello. The tokenizer begins at the character level: h e l l o, etc. It then repeatedly merges the most frequent adjacent characters—first l l → ll (3 times), then e ll → ell, then h ell → hell, and so on. After a few rounds, hello becomes a single token. The same process turns world into a token. In the end, BPE discovers hello and world as its core learned units.

Here is an OpenAI link that shows how tokenizer works.

SPLIT_PATTERN = (
    r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| """
    r"""?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
)
...
@classmethod
def train_from_iterator(cls, text_iterator, vocab_size):
    # train from an iterator of text
    tokenizer = HFTokenizer(BPE(
        byte_fallback=True, # needed!
        unk_token=None,
        fuse_unk=False,
    ))
    tokenizer.normalizer = None
    gpt4_split_regex = Regex(SPLIT_PATTERN) # huggingface demands that you wrap it in Regex!!
    tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
        pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
        pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
    ])
    # Decoder: ByteLevel (it pairs together with the ByteLevel pre-tokenizer)
    tokenizer.decoder = decoders.ByteLevel()
    ...
    trainer = BpeTrainer(
        vocab_size=vocab_size,
        show_progress=True,
        min_frequency=0, # no minimum frequency
        initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
        special_tokens=SPECIAL_TOKENS,
    )
    # Kick off the training
    tokenizer.train_from_iterator(text_iterator, trainer)
    return cls(tokenizer)

dataloader.py: From Parquet Files to Tokenizer Batches

# 1) Parquet files (indexed by pq_idx)
#     pq_idx = 0  →  part-00000.parquet
#     pq_idx = 1  →  part-00001.parquet
#     pq_idx = 2  →  part-00002.parquet   ← example below

# 2) Zoom into part-00002.parquet (pq_idx = 2)
#     ┌── part-00002.parquet ────────────────────────────────┐
#     │ RowGroup 0 (rg_idx = 0) → text: [t0,  ..., t9]       │
#     │ RowGroup 1 (rg_idx = 1) → text: [t10, ..., t19]      │
#     │ RowGroup 2 (rg_idx = 2) → text: [t20, ..., t29]      │
#     │ RowGroup 3 (rg_idx = 3) → text: [t30, ..., t39]      │
#     └──────────────────────────────────────────────────────┘
#     (Each RowGroup is just a list of text lines.)

# 3) Distributed training (world_size = 2)
#     GPU0 (rank = 0) → picks RowGroups: 0, 2, ...
#     GPU1 (rank = 1) → picks RowGroups: 1, 3, ...

#     Example for GPU0:
#         rg_idx = 0
#         batch = [t0, t1, t2, ..., t9]       # 10 text lines


# 4) Split one RowGroup into small doc_batches (for tokenizer)
#     Suppose tokenizer_batch_size = 4.
#     Then batch (10 lines) becomes:
#         batch[0:4]  → doc_batch_0
#         batch[4:8]  → doc_batch_1
#         batch[8:10] → doc_batch_2
#     Each doc_batch is sent to:
#         tokenizer.encode(doc_batch)

def tokenizing_distributed_data_loader_with_state(
    B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None
):
    ...
    ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
    def document_batches():
        parquet_paths = list_parquet_files()
        ...
        resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
        resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
        pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
        while True: # iterate infinitely (multi-epoch)
            while pq_idx < len(parquet_paths): # iterate over all parquet files
                filepath = parquet_paths[pq_idx]
                pf = pq.ParquetFile(filepath)
                # Start from resume point if resuming on same file, otherwise from DDP rank
                # I know this state resumption is a little bit tricky and a little bit hacky... sigh.
                if resume_rg_idx is not None:
                    base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
                    base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
                    rg_idx = base_idx * ddp_world_size + ddp_rank
                    resume_rg_idx = None # set to None as we only want to do this a single time
                else:
                    rg_idx = ddp_rank
                while rg_idx < pf.num_row_groups:
                    rg = pf.read_row_group(rg_idx)
                    batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
                    # the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
                    for i in range(0, len(batch), tokenizer_batch_size):
                        yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
                    rg_idx += ddp_world_size # advance to the next row group (in DDP)
                pq_idx += 1 # advance to the next parquet file

This code streams raw text from Parquet files, tokenizes it into one long continuous token sequence. They take a chunk of tokens, use scratch[:-1] as the input, and use scratch[1:] (the same tokens shifted one step) as the labels.

def tokenizing_distributed_data_loader_with_state():
    ...
    batches = document_batches()
    needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
    ...
    token_buffer = deque() # we stream tokens on the right and pop from the left
    while True:
        # Accumulate enough tokens for one iteration before yielding.
        while len(token_buffer) < needed_tokens:
            doc_batch, (pq_idx, rg_idx) = next(batches)
            token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
            for tokens in token_lists:
                token_buffer.extend(tokens)
        # Move tokens from the deque into the scratch buffer
        tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
        ...
        scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) 
        # Create the inputs/targets as 1D tensors
        inputs_cpu = scratch[:-1]
        targets_cpu = scratch[1:]
        # Reshape to 2D and move to GPU async
        inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
        targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
        state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # for approximately resume training
        yield inputs, targets, state_dict

execution.py: Sandboxed Code Execution

def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[int], result_dict):
    """Execute code in a subprocess with safety guards. Results are written to result_dict."""
    ...

def execute_code(
    code: str,
    timeout: float = 5.0, # 5 seconds default
    maximum_memory_bytes: Optional[int] = 256 * 1024 * 1024, # 256MB default
) -> ExecutionResult:
    """
    Execute Python code in a sandboxed environment.

    Example:
        >>> result = execute_code("print('hello world')")
        >>> result.success
        True
        >>> result.stdout
        'hello world\\n'
    """

    manager = multiprocessing.Manager()
    result_dict = manager.dict()

    p = multiprocessing.Process(
        target=_unsafe_execute,
        args=(code, timeout, maximum_memory_bytes, result_dict)
    )
    p.start()
    p.join(timeout=timeout + 1)

    if p.is_alive():
        p.kill()
        return ExecutionResult(
            success=False,
            stdout="",
            stderr="",
            error="Execution timed out (process killed)",
            timeout=True,
            memory_exceeded=False,
        )
    ...
    return ExecutionResult(
        success=result_dict["success"],
        stdout=result_dict["stdout"],
        stderr=result_dict["stderr"],
        error=result_dict["error"],
        timeout=result_dict["timeout"],
        memory_exceeded=result_dict["memory_exceeded"],
    )

core_eval.py: Task Evaluation Logic

# 1) For each data example (item)
#     Optionally sample few-shot examples (excluding this item).

# 2) Render prompts (depends on task_type)
#     - MC:
#         [Context] + Delimiter + Choice A
#         [Context] + Delimiter + Choice B
#         ...
#     - Schema:
#         Context_1 + Delimiter + Continuation
#         Context_2 + Delimiter + Continuation
#         ...
#     - LM:
#         prompt_without = Context + Delimiter
#         prompt_with    = Context + Delimiter + Continuation

# 3) Tokenize prompts and locate the continuation span
#     - MC:    continuation = tokens after common prefix
#     - Schema: continuation = tokens in common suffix
#     - LM:     continuation = tokens added by prompt_with

#     Produce:
#         tokens, start_idx[i], end_idx[i]

# 4) Batch + pad sequences
#     input_ids = stack_sequences(tokens)
#     If model has max_seq_len → crop from left and shift indices.

# 5) Forward model
#     losses, predictions = forward_model(model, input_ids)
#     losses:      per-token cross entropy
#     predictions: argmax tokens

# 6) Decide correctness
#     - LM:
#         Compare predicted continuation tokens with ground truth.
#     - MC / Schema:
#         For each option:
#             score_i = mean(losses[i, si-1 : ei-1])
#         Pick lowest score; compare with gold.

# 7) Distributed evaluation
#     Each rank handles a stride of examples.
#     dist.all_reduce to combine results.
#     Final accuracy = mean(correct).

loss_eval.py: Byte-Normalized Loss

The loss computes bits-per-byte (BPB) by summing token NLLs and normalizing by the total raw-text bytes represented by valid targets. It excludes special tokens and ignore_index entries to keep the metric tokenizer-independent.

scripts: Training, Mid-train, SFT, RL, and Eval Pipeline

base_train.py: Base Pretrain Script

Scaling Laws: Model parameters v.s. FLOPs (Team, 2023).

Budget Planning

batch_per_gpu = 64 sequences
seq_len = 2048
world_size = 8 GPUs
grad_accum_steps = 1

By Chinchilla scaling laws (Team, 2023), to train a 1B model, we can choose

Model: 1B params
Data:  ~20B tokens # scaling law
Compute: ~10^20 FLOPs # scaling law - figure

world_size = 8 GPUs
batch_per_gpu = 64 sequences
seq_len = 2048
grad_accum_steps = 1

Then tokens per training step:

\[\begin{align} &\text{Training compute per token: } \mathrm{FLOPs/token} \approx 6 \times 10^9 \notag \\ &\text{Tokens per step: } 8 \times 64 \times 2048 \approx 10^6 \notag \\ &\text{Compute per step: FLOPs/token} \times \text{Tokens per step} \approx 6 \times 10^{15} \notag \\ &\text{8 H100 FLOPs per step: } 8 \times 10^{15} \notag \\ &\text{MFU: } \frac{6 \times 10^{15}}{8 \times 10^{15}} \approx 75\%. \notag \end{align}\]

The num_iterations can be set to

\[\begin{align} \text{steps}=\frac{20,000,000,000}{1,048,576​}≈19,100. \notag \end{align}\]

A more comprehensive cost evaluation can be seen in [Link].

The peak performance of NVIDIA GPU is shown in the datasheet of [Link].

FLOPS are not runtime (CS 336, Tatsu)

Class FLOP Runtime
Tensor contraction 99.5% 61.0%
Stat. normalization 0.17% 25.5%
Element-wise 0.03% 13.5%
if num_iterations > 0:
    print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif target_flops > 0:
    # calculate the number of iterations from the target flops
    num_iterations = round(target_flops / (num_flops_per_token * total_batch_size))
    print0(f"Calculated number of iterations from target FLOPs: {num_iterations:,}")
elif target_param_data_ratio > 0:
    # calculate the number of iterations from the target param data ratio
    target_tokens = target_param_data_ratio * num_params
    num_iterations = target_tokens // total_batch_size
    print0(f"Calculated number of iterations from target data:param ratio: {num_iterations:,}")
else:
    raise ValueError("No training horizon specified")
total_tokens = total_batch_size * num_iterations
print0(f"Total number of training tokens: {total_tokens:,}")
print0(f"Tokens : Params ratio: {total_batch_size * num_iterations / num_params:.2f}") # Chinchilla is ~20
print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")

Hyperparameters

Learning rate schedule

1. gently warm up the learning rate
2. keep it flat at its maximum for most of training
3. linearly decay it to a small final value for stable convergence.

Momentum schedule Muon optimizer: 0.85 to stabilize early training

def get_muon_momentum(it):
    frac = min(it / 300, 1)
    momentum = (1 - frac) * 0.85 + frac * 0.95
    return momentum

The main training loop runs the full training process:

Use torch.autocast with bfloat16 to speed up training and reduce memory usage.

CPU loads the next batch while GPU computes the current one:

with autocast_ctx:
    loss = model(x, y)
loss.backward()

x, y, dataloader_state_dict = next(train_loader)

Because mid-training uses a mixture task with variable-length conversations, each GPU may reach the end of its portion of the dataset at different times. Without \(\text{dist.all_reduce(last_step_tensor, MAX)}\), one GPU may stop (because it hit the dataset end) while others continue training — causing the entire distributed job to deadlock. This reduction broadcasts the earliest stop signal so all GPUs exit the loop together.

1. evaluates validation bpb;
2. computes the CORE metric;
3. samples model outputs for eye-ball sanity check
4. save checkpoints.
1. model weights
2. optimizer states
3. data loader state
4. loop state (loss, time, etc.) to fully resume training.

After every step, it synchronize() and logs key stats (loss, throughput, MFU, FLOPs, etc.) and continues until all planned iterations are completed.

mid_train.py: Skill-Oriented LM Refinement (Next-Token Style)

train_dataset = TaskMixture([
    SmolTalk(split="train"), # 460K rows of general conversations
    MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
    GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
    CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
    CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
    SimpleSpelling(size=200000, split="train"), # 200K rows of Simple Spelling (e.g. spell the word 'apple')
    SpellingBee(size=80000, split="train"), # 80K rows of Spelling Bee (e.g. how many 'r' are in 'strawberry'?)
]) # total: 460K + 100K + 8K + 200K + 80K = 848K rows

Techniques used in the mid-training loop:

Use torch.autocast with bfloat16 to speed up training and reduce memory usage.

Flat for first 80% of progress, then linear decay to 0

init_lr_frac = 1.0

CPU loads the next batch while GPU computes the current one:

with autocast_ctx:
    loss = model(x, y)
loss.backward()

# Prefetch the next batch while GPU is still busy
x, y = next(train_loader)

Because mid-training uses a mixture task with variable-length conversations, each GPU may reach the end of its portion of the dataset at different times. Without \(\text{dist.all_reduce(last_step_tensor, MAX)}\), one GPU may stop (because it hit the dataset end) while others continue training — causing the entire distributed job to deadlock. This reduction broadcasts the earliest stop signal so all GPUs exit the loop together.

chat_sft.py: Supervised Fine-Tuning (SFT) for Chat

Pretrain/mid-train data:

Input:  "The capital of France is Pa"
Target: "r"

SFT data

Input:  "<user> What is the capital of France?"
Output: "<assistant> The capital of France is Paris."
Target tokens = assistant’s answer only.

Data mixture is smaller and more “exam/chat” flavored:

train_ds = TaskMixture([
    ARC("ARC-Easy", "train"),
    ARC("ARC-Challenge", "train"),
    GSM8K("main", "train"),
    SmolTalk("train", stop=10_000),
    CustomJSON(identity_conversations),
    SimpleSpelling(size=300),
    SpellingBee(size=300),
])

Techniques used in the SFT loop:

init_lr_frac = 0.02. Then linearly decays to 0 over all SFT steps.

Only assistant tokens contribute to the loss; other tokens are masked to -1:

pad_token_id = tokenizer.encode_special("<|assistant_end|>")

ids, mask = tokenizer.render_conversation(doc)
...
row_targets = ids_tensor[1:]
mask_tensor = torch.tensor(mask[1:], dtype=torch.long)
row_targets[mask_tensor == 0] = -1  # ignore non-assistant positions

ncols = max(len(ids) for ids, mask in batch) - 1
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
targets = torch.full((nrows, ncols), -1, dtype=torch.long)

# fill row-by-row, shorter rows remain padded
inputs[i, :n-1] = ids_tensor[:-1]
targets[i, :n-1] = row_targets
num_tokens += (train_targets >= 0).sum()
if ddp:
    dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM)

chat_rl.py: Reinforcement learning on GSM8K via “GRPO” (REINFORCE)

The core logic implements a lightweight GRPO objective with no trust region and no PPO ratio/clip.

\[\begin{align} \mathrm{\max_{\theta} \;\; \mathbb{E}_{\mathbf{a} \sim \pi_\theta}\Bigg[ \sum_{t=1}^{T} \mathbf{1}_{\text{mask}_t=1} A_t \,\log \pi_\theta(a_t \mid s_t) \Bigg]},\notag \end{align}\]

where \(\mathrm{A_t = R - \mu}\), $\mathrm{a_t}$ is the generated token at step $\mathrm{t}$, $\mathrm{s_t}$ is its preceding prefix, and $\mathrm{A_t}$ denotes the token-level advantage, we still need mask since we only evaluate the generated answers.

def get_input_target_reward_advatnage_batch():
    ...
    for sampling_step in range(num_sampling_steps):
        ...
        generated_token_sequences.extend(generated_token_sequences_batch)
        masks.extend(masks_batch)
    # Calculate the rewards for each sample
    rewards = []
    for sample_tokens in generated_token_sequences:
        generated_tokens = sample_tokens[prefix_length:]
        generated_text = tokenizer.decode(generated_tokens)
        reward = train_task.reward(conversation, generated_text)
        rewards.append(reward)
    # Pad the sequences so that their lengths (in time) match
    padded_generated_token_sequences = ...
    padded_masks = ...
    # Stack up the sequences and masks into PyTorch tensors
    ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
    mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
    # Generate autoregressive inputs and targets to the Transformer
    inputs = ids[:, :-1]
    targets = ids[:, 1:].clone() # clone to avoid in-place modification:
    targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
    rewards = torch.tensor(rewards, dtype=torch.float, device=device)
    # Calculate the advantages by simply subtracting the mean (instead of z-score (x-mu)/sigma)
    mu = rewards.mean()
    advantages = rewards - mu
    # yield inputs/targets as (B, T) of ids and rewards as (B,) of floats
    yield generated_token_sequences, inputs, targets, rewards, advantages

Evaluate a model on GSM8K by computing pass@k style stats:

def run_gsm8k_eval(..., num_samples=1, max_completion_tokens=256, temperature=0.0, top_k=50):
    for idx in range(ddp_rank, max_examples, ddp_world_size):
        conversation = task[idx]
        '''
        U: What is 2+3?
        A: 2+3=5.
        U: What is 2+7?
        '''
        tokens = tokenizer.render_for_completion(conversation)
        '''
        <user> What is 2+3?
        <assistant> 2+3=5.
        <user> What is 2+7?
        <assistant>    ← model will generate here
        '''
        prefix_length = len(tokens)
        ...
        generated_token_sequences, masks = engine.generate_batch(
            tokens,
            num_samples=num_samples,
            max_tokens=max_completion_tokens,
            temperature=temperature,
            top_k=top_k
        )
        # Check each sample for correctness
        outcomes = []
        for sample_tokens in generated_token_sequences:
            generated_tokens = sample_tokens[prefix_length:]
            generated_text = tokenizer.decode(generated_tokens)
            is_correct = task.evaluate(conversation, generated_text)
            outcomes.append({"is_correct": is_correct})
        # A bit bloated because I wanted to do more complex logging at one point.
        record = {"idx": idx, "outcomes": outcomes}
        yield record

The core training loop follows

examples_per_rank = examples_per_step // ddp_world_size # per GPU
batch_iterator = get_batch()
for step in range(num_steps):    
    ...
    # Forward/Backward on rollouts over multiple examples in the dataset
    for example_step in range(examples_per_rank):
        sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
        model.train() # ensure the model is in train mode
        ...
        for pass_idx in range(num_passes):
            # Pluck out the batch for this pass
            b0, b1 = pass_idx * device_batch_size, (pass_idx + 1) * device_batch_size
            inputs = inputs_all[b0:b1]
            targets = targets_all[b0:b1]
            rewards = rewards_all[b0:b1]
            advantages = advantages_all[b0:b1]
            # Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
            with autocast_ctx:
                logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
            # Calculate the PG objective. ignore_index=-1 ensures that invalid tokens have loss 0.
            pg_obj = (logp * advantages.unsqueeze(-1)).sum()
            # normalize by the number of valid tokens, number of passes, and examples_per_rank
            num_valid = (targets >= 0).sum().clamp(min=1)
            pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
            loss = -pg_obj
            loss.backward()
    ...
    for opt in optimizers: # then step the optimizers
        opt.step()

Generate several completions per input using \(\text{engine.generate_batch(num_samples=device_batch_size, top_k=top_k,..)}\)

seed = hash((step, example_idx, sampling_step)) & 0x7FFFFFFF # positive half of int32
with autocast_ctx:
    ...

Prompt tokens and padding are masked out.

targets[mask_ids[:, 1:] == 0] = -1
dist.all_reduce(mean_reward_tensor, AVG)

Comparison: Pretrain vs Mid-Train vs SFT vs RL

Category Pretrain Mid-Train SFT RL (GRPO)
Script base_train.py mid_train.py chat_sft.py chat_rl.py
Data Massive corpus SmolTalk, GSM8K, ... CustomJSON, GSM8K, ... Prompts + rollouts
Rows 20B tokens for 1B model 848K 23K
Target Next tokens (MT) MT Masked MT Generated answer tokens
LR warm up + flat (most) + linear decay to small Flat (most) + linear decay to 0 Start from small Start from mid-small
Masking Mask non-assistant Mask prompt & padding
Sampling
Rewards
Sequence Fixed Fixed Variable Variable rollouts
Areas General Knowledge Math, QA, skills Instruction following, chat Correctness, logic

Future Works

NanoGPT does not include the following components:




Citation

@misc{deng2025nanochat,
  title   ={{Practical Notes on Training NanoChat}},
  author  ={Wei Deng},
  journal ={waynedw.github.io},
  year    ={2025},
  howpublished = {\url{https://weideng.org/posts/nanochat}}
}
  1. Karpathy, A. (Oct 2025Oct 2025). nanochat. Https://Github.com/Karpathy/Nanochat.
  2. Team, D. M. (2023). Training Compute-Optimal Large Language Models. Advances in Neural Information Processing Systems (NeurIPS).