Neighborhood Attention: dynamic restriction of self attention
Ali Hassani
Committee: Humphrey Shi (chair), Thien Nguyen, Thanh Nguyen
Area Exam(Jul 2023)
Keywords: Self-attention, computer vision

Transformers, and more generally attention-based models, are omnipresent in modern deep learning frameworks, dominating a wide range of applications from language to vision and speech. Self attention, one of the primary operators in these models, is often cited for its quadratic complexity with respect to input size. Avoiding this complexity has often been done through local and sparse patterns, which can be effective, but usually either eliminate useful properties and inductive biases, or are often difficult to implement and scale.

We propose neighborhood attention, a restriction of self attention to nearest neighbors, which results in linear time complexity, and can maintain many of the properties present in self attention. This pattern can be thought of as a flexible sliding window pattern, aimed at capturing consistent local context throughout the input. Previous attempts at sliding window patterns in attention were roadblocked by the relative difficulty in implementing such patterns for parallel hardware, which dominate training deep learning models. To that end, we propose and implement a range of different algorithms for neighborhood attention, starting with naive implementations.

We then formulate neighborhood attention as an implicit general matrix-matrix multiplication (GEMM) problem, and implement its kernels in CUTLASS. This implementation provides up to 6X improvement in latency in 1D problems, and 4X improvement in 2D problems compared to naive GPU kernels. We package and release all of our implementations as a Python package, NATTEN, which would allow researchers to quickly set up and use neighborhood attention. We finally show some of many possible applications of neighborhood attention, by introducing a set of hierarchical vision transformers, Neighborhood Attention Transformer, and present their performance on image classification, object detection, and image segmentation.