Neural Architecture Search (NAS) has achieved great success by searching for optimal architectures for specific tasks by selecting the best operations across different layers. However, in the era of Large Language Models (LLMs), people tend to use a single model with a unified architecture, e.g., the transformer, and apply it to diverse tasks. However, this does not mean that NAS cannot contribute to LLM development. Indeed, an operation only represents how one token interacts with other tokens. Therefore, deep learning operations in the search space should not be limited to the layer level.
In this blog post, we present Neural Attention Search (NAtS), a framework that designs the search space at a more fine-grained level. We show that NAtS could search for a more efficient transformer model while preserving most of its model performance.
Token Level Search Space Design
Softmax attention first computes the correlations between different tokens. It then applies an attention mask that controls whether one specific token should be correlated with another one. A commonly used attention mask is the causal mask, which prevents earlier tokens from receiving information from later tokens. Hence, we could also view each column of the attention mask as an operation applied to each token.

Instead of drawing from different architectural families (e.g., convolutional or recurrent layers), we design a search space composed entirely of softmax-attention variants. More specifically, we design three token types: (i) global tokens, tokens that will be preserved until the very end. (ii) local tokens, tokens that will only be preserved until the next global token appears, (iii) sliding window tokens, tokens that will be preserved for a fixed number of timesteps. This search space is applied to every token in the input sequence to generate unique attention masks that best fit the input sequence.
Searching for the Optimal Attention Maps
We use a linear layer (the scorer layer) that maps the attention input feature maps to a set of scores that determines each token’s attention types. The operation with the highest scores will be applied to that token. Similar to the one-shot supernet and once-for-all NAS approaches, all the attention operations share the same set of weights. We then jointly optimize the model and scorer layer parameters using gradient descent.
To better trade off accuracy and efficiency, we introduce a regularization term that is directly applied to the gradient for each operation type, encouraging more sliding-window tokens. In the end, the scorer layer automatically determines the optimal token type given the input token values. This provides us with an end-to-end sparse attention model that adapts its focus on the tokens that are likely to provide more information in the context. Hence, NAtS can be applied to pre-train a new model from scratch or to approximate the output of an existing transformer model.
By removing tokens that can no longer influence the following time steps, we can efficiently reduce the time required for pre-filling and decoding while keeping the memory cost of the KV cache low. In experiments, we show that NAtS can preserve most of a full transformer’s ability with a much smaller required KV cache size.
For further details, please check out our paper Neural Attention Search.
Comments are closed