LSTM-TD3: RL Agent for POMDP Challenges
- LSTM-TD3 is a reinforcement learning method that augments standard TD3 with an LSTM memory module to recover latent state information in POMDPs.
- It integrates actor and critic architectures with dedicated LSTM subnetworks that capture temporal dependencies from action-observation histories.
- Empirical results on PyBulletGym benchmarks show that LSTM-TD3 significantly outperforms standard TD3 and windowing methods in environments with noisy and missing observations.
The Twin Delayed Deep Deterministic Policy Gradient with Long Short-Term Memory (LSTM-TD3) agent is a reinforcement learning algorithm that augments the standard TD3 architecture with explicit memory integration via an LSTM, targeting the resolution of Partially Observable Markov Decision Processes (POMDPs). In POMDPs, the observable agent input at each timestep provides only a partial and potentially noisy view of the true system state. LSTM-TD3 introduces a learned memory subsystem to extract temporal dependencies and reconstruct latent states, thus enabling the agent to perform robustly in real-world scenarios where missing or corrupted sensory input is common (Meng et al., 2021).
1. Network Architecture and Memory Integration
LSTM-TD3 extends the canonical TD3 actor–critic framework by integrating a memory-extraction LSTM subnetwork into both the actor and each critic, which operate as follows:
- Actor ():
- Receives a length- history , processed by an LSTM to yield memory vector .
- The current observation is embedded via a compact MLP ("current-feature extractor," denoted ) to yield .
- The concatenated vector is passed through a MLP ("perception integration," ), and the output specifies the continuous action .
- Critics (, ):
- Use a parallel LSTM structure (distinct weights from the actor) to process , resulting in .
- The pair is projected by an MLP ("current-feature extractor," ) to .
- feeds into a final MLP () producing the Q-value estimate .
Both actor and critics employ two-layer ReLU-activated MLPs analogous in size to standard TD3 (e.g., 256–256 units) and an LSTM cell size of 128.
2. Mathematical Formulation of POMDPs
The environment is formalized as a tuple , with latent state , action (continuous), and partial observation . Transitions follow and . The policy receives the -step history (filled with dummy entries for ), and maximizes the discounted reward expectation:
3. Forward Pass and Actor-Critic Computation
Let for the critic LSTM, for the actor. The LSTM processes input recursively via standard gates:
The memory vector is . The actor outputs ; the critics yield .
4. Optimization Procedures and Update Mechanisms
- Critic Loss: Each minimizes the MSE to the double-delayed target,
where and , .
- Actor Loss: The policy update maximizes , i.e.
- Delayed Policy Update and Target Networks: As in TD3, the policy (and target networks) are updated every steps (commonly ). Target networks undergo soft updates:
with .
5. Training Algorithm Pseudocode
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
Initialize θ^{Q₁},θ^{Q₂},θ^μ randomly
θ^{Q₁⁻}←θ^{Q₁}, θ^{Q₂⁻}←θ^{Q₂}, θ^{μ⁻}←θ^μ
Replay buffer D ← ∅
h⁰ ← zeros, o₁←env.reset()
for t=1…T do
aₜ ← μ_{θ^μ}(oₜ, hₜˡ) + ϵ, ϵ∼N(0,σ)
observe rₜ,oₜ₊₁,dₜ ← env.step(aₜ)
store (oₜ,aₜ,rₜ,oₜ₊₁,dₜ) in D
if dₜ then
hₜ₊₁ˡ←zeros, oₜ₊₁←env.reset()
else
hₜ₊₁ˡ ← (hₜˡ minus oldest (o,a)) ∪ (oₜ,aₜ)
end
if |D|>batch_size then
sample N tuples and their histories {hᶦ, oᶦ, aᶦ, rᶦ, o'ᶦ, dᶦ}ᵢ₌₁ⁿ
for j in {1,2}:
Ŷᶦ ← rᶦ + γ(1−dᶦ)·min_k Q_k⁻(o'ᶦ, μ⁻(o'ᶦ,hᶦ_next), hᶦ_next)
Lⱼ ← (1/N)∑ᶦ [Qⱼ(oᶦ,aᶦ,hᶦ) − Ŷᶦ]²
θ^{Qⱼ} ← Adam(∇_{θ^{Qⱼ} Lⱼ)
end
if t mod d_μ ==0 then
L^μ ← −(1/N)∑ᶦ Q₁(oᶦ, μ(oᶦ,hᶦ),hᶦ)
θ^μ ← Adam(∇_{θ^μ} L^μ)
for j in {1,2}:
θ^{Qⱼ⁻}←τθ^{Qⱼ}+(1−τ)θ^{Qⱼ⁻}
end
θ^{μ⁻}←τθ^μ+(1−τ)θ^{μ⁻}
end
end
end |
6. Hyperparameters and Memory Ablation
Principal hyperparameters include:
- History length (additionally tested)
- Replay buffer size:
- Batch size
- Discount factor
- Policy noise , noise clip
- Policy delay
- Target network update
- Actor/critic learning rates: (Adam)
- MLP architecture:
- LSTM hidden size: $128$
Ablation studies reveal:
- Removing the double-critic structure destabilizes learning (yielding LSTM-DDPG/RDPG).
- Omitting target policy smoothing produces a milder performance drop.
- Excluding the current-feature extractor degrades MDP performance severely.
- Removing past-action inputs from the history significantly impairs POMDP handling; both actor and critic require both and in their respective histories.
7. Empirical Evaluation and Baseline Comparisons
LSTM-TD3 was evaluated on five PyBulletGym benchmarks: HalfCheetah, Ant, Walker2D, Hopper, and InvertedDoublePendulum. Scenarios included:
- MDP: Full observations.
- POMDP-RV: Velocity entries removed.
- POMDP-FLK: Entire observations zeroed at random ().
- POMDP-RN: Additive Gaussian noise ().
- POMDP-RSM: Individual entries zeroed randomly ().
Baselines comprised DDPG, SAC, vanilla TD3, TD3-OW (recent observations concatenated), and TD3-OW+PA (recent actions also concatenated).
For HalfCheetah (after 1M steps, ):
| Version | TD3 | LSTM-TD3(5) |
|---|---|---|
| MDP | 11,200±300 | 10,900±250 |
| POMDP-RV | 9,800±400 | 10,300±320 |
| POMDP-FLK | 1,200±500 | 9,500±400 |
| POMDP-RN | 4,000±800 | 9,800±350 |
| POMDP-RSM | 3,200±700 | 9,200±410 |
On pure MDPs, LSTM-TD3 matches state-of-the-art (TD3/SAC); on POMDPs with missing, noisy, or corrupted observations, LSTM-TD3 outperforms all baselines, sometimes by more than a factor of two. On tasks where underlying latent variables (e.g., velocity) are removed from observations, the memory module supports estimation via the action-observation sequence, recovering most of the performance lost by conventional architectures except possibly in high-frequency environments where the history window is too short for reliable inference. TD3-OW (observation windowing) slightly improves over naive TD3 in some POMDPs but fails catastrophically in high-noise/flickering settings, and TD3-OW+PA is usually inferior to TD3-OW in both MDP and POMDP regimes.
A plausible implication is that the explicit LSTM-based memory extraction enables true temporal inference necessary for POMDPs, a capability unattainable with mere windowing or static memory concatenation. TD3's architectural components—double critic, policy smoothing, delayed updates—remain critical to stability and sample efficiency under partial observability. Proximal Policy Optimization (PPO) was not included; on MuJoCo-style tasks, PPO requires 2–5x more samples to reach similar returns, so under the 1M step constraint it was not competitive (Meng et al., 2021).