Project Overview
This project is a from-scratch, NumPy-only Transformer Decoder with an emphasis on correctness, explainability, and production-minded details like KV caching, deterministic masking, and a training loop you can actually read. The code is modular (each layer lives in its own file), instrumented with validation blocks, and designed so you can lift any part (e.g., attention, FFN, PE) into other projects.
At a glance:
- Character-level
Tokenizer
and configurableDecoder
stack - Sinusoidal
PositionalEncoding
with offset support for KV cache - Autoregressive, masked
MultiHeadAttention
with a clear backward pass - Position-wise
FeedForwardNetwork
andLayerNormalization
- Training loop with
Adam
andCrossEntropyLoss
- A
generate
utility that supports temperature sampling, stop tokens, and KV cache
Problem & Motivation
Pain Point | Effect |
---|---|
Black-box training stacks | Hard to reason about correctness or performance trade-offs |
No KV cache in simple reference impls | Slow, O(T^2) per step generation |
Hidden masking rules | Subtle causality bugs and non-reproducible behavior |
Opaque codebases | Difficult to teach, modify, or benchmark |
I wanted a compact but serious codebase that is both a learning tool and a research playground. The implementation is intentionally verbose where it matters (shapes, caches, gradients) and minimal where it doesn’t (I/O, frameworks). Each module is validated in its own __main__
block and uses the same configuration surface, so you can swap dimensions and re-run sanity checks fast.
System Architecture
High-level flow for training/inference:
- Token IDs →
InputEmbedding
→ addPositionalEncoding
→ N×DecoderBlock
→Linear
projection → logits - Each
DecoderBlock
= masked self-MultiHeadAttention
+ residual +LayerNormalization
, thenFeedForwardNetwork
+ residual +LayerNormalization
. - During generation, a KV cache stores past K/V per layer; we feed only the new token each step and apply the correct PE offset.
Key modules:
src/tokenizer.py
: simple, inspectable character-level tokenizersrc/positional_encoding.py
: sinusoidal PE with offset supportsrc/multi_head_attention.py
: masked MHA, KV cache, explicit backwardsrc/feed_forward_network.py
: two-layer MLP with ReLU and optional dropoutsrc/layer_normalization.py
: stable LN with cached stats and gradientssrc/decoder_block.py
: one decoder block (MHA → Add&Norm → FFN → Add&Norm)src/decoder.py
: stacks blocks, handles PE offset, returns logitssrc/training_loop.py
:CrossEntropyLoss
,Adam
, batching, andtrain
src/generation.py
: temperature sampling, stop tokens, and KV cache loop
Design Choices
- Explicit shapes and caches: every forward path stores what the backward needs.
- Masking and causality are obvious: a standalone
create_causal_mask
and unit checks on the upper triangle. - KV cache correctness: PE offsets are computed from the current token index; SDPA masking is disabled at 1-token steps because causality is enforced by the step ordering.
- Framework independence: pure NumPy for clarity and portability.
- Config-driven: a single
config.py
defines dimensions and toggles likeuse_kv_cache
.
Example: centralized configuration surface with safe overrides and a KV cache flag:
"""src/config.py"""
def get_config(overrides: dict = None) -> dict:
# ...
config = {
"vocab_size": vocab_size,
"d_model": d_model,
"n_layers": n_layers,
"n_heads": n_heads,
"d_ff": d_ff,
"dropout_rate": dropout_rate,
"max_seq_len": max_seq_len,
"use_kv_cache": use_kv_cache,
"epsilon": epsilon,
}
# recalculates d_head if overridden
# ...
return config
Technical Deep Dive
Tokenization
Character-level, deterministic, and trivial to inspect.
"""src/tokenizer.py"""
class Tokenizer:
def encode(self, text: str) -> list[int]:
return [self.char_to_idx[char] for char in text if char in self.char_to_idx]
def decode(self, tokens: list[int]) -> str:
return "".join([self.idx_to_char[idx] for idx in tokens if idx in self.idx_to_char])
Positional Encoding with offsets (KV cache aware)
We precompute PE up to max_seq_len
and add the correct slice with an offset during generation.
"""src/positional_encoding.py"""
def forward(self, x: np.ndarray, offset: int = 0) -> np.ndarray:
input_seq_len = x.shape[-2]
effective_len = offset + input_seq_len
if effective_len > self.max_seq_len:
raise ValueError(
f"Effective sequence length (offset {offset} + input_seq_len {input_seq_len} = {effective_len}) "
f"exceeds max_seq_len {self.max_seq_len}."
)
positional_encodings_slice = self.pe[offset : effective_len, :]
return x + positional_encodings_slice
This is paired with an explicit offset in the decoder when a cache is active:
"""src/decoder.py"""
pe_offset = current_token_idx if kv_cache_list is not None and seq_len == 1 else 0
x = self.positional_encoding.forward(x, offset=pe_offset)
Multi-Head Attention and causal masking
Autoregressive causality is handled by a clear mask. For single-token cached steps, we skip masking inside SDPA (the step ordering enforces causality).
"""src/multi_head_attention.py"""
def create_causal_mask(seq_len: int) -> np.ndarray:
mask = np.triu(np.ones((1, 1, seq_len, seq_len)), k=1).astype(bool)
return mask
KV cache concatenates past K/V with current K/V and updates in-place:
"""src/multi_head_attention.py"""
if kv_cache is not None:
if 'k' in kv_cache and kv_cache['k'] is not None:
K_combined = np.concatenate((past_K, K_curr_split), axis=2)
V_combined = np.concatenate((past_V, V_curr_split), axis=2)
else:
K_combined = K_curr_split
V_combined = V_curr_split
kv_cache['k'] = K_combined
kv_cache['v'] = V_combined
sdpa_mask = None
Feed-Forward Network
He-initialized two-layer MLP with optional dropout.
"""src/feed_forward_network.py"""
self.linear1_output_cache = np.dot(x, self.W1) + self.b1
activated_output = self._relu(self.linear1_output_cache)
if self.dropout_rate > 0.0 and self.training_mode:
self.dropout_mask_cache = (np.random.rand(*activated_output.shape) > self.dropout_rate) / (1.0 - self.dropout_rate)
dropped_output = activated_output * self.dropout_mask_cache
else:
dropped_output = activated_output
output = np.dot(dropped_output, self.W2) + self.b2
Layer Normalization
Stable, per-token normalization with cached stats and a clear backward.
"""src/layer_normalization.py"""
self.mean_cache = np.mean(x, axis=-1, keepdims=True)
self.variance_cache = np.var(x, axis=-1, keepdims=True)
self.normalized_input_cache = (x - self.mean_cache) / np.sqrt(self.variance_cache + self.epsilon)
output = self.gamma * self.normalized_input_cache + self.beta
Loss and Optimizer
Cross-entropy and Adam are implemented directly for transparency.
"""src/training_loop.py"""
# CrossEntropy backward (flat)
grad_logits = self.probs.copy()
grad_logits[np.arange(len(self.target)), self.target] -= 1
grad_logits /= len(self.target)
"""src/training_loop.py"""
# Adam update
self.m[name] = self.beta1 * self.m[name] + (1 - self.beta1) * grad_value
self.v[name] = self.beta2 * self.v[name] + (1 - self.beta2) * (grad_value ** 2)
m_hat = self.m[name] / (1 - self.beta1 ** self.t)
v_hat = self.v[name] / (1 - self.beta2 ** self.t)
param_to_update -= self.learning_rate * m_hat / (np.sqrt(v_hat) + self.epsilon)
Generation loop with KV cache and temperature
The generator feeds one token at a time when using the KV cache and applies temperature-scaled sampling (or greedy when temperature ≤ 0).
"""src/generation.py"""
if kv_cache_list is not None:
input_tokens_step = np.array([[current_sequence_tokens[-1]]], dtype=np.int32)
logits = model.forward(input_tokens_step, mask=None, kv_cache_list=kv_cache_list, current_token_idx=current_input_pos)
else:
input_tokens_full = np.array([current_sequence_tokens], dtype=np.int32)
causal_mask_full = dec.Decoder.create_causal_mask(seq_len_full)
logits = model.forward(input_tokens_full, mask=causal_mask_full, current_token_idx=0)
next_token_logits = logits[0, -1, :]
if temperature <= 0.0:
next_token_id = np.argmax(next_token_logits)
else:
scaled_logits = next_token_logits / temperature
probs = np.exp(scaled_logits - np.max(scaled_logits))
probs = probs / np.sum(probs)
next_token_id = np.random.choice(len(probs), p=probs)
Training Pipeline
The training loop is intentionally straightforward: create causal masks, forward pass, compute cross-entropy, backprop through all submodules, and apply Adam. Batching uses a sliding window over the tokenized corpus.
"""src/training_loop.py"""
def train(model: dec.Decoder, tokenizer_instance: tkn.Tokenizer, corpus: str,
epochs: int, batch_size: int, learning_rate: float,
seq_len: int, print_every: int = 100):
# ... prepare_batch → forward → loss → backward → optimizer.update()
Practical defaults live in src/config.py
, and each module has a __main__
block with shape and sanity checks to help you iterate quickly.
Inference & Performance
The generate
utility supports:
- KV cache enabled/disabled
- Temperature sampling and greedy decoding
- Optional stop token
- Optional return of attention weights per step
The repository also includes a simple timing harness that compares cached vs non-cached generation using identical weights.
"""src/generation.py"""
start_time_nc = time.time()
output_no_cache = generate(model_perf_non_cached, tokenizer_instance, perf_prompt, perf_gen_len, temperature=0.0, use_kv_cache=False)
time_no_cache = time.time() - start_time_nc
start_time_c = time.time()
output_cache = generate(model_perf_cached, tokenizer_instance, perf_prompt, perf_gen_len, temperature=0.0, use_kv_cache=True)
time_cache = time.time() - start_time_c
Expect the cached path to reduce per-token compute from O(T^2) to amortized O(T) across steps in the generation loop when sequence lengths grow.
Why These Decisions?
- NumPy first: you learn more when you see every matmul, transpose, and broadcast.
- Separate, tested modules: makes debugging faster and reuse easier.
- KV cache and PE offsets: necessary to make small models “feel” fast enough in notebooks and demos.
- Clear masking: debugging causality errors costs more time than writing one good helper.
What’s Next
- Swappable tokenization (BPE/WordPiece) while retaining the same decoder core
- Rotary or ALiBi positional encodings alongside sinusoidal PE
- Mixed precision and simple JIT backends for speedups
- Checkpointing and dataset streaming
If you read the code and can see exactly how it works, that’s the point. Fork it, break it, and make it yours.
Last updated on August 24, 2025 at 12:16 PM EST. See Changelog