Multi-mask Tensorized Self-Attention (MTSA)
- The paper introduces MTSA, a self-attention mechanism that unifies dot-product and additive scoring within a tensorized framework to efficiently model dependencies.
- It leverages multi-head and multi-mask architectures to capture both local and global context while maintaining parallelizable, memory-efficient computations.
- Empirical evaluations demonstrate state-of-the-art performance on NLP tasks with significant improvements in speed and resource usage compared to RNN and CNN models.
Multi-mask Tensorized Self-Attention (MTSA) is an advanced self-attention mechanism designed to model both local and global dependencies in sequence data efficiently. MTSA innovatively combines dot-product (token2token) and additive (source2token) attention within a unified tensorized framework. By distributing these computations across multiple heads, each with its own positional mask, MTSA achieves high representational capacity while retaining parallelizability and efficiency comparable to convolutional architectures. MTSA has demonstrated state-of-the-art or competitive performance across a diverse suite of natural language processing benchmarks, with superior memory- and time-efficiency over both RNNs and standard self-attention models (Shen et al., 2018).
1. Motivation and Background
Self-attention models, particularly the scaled dot-product attention of Vaswani et al. (2017), have eclipsed CNNs and RNNs in many NLP applications due to their ability to directly capture arbitrary long-range token dependencies and provide parallelizable computation. However, classic attention mechanisms traditionally assign a single scalar weight to each token pair, which limits their expressiveness for modeling the fine-grained, feature-wise interactions that can be crucial for tasks such as semantic role labeling and sequence classification. Prior attempts at multi-dimensional attention (e.g., additive or MLP-based attention) can, in principle, output a full vector for each token pair—resulting in a rank-3 alignment tensor, but naïve implementations incur memory and computational costs that scale as for sequence length and feature dimension , becoming prohibitive for long sequences or large models.
MTSA addresses these challenges by (a) representing both token2token (pairwise) and source2token (global) dependencies as a tensorized alignment, (b) distributing the computation across multiple masked heads to control resource usage, and (c) implementing all operations via matrix multiplications, circumventing the need to materialize the full alignment tensor in memory (Shen et al., 2018).
2. Formal Definition of MTSA
Let be the sequence of input token embeddings, with total feature dimension .
2.1 Head-wise Projections: For each head , the following projections are used:
- Query:
- Key:
- Value:
2.2 Token2token Dot-Product Score:\ For each position pair :
This provides an efficient, parallelizable scalar measure of local dependency.
2.3 Source2token (Global) Additive Score:\ For position :
Here, , , , and are learned parameters; denotes the MLP nonlinearity.
2.4 Combined Tensorized Alignment:\ For each (where for features):
is a positional mask matrix for head ; are scaling functions (identity or log-sigmoid).
2.5 Normalization and Output:\ For each query and feature :
The final output vector at position :
2.6 Multi-Head Output:\ Each head yields ; the concatenation is projected:
3. Multi-Head and Multi-Mask Architecture
Each MTSA head operates on a distinct subspace and is assigned a unique positional mask . Common choices include forward masks ( if , elsewhere), backward masks, and local window masks. In empirical setups, half the heads receive forward masks and half receive backward masks, thus capturing left-to-right and right-to-left sequential order in parallel.
By distributing the alignment computation across heads, each of subspace dimension , MTSA efficiently processes both structural and sequential priors without explicit dense tensor storage. Each head computes its token2token and source2token contributions independently, and the final representation is aggregated over all heads.
This architecture enables the simultaneous modeling of multiple types of dependencies and positional priors (such as directionality and localness), each encoded independently in separate heads.
4. Efficiency and Computational Complexity
MTSA achieves a favorable computational profile:
- Token2token per head:
- Source2token per head:
- Normalization and aggregation:
- Total time complexity: (with )
- Total memory: , as the full tensor is never explicitly materialized.
Compared to standard multi-head dot-product attention, which also costs , MTSA delivers similar scaling while preserving the full multi-dimensional alignment structure. Empirical profiling indicates that MTSA matches the speed and memory efficiency of CNNs as increases, while outperforming both CNN and RNN baselines in representational power and downstream accuracy (Shen et al., 2018).
5. Empirical Evaluation and Applications
MTSA has been applied as the main encoder component in a variety of NLP benchmarks, consistently targeting a model size of (e.g., , per head).
Key empirical results:
- SNLI (Natural Language Inference): Test accuracy 86.3%, which is state-of-the-art among sentence-encoding models.
- MultiNLI: Matched accuracy 76.7%, mismatched accuracy 76.4%.
- Semantic Role Labeling (SRL), CoNLL-05: (WSJ test), (Brown).
- Sentence classification (various benchmarks): MTSA achieves new bests on 4 of 5 tasks, e.g., SST-5 51.3%.
- Machine Translation (WMT14 En–De): Replacing Transformer encoder’s multi-head attention with MTSA yields BLEU = 24.09 versus 23.64 for the baseline (statistically significant improvement).
Efficiency findings:
- Inference time on SNLI is 1.6s (MTSA), compared to 9.1s (Bi-LSTM) and 1.4s (CNN).
- GPU memory usage (batch size 64, sequence length 64): 558MB (MTSA), versus 942MB (Bi-LSTM) and 208MB (CNN).
In all tested scenarios, MTSA matches or exceeds the performance of dot-product, additive, and previous multi-dimensional self-attention architectures, with no requirement for RNN or CNN modules (Shen et al., 2018).
6. Design Criteria, Strengths, and Limitations
MTSA demonstrates several design achievements:
- Expressiveness: Jointly models pairwise and feature-wise global dependencies within a unified tensor alignment.
- Parallelizability: Retains parallelism inherent to self-attention; computational cost is independent of sequence order or structure.
- Scalability: Efficiently processes long sequences, as memory and time are controlled by head parallelism and matrix ops.
- Configurable Priors: Head-specific positional masks enable explicit modeling of directionality or local context, and can be extended to alternative priors such as sliding windows or relative distance constraints.
Reported limitations and future directions:
- Mask expressiveness: Current studies have used only simple forward/backward and local window masks. Extension to richer or dynamically learned mask schemes may improve inductive bias.
- Deeper or hierarchical architectures: Stacking MTSA layers, or integrating with block-wise/hierarchical SA (e.g., Bi-BloSA), remains underexplored.
- Language modeling integration: Combining MTSA with large-scale pretrained LLMs (such as BERT or GPT) is a promising direction for leveraging universal representations.
- Head adaptation: Investigating dynamic or learned mask/head allocation for adaptive dependency modeling.
A plausible implication is that further innovation in mask design and architectural stacking could enable MTSA to extend its efficiency and accuracy to even more challenging or structured data domains (Shen et al., 2018).
7. Relationship to Related Approaches
MTSA generalizes and extends the functionality of standard multi-head self-attention and prior multi-dimensional attention methods:
- Vs. dot-product self-attention: Unlike the scalar scoring of the Transformer, MTSA generates vector-valued alignment for each token pair, enriching representational capacity.
- Vs. additive/Multi-dimensional attention (DiSA, Bi-BloSA): MTSA matches or improves on these approaches while drastically reducing memory and compute overhead by distributing tensorization via multi-heads and matrix multiplications.
- Positional encoding: Where Transformers employ additive position embeddings, MTSA encodes sequential and structural order directly via head-specific masks.
- Efficiency: MTSA achieves practical efficiency on par with CNNs for long sequences, while offering the flexibility and long-range capacity of self-attention.
These design choices position MTSA as a compelling framework for highly expressive yet tractable sequence modeling in large-scale NLP settings (Shen et al., 2018).