Masked Contrastive Predictive Coding Overview
- Masked CPC is a self-supervised framework that predicts masked components in high-dimensional data using contrastive objectives.
- It extends classic CPC with flexible temporal and spatial masking schemes, integrating autoregressive modeling and negative sampling for robust representation learning.
- The approach enhances performance in downstream tasks such as ASR and digital pathology by unifying elements of contrastive, autoencoding, and autoregressive paradigms.
Masked Contrastive Predictive Coding (CPC) is a self-supervised representation learning framework that employs a contrastive objective to predict masked components of high-dimensional sequential or spatial data given their context. This approach extends classic CPC, which leverages autoregressive modeling and negative sampling to maximize mutual information between context and future or hidden representations. Masked CPC incorporates explicit masking schemes in the input or latent space, enabling flexible pretext tasks such as in-filling (spatial) or masked token prediction (temporal), and can be integrated with supervised objectives for efficient transfer to downstream tasks.
1. Foundations: Contrastive Predictive Coding
Classic Contrastive Predictive Coding, as introduced by van den Oord et al. (Oord et al., 2018), consists of two primary modules: a non-linear encoder mapping input to latent representations and an autoregressive model producing a context vector by ingesting a sequence . The core task is to learn representations by predicting the future (or masked) latent vectors from this context, formulating the problem in terms of mutual information.
The framework employs the InfoNCE loss: where is a critic function representing an unnormalized density ratio. Positive samples correspond to the actual future (or masked) latents, and negatives are sampled from a proposal distribution, often other elements in the minibatch or sequence.
2. Masking Schemes and Masked CPC Objectives
Masked CPC generalizes the classic autoregressive context by introducing explicit masking in the latent space, modifying the context construction and prediction tasks.
Temporal Masking (ASR Context)
In "Joint Masked CPC and CTC Training for ASR" (Talnikar et al., 2020), given an encoded sequence , masking is implemented as follows:
- A set of time indices is sampled.
- For each , is replaced by a learned mask vector to create .
- is processed by a context network (Transformer-based), resulting in context vectors .
- For each , is used to "predict" the original (unmasked) by contrasting against negatives (other unmasked frame features).
The loss is: where is a set of negative indices, typically other non-masked frames in the same sample.
Spatial Masking (Vision/Pathology Context)
Masked CPC can also employ spatial masks. In "Unsupervised Representation Learning from Pathology Images with Multi-directional CPC" (Carse et al., 2021):
- Input images are split into overlapping grids .
- Binary masks specify context (1) vs. masked (0) positions.
- Top-down: context is e.g., first rows only.
- In-filling: only the spatial perimeter is context, inner patches are masked (to be predicted).
- The autoregressor takes masked input to form the context, predicting features at all masked grid positions.
Objective: where indexes all masked positions and is a linear projection.
3. Architectural and Implementation Details
The encoder and context networks in Masked CPC vary by modality but share key design motifs:
- Audio (ASR, Speech):
- Encoder: stack of convolutional layers (e.g., 7 layers, 512 channels each, strides yielding 20 ms temporal granularity, ).
- Context network: Transformer-based, with bidirectional attention (e.g., 12 or 24 layers), or single-layer GRU for classic CPC.
- Mask vector is a learned embedding used for masked frames.
- Masking typically covers frames per utterance with (Talnikar et al., 2020).
- Vision/Pathology:
- Encoder: ResNeXt-101 backbone extracting patch features (e.g., patches from images, ) (Carse et al., 2021).
- Context network: Stack of multi-directional PixelCNN blocks (with orientation-rotated, masked convolutions per block), aggregating context from all cardinal directions.
- Masking can be top-down or in-filling; the in-filling mask reveals only the spatial perimeter to the context network.
- Negative Sampling:
- Negatives sampled from within the minibatch or from non-masked positions in the same sequence/image, e.g., for audio, per sample for pathology (Talnikar et al., 2020, Carse et al., 2021).
4. Training Protocols and Optimization
Optimizing Masked CPC involves the following procedural components:
- For unlabeled (self-supervised) batches, compute the masked-CPC loss and update encoder/context parameters with Adam optimizer (learning rates and betas as per modality and model scale).
- For supervised tasks (e.g., ASR), incorporate labels and a Connectionist Temporal Classification (CTC) loss in an alternating or interleaved regime, usually with a much smaller supervised learning rate, two separate optimizers, and careful control of momentum states (Talnikar et al., 2020).
- Data augmentations are standard in vision settings (random rotations, flips) (Carse et al., 2021).
- Early stopping or validation loss is used to determine training length, with learning rate warmup/decay as required.
Empirical training parameters (examples):
- Audio (Base): =10 masks per utterance, batch size unspecified, learning rate for unsupervised Adam , supervised Adam , , negatives (Talnikar et al., 2020).
- Pathology: batch size 16 for CPC, Adam with , negatives, 20 epochs (Carse et al., 2021).
5. Empirical Results and Comparative Analyses
Audio: Automatic Speech Recognition
- Joint masked-CPC and CTC training on LibriSpeech (100 h labeled + 960 h unlabeled) yields word error rates (WER) competitive with wav2vec 2.0:
- Dev-clean: 6.1% (masked CPC) vs. 6.1% (wav2vec 2.0); Dev-other: 13.7% vs. 13.5% (no LM).
- With 4-gram LM: 3.0%/7.7% (masked CPC) vs. 3.2%/8.9% (wav2vec 2.0).
- The contrastive loss acts as a regularizer, with higher training CTC loss but lower validation loss than pure supervised training. On fully labeled LibriSpeech (960 h), joint training achieves WER 5.8% (dev-other, 4-gram LM) vs. 7.2% for supervised-only (Talnikar et al., 2020).
Vision: Digital Pathology
- In multi-directional CPC, ring (in-filling) masking and orientation-invariant context modeling enable more stable training and higher downstream classification accuracy, especially in regimes with limited labeled data.
- On Patch Camelyon:
- With only 10 labeled examples, multi-directional/top-down achieves 0.566 accuracy vs. 0.518 (no pretrain) and 0.509 (single-direction/top-down).
- With 32 examples, multi-directional/in-filling reaches 0.614.
- At 1000 labeled, multi-directional/in-filling: 0.769, outperforming other variants.
- Multi-directional autoregression consistently provides lower CPC validation loss and faster convergence (Carse et al., 2021).
6. Theoretical Insights, Variants, and Extensions
Masked CPC's predictive formulation equates maximizing a lower bound on the mutual information between masked/future target representations and the context. By formulating the contrastive loss in this density-ratio framework, the models select for representations that capture temporally or spatially persistent, high-level features ("slow features"), providing robustness and transferability (Oord et al., 2018).
Framework flexibility:
- Masking can be applied in arbitrary patterns (random, blockwise, in-filling), and the context generator may be chosen as RNN, Transformer, PixelCNN, or other masked/self-attention models.
- This flexibility enables adaptations for diverse domains and data structures: audio (temporal), vision (spatial), text (sequential, masked tokens).
- Integrating masked CPC with supervised losses yields efficient single-stage training for tasks such as ASR, removing the necessity for separate pretraining and finetuning (Talnikar et al., 2020).
A plausible implication is that masked CPC and related masked-prediction models unify contrastive, autoencoding, and autoregressive paradigms, providing a scalable and regularizing approach for diverse data modalities and annotation regimes.
7. Summary Table: Representative Masked CPC Schemes
| Domain | Masking Type | Context Network | Negative Sampling |
|---|---|---|---|
| Speech/ASR | Random temporal masking | Transformer (bidirectional) | Within-sample frames |
| Vision | Top-down, in-filling (ring) | Multi-dir. PixelCNN blocks | Other images' patches |
| Pathology | Top-down, in-filling (ring) | Multi-dir. PixelCNN blocks | Other images' patches |
This organization highlights the adaptability of Masked CPC to multiple data types, the role of masking in the context construction, and the integration of advanced context networks such as Transformers and multi-directional PixelCNNs (Oord et al., 2018, Talnikar et al., 2020, Carse et al., 2021).