Incorporating Hierarchical Semantics in Sparse Autoencoder Architectures
Abstract: Sparse dictionary learning (and, in particular, sparse autoencoders) attempts to learn a set of human-understandable concepts that can explain variation on an abstract space. A basic limitation of this approach is that it neither exploits nor represents the semantic relationships between the learned concepts. In this paper, we introduce a modified SAE architecture that explicitly models a semantic hierarchy of concepts. Application of this architecture to the internal representations of LLMs shows both that semantic hierarchy can be learned, and that doing so improves both reconstruction and interpretability. Additionally, the architecture leads to significant improvements in computational efficiency.
Summary
- The paper introduces a hierarchical SAE that combines a top-level encoder with expert-specific low-level autoencoders to enhance feature interpretability.
- It employs a Mixture-of-Experts activation with projection matrices to reduce feature splitting and achieve computational efficiency.
- Empirical results demonstrate improved reconstruction performance and reduced feature absorption compared to standard sparse autoencoders.
This paper introduces a Hierarchical Sparse Autoencoder (H-SAE) architecture designed to improve the interpretability and reconstruction performance of Sparse Autoencoders (SAEs) by explicitly modeling the hierarchical structure inherent in semantic concepts. Standard SAEs learn a flat set of features, which can lead to "feature splitting" (a single concept represented by multiple specialized features) and a trade-off between reconstruction accuracy and feature interpretability. The H-SAE aims to mitigate these issues.
The core idea is inspired by findings that LLMs represent categorical concepts with a parent feature (indicating concept activation) and a low-rank subspace containing child features (specific instances of the concept) (Shafayat et al., 2024). The H-SAE architecture mirrors this structure:
- Top-Level SAE: A standard SAE with a relatively small number of features, designed to capture high-level concepts.
- Projection Matrices: For each feature (expert) in the top-level SAE, there are learnable down-projection (Πjdown​) and up-projection (Πjup​) matrices. These map the input to a lower-dimensional subspace associated with the high-level concept and then back to the original space.
- Low-Level SAEs (Experts): Each high-level feature has an associated low-level SAE that operates on the projected low-dimensional subspace. These low-level SAEs learn finer-grained features (sub-latents) specific to the activated high-level concept.
A key aspect is the Mixture-of-Experts (MoE) style activation: a low-level SAE is only activated if its corresponding high-level feature is among the top-k activated features. This respects the conceptual hierarchy (e.g., "corgi" can only be active if "dog" is active) and significantly improves computational efficiency. The low-level SAEs use a TopK1​ operation, meaning only a single sub-latent is chosen per activated expert.
The forward pass for the H-SAE is given by:
H-SAE(x)=j∈TopKIndicesk​∑​(zj​dj​+Πjup​SAE1j​(Πjdown​x))
where x is the input, TopKIndicesk​ are the indices of the top k activated high-level features, zj​ is the activation of the j-th high-level feature, dj​ is the j-th high-level decoder vector (feature), and SAE1j​ is the expert-specific low-level autoencoder for the j-th high-level feature.
The training objective is:
L=Lrecon​+λ1​Lortho​+λ2​Lsparse​
where:
- Lrecon​=∥x−H-SAE(x)∥22​+β∥x−x^high∥22​: The reconstruction loss includes a term for the overall reconstruction and a term (β=0.1) specifically for the reconstruction from only the top-level SAE. This encourages the top-level features to be meaningful on their own. x^high=DTopKk​(LeakyReLUα​(E(x−b))).
- Lortho​=mtop2​−mtop​∥ED−diag(ED)∥F​​: A bi-orthogonality penalty on the top-level encoder (E) and decoder (D) matrices to discourage semantic redundancy and reduce dead features. mtop​ is the number of top-level features.
- Lsparse​: An ℓ1​ penalty on latent activations (both top and low-level) outside the top-k to encourage further specialization.
Algorithm 1 details the forward pass and loss computation:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
def forward_pass(x, E_top, D_top, top_k, E_experts, D_experts, Pi_down, Pi_up, leaky_relu_alpha): # High-level encoding encoded_top_raw = leaky_relu_alpha(E_top @ x) # Assuming x is already (x-b) top_k_activations, top_k_indices = get_top_k(encoded_top_raw, top_k) # TopK operation # High-level reconstruction x_hat_high = D_top @ top_k_activations # Or D_top[:, top_k_indices] @ top_k_activations if activations are sparse # Initialize low-level reconstruction x_hat_low = np.zeros_like(x) # Store low-level activations for sparsity loss z_low_levels = [] for j in top_k_indices: # Project to expert subspace x_sub_j = Pi_down[j] @ x # Low-level encoding (using SAE_1, so k_low = 1) # Note: The paper uses SAE^j_1, implying a TopK_1 operation within the expert SAE. # This simplifies to selecting the max activated feature in the expert. encoded_low_raw_j = leaky_relu_alpha(E_experts[j] @ x_sub_j) # For SAE_1 (TopK_1): # max_low_activation_idx = np.argmax(encoded_low_raw_j) # z_j_low_sparse = np.zeros_like(encoded_low_raw_j) # z_j_low_sparse[max_low_activation_idx] = encoded_low_raw_j[max_low_activation_idx] # For simplicity, if the expert SAE itself is a TopK_1 SAE: z_j_low_sparse = get_top_k(encoded_low_raw_j, 1)[0] # Assuming get_top_k returns sparse activations z_low_levels.append(z_j_low_sparse) # Store for sparsity loss # Reconstruct in subspace x_hat_sub_j = D_experts[j] @ z_j_low_sparse # Project back and accumulate x_hat_low += Pi_up[j] @ x_hat_sub_j # Combined reconstruction x_hat = x_hat_high + x_hat_low return x_hat, top_k_activations, z_low_levels, top_k_indices def compute_loss(x, x_hat, x_hat_high, z_top, z_low_levels, E_top, D_top, beta, lambda1, lambda2): # Reconstruction loss l_recon_total = np.sum((x - x_hat)**2) l_recon_top = np.sum((x - x_hat_high)**2) l_recon = l_recon_total + beta * l_recon_top # Sparsity loss l_sparse = np.sum(np.abs(z_top)) # L1 on all top activations before TopK for z_j_low in z_low_levels: l_sparse += np.sum(np.abs(z_j_low)) # L1 on all low-level expert activations before TopK_1 # Orthogonality loss ED_prod = E_top @ D_top diag_ED = np.diag(np.diag(ED_prod)) m_top = D_top.shape[1] # Number of top-level features l_ortho = np.linalg.norm(ED_prod - diag_ED, 'fro')**2 / (m_top**2 - m_top) total_loss = l_recon + lambda1 * l_ortho + lambda2 * l_sparse return total_loss def get_top_k(activations, k): indices = np.argsort(activations)[-k:] values = activations[indices] sparse_activations = np.zeros_like(activations) sparse_activations[indices] = values return sparse_activations, indices # Or just values and indices depending on subsequent use |
Experimental Setup and Results:
- Data: 1 billion residual stream vectors from layer 20 of Gemma 2-2B, extracted from Wikipedia articles. Vectors are normalized to unit norm.
- Baseline: TopK SAE.
- Reconstruction: H-SAE shows significantly better reconstruction performance (lower 1 - explained variance and lower LLM CrossEntropy loss when reconstructed activations are used) compared to standard SAEs. For example, an H-SAE with 8k top-level features and 64 sub-latents per expert performs comparably to a standard SAE with 32k features, but with 1/4th the compute cost for the top-level.
- Interpretability:
- Qualitative: Visualizations (Figure 1, 2, 3, 7, 8) show H-SAE learns meaningful hierarchical features (e.g., "marriage" high-level, "divorce," "engagement" low-level; "airports" high-level, "US airport," "airport size" low-level).
- Feature Absorption: H-SAE shows less feature absorption (undesirable merging of distinct concepts into one feature or splitting of one concept) on the SAEBench first-letter classification task. The H-SAE architecture had a lower "Mean Absorption Fraction Score" (Figure 5a).
- Cross-Lingual Redundancy: H-SAE activates more similar sets of features for the same token in different languages (English, French, Spanish, German), indicating less redundancy and better composability (Figure 5b). It achieved lower mean set differences.
- Computational Efficiency: Due to the sparse activation of experts, the H-SAE adds negligible computational overhead compared to a standard SAE with the same number of top-level features, while offering a much larger effective dictionary size (mtop​×mlow​).
Implementation Considerations:
- Implemented in JAX and Equinox.
- Trained with a batch size of 32,512 and top-k of 32 for high-level features.
- Subspace dimension (s) was 4 for 16 sub-latents per expert, and 8 otherwise.
- λ1​ (orthogonality) = 0.1, β (top-level recon) = 0.1, λ2​ (L1 sparsity) = 0.001.
- Adam optimizer, learning rate 5⋅10−4 with warmup and cosine decay.
- Ablation studies on token unembeddings (Appendix B) suggest that whitening the input (multiplying by the inverse square root of the covariance matrix) is crucial for learning meaningful features in that context, aligning with the concept of a "causal inner product." The orthogonality and ℓ1​ regularizers were not strictly necessary for interpretability in these ablations but were kept for practical benefits like reducing dead latents.
Limitations and Future Work:
- While improved, results are not perfect; some hard-to-interpret features remain, and reconstruction is not perfect.
- The paper suggests exploring non-Euclidean reconstruction objectives or more sophisticated objectives from causal representation learning.
In summary, the H-SAE architecture provides a practical method to improve SAEs by incorporating semantic hierarchy. This leads to better reconstruction, improved interpretability (less feature splitting/absorption, more composable features), and significant computational efficiency, allowing for effectively larger and more fine-grained dictionaries.
Paper to Video (Beta)
No one has generated a video about this paper yet.
Whiteboard
No one has generated a whiteboard explanation for this paper yet.
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.
Top Community Prompts
Open Problems
We haven't generated a list of open problems mentioned in this paper yet.
Continue Learning
- How does the hierarchical structure in H-SAE address the trade-off between reconstruction accuracy and feature interpretability compared to standard SAEs?
- What are the implications of using a Mixture-of-Experts gating scheme for low-level SAEs in terms of compositionality and computational efficiency?
- In what ways does the bi-orthogonality penalty on the top-level encoder and decoder matrices improve semantic distinctiveness of learned features?
- How does the cross-lingual performance and compositionality of H-SAE differ from standard SAEs, and what might this suggest for multilingual representation learning?
- Find recent papers about hierarchical sparse autoencoders and their applications in neural network interpretability.
Related Papers
- Efficient Dictionary Learning with Switch Sparse Autoencoders (2024)
- Sparse Autoencoders Do Not Find Canonical Units of Analysis (2025)
- Feature Hedging: Correlated Features Break Narrow Sparse Autoencoders (2025)
- Train One Sparse Autoencoder Across Multiple Sparsity Budgets to Preserve Interpretability and Accuracy (2025)
- From Flat to Hierarchical: Extracting Sparse Representations with Matching Pursuit (2025)
- Taming Polysemanticity in LLMs: Provable Feature Recovery via Sparse Autoencoders (2025)
- Dense SAE Latents Are Features, Not Bugs (2025)
- Interpretable Embeddings with Sparse Autoencoders: A Data Analysis Toolkit (2025)
- SynthSAEBench: Evaluating Sparse Autoencoders on Scalable Realistic Synthetic Data (2026)
- Sanity Checks for Sparse Autoencoders: Do SAEs Beat Random Baselines? (2026)
Authors (4)
Collections
Sign up for free to add this paper to one or more collections.
Tweets
Sign up for free to view the 3 tweets with 17 likes about this paper.