Understanding GPT-OSS architecture
For me, the relase of gpt-oss has been a much-awaited OpenAI architecture reveal since their detailed GPT-2 architecture. For a long period, OpenAI has reigned supreme in creating the state-of-the-art and staying at the cutting edge. Resulting in intrigue amongst researchers and practitioners alike, who always wanted to peek into what architectures would they be following - how different/similar they would be to the oss models out in the open. That mystery to quite a good extent ends today.
Some key observations:
- MoE GPT-2-style Transformer (36 / 24 layers)
- 128 / 32 experts with top-4 routing ⇒ only 5.1 B / 3.6 B active params
- RMSNorm (in float32 for higher precision, only scale params)
- Grouped Query Attention + RoPE attn (biases still present in QKV matrices)
- 131 K context via YaRN (Sliding window)
- 4-bit MXFP4 packs 120 B on 80 GB & 20 B on 16 GB
- SwiGLU activation
Architecture at a glance — tap to expand
Component | Choice | Notes |
---|---|---|
Backbone | Decoder-only Transformer | GPT-2 style, pre-norm |
Normalization | RMSNorm | Scale-only; math in float32 |
Attention | GQA + RoPE | Grouped Query Attention with rotary embeddings |
MoE | Top-4 routing | 128/32 experts per FFN; 5.1B/3.6B active params |
Context | YaRN ~131K | Sliding-window attention |
Activation | SwiGLU | Modern gated MLP |
Quantization | MXFP4 (4-bit) | High packing for inference footprint |
Biases | QKV biases kept | Explicit in projection layers |
RMSNorm
Let's first take a look at the RMSNorm implementation in gpt-oss. I have recently published a detailed post on Normalization on Transformer-based LLMs (read it here). LayerNorm involved scaling and shifting. However, since shifting was dispensable as described RMS Norm paper[2], RMSNorm which just involved scaling seemed enough.
In pre-norm Transformers (norm before every sub-block), the residual connection already carries the mean information forward. What mainly causes training instability is explosive/vanishing vector length. RMSNorm tackles exactly that by normalizing to unit root-mean-square.
# title="RMSNorm"
class RMSNorm(nn.Module):
def __init__(self, num_features: int, eps: float = 1e-05, device: torch.device | None = None):
super().__init__()
self.num_features = num_features
self.eps = eps
self.scale = torch.nn.Parameter(
torch.ones(num_features, device=device, dtype=torch.float32)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.shape[-1] == self.num_features
t, dtype = x.float(), x.dtype
t = t * torch.rsqrt(torch.mean(t ** 2, dim=-1, keepdim=True) + self.eps)
return (t * self.scale).to(dtype)
Looking at the code, we can see that:
- In forward,
x.float()
makes a temporary, 32-bit copy of the activations so the normalization math is done in higher precision, then the result is cast back to whatever dtype the caller was using. self.scale
is registered asfloat32
; doing the normalization in the same dtype avoids precision mismatches when multiplying by the scale vector.- Also, an epsilon (1e-05, most commonly used value - no surprise here) is added to avoid division by zero which is a common practice in RMSNorm.
- Finally, the
.to(dtype)
at the end restores the original precision so whatever layer follows sees the dtype it expects.
RoPE: Rotary Embeddings
Since, token embeddings do not contain any positional information, we need to add it to the input embeddings. This is where RoPE, which stands for Rotational Positional Encoding, comes in.
Though I have written a detailed post on RoPE, I will briefly summarize it here for the sake of completeness. Let's look at the code in gpt-oss repository for applying the rotary embeddings.
# title="Apply rotary embedding"
def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
x1, x2 = torch.chunk(x, 2, dim=-1)
o1 = x1 * cos - x2 * sin
o2 = x1 * sin + x2 * cos
return torch.cat((o1, o2), dim=-1)
# title="Precompute RoPE cache"
def precompute_rope_cache(seq_len: int, head_dim: int, base: float = 10000.0, device: torch.device | None = None):
assert head_dim % 2 == 0
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
t = torch.arange(seq_len, device=device).float()
freqs = torch.einsum('i,j->ij', t, inv_freq)
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return cos, sin
- L1-L2 ensure that the rotation is applied to the last dimenstion while keeping the head, batch, and sequence length dimensions unchanged.
- Though the splitting the embedding dimension can be done in alternate or contiguous chunks, seems like the authors have chosen to split it in half as it is more efficient.
If we look at the code then compare it with the rotation matrix below, we can see that the code replicates the rotation matrix.
Want a gentle primer on positional encodings? See my post on RoPE.
Grouped Query Attention (GQA)
Classic multi-head attention gives each query head its own key/value heads. Multi-query attention (MQA) shares a single KV pair across all heads. Grouped Query Attention is the middle ground: queries have many heads, but keys/values are shared across small groups of heads. This trades a small reduction in expressivity for large memory/bandwidth wins at long context.
- GQA reduces KV cache size roughly by a factor equal to the number of heads per group.
- With RoPE, positional information remains multiplicative and inexpensive at inference.
You can dive deeper in my focused post: GQA and the contrast with MQA.

# title="Attention with GQA + RoPE (excerpt)"
class Attention(nn.Module):
def __init__(self, d_model: int, num_q_heads: int, num_kv_heads: int, head_dim: int, bias: bool = True, device: torch.device | None = None):
super().__init__()
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
inner_q = num_q_heads * head_dim
inner_kv = num_kv_heads * head_dim
self.wq = nn.Linear(d_model, inner_q, bias=bias, device=device)
self.wk = nn.Linear(d_model, inner_kv, bias=bias, device=device)
self.wv = nn.Linear(d_model, inner_kv, bias=bias, device=device)
self.wo = nn.Linear(inner_q, d_model, bias=bias, device=device)
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
b, s, _ = x.shape
q = self.wq(x).view(b, s, self.num_q_heads, self.head_dim)
k = self.wk(x).view(b, s, self.num_kv_heads, self.head_dim)
v = self.wv(x).view(b, s, self.num_kv_heads, self.head_dim)
# Apply RoPE to q and k
q = _apply_rotary_emb(q, cos[:s], sin[:s])
k = _apply_rotary_emb(k, cos[:s], sin[:s])
# Expand KV heads to match Q heads (GQA)
if self.num_kv_heads != self.num_q_heads:
repeat = self.num_q_heads // self.num_kv_heads
k = k.repeat_interleave(repeat, dim=2)
v = v.repeat_interleave(repeat, dim=2)
# Scaled dot-product attention
scale = 1.0 / math.sqrt(self.head_dim)
att = torch.einsum('bthd,bshd->bhts', q, k) * scale
if attn_mask is not None:
att = att.masked_fill(attn_mask == 0, float('-inf'))
att = torch.softmax(att, dim=-1)
o = torch.einsum('bhts,bshd->bthd', att, v)
o = o.reshape(b, s, self.num_q_heads * self.head_dim)
return self.wo(o)
Mixture of Experts (MoE) with top-4 routing
The feed-forward block is replaced by a sparse mixture: many experts exist, but only the top-k (k=4) are activated per token by a router. This gives the model a very large parameter budget while keeping per-token compute modest.
- 128/32 experts (two released scales) with top-4 routing.
- Only ~5.1B / 3.6B parameters are active per token step despite much larger total parameters.
- Sparse gating often includes load-balancing loss to avoid degenerate expert collapse.
For a foundations refresher on routing and capacity, see my primer on Mixture of Experts.
High-level routing sketch:
token → router logits → softmax → select top-4 experts → combine expert outputs (weighted)
# title="SwiGLU MLP (expert)"
class SwiGLU(nn.Module):
def __init__(self, d_model: int, d_ff: int, device: torch.device | None = None):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False, device=device) # gate
self.w3 = nn.Linear(d_model, d_ff, bias=False, device=device) # up
self.w2 = nn.Linear(d_ff, d_model, bias=False, device=device) # down
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
# title="Top-4 MoE router (excerpt)"
class TopKGate(nn.Module):
def __init__(self, d_model: int, num_experts: int, k: int = 4, device: torch.device | None = None):
super().__init__()
assert k <= num_experts
self.k = k
self.wg = nn.Linear(d_model, num_experts, bias=False, device=device)
def forward(self, x: torch.Tensor):
logits = self.wg(x)
gates = torch.softmax(logits, dim=-1)
topk_scores, topk_indices = torch.topk(gates, self.k, dim=-1)
return topk_scores, topk_indices
class MoE(nn.Module):
def __init__(self, d_model: int, d_ff: int, num_experts: int, k: int = 4, device: torch.device | None = None):
super().__init__()
self.gate = TopKGate(d_model, num_experts, k=k, device=device)
self.experts = nn.ModuleList([SwiGLU(d_model, d_ff, device=device) for _ in range(num_experts)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
scores, indices = self.gate(x) # [B,S,K], [B,S,K]
b, s, k = scores.shape
out = torch.zeros_like(x)
for i in range(k):
expert_idx = indices[..., i]
expert_scores = scores[..., i].unsqueeze(-1)
# Dispatch token-wise to selected expert
expert_outputs = []
for e in range(self.experts.__len__()):
mask = (expert_idx == e)
if not mask.any():
continue
y = self.experts[e](x[mask])
expert_outputs.append((mask, y))
# Combine back
for mask, y in expert_outputs:
out[mask] += expert_scores[mask] * y
return out
YaRN sliding-window attention (~131K tokens)
YaRN extends context via sliding windows, keeping a fixed-size recent window fully attended while compressing or limiting attention to distant tokens. This preserves locality (most useful at inference) and enables very long sequences without quadratic blow-up.
- Effective context: ~131k tokens
- Pragmatic for production: KV cache and bandwidth remain bounded
Intuition: if (W) is the window and (t) the current step, attend densely over ([t-W,, t]) and sparsify/aggregate further back.
MXFP4 (4-bit) quantization
The release highlights a 4-bit mixed-floating packing (MXFP4) that enables high-parameter variants to fit common accelerators:
- 120B fits on 80 GB
- 20B fits on 16 GB
These formats typically use block-wise scales with shared exponents to preserve dynamic range while keeping memory and bandwidth low. Expect small accuracy deltas, but large throughput and cost benefits.
SwiGLU activation
Gated activations such as SwiGLU have become the default for modern LLM MLPs. Compared to GELU/ReLU, they improve expressivity with limited extra cost and pair well with MoE.
Putting it together
- Pre-norm decoder-only Transformer
- GQA + RoPE for scalable attention and efficient KV caches
- Sparse MoE with top-4 routing for large capacity but modest per-token compute
- YaRN for very long context with sliding windows
- Quantization via MXFP4 for practical deployment
If you are familiar with GPT-2 internals, this feels like a well-engineered, modernized path: keep what works, update attention, normalization, and FFNs to current best practices, and add sparse scaling.
Quick Q&A
Why keep QKV biases?
Biases can help with small distribution shifts and calibration. They are a small overhead relative to projection matrices and may stabilize training with very long context and quantization.
How do “active params” stay small?
With MoE, only a subset of experts process each token. If total experts are large but only the top-4 run, the effective compute and memory per step is bounded by those experts plus shared layers.
Does YaRN hurt tasks needing global context?
Some very long-range interactions can degrade. In practice, many tasks rely heavily on recent context; you can increase window size or combine with retrieval for global cues.
References
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin. Attention is All You Need. arXiv:1706.03762
- Biao Zhang, Rico Sennrich. Root Mean Square Layer Normalization. arXiv:1910.07467