nGPT: Normalized Transformer with Representation Learning on the Hypersphere
Abstract: We propose a novel neural network architecture, the normalized Transformer (nGPT) with representation learning on the hypersphere. In nGPT, all vectors forming the embeddings, MLP, attention matrices and hidden states are unit norm normalized. The input stream of tokens travels on the surface of a hypersphere, with each layer contributing a displacement towards the target output predictions. These displacements are defined by the MLP and attention blocks, whose vector components also reside on the same hypersphere. Experiments show that nGPT learns much faster, reducing the number of training steps required to achieve the same accuracy by a factor of 4 to 20, depending on the sequence length.
Summary
- The paper introduces a Transformer that enforces L2 normalization on token embeddings and weights, constraining representations to a unit hypersphere for faster convergence.
- The approach removes traditional normalization layers and weight decay, replacing standard residual updates with a hypersphere update rule featuring learnable per-dimension scaling.
- Experimental results demonstrate 4x-20x faster convergence, improved numerical stability, and effective length extrapolation, with ablations confirming the impact of learnable scaling factors.
The paper "nGPT: Normalized Transformer with Representation Learning on the Hypersphere" (2410.01131) introduces a modified Transformer architecture where key vector representations are constrained to reside on the surface of a unit hypersphere. This constraint is enforced through explicit L2 normalization applied to token embeddings, hidden states, and the vectors constituting the weight matrices of the attention and MLP blocks (along the embedding dimension). The core hypothesis is that operating on this manifold simplifies the optimization landscape and leads to faster convergence.
Normalization Strategy and Architectural Changes
The central modification in nGPT is the pervasive application of L2 normalization. Unlike standard Transformers which rely on LayerNorm or RMSNorm applied before attention or MLP blocks, nGPT removes these entirely. Instead, normalization is applied after computations and updates:
- Matrix Normalization: After each optimizer step, all weight matrices (Winput, Woutput, Wq,Wk,Wv,Wo, Wu,Wν,WoMLP) are normalized such that the L2 norm of vectors along the embedding dimension ($d_{\text{model}$) is 1. For a matrix W∈Rd1×d2, where d2=dmodel, this means normalizing each of the d1 row vectors (if embedding is the second dimension) or d1 column vectors (if embedding is the first dimension, depending on convention) to have unit norm. This normalization occurs outside the forward/backward pass, directly modifying the weights after the gradient update.
- Hidden State Normalization: The hidden state vector h∈Rdmodel is explicitly normalized to unit norm after the attention and MLP block updates within the forward pass. This ensures the "information carrier" always remains on the hypersphere.
- Removal of Standard Normalization: LayerNorm and RMSNorm layers are completely removed from the architecture.
- Removal of Weight Decay: Because matrix rows/columns are constantly renormalized to unit norm, their magnitude is controlled. Consequently, weight decay (L2 regularization) is deemed unnecessary and removed (setting
weight_decay=0in the optimizer, making Adam equivalent to AdamW). - Removal of Learning Rate Warmup: The paper reports successful training without learning rate warmup schedules.
Modified Update Rule and Optimization Perspective
The standard residual connection h←h+SubLayer(h) is replaced with a modified update rule that incorporates learnable per-dimension scaling and explicit normalization. This frames the layer update as a step on the hypersphere.
Let h be the input hidden state to a block (Attention or MLP), also assumed to be normalized (∣∣h∣∣2=1). Let hsuggestion=Norm(SubLayer(h)) be the normalized output of the sub-layer's core computation (e.g., ATTN(h) or MLP(h)). The update rule is:
h←Norm(h+α⊙(hsuggestion−h))
Here:
- Norm(x)=x/∣∣x∣∣2 denotes L2 normalization.
- α is a learnable vector of size $d_{\text{model}$ (distinct for Attention, αA, and MLP, $\bm{\alpha}_{\text{M}$). These are termed "eigen learning rates".
- ⊙ denotes element-wise multiplication.
- (hsuggestion−h) represents the update direction suggested by the sub-layer.
- The α vector scales the contribution of this update direction along each dimension.
- The final Norm(⋅) projects the result of the scaled update back onto the unit hypersphere, acting as a retraction step in manifold optimization.
This formulation can be interpreted as a variable-metric optimization step on the hypersphere, where α represents the diagonal elements of a metric tensor that adapts the step size along different dimensions. The paper suggests this constrained optimization on the manifold contributes to the observed faster convergence.
Scaling Factors for Degrees of Freedom
Since normalization forces all vectors to have unit magnitude, potentially losing important scaling information, nGPT introduces several learnable scaling factors to restore these degrees of freedom:
- Logit Scaling (sz): A learnable scalar sz scales the final output logits before the softmax function: logits=sz⋅(hfinalWoutputT). This controls the sharpness or confidence of the final probability distribution.
- Query-Key Scaling (sqk): In the attention mechanism, the dot product becomes softmax(dk(hWqT)(hWkT)T). With normalized Wq,Wk, the magnitude of projected Q and K vectors is bounded. nGPT optionally normalizes Q and K vectors themselves after projection and introduces a learnable scalar sqk. The attention formula potentially changes (details vary slightly in the paper/appendix) but often involves scaling Q and K, e.g., sqk⋅Q and sqk⋅K, and modifying the softmax denominator, sometimes using dk instead of 1/dk. The exact implementation may involve
scale_qk * Q @ K.T / sqrt(dk)or other variations depending on whether Q/K are normalized post-projection. An important detail is scaling the softmax input up by dk before the softmax, effectively reversing the standard scaling, justified by the bounded cosine similarities from normalized vectors. - MLP Scaling (su,sν): Within the SwiGLU MLP variant (Wu,Wν projections), learnable scalars su,sν scale the outputs before the SiLU activation and element-wise multiplication: MLP(h)=((su⋅hWuT⋅σ(sν⋅hWνT))WoMLPT. A fixed scaling factor of dmodel is also applied to the input of the SiLU activation (σ) to ensure its argument has sufficient variance to operate in its non-linear regime, as hWνT involves normalized vectors.
These scaling factors are learned during training alongside other parameters. Ablation studies show that while beneficial, simplifying or fixing some of these scales (e.g., fixing sqk, su, sν to 1, or using a single scalar sz) results in only minor performance degradation.
Implementation Considerations and Computational Cost
Implementing nGPT requires modifying standard Transformer code:
- Add normalization functions for vectors/matrices. Matrix normalization typically happens after
optimizer.step().1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
def normalize_matrix_rows(W): # W shape: [dim_out, dim_in] (normalize rows along dim_in) norm = torch.linalg.vector_norm(W, dim=1, keepdim=True) W.data /= norm return W # After optimizer step: # with torch.no_grad(): # model.transformer.wte.weight = normalize_matrix_rows(model.transformer.wte.weight) # model.lm_head.weight = normalize_matrix_rows(model.lm_head.weight) # for block in model.transformer.h: # block.attn.c_attn.weight = normalize_matrix_rows(block.attn.c_attn.weight) # Assuming fused QKV # block.attn.c_proj.weight = normalize_matrix_rows(block.attn.c_proj.weight) # block.mlp.c_fc1.weight = normalize_matrix_rows(block.mlp.c_fc1.weight) # Assuming fused Up/Gate in SwiGLU # block.mlp.c_proj.weight = normalize_matrix_rows(block.mlp.c_proj.weight)
- Replace residual connections with the hypersphere update rule, incorporating learnable α vectors.
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Inside a Transformer block's forward pass # h is the input state (assumed normalized) # self.alpha_A is the learnable alpha vector for attention attn_output = self.attn(h) # Core attention computation h_suggestion_A = F.normalize(attn_output, dim=-1) h = F.normalize(h + self.alpha_A * (h_suggestion_A - h), dim=-1) # Similarly for MLP using self.alpha_M mlp_output = self.mlp(h) # Core MLP computation h_suggestion_M = F.normalize(mlp_output, dim=-1) h = F.normalize(h + self.alpha_M * (h_suggestion_M - h), dim=-1) return h
- Remove LayerNorm/RMSNorm layers.
- Initialize and learn the scaling factors (sz,sqk,su,sν) and α vectors.
- Configure the optimizer with
weight_decay=0and potentially remove the learning rate warmup schedule.
A significant practical consideration is the computational overhead. The explicit normalization operations (especially matrix normalization after each step and hidden state normalizations within each layer) add computational cost. The paper reports a 60-80% increase in time per training step compared to a baseline GPT, depending on context length. This overhead arises from the normalization calls and memory transfers, which are currently not heavily optimized in standard deep learning frameworks. The authors suggest that fused kernels could mitigate this, and the relative overhead might decrease for very large models where matrix multiplications dominate computation time. However, the drastically reduced number of required training steps is argued to outweigh this per-step cost in terms of total training time and compute.
Experimental Results and Performance
The primary claim of nGPT is significantly accelerated convergence. Experiments on the OpenWebText dataset using 0.5B and 1B parameter models show:
- Faster Convergence: nGPT reaches target validation loss levels using substantially fewer training steps (and tokens processed) compared to a baseline GPT model:
- 4x fewer steps for 1k context length.
- 10x fewer steps for 4k context length.
- 20x fewer steps for 8k context length.
- Downstream Performance: This faster convergence translates to faster achievement of comparable performance on downstream tasks (ARC-E, HellaSwag, WinoGrande, etc.).
- Numerical Stability: nGPT matrices (embeddings, attention, MLP projections) exhibit significantly lower condition numbers compared to the baseline GPT, suggesting better-behaved, less degenerate representations and potentially improved numerical stability during training.
- Length Extrapolation: When evaluated on the PG19 dataset with sequences longer than the training context length (8k), nGPT showed more stable perplexity compared to the baseline GPT, without requiring specific positional encoding modifications like RoPE adjustments typically needed for extrapolation.
- Learned Parameters: The eigen learning rates (αA, $\bm{\alpha}_{\text{M}$) learn modest average values (around 0.2-0.37), indicating controlled step sizes. The scaling factors (sz,sqk,su,sν) also converge to non-trivial values, confirming their role in restoring necessary scale information.
While the wall-clock time per step is higher, the substantial reduction in the number of steps needed makes nGPT potentially much faster overall for achieving a target performance level, especially at longer sequence lengths.
Ablation Studies
Ablations confirmed the utility of the introduced components:
- Removing or simplifying scaling factors led to slight performance drops, indicating their usefulness but also suggesting potential for simplification (e.g., using fixed scales or removing optional Q/K normalization).
- The hypersphere update mechanism and matrix normalizations were crucial for the observed speedups.
Conclusion
nGPT proposes a modification to the Transformer architecture centered around explicit L2 normalization of representations and weight matrices, framing the learning process as optimization on a hypersphere. This approach demonstrably accelerates convergence by a significant factor (4x-20x fewer steps) across different context lengths, albeit with an increased computational cost per step. The improved numerical stability and length extrapolation capabilities are additional potential benefits. The core trade-off lies between the reduced number of training iterations and the increased cost per iteration.
Paper to Video (Beta)
No one has generated a video about this paper yet.
Whiteboard
No one has generated a whiteboard explanation for this paper yet.
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.
Top Community Prompts
Open Problems
We found no open problems mentioned in this paper.
Continue Learning
- How does constraining representations to the hypersphere impact model expressiveness and capacity compared to vanilla Transformers?
- What are the theoretical implications of hypersphere optimization on generalization and overfitting in large language models?
- How might the removal of normalization layers like LayerNorm and RMSNorm affect model robustness to distribution shift or adversarial examples?
- What are the possible drawbacks or limitations of the increased per-step computational cost, especially for massively scaled models or practical deployments?
- Find recent papers about hyperspherical manifold constraints and optimization methods in neural networks.
Authors (4)
Collections
Sign up for free to add this paper to one or more collections.
Tweets
Sign up for free to view the 37 tweets with 3990 likes about this paper.
YouTube
HackerNews
- NGPT: Normalized Transformer with Representation Learning on the Hypersphere (4 points, 0 comments)