I've been keeping a weekly work journal as I've been at the Recurse Center to help my organize notes as I've been learning. These are essentially snapshots of what I've been looking at. A fellow recurser pointed out to me that these might be worth putting out into the world on their own so this is an experiment in that.
This post is going to be more or less those raw notes, lightly edited.
This week, we have:
I extended my RC batch to another 6 weeks. Judging by the first half, 6 weeks is a lot less time than you’d think. I’m writing this out to be a bit more picky about what I’m going to spend time on here. The main culprits are CS336 and ARENA, but I’d also like to have read through the original circuits and transformer circuits threads and feel like I can read recent research papers and generally understand what's going on.
Paper reading goals:
There's basics (implementing), systems (gpus), scaling (laws), data (gathering), and RL.
I may try to gloss over the GPUs and data gathering portions, but I do want to do the RL part more deeply. I feel somewhat torn about the GPU assignment because I think it would really sharpen my intuition on GPUs but I also think it’s probably distracting and would take me quite a while. Assignment 3 seems pretty short and sweet. Assignment 4 on data cleanup I will probably skim, then focus on assignment 5.
There's Chapter 1 (transformer interpretability), Chapter 2 (RL), and Chapter 3 (LLM evals).
I feel more interested in chapters 1 and 2, and may skim the work on evals.
Much of this content seems to overlap with itself which seems in my favor, but more for reviewing the knowledge vs. time savings. I would also like to spend some time writing and/or working on a last project to cap off this work.
Here’s a sample schedule to see if these goals are actually feasible:
Schedule:
It’s possible that I should spend more time on RL and put that earlier instead, depending on how out of order the stanford class can be. This seems vaguely doable, but still pretty aggressive scheduling.
Instead of looking at attention as a unit that scales out, this paper prefers to look at the transformer using the QK circuit and OV circuit, which it states are roughly independent.
The QK circuit is the one that makes up the attention scores (with softmax), i.e. what should the head attend to?
The OV circuit is the one that tells the attention head how much that attention should update the output logits. OV is made up of the values matrix and the output projection matrix that goes from d_head to d_model (i.e. dimension of the residual stream).
The destination token attends to the source token. The QK is a function of both, but OV is a function of only the source. In other words, the destination token is attending to the source by writing some information about the source to itself. Note that “tokens” here refer to their positions in the context window.
When you have a one layer model, values can reach the model output either via the residual connection or entering through one attention ahead.
Note that trigram is a bit odd, because it’s still a prediction based on the src and dest. Neel mentions it might be better to call this “skip bigram” instead.
One cool behavior we see is the model doing what looks like “normalizing” the tokenizer output where some tokens are the ~same in some sense ( Ralph
and Ralph
and RALPH
should be the ~same, conceptually but not grammatically).
A lot of the heads seem to be doing copying behavior like above, where the destination token copies the source token into its predicted output (like the “perfect” example above).
A limit of trigram behavior is that it incorrectly predicts outputs too.
This seems to happen because the destination token is only considered for the KV circuit, not the OV one. So it can tell the output “where” to look, but not “what” which causes that step to not always produce useful output.
The paper also does some analysis of eigenvectors/values with the intuition that:
Negative | Imaginary | Positive | |
---|---|---|---|
KV Circuit | Avoid same-attention | Other tokens | Prefer same-attention |
OV Circuit | Anti-Copying | Other tokens | Copying |
It’s unclear to me exactly how this was computed or how generally useful it is, but the input/output vectors are on the vocab space, not the context window.
The paper then talks about the fact that when you go into 2-layer models, you have 3 paths: the residual, the attention heads, and the “virtual attention heads” which are all of the combinations between the values that entered one head each for each of the layers.
There’s then some “term importance” math to decide if the virtual heads matter and empirically, they do some tricks to see which affect the output loss the most. The individual attention heads seem most important, followed by the residual stream, followed by the virtual attention heads.
Following skip trigrams, the paper posits that the 2-layer model spends its time composing the layers and forming “induction heads” which guess the next output by trying to look at previous examples within the context window. There are a few interesting parts to this.
In-context learning This is in some sense, “learning” within the context window because it infers possible outcomes for the text based on its surrounding words.
Learned feature from composition This behavior uses the previous layer to figure out what’s going on. It uses a previous layers’ head that does “previous token lookup”
No notes, but interesting overview of a tool they made at Anthropic to help with inspecting models. This sounds pretty neat to work on.
Rough idea is that they analyzed what context prefixes would trigger a given neuron in MLP layers in an LLM to get a sense for what each is trying to do. They took each neuron, found which inputs best activated that neuron, then tried to get humans to see if they did something interpretable in English.
More concretely,
Key-value pairs correspond to what individual neurons are doing.
For instance, in a transformer the key might be “military bases in“ and the value distribution would be places where bases might be. For MNIST, keys are “inputs that have closed loop” might be values of 0, 6, 8, or 9.
Even more concretely,
With this, they saw that:
Note that the simpler behaviors here also seem to kind of feel like the “skip trigrams” model that the Anthropic mathematical framework paper proposes, but that is for the attention heads not the MLP. Also, this paper doesn’t talk about in-context learning in the attention head, whereas the MLP is doing the opposite - word-specific learning. How do these interact? It somewhat feels like they are learning similar things but in different ways.
They also sort of discuss that later layers seem more focused on output than the earlier layers.
They also suggest that MLPs contribute small parts to the residual stream, which get combined into the answer. This is pretty agreeable with our current mental model of the residual stream which implies that vectors get sharpened while flowing through the network, with the resdidual stream used as a basis.
This paper seems to be the one that one of the NGW2 talks applying this to MNIST was based on.
Modern GPUs tend to be memory-bound and not CPU-bound because CPU scaled has moved faster than memory.
Major tricks:
Use all of the above tricks when calculating the attention head. In particular, reduce the cost of the sequence^2 attention scores by fusing the entire attention head into one kernel.
To do this, the matrices need to be tiled. Softmax is difficult because it requires computing across the entire row, but online softmax avoids this by doing a lclever math trick and doing a running calc of the denominator.
FlashAttention also prefers recomputing certain parts (attention scores) during the backprop instead of storing/loading from memory.
FlashAttention is different in that it is just as correct/precise as vanilla attention, but with less memory access for better arithmetic intensity (how much math per byte moved).
This was mostly review from assignment 1 of cs336 so I won’t have much here.
Keep sampling the k best sequences so far and return a full output at the end that has the best loss. More likely to get a sequence with good loss overall, uncovers hidden sequences that have good loss.