PMI-Masking: Data-Driven MLM Strategy
- PMI-Masking is a data-driven strategy that uses pointwise mutual information to identify statistically significant token spans or collocations for masking in language model pretraining.
- By prioritizing high-PMI spans, the method improves convergence rates and downstream performance metrics compared to traditional random or heuristic masking techniques.
- Despite its effectiveness in deepening semantic encoding, PMI-Masking requires extensive preprocessing and fine-tuning to adapt to various domains and low-resource settings.
PMI-Masking is a principled data-driven strategy for selecting tokens or token spans to mask during masked LLM (MLM) pretraining. Rather than employing uniform random masking or heuristics such as whole-word or random-span masking, PMI-Masking leverages pointwise mutual information (PMI) to identify and jointly mask statistically significant collocations. This approach aims to force MLMs to encode deeper semantic dependencies and world knowledge by focusing optimization on contextually informative and hard-to-predict units, mitigating the inefficiencies associated with traditional random masking. Several variants have been proposed, including the original PMI-Masking for multi-token spans (Levine et al., 2020) and informative token masking strategies such as InforMask (Sadeq et al., 2022).
1. Theoretical Basis: Pointwise Mutual Information
PMI measures the association strength between tokens or spans based on their co-occurrence frequency relative to independent occurrence. For tokens and in a vocabulary, PMI is defined as: where is the empirical joint probability and are the respective marginals. For n-gram spans , the principled extension evaluates the minimum PMI over all contiguous segmentations: This formulation ensures that only spans with genuinely strong internal bonding are selected, avoiding “spurious” high scores when a single sub-span dominates the statistics (Levine et al., 2020). High-PMI units typically correspond to multi-word expressions, named entities, or technical collocations.
2. PMI-Masking Algorithms
a) High-PMI Span Masking
The original PMI-Masking procedure operates as follows:
- Span Candidate Selection: All n-grams of length occurring at least 10 times in a large corpus (e.g., English Wikipedia + BookCorpus) are considered. For each, compute as above.
- Vocabulary Construction: Candidates are ranked by descending and the top are retained as the masking vocabulary . This threshold empirically maximizes the balanced between retrieving true collocations and excluding non-collocations.
- Segmentation and Masking: During each pretraining batch, the text is segmented into the longest non-overlapping units in , with untokenized residues treated as singletons. Units are sampled without replacement to reach a target masking rate (e.g., 15%) (Levine et al., 2020).
b) Informative Token Masking (InforMask)
InforMask generalizes PMI-Masking to the token level for BERT-style MLMs:
- PMI Matrix Computation: The entire vocabulary (top tokens) is embedded in a PMI matrix based on co-occurrence counts from a large corpus using a sliding window (typically length 11).
- Sample-and-Score: For a sample of sentences, enumerate candidate sets of tokens per sentence and score each candidate by summing the PMI between each masked and unmasked token:
The candidate with the maximum is retained, and masked tokens’ statistics are recorded.
- Token-Specific Masking Rates: For each token , estimate . Pretraining then samples mask positions by these probabilities, yielding a per-token informative approximation that is dramatically more efficient than per-epoch sample-and-score (Sadeq et al., 2022).
3. Comparison to Prior and Alternative Masking Strategies
PMI-Masking unifies and extends earlier approaches:
| Method | Masking Unit | Selection Principle |
|---|---|---|
| Random-Token | individual token | uniform random |
| Whole-Word | all subword tokens | uniform per word |
| Entity/Phrase | parser-derived spans | external tool / heuristic |
| Random-Span | contiguous random spans | geometric distribution |
| PMI-Masking | high-PMI n-grams | data-driven collocation (PMI) |
Random-token and whole-word masking lack sensitivity to contextually significant units, whereas entity/phrase masking depends on NLP tools and can be error-prone. Random-span masking heuristically increases multi-token coverage but often splits collocations. PMI-Masking automatically selects the most informationally dense and collocated units, covering approximately 50% of corpus tokens with the top 800k n-grams (Levine et al., 2020).
4. Experimental Results and Empirical Findings
PMI-Masking and InforMask demonstrate accelerated learning dynamics and superior end-task performance relative to standard masking methods.
- Convergence: PMI-Masking matches random-span’s end-of-training SQuAD2.0 F1 in roughly 50% of the pretraining steps at all data scales. With 16GB data and 2.4M steps: end F1 for PMI-Masking is 83.6 versus random-span’s 82.8; RACE accuracy is 70.9 versus 68.7 (Levine et al., 2020).
- Downstream Performance: InformBERT (InforMask) achieves a LAMA mean reciprocal rank (MRR) of 0.698 versus 0.553 for BERT-base and 0.592 for RoBERTa-base, despite using only 10% of RoBERTa’s corpus (Sadeq et al., 2022).
- Ablations: High-PMI masking consistently outperforms frequency-based span masking and naive (non-minimum segmentation) PMI extensions. Static token-specific masking rates outperform repeated per-epoch sample-and-score (Sadeq et al., 2022).
- Token Coverage: Stop words show significant reductions in mask rates, whereas named entities and technical terms are masked at rates up to (Sadeq et al., 2022).
5. Theoretical Motivation and Limitations
Principled masking with PMI forces the MLM to leverage nontrivial statistical dependencies, thus reducing reliance on local cues and shallow statistics. Masking entire collocations (e.g., “New York City”) collapses internal redundancy and requires broad contextual reasoning. For token-level schemes, high PMI correlates with informativeness—by concentrating learning capacity on non-trivial spans or tokens, each prediction maximally benefits model generalization to semantic and factual recall benchmarks.
Limitations include substantial preprocessing cost (PMI for all -grams and vocabulary pairs over large corpora) and the static nature of the resulting masking vocabulary, which may perform suboptimally under domain shift or in low-resource settings. The M=800k threshold is empirically tuned and might require per-task adaptation.
6. Extensions and Practical Implications
Potential extensions comprise dynamic or curriculum-based masking (modulating PMI threshold over epochs), context-dependent or local PMI calculation, adaptation to morphologically-rich or low-resource languages, and hybridization with parser-based or supervised entity span strategies. PMI-Masking can be integrated with span-prediction objectives beyond MLM (e.g., replaced token detection) (Levine et al., 2020), and the static masking probabilities of InforMask enable negligible runtime overhead after a one-time preprocessing step (Sadeq et al., 2022).
The empirical gains achieved by PMI-Masking strategies suggest a robust route to improving the factual and knowledge retention capabilities of LLMs, with wide applicability across NLU and QA benchmarks.