Parallel Sequence Modeling via Generalized Spatial Propagation Network

1NVIDIA, 2The University of Hong Kong, 3University of California, San Diego
(† the work was done at an internship at NVIDIA)

Abstract


We present the Generalized Spatial Propagation Network (GSPN), a new attention mechanism optimized for vision tasks that inherently captures 2D spatial structures. Existing attention models, including transformers, linear attention, and state-space models like Mamba, process multi-dimensional data as 1D sequences, compromising spatial coherence and efficiency. GSPN overcomes these limitations by directly operating on spatially coherent image data and forming dense pairwise connections through a line-scan approach. Central to GSPN is the Stability-Context Condition, which ensures stable, context-aware propagation across 2D sequences and reduces the effective sequence length to $\sqrt{N}$, significantly enhancing computational efficiency. With learnable, input-dependent weights and no reliance on positional embeddings, GSPN achieves superior spatial fidelity and state-of-the-art performance in vision tasks, including ImageNet classification, class-guided image generation, and text-to-image generation. Notably, GSPN accelerates SD-XL with softmax-attention by over 84X when generating 16K images.

Motivation


Transformers have revolutionized machine learning, natural language processing and computer vision. However, their quadratic computational complexity of attention mechanisms hampers efficiency at large scales (e.g. high-resolution images) and the disregard for spatial structure diminishes their suitability for vision tasks. In this work, we introduce the Generalized Spatial Propagation Network (GSPN), a linear attention mechanism optimized for multi-dimensional data such as images. Central to GSPN is the Stability-Context Condition, which ensures both stability and effective long-range context propagation across 2D sequences by maintaining a consistent propagation weight norm. This condition allows information from distant elements to influence large spatial areas meaningfully while preventing exponential growth in dependencies, thus enabling stable and context-aware propagation essential for vision tasks. With a linear line-scan operation, GSPN parallelizes propagation across rows and columns, reducing the effective sequence length to $\sqrt{N}$, significantly enhancing the computational efficiency. This makes GSPN a robust and scalable framework that overcomes the key limitations of existing attention mechanisms by inherently capturing 2D spatial structures.

GSPN module with 2D Linear Propagation


GSPN guarantees Dense Pairwise Connections via 3-way connection and 4-directional scanning. The scanning of each direction corresponds to a lower triangular affinity matrix. The finally full matrix is obtained through a learnable linear aggregation. For the i-th row, each item in the hidden state h_i is computed by: (1) a weighted sum of three neighboring values from the hidden layer h_{i-1}, where weights form a normalized tridiagonal matrix w_i, and (2) the element-wise product of the current input x_i with λ. w_i and λ are both learnable and input-dependent parameters. The weights in w_i are obtained by applying sigmoid activation followed by row-wise normalization.

Quantitative Results


As a new sub-quadratic attention block tailored for vision, it is crucial to comprehensively benchmark GSPN to showcase its effectiveness and efficiency. To this end, we conduct extensive evaluations across a diverse range of visual tasks, including deterministic tasks like ImageNet classification, and generative tasks such as class-conditional generation (DiT) and text-to-image (T2I) generation.

Image Classification

Class-conditional Generation (400K iterations)

Text-to-image Generation

Qualitative Results


SD-v1.5

SDXL

Heatmap


The heatmap presents a comprehensive analysis of our GSPN across four distinct directional scans, revealing a pronounced anisotropic behavior. An additional aggregated heatmap provides a holistic view, synthesizing the directional insights into a unified representation that captures long-context and dense pairwise connections through a 3-way connection mechanism.

BibTeX

@article{wang2025parallel,
    author    = {Wang, Hongjun and Byeon, Wonmin and Xu, Jiarui and Gu, Jinwei and Cheung, Ka Chun and Wang, Xiaolong and Han, Kai and Kautz, Jan and Liu, Sifei},
    title     = {Parallel Sequence Modeling via Generalized Spatial Propagation Network},
    journal   = {arXiv preprint arXiv: 2501.12381},
    year      = {2025}
}

Acknowledgement

This web page is modified based on the template from nerfies. Thanks for their great work.