Toward generalizable learning of all (linear) first-order methods via memory augmented Transformers
Abstract: We show that memory-augmented Transformers can implement the entire class of linear first-order methods (LFOMs), a class that contains gradient descent (GD) and more advanced methods such as conjugate gradient descent (CGD), momentum methods and all other variants that linearly combine past gradients. Building on prior work that studies how Transformers simulate GD, we provide theoretical and empirical evidence that memory-augmented Transformers can learn more advanced algorithms. We then take a first step toward turning the learned algorithms into actually usable methods by developing a mixture-of-experts (MoE) approach for test-time adaptation to out-of-distribution (OOD) samples. Lastly, we show that LFOMs can themselves be treated as learnable algorithms, whose parameters can be learned from data to attain strong performance.
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.
Top Community Prompts
Collections
Sign up for free to add this paper to one or more collections.