Taylor-Lagrange Neural ODE Solver
- TL-NODE is a neural ODE solver that leverages fixed-order Taylor expansion and a Lagrange remainder estimator for efficient integration and training.
- It minimizes computational overhead by reducing the number of adaptive function evaluations while preserving accuracy in both supervised and generative tasks.
- Empirical results show significant speedups in training and evaluation, making TL-NODE suitable for real-time and large-scale applications.
Taylor-Lagrange Neural Ordinary Differential Equations (TL-NODEs) are a class of neural ODE solvers that combine fixed-order Taylor expansions with a data-driven Lagrange remainder estimator to accelerate the integration and training of neural ODEs. TL-NODE addresses the computational bottlenecks of standard NODE training and evaluation, especially the high cost imposed by adaptive-step solvers and repeated neural network evaluations, while maintaining or improving accuracy across supervised and generative modeling tasks (Djeumou et al., 2022).
1. Motivation and Foundational Concepts
A standard neural ordinary differential equation (NODE) parametrizes continuous-time dynamics by a neural network:
where and is a neural network with parameters . Solving for given typically requires numerically integrating , often with adaptive schemes (e.g., Dormand–Prince “Dopri5”). These schemes provide accuracy but do so at the cost of numerous evaluations of per integration interval, leading to high compute and memory cost, especially with gradient-based training methods that require both forward and backward passes through the neural dynamics. This bottleneck becomes acute for large-scale learning or deployment settings where fast inference is critical.
TL-NODE mitigates this by replacing the adaptive solver with a fixed-order Taylor expansion plus an estimated Lagrange remainder, allowing for a constant and low number of network evaluations per step. The Taylor expansion advances the solution deterministically, while a small auxiliary neural network estimates and corrects for the truncation error, preserving the desired accuracy in only a few function and derivative evaluations per step.
2. Mathematical Formulation
The central update of TL-NODE approximates the flow of the ODE as:
where , and is the (unknown) Taylor-Lagrange remainder encapsulating the local truncation error (of order ).
To estimate the remainder, TL-NODE introduces a second neural network , which predicts the appropriate “midpoint” (for some ) where the -th derivative should be evaluated for optimal local error correction. The corrected update then reads:
where and the notation denotes the -th total time derivative computed via Taylor-mode automatic differentiation.
The joint training objective alternates between fitting the main NODE parameters (using standard supervised or likelihood loss, penalizing large higher derivatives for regularization) and fitting the remainder network parameters (by minimizing squared error against high-accuracy solutions produced by a standard ODE solver).
3. Training Procedure and Integration Algorithm
The TL-NODE integration pipeline partitions the interval into steps of size . At each step, the procedure is:
- Compute Taylor coefficients at via Taylor-mode automatic differentiation.
- Use to predict an in-state midpoint .
- Update with the truncated Taylor sum (up to ) plus the -th order term (using ).
All operations are differentiable; standard backpropagation suffices, and there is no need for the adjoint-state method typical in standard neural ODEs. Memory usage is limited to storing relevant states and model parameters.
TL-NODE Forward Pass Pseudocode
1 2 3 4 5 6 7 8 9 |
function TL-NODE-Solve(x₀, t₀, T; θ, φ, p, H)
Δt ← (T−t₀)/H
x ← x₀
for i = 0 to H−1 do
{f^{[1]},…,f^{[p]}} ← TaylorModeAD(f_θ, x)
x̂_mid ← x + g(x, Δt; φ)⊙f^{[1]}
x ← x + Σ_{ℓ=1}^{p−1} Δt^ℓ/ℓ! · f^{[ℓ]}
+ Δt^p/p! · f^{[p]}(x̂_mid)
return x |
4. Computational Complexity
Contrasting with conventional ODE solvers, TL-NODE maintains a fixed and small number of function evaluations ( per step versus for adaptive integrators). The Taylor-mode AD pass per step is or operations, negligible for small and practical network sizes. Memory overhead is similarly reduced; there is no need to store augmented continuous states or solver internals.
In summary, per time step:
| Method | Time Complexity | Memory Overhead |
|---|---|---|
| Standard (Dopri5) | (reverse-mode) | |
| TL-NODE | beyond parameters |
Here is the cost of one -evaluation; is the cost for automatic differentiation.
5. Empirical Results
TL-NODE was benchmarked against standard NODE solvers (Dopri5, RK4), fixed-order Taylor methods without the Lagrange correction, and hypersolver schemes on a range of tasks:
Stiff ODE Integration
- System: , eigenvalues , one step per interval.
- TL-NODE (): error to , evaluation time s.
- Dopri5: error to , but $0.004$ s evaluation (factor of slower).
- RK4/others: error (failed for long horizons).
Learning Stiff Dynamics
- 2D -matrix ODE, s.
- TL-NODE: MSE , training time $31.9$ s.
- Vanilla NODE (Dopri5): MSE matched but $609.8$ s training time.
- RK4 NODE, T-NODE (no correction): similar or slightly faster than TL-NODE but with higher error.
Image Classification (MNIST)
- Model: 2-layer NODE (100→728 hidden units).
- Results:
| Method | Train Acc | Test Acc | Train Time | Eval Time | NFE |
|---|---|---|---|---|---|
| Vanilla NODE | 99.33% | 97.87% | 42.7 min | 16 ms | 110.6 |
| TL-NODE | 99.96% | 98.23% | 2.55 min | 1.04 ms | 62 |
TL-NODE achieves faster training, faster evaluation, and greater accuracy with lower function-evaluation counts.
Density Estimation (MiniBooNE)
- 43-dim continuous normalizing flow task.
- TL-NODE best loss $9.62$ nats (12.3 min, NFE=168); vanilla NODE $9.74$ (59.7 min, NFE=184).
6. Limitations and Prospective Developments
Current TL-NODE instantiations operate at a fixed Taylor expansion order and fixed integration step . For highly stiff or long-horizon systems, higher or smaller may be necessary to maintain accuracy, reducing computational gains.
Envisaged future work includes:
- Adaptive order/step selection: Dynamic adjustment of and/or based on local error criteria.
- Stiffness-aware extensions: Coupling with implicit Taylor methods or symplectic constraints for better performance on energy-conserving or stiff systems.
- Parallel higher-order AD: Optimizing Taylor-mode AD for larger expansion orders using efficient hardware (e.g., GPUs).
7. Significance
TL-NODE replaces expensive, black-box adaptive solvers with a hybrid of fixed-order Taylor expansion and data-driven Lagrange remainder, achieving up to one order-of-magnitude speedup for training and inference without loss of accuracy. This enables deployment of neural ODEs in real-time and large-scale scenarios across diverse domains, including modeling physical systems, supervised learning, and generative modeling (Djeumou et al., 2022).