Implementing and testing $\mathrm{grokfast}$ for LLMs on small models trained on wikitext.
Quick summary:
-
$\mathrm{grokfast}$ doesn't help performance on wikitext, even after$260$ epochs - It also doesn't hurt performance
- Therefore, it might be helpful for improving performance on algorithmic tasks without hurting it on regular language tasks (though at the cost of an increased memory footprint)
Table of contents:
There are two hyperparameters: alpha
and gain
.
Mathematically, this works like this:
where
Here is how I've implemented it (done before the paper's code was released):
import torch
@torch.no_grad()
def grokfast(net: torch.Module, alpha: float, gain: float):
for parameter in net.parameters():
if not hasattr(parameter, "grad_ema"):
setattr(
parameter,
"grad_ema",
torch.empty_like(parameter.grad, dtype=torch.float32).copy_(parameter.grad)
) # copy gradients
else:
parameter.grad_ema = alpha * parameter.grad_ema + (1 - alpha) * parameter.grad
parameter.grad += gain * parameter.grad_ema
The authors recommend to use
They show that
Here is an LLM training run without
Runs where alpha
and gain
are high tend to fare worse. To quantify this, here is the average of the
alpha |
gain |
Mean of |
|
---|---|---|---|
0.0 | 0.0 | False | 3.20 |
0.9 | 0.5 | True | 3.21 |
0.9 | 0.1 | True | 3.22 |
0.8 | 0.1 | True | 3.23 |
0.8 | 0.5 | True | 3.23 |
0.9 | 2.0 | True | 3.26 |
0.8 | 2.0 | True | 3.30 |
0.99 | 0.1 | True | 3.54 |
0.9 | 5.0 | True | 3.90 |
0.99 | 0.5 | True | 4.39 |
0.8 | 5.0 | True | 4.67 |
0.99 | 2.0 | True | 6.11 |
0.99 | 5.0 | True | 6.34 |
Clearly, not using
It's possible that
This doesn't seem to be better in a way that cannot be explained by random noise. Is it better at all? Let's look at the stats again:
alpha |
gain |
Mean of |
|
---|---|---|---|
0.9 | 0.1 | True | 3.23 |
0.8 | 0.5 | True | 3.24 |
0.9 | 0.5 | True | 3.24 |
0.8 | 0.1 | True | 3.25 |
0.0 | 0.0 | False | 3.27 |
0.9 | 2.0 | True | 3.28 |
0.8 | 2.0 | True | 3.33 |
0.99 | 0.1 | True | 3.67 |
0.9 | 5.0 | True | 4.65 |
0.99 | 0.5 | True | 4.77 |
0.8 | 5.0 | True | 5.25 |
0.99 | 2.0 | True | 6.11 |
0.99 | 5.0 | True | 6.34 |
The
Is it any better in the first ~2 epochs?
Again: no. The two runs seem identical. The statistics say the same:
alpha |
gain |
Mean of |
|
---|---|---|---|
0.9 | 0.1 | True | 4.89 |
0.0 | 0.0 | False | 4.90 |
0.8 | 0.1 | True | 4.90 |
0.8 | 0.5 | True | 4.90 |
0.9 | 0.5 | True | 4.92 |
0.8 | 2.0 | True | 5.00 |
0.9 | 2.0 | True | 5.01 |
0.99 | 0.1 | True | 5.28 |
0.99 | 0.5 | True | 6.10 |
0.9 | 5.0 | True | 6.37 |
0.8 | 5.0 | True | 6.47 |
0.99 | 2.0 | True | 6.97 |
0.99 | 5.0 | True | 7.17 |
On the one hand,
This isn't super surprising; wikitext cannot really be reduced to a simple algorithm, the way that modular addition can, so what would
On the other hand,
That seems like a big win!
However, it comes at the cost of having to keep the EMA of the gradients in memory, which is significant.
The
This package is based on my own hlb-gpt-cli, which is in turn based on Fern's hlb-gpt, who can be cited in the following way:
cff-version: 1.2.0
message: "Citations would be appreciated if you end up using this tool! I currently go by Fern, no last name given."
authors:
given-names: "Fern"
title: "hlb-gpt"
version: 0.4.0
date-released: 2023-03-05
url: "https://github.com/tysam-code/hlb-gpt"