Attention sampling speeds up the processing of large inputs by only processing a fraction of the input in high resolution. To achieve that we make use of an attention distribution predicted in low resolution that provides the information regarding the usefulness or importance of each part of the image.
The following paragraphs will introduce the theory and the model as well as how each part is implemented in this python library. The full details can be found in our paper "Processing Megapixel Images with Deep Attention-Sampling Models" as well as our poster and presentation.
Firstly, we need to define what attention means in the context of neural networks. Given an input and features , we aggregate those features using an attention distribution such that in the following way:
The above definition of attention does not make use of "queries" and "keys", that are common in NLP related papers, however, the query can be included in as an extra input and then the definition becomes equally general. For our purposes we assume the query to be always one, namely what are the most useful feature positions in our input.
The functions defined above are implemented as neural networks and we refer to them as attention network and feature network.
Under the assumption that is difficult to compute, we aim to save computational resources by avoiding to compute the features for all the positions. We can instead approximate using Monte-Carlo sampling:
In our paper we show that
- the approximation above is optimal (of minimum variance) if we normalize the features
- we can derive an unbiased gradient estimator that uses only the samples Q to train our models in an end to end fashion
- we can derive similar unbiased estimators for sampling without replacement that ensures that we compute only once
In most cases, the attention weights are a function of the features; thus in order to compute the attention distribution and sample feature positions we need to compute all the features which makes our approximation unnecessary.
Instead, we propose computing the attention in low resolution directly from the input and consequently much faster than computing all the features. The full pipeline for images is depicted in Figure 1.
The above pipeline is implemented by the method
ats.core.attention_sampling(...) that accepts as parameters, the attention
network, the feature network, the size and number of the patches to be sampled.
See the API documentation for more details.