This is the first post in a series of posts for the year 2024 that I’m calling “A Paper a Week”. My goal is to read an academic paper each week of the year in an effort to gain a deeper understanding into the current state of AI and to increase the “academia” in my otherwise professional life. Along with each of these papers I plan on writing a blog post to accompany it. I’m not exactly sure what form the blog posts should be - I think at the very least they will contain a summary of the paper along with any thoughts I have after reading (as well as a fun DALL-E image generated by prompting with the title of the paper).

Introduction

To start this series, I read a very topical and relevant paper in “Attention Is All You Need” by Vaswani et al. otherwise known as the paper credited with creating the transformer architecture. The transformer has quickly taken over almost every field of deep learning such as computer vision, natural language processing, and even reinforcement learning due to its incredible ability to ingest sequential data and maintain long-range dependencies and information states compared to any other sequential-based architecture.

Summary

The motivating problem for the transformer architecture was finding a way to scale sequential models. In NLP, the dominating architectures before the transformer were the RNN (Recurrent Neural Network), the LSTM (Long Short-Term Memory), and the GRU (Gated Recurrent Unit). These architectures, through one way or another, were structured around processing the input, x, from the current timestep as well as all previous timesteps using a context state, c. While these models can perform well with short sequences, they are by nature, sequential, which means they can’t parallelize computation well, and more importantly, as the sequence length grows, the information storage capability of the context state, or hidden state, quickly diminishes. In other words, as the sequence length, N, increases, the sequential computation scales with N and the amount of information that has to be stored in the context state also scales with N. This leads to an information bottleneck in the hidden state due to the variable-length input and the fixed-length nature of the hidden state. A real world analogy to this would be continuously zooming out on an image. As the amount of content we are trying to store in the image grows, the pixels we have available to us to represent the image data remain fixed. Thus, objects in the image will get more blurry since they have fewer pixels available to represent their shape or color (i.e. information loss).

To solve this problem, Vaswani et al, ditched the recurrent structure entirely and decided to instead only utilize the attention mechanism. This enabled the model to create global dependencies between input and output tokens (i.e. no more fixed-length context/hidden state) as well as parallel processing of inputs. The actual architecture used was an encoder-decoder architecture, which was already established at a high-level with RNNs and LSTMs, where an input sequence, x, is encoded into a continuous sequential represenation, z, and then decoded back into an output sequence, y, one element at a time. However, a lot of different mechanisms happen under the hood in the transformer architecture, including a positional encoding (so that the model can maintain a sense of order of inputs), residual connections between attention layers, and finally multi-head attention, which is similar to using multiple channels in a CNN. In other words, more complex representations can be learned since each attention head is learning a different task, similar to stacking convolution channels in CV tasks.

Of course, there are also limitations to using attention. Namely, one has to set a maximum input size, N, of tokens to be processed by the model for a forward pass. This limit sets the length of long-range dependencies that can be learned since it represents a hard cut off for what information can be observed by the model. However, in practice, these maximum input sizes are much larger than what RNNs are capable of handling accurately, which is why the architecture is so dominant comparitively. Plus, theoretically, the input size, N, could be extremely large (even infinite), it would just mean that the computational cost of the model would be extremely high.

Overall, the transformer architecture respresents a more efficient, and more accurate, way of processing sequential data than any of the previous leading architectures in NLP.

Questions/Notes I Have

  1. I want to dig deeper into self-attention and learn what motivated it in the first place.
  2. I wanted a better understanding of how inputs are actually being processed in attention layers so I found a pretty cool resource that implements self-attention in pytorch and I also asked ChatGPT (which ironically, gave me the best answer).