📐 Muon Optimizer: The Power of Collective Momentum
Let's continue our optimizer discussion.
The Great Model Inflation
Adam was introduced in 2014. For the first few years, memory was not a problem. Until transformer.
| Year | Model | Parameters | Adam Memory |
|---|---|---|---|
| 2014 | VGG-19 | 144M | ~1.7 GB |
| 2017 | Transformer | 213M | ~2.5 GB |
| 2018 | BERT-Large | 340M | ~4 GB |
| 2019 | GPT-2 | 1.5B | ~18 GB |
| 2020 | GPT-3 | 175B | ~2 TB |
| 2023 | Llama 2 70B | 70B | ~840 GB |
| 2024 | Llama 3 405B | 405B | ~4.9 TB |
Between 2018 and 2024, model sizes exploded 1000×. The 3-copy memory requirement of Adam (parameters + 2 optimizer states) suddenly became the dominant cost.
The burning question: Can we match Adam's performance while using less memory?
Muon: The Basic Idea
Here's Muon's core insight:
Instead of tracking individual parameter statistics (first moment + second moment), track the collective momentum of the entire parameter matrix, then deep dive.
Let's break this down:
Adam's approach:
- Track first moment m: "Which direction is each parameter moving?"
- Track second moment v: "How volatile is each parameter?"
- Treats each parameter independently
- Memory: 3 copies (θ + m + v)
Muon's approach:
- Track first moment m: "Which direction is the entire matrix moving?"
- Orthogonalize m: Leverage the collective behavior of all parameters in the matrix
- Treats parameters as a coordinated group
- Memory: 2 copies (θ + m)
The key shift: From individual accounting to group dynamics.
In neural networks especially transformers, weight matrices don't act as isolated numbers—they form linear transformations. Parameters work together. Muon exploits this: instead of asking "how should I adjust this one parameter?", it asks "how should I adjust this transformation?". In particular, Muon uses Newton-Schulz, a fast rthogonalization algorithm.
Orthogonalization: A 2D View
To understand what orthogonalization does, let's look at a simple 2×2 matrix and visualize its column vectors in 2D space.
Before Orthogonalization
Consider this matrix:
Matrix M:
[10.0 0.5]
[ 0.5 0.1]
- Vector 1 is long (length ≈ 10) and nearly vertical
- Vector 2 is short (length ≈ 0.5) and nearly horizontal
- They're not perpendicular to each other
After Orthogonalization
Orthogonalization transforms these vectors to make them:
- Perpendicular (orthogonal) to each other
- Unit length (normalized)
Orthogonalized M:
[0.998 -0.050]
[0.050 0.998]
- Both vectors now have length = 1
- They are perpendicular (90° apart)
What Just Happened?
Orthogonalization did two things:
- Made vectors perpendicular: They now point in independent directions
- Normalized lengths: All vectors have equal magnitude (length = 1)
In the momentum context, this means:
- Directions that accumulated large momentum get scaled down
- Directions that had small momentum get scaled up
- All directions become equally weighted in the update
This is how one operation provides both momentum smoothing AND adaptive step sizing—by restructuring the matrix to have balanced, orthogonal directions rather than a few dominant ones.
The Power of Group Momentum
Why do we want to treat parameters collectively? Let's examine the challenges of high-dimensional optimization—particularly saddle points.
The Saddle Point Challenge
Modern research shows that in high-dimensional neural networks, the real obstacle isn't local minima—it's saddle points.
What training looks like in high dimensions:
╲___╱╲___ ← Saddle point (everywhere!)
At a saddle point:
- One direction curves down (escape route)
- Many directions are flat (nearly zero gradient)
- The optimizer must find and follow the descent direction
At saddle points, most gradient components are nearly zero. The challenge is identifying which directions lead to progress versus which are just noise.
Two Approaches, Both Effective
Adam's element-wise approach:
- Tracks variance for each parameter independently
- Adapts step size: parameters with high variance get smaller steps
- Works well: Amplifies small but consistent gradients, helping escape flat regions
Muon's matrix-wise approach:
- Views the momentum matrix wholistically
- Orthogonalizes to balance all directions equally
- Also works well: Prevents any direction from dominating, forces balanced exploration
Both approaches succeed at saddle points, just through different mechanisms. Adam uses statistical adaptation (per-element variance), while Muon uses geometric rebalancing (matrix orthogonalization).
The Group Advantage
What makes Muon's approach distinct is how it leverages matrix structure:
Momentum matrix at saddle point:
[●●●●●●●●] ← Large singular value (one direction dominates)
[●● ] ← Medium singular value
[● ] ← Small singular values (underexplored directions)
The momentum naturally reveals:
- Strong patterns = Directions heavily explored
- Weak patterns = Directions barely tried
Orthogonalization rebalances all directions to equal weight, ensuring the optimizer explores the full parameter space rather than getting stuck repeatedly trying the same dominant directions.
The key insight: Parameters in a matrix form a coordinated transformation. By treating them as a group rather than individuals, Muon exploits structural information that element-wise methods cannot access.
Trading Compute for Memory
Nothing is free. Muon's memory savings come at a computational cost.
The Tradeoff
Memory savings: 33%
- AdamW: 3 copies = ~84 GB for 7B model
- Muon: 2 copies = ~56 GB for 7B model
Computational overhead: ~1%
- Newton-Schulz: ~5 matrix multiplications per weight matrix
- Adds <1% to total training FLOPs
Implementation Details
Muon only applies to 2D weight matrices, i.e. inear layers in transformers. We still need Adam(W):
Muon optimizes:
- Q, K, V projection matrices
- Feed-forward layer matrices
- Other hidden layer weight matrices
AdamW optimizes:
- Embedding layers
- Output/classifier heads
- Biases (1D)
- Layer normalization parameters (1D)
In distributed training:
Pipeline parallelism splits layers across GPUs. Each GPU independently:
- Runs Muon on its Linear layer weights
- Runs AdamW on its embeddings/biases/norms
- Stores only its own parameters and optimizer states
Muon's orthogonalization happens locally on each GPU. Since all stages perform similar matrix operations, Newton-Schulz adds uniform compute overhead across stages—no new bottleneck emerges.
The memory benefit: each GPU needs 33% less memory for its optimizer states, allowing larger batch sizes or fitting bigger models per stage.
Departing Thoughts
Transformers organize parameters into large matrices. This is an inductive bias, which also bears great GPU efficiency. Muon is the first major optimizer to exploit this.
As architectures continue to evolve (state space models, mixture of experts, hybrid designs), expect more innovations that leverage macro structural patterns.
Further Reading
- Muon optimizer by Keller Jordan et al.: https://kellerjordan.github.io/posts/muon/
- Deriving Muon by Jeremy Bernstein: https://jeremybernste.in/writing/deriving-muon
- Muon is Scalable for LLM Training (Moonlight paper): https://arxiv.org/abs/2502.16982
- Old Optimizer, New Norm by Bernstein & Newhouse: https://arxiv.org/abs/2409.20325
- Identifying and attacking the saddle point problem in high-dimensional non-convex optimization by Dauphin et al.: https://arxiv.org/abs/1406.2572