import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import random
import csv
from pathlib import Path
from collections import Counter
from tqdm.auto import trange, tqdm
Makemore Subreddits - Part 3 Activations and Gradients
Let’s dive deep into the Activations and Gradients in a Multi-Layer Perceptron language model for subreddit names.
This loosely follows part 3 of Andrej Karpathy’s excellent makemore; go and check that out first. However he used a list of US names, where we’re going to use subreddit names. See Makemore Subreddits - Part 2 MLP for the original Multi-Layer Perceptron following from Bengio et al. 2003.
This is a Jupyter notebook you can download the notebook or view it on Kaggle.
Loading the Data
This is largely similar to Part 1 where we get the most common subreddit names from All Subreddits and Relations Between Them.
Filter to subreddits that:
- Have at least 1000 subscribers
- Are not archived
- Are safe for work
- And are not quarantined
Note that you need to have downloaded subreddits.csv (and uncompresed if appropriate).
= Path('./data')
data_path = 1_000
min_subscribers
with open(data_path / 'subreddits.csv', 'r') as f:
= [d['name'] for d in csv.DictReader(f)
names if int(d['subscribers'] or 0) >= min_subscribers
and d['description']
and d['type'] != 'archived'
and d['nsfw'] == 'f'
and d['quarantined'] == 'f']
len(names)
42)
random.seed(
random.shuffle(names)
= len(names)
N
= names[:int(0.8*N)]
names_train = names[int(0.8*N):int(0.9*N)]
names_val = names[int(0.9*N):]
names_test
for name in names_train[:10]:
print(name)
len(names_train), len(names_val), len(names_test)
splunk
thenwa
soylent
factorio
christinaricci
blues
vegancheesemaking
goldredditsays
reformed
nagoya
(26876, 3359, 3360)
Compile the Data
Now convert the dataset into something that the model can easily work with. First represent all the character tokens as consecutive integers. We create a special PAD_CHAR
with index 0 to represent tokens outside of the sequence.
= '.'
PAD_CHAR = 0
PAD_IDX
= sorted(set(''.join(names_train)))
i2s assert PAD_CHAR not in i2s
i2s.insert(PAD_IDX, PAD_CHAR)
= {s:i for i, s in enumerate(i2s)}
s2i
= len(i2s)
V
def compile_dataset(names, block_size, PAD_CHAR=PAD_CHAR, s2i=s2i):
= [], []
X, y for name in names:
= PAD_CHAR * block_size + name + PAD_CHAR
padded_name = [s2i[c] for c in padded_name]
padded_tokens for *context, target in zip(*[padded_tokens[i:] for i in range(block_size+1)]):
X.append(context)
y.append(target)return torch.tensor(X), torch.tensor(y)
= 3
block_size
= compile_dataset(names_train, block_size)
X, y = compile_dataset(names_val, block_size)
X_val, y_val = compile_dataset(names_test, block_size)
X_test, y_test
X.shape, y.shape
(torch.Size([330143, 3]), torch.Size([330143]))
Review: Multi-Layer Perceptron
We will start with the Multi-Layer perceptron implementation from part 2
= 30
default_m = 200
default_h
class MLP:
def __init__(self, m=default_m, h=default_h, V=V, block_size=block_size):
self.m = m
self.h = h
self.V = V
self.block_size = block_size
# Word embedding layer
self.C = torch.randn(V, m)
# First hidden layer
self.H = torch.randn(block_size * m, h)
self.d = torch.randn(h)
# Second hidden layer
self.U = torch.randn(h, V)
self.b = torch.randn(V)
def parameters(self):
return [self.C, self.H, self.d, self.U, self.b]
def requires_grad_(self, requires_grad=True):
for p in self.parameters():
p.requires_grad_(requires_grad)return self
def zero_grad(self):
for p in self.parameters():
= None
p.grad return self
def forward(self, X):
self.embeddings = self.C[X]
self.hidden_layer = self.embeddings.view(X.shape[0], self.block_size * self.m) @ self.H + self.d
self.hidden_activations = torch.tanh(self.hidden_layer)
self.output_logits = self.hidden_activations @ self.U + self.b
return self.output_logits
def __call__(self, X):
return self.forward(X)
So we can start from a randomly initialised MLP:
= MLP().requires_grad_()
mlp
with torch.no_grad():
= mlp(X_val)
preds_val = F.cross_entropy(preds_val, y_val).item()
val_loss val_loss
27.776723861694336
And code to sample from it (which gives random output):
def sample(mlp, pad_idx=PAD_IDX, block_size=block_size, i2s=i2s, generator=None):
= []
ans = torch.tensor([[pad_idx] * block_size])
state while True:
= mlp(state).softmax(axis=1)
probs = torch.multinomial(mlp(state).softmax(axis=1), 1, generator=generator)
next_idx = torch.concat([state, next_idx], axis=1)[:,1:]
state
= next_idx[0,0].item()
next_idx if next_idx == pad_idx:
return ''.join(ans)
ans.append(i2s[next_idx])
sample(mlp)
'qaxtgkseh9ekyf9uokp29euoy4a7igiareg'
And code to train it:
= 32
batch_size = 100
val_step
def train(model, n_step, lr, batch_size=batch_size, val_step=val_step, X=X, y=y, X_val=X_val, y_val=y_val, callback=None):
= [], []
losses, val_losses
for step in trange(n_step):
= True # NEW: support models that are different at train and inference time
model.training = torch.randint(0, len(X), (batch_size,))
idx
model.zero_grad()= model(X[idx])
logits = F.cross_entropy(input=logits, target=y[idx])
loss
losses.append((step, loss.item()))
loss.backward()
# NEW: Support injecting a callback to do some mutation
if callback is not None:
**locals())
callback(
for p in model.parameters():
-= p.grad * lr(step, n_step)
p.data
if step % val_step == 0:
= False # NEW: support models that are different at train and inference time
model.training with torch.no_grad():
= model(X_val)
preds_val = F.cross_entropy(preds_val, y_val).item()
val_loss
val_losses.append((step, val_loss))
= False # NEW
model.training return losses, val_losses
= train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1) losses, val_losses
The loss decreases quickly from a very high value, and then slowly descends.
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
baseline_loss baseline_loss
2.6191823482513428
We get samples that are much less random, but still don’t really seem like subreddit names.
for _ in range(20):
print(sample(mlp))
lolleckingurs
r64onfrimen
caryimalcuraogardcartfgatulgeterheisting
clan
voumg
hvur
stagb
monconts
vart
scaromark
uprycherce
wing
lok
crteyrar
ling
jerkeymentfumearnight
kurcudver
learovedx
osarenetroor
ct2ivzenmirdur34therrertifar
Fixing Initialisation
The loss starts off very high giving the “hockey-stick” shaped loss curve. We can fix this by starting the weights in a better place.
The initial loss is very high:
= MLP()
mlp
with torch.no_grad():
= mlp(X_val)
preds_val = F.cross_entropy(preds_val, y_val).item()
val_loss val_loss
29.860868453979492
A good baseline model is the constant model that assigns each token in the vocabulary an equal probability
= 1/V * torch.ones(size=(len(X_val), V)) constant_model
Such a model has a 1/V probability of having the correct prediction, and so the loss will be -log(1/V)
-torch.log(torch.tensor(1/V)).item()
3.6375861167907715
This is accurate to a very good approximation, and much better than the random weights.
F.cross_entropy(constant_model, y_val).item()
3.637585163116455
Walking through the model
We want to initialise the weights so that the model predicts close to a random distribution of outputs.
Let’s step through the layers of our current model for a batch of training data to understand what is currently happening:
self.embeddings = self.C[X]
self.hidden_layer = self.embeddings.view(X.shape[0], self.block_size * self.m) @ self.H + self.d
self.hidden_activations = torch.tanh(self.hidden_layer)
self.output_logits = self.hidden_activations @ self.U + self.b
with torch.no_grad():
= mlp(X[:32]) batch
The embeddings have a sort of standard normal distribution; slightly distorted by the item frequency.
self.C = torch.randn(V, m)
...
self.embeddings = self.C[X]
-1), bins=50);
plt.hist(mlp.embeddings.view( mlp.embeddings.shape, mlp.embeddings.std()
(torch.Size([32, 3, 30]), tensor(0.9805))
The first hidden layer performs a linear transformation.
# First hidden layer
self.H = torch.randn(block_size * m, h)
self.d = torch.randn(h)
...
self.hidden_layer = self.embeddings.view(X.shape[0], self.block_size * self.m) @ self.H + self.d
In index notation the hidden layer \(h\) looks like:
\[h_{i,k} = \sum_{j=1}^{m \times {\rm block\_size}} e_{i,j} H_{j,k} + d_{i,k}\]
And assuming the variables are all independent, and that the embeddings e and matrix H each consist of elements of zero mean and equal standard deviation then \(\mathbb{E}(h) = 0\) and
\[ \mathbb{V}(h_{i,k}) = (m \times {\rm block\_size}) \mathbb{V}(e) \mathbb{V}(H) + \mathbb{V}(d)\]
So in particular here we’ve set all the element variances to 1, and so the output variance should be:
* mlp.m + 1 mlp.block_size
91
It’s pretty close to this (with some random error)
-1), bins=50);
plt.hist(mlp.hidden_layer.view( mlp.hidden_layer.shape, mlp.hidden_layer.var()
(torch.Size([32, 200]), tensor(83.0976))
We then perform a tanh transformation, which squishes values far from 1 towards 1.
self.hidden_activations = torch.tanh(self.hidden_layer)
= torch.arange(-3, 3, step=0.01)
x ; plt.plot(x, torch.tanh(x))
Consequently we get all our values squished around -1 and 1
-1), bins=50);
plt.hist(mlp.hidden_activations.view( mlp.hidden_activations.shape, mlp.hidden_activations.var()
(torch.Size([32, 200]), tensor(0.9154))
Many of the activations are above 0.99, which means the gradient is \(\tanh'(x) = 1-\tanh^2(x) < 0.02\), which can lead to gradient underflow.
abs() > 0.99, cmap='gray', interpolation='nearest'); plt.matshow(mlp.hidden_activations.
The output logits are then mostly either U + b or -U + b, and so they are approximately normal.
self.output_logits = self.hidden_activations @ self.U + self.b
It has a huge variance, approximately \[{\rm hidden\_size} \times \mathbb{V}(U) + \mathbb{V}(b) = \rm hidden\_size+ 1\]
+ 1 mlp.h
201
-1), bins=50);
plt.hist(mlp.output_logits.view( mlp.output_logits.shape, mlp.output_logits.var()
(torch.Size([32, 38]), tensor(190.1738))
This means the logits fluctuate wildly and the predictions are very extreme with a very high probability prediction:
=1), cmap='gray'); plt.matshow(mlp.output_logits.softmax(axis
We can fix these simply by scaling down the activations; this is much more important in deep networks where these effects compound often leading to exploding or vanishing gradients.
This is explained clearly in Xavier Glorot and Yoshua Bengio’s Understanding the difficulty of training deep feedforward neural networks where they derive the variance ignoring non-linearities for the backward and forward pass and suggest initialising with
\[ w \sim U\left(-\frac{\sqrt{6}}{\sqrt{n_{\rm in} + n_{\rm out}}}, \frac{\sqrt{6}}{\sqrt{n_{\rm in} + n_{\rm out}}}\right)\]
Alternatively their analysis suggests you could also use
\[ w \sim \mathcal{N}\left(\mu=0, \sigma = \frac{\sqrt{2}}{\sqrt{n_{\rm in} + n_{\rm out}}}\right)\]
or approximately, \(w \sim \mathcal{N}\left(\mu=0, \sigma = 1/\sqrt{n_{\rm in}}\right)\) and this latter form is called Xavier initialisation or Glorot initialisation. They show these allow training CNNs up to 9 layers deep which was difficult without this.
In Delving Deep into Rectifiers He, Zheng, Ren, and Sun take into account the ReLU non-linearity and show you need to introduce a gain of \(\sqrt{2}\) (the He initialisation or Kaiming Initialisation after the first author), which allows them to go from 22 layers to 30.
For other non-linearities Siddharth Krishna Kumar derivies the variance in On weight initialization in deep neural networks of a differentiable activation function \(g\) and uses a local expansion to derive an initialisation of
\[ w \sim \mathcal{N}\left(\mu=0, \sigma = \frac{1}{\left|g'(0)\right|\sqrt{n (1 + g(0))^2}}\right) \]
For tanh this suggests a gain of 1, but as Andrej Karpathy argues in this lecture since tanh is contractive the gain must be more than 1, but this would require a higher-order approximation.
PyTorch has a gain of 5/3, but no one remembers why. It’s likely an estimate of the variance of \(\tanh\) under a standard normally distributed input:
\[ \begin{align} \mathbb{V}\left[\tanh\right] &= \mathbb{E}\left[\left(\tanh - \mathbb{E}(\tanh)\right)^2\right] \\ &= \mathbb{E}\left[\tanh^2\right] \\ &= \int_{-\infty}^{\infty} \left(\frac{1}{\sqrt{2 \pi}} e^{-x^2/2} \right) \tanh^2(x)\,{\rm d}x \\ & \approx 0.394 \end{align}\]
To normalise the gain we need to divide by the standard deviation (the square root of the variance) \({\rm gain} \approx 1/\sqrt{0.394} \approx 1.59 \approx 5/3\) (where the last term introduces an error of around 5%).
In general even if it is difficult to calculate the variance (for example because we’re unsure of the input distribution), it can be empirically derived, as in All you Need is a Good Init by Mishkin and Matas with their iterative Layer-sequential Unit Variance (LSUV) method.
In any case it won’t matter with our 1-layer MLP, and other methods such as batch/layer normalisation, skip connections, and better optimisers have made these less important (but you can train a transformer without the normalisation and skip connections).
For simplicity we’ll use Kaiming initialisation, setting the initialisation of the biases to a very low number and further scaling down the output logits to get a better initial loss.
def fix_init(mlp, bias_variance=1e-4, output_variance=1e-1):
with torch.no_grad():
*= 1/(mlp.block_size*mlp.m)**0.5
mlp.H *= bias_variance ** 0.5
mlp.d
*= output_variance**0.5/(mlp.h**0.5)
mlp.U *= bias_variance ** 0.5
mlp.b return mlp
= fix_init(MLP()).requires_grad_() mlp
We can see out initial loss is much lower, and closer to random.
with torch.no_grad():
= mlp(X_val)
preds_val = F.cross_entropy(preds_val, y_val).item()
val_loss
val_loss
3.64780330657959
F.cross_entropy(torch.ones_like(preds_val), y_val).item()
3.637585163116455
As before we can step through the layers
with torch.no_grad():
= mlp(X[:32]) preds
The hidden pre-activations are now standard normal
-1), bins=50);
plt.hist(mlp.hidden_layer.view( mlp.hidden_layer.shape, mlp.hidden_layer.var()
(torch.Size([32, 200]), tensor(1.0259))
The activations are slightly saturated, but uniform-ish.
-1), bins=50);
plt.hist(mlp.hidden_activations.view( mlp.hidden_activations.shape, mlp.hidden_activations.var()
(torch.Size([32, 200]), tensor(0.3962))
The output layer is now standard normal with the variance we set.
-1), bins=50);
plt.hist(mlp.output_logits.view( mlp.output_logits.shape, mlp.output_logits.std()
(torch.Size([32, 38]), tensor(0.1852))
The predicted probabilities are much more uniform.
=1), cmap='gray') plt.matshow(mlp.output_logits.softmax(axis
<matplotlib.image.AxesImage at 0x7f37595b3910>
= train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1) losses, val_losses
The loss doesn’t have as sharp an initial drop-off and reaches a lower value (before it was ~2.6)
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
fix_init_loss
f'{fix_init_loss / baseline_loss:0.2%} of baseline loss' fix_init_loss,
(2.4480857849121094, '93.47% of baseline loss')
We can run samples but they’re not qualitatively better than before:
for _ in range(20):
print(sample(mlp))
cata
sokemolos
clsybeafwensterthewlukening
therez
breadrycoonn
kritefactreamancoonscotterry
balarcare
coreporn
clang
neleng
cretages
reddits
comporn
curans
ristoryardsteton
nongoreddit
terbateas
millippertoockpharnothenesek
acepapertrole
deundumpstbroknv
Batch Norm
Another way to control the distribution of the pre-activations is to rescale them to be in that distribution. The challenge here is that we need to estimate the distribution of the weights somehow. Batch Norm does this by calculating the statistics across a batch, which works both to normalise and regularise, but coupling examples across a batch makes the process more complicated and error prone.
To implement it we need to add:
- Learnable shift and scale parameters
- Fixed eps
- In the forward
- normalise based on batch statistics
- rescale with shift and scale parameters
Let’s also track how the statistics change over time
class MLPBatchNorm(MLP):
def __init__(self, m=default_m, h=default_h, V=V, block_size=block_size, bn_eps=1e-8):
super().__init__(m=m, h=h, V=V, block_size=block_size)
# New stuff
self.bn_scale = torch.ones((1,h))
self.bn_shift = torch.zeros((1,h))
self.bn_eps = bn_eps
# Track statistics for debugging
self.bn_means = []
self.bn_vars = []
def parameters(self):
return super().parameters() + [self.bn_scale, self.bn_shift]
def forward(self, X, bn_mean=None, bn_var=None):
self.embeddings = self.C[X]
self.hidden_layer = self.embeddings.view(X.shape[0], self.block_size * self.m) @ self.H + self.d
# New stuff; allow passing in a batch norm mean and variance for debugging
= self.hidden_layer.mean(dim=0, keepdim=True) if bn_mean is None else bn_mean
μ 2 = self.hidden_layer.var(dim=0, keepdim=True) if bn_var is None else bn_var
σself.hidden_bn = self.bn_scale * (self.hidden_layer - μ) / (σ2 + self.bn_eps) ** 0.5 + self.bn_shift
self.hidden_activations = torch.tanh(self.hidden_bn)
# Track statistics only in training (not validation) for debugging
if self.training:
self.bn_means.append(μ.detach())
self.bn_vars.append(σ2.detach())
# End new stuff
self.output_logits = self.hidden_activations @ self.U + self.b
return self.output_logits
= fix_init(MLPBatchNorm()).requires_grad_() mlp
= train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1) losses, val_losses
Training this we get a similar loss value as before:
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
bn_loss
f'{bn_loss / fix_init_loss:0.2%} of fix_init loss' bn_loss,
(2.4308383464813232, '99.30% of fix_init loss')
But there’s a problem, we can no longer make predictions on a single example because we normalise it away when calculating statistics:
1]) mlp(X_val[:
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
grad_fn=<AddBackward0>)
One way to handle this is to use the batch norm mean and variation on the training set
with torch.no_grad():
= mlp(X)
preds = mlp.hidden_layer
preact = preact.mean(dim=0)
bn_mean = preact.var(dim=0) bn_var
For each one of the 200 dimensions we have a mean and a standard deviation
range(len(bn_mean)), bn_mean); plt.bar(
range(len(bn_mean)), bn_var); plt.bar(
We can use these population statistics as the batch norm mean and variance. This also means that the predictions will be independent of the other items in the batch.
1], bn_mean, bn_var) mlp.forward(X_val[:
tensor([[-2.2420, -4.6427, -2.1474, -1.2763, -2.2146, -2.6077, -3.2863, -3.3461,
-3.4700, -3.0928, -2.5705, -2.5806, 2.0783, 2.1519, 3.0222, 1.6728,
1.3640, 2.1493, 1.9894, 1.2325, 1.4388, 0.6066, 0.8470, 1.6142,
2.5117, 1.8061, 0.8732, 1.9734, -1.2449, 1.3586, 2.1918, 2.8410,
1.2704, 0.9921, 1.4396, -0.9299, -0.2703, -0.8756]],
grad_fn=<AddBackward0>)
We get essentially the same loss here as well (but if there were distribution shift between the training and validation sets then there could be a substantial difference using the validation statistics, which may not be possible in an online setting).
with torch.no_grad():
= mlp.forward(X_val, bn_mean, bn_var)
preds_val = F.cross_entropy(preds_val, y_val).item()
val_loss
f'{val_loss / bn_loss:0.2%} of bn_loss loss' val_loss,
(2.4395534992218018, '100.36% of bn_loss loss')
Running statistics
Calculating the training statistics after is an additional step, could we estimate them on the fly?
Let’s looks at the batch norm means
= torch.concat(mlp.bn_means)
train_bn_means = torch.concat(mlp.bn_vars)
train_bn_vars train_bn_means.shape
torch.Size([50000, 200])
Let’s look at a single dimension of the hidden layer. Each point is the average value of the output over a single batch, it changes over time because of:
- random variations between batches
- changes in the parameters of the embeddings and hidden layer
We can see here that the value and variance changes over time (the final training set mean, m_inf
, is in red)
= 0
idx = train_bn_means[:,0]
m = bn_mean[0]
m_inf
plt.plot(m)0, len(m), color='r'); plt.hlines(m_inf,
This means if we took the simple average of points we severely mis-estimate the final value because it’s moving.
m.mean(), m_inf
(tensor(-0.5176), tensor(-0.9794))
Similarly if we just took the last value it may mis-estimate because of the variance
-1], m_inf m[
(tensor(-0.6533), tensor(-0.9794))
One option between the two extremes is to use an Exponential Moving Average to track the value over time:
= 0.1
momentum = []
m_ema = m[0]
ema for m_i in m:
= ema * (1 - momentum) + m_i * momentum
ema
m_ema.append(ema)
plt.plot(m)
plt.plot(m_ema)0, len(m), color='r')
plt.hlines(m_inf,
ema.item(), m_inf.item()
(-0.978546679019928, -0.9794382452964783)
The one hyperparameter we have to tune is the momentum; how much of the previous value do we keep in each step.
- A momentum that is too high means we will get too much variance (momentum = 1 gives the last value)
- A momentum that is too low will not respond quickly enough to the changes of parameters (too much bias)
= 0.00001
momentum = []
m_ema = m[0]
ema for m_i in m:
= ema * (1 - momentum) + m_i * momentum
ema
m_ema.append(ema)
plt.plot(m)
plt.plot(m_ema)0, len(m), color='r')
plt.hlines(m_inf,
-1].item(), m_inf.item() m_ema[
(-0.43218255043029785, -0.9794382452964783)
The way this works is we get an exponentially decaying weight on old values; with a momentum \(\alpha\), the exponential moving average estimates \(y_i\) are given recursively as the weighted interpolation of the last estimate and the next value: $ y_i = (1-) y_{i-1} + x_i$ and so:
\[\begin{eqnarray} y_0 &=& x_0 \\ y_1 &=& \alpha x_1 + (1 - \alpha) y_0 \\ &=& \alpha x_1 + (1 - \alpha) x_0 \\ y_2 &=& \alpha x_2 + (1 - \alpha) y_1 \\ &=& \alpha x_2 + (1 - \alpha) \alpha x_1 + (1- \alpha) ^2 x_0 \\ &\vdots& \\ y_n &=& \alpha x_n + (1 - \alpha) y_{n-1} \\ &=& \alpha x_n + (1 - \alpha) \alpha x_{n-1} + (1 - \alpha)^2 \alpha x_{n-2} + \cdots + (1-\alpha)^{n-1} \alpha x_1 + (1- \alpha)^n x_0 \end{eqnarray}\]
The weights sum to 1 by the geometric series:
\[ 1 + (1 - \alpha) + (1-\alpha)^2 + \ldots + (1-\alpha)^{n-1} = \frac{1 - (1 - \alpha)^n}{\alpha}\]
For large enough \(n\) we can ignore the last term and the terms approximately sum to 1
sum((1-momentum) ** torch.arange(len(m) - 1, -1, -1)) * momentum
tensor(0.3933)
We can then calcuate the exponential moving average quickly using the formula:
\[ y_n = \alpha \left((1-\alpha)^0 x_n + (1-\alpha)^1 x_{n-1} + \cdots + (1-\alpha)^n x_0 \right) + (1-\alpha)^{n+1} x_0\]
def fastema(z, momentum):
= (1-momentum) ** torch.arange(z.shape[-1] - 1, -1, -1)
weights return momentum * (z * weights).sum(axis=-1) + (1 - momentum)**(len(weights)) * z[...,0]
Check this gives the right answer
= torch.rand(5)
z
= 0.1
momentum = z[0]
ema for m_i in z:
= ema * (1 - momentum) + m_i * momentum
ema
ema.item(), fastema(z, momentum).item(), torch.allclose(ema, fastema(z, momentum))
(0.5154919624328613, 0.5154920220375061, True)
It even works in two dimensions
= torch.rand(5)
z2 = torch.stack([z, z2])
zz zz.shape
torch.Size([2, 5])
Giving the same result over two dimensions
fastema(zz, momentum), fastema(z, momentum), fastema(z2, momentum)
(tensor([0.5155, 0.5326]), tensor(0.5155), tensor(0.5326))
We can now compare how good the estimate is for different values of momentum; in this case it’s best around \([10^{-3}, 10^{-2}]\)
= torch.arange(-5, 0, 0.1)
momentums_log10 = 10**momentums_log10
momentums = [fastema(train_bn_means.transpose(0, 1), m) for m in momentums]
emas
= [((ema - bn_mean)**2).mean()**0.5 for ema in emas]
emas_rms_error
; plt.plot(momentums_log10, emas_rms_error)
We get a similar result for the variances
= [fastema(train_bn_vars.transpose(0, 1), m) for m in momentums]
emas
= [((ema - bn_var)**2).mean()**0.5 for ema in emas]
emas_rms_error
; plt.plot(momentums_log10, emas_rms_error)
We can also look at the errors
= fastema(train_bn_means.transpose(0, 1), momentum=0.001)
ema - bn_mean, bins=20);
plt.hist(ema - bn_mean)**2).mean() ** 0.5 ((ema
tensor(0.0327)
Note that the optimum momentum will be a factor of the batch size, as batch size increases:
- the variance within each step will decrease
- the number of steps in an epoch will decrease
- optimum momentum will increase
= fix_init(MLPBatchNorm()).requires_grad_()
mlp = train(mlp, n_step=50_000//10, lr=lambda step, n_step: 0.1, batch_size=batch_size*10)
losses, val_losses
; plt.plot(val_loss_step, val_loss_value)
= torch.concat(mlp.bn_means)
train_bn_means = torch.concat(mlp.bn_vars)
train_bn_vars
with torch.no_grad():
= mlp(X)
preds = mlp.hidden_layer
preact = preact.mean(dim=0)
bn_mean = preact.var(dim=0) bn_var
We can see that the optimum momentum estimators get higher
= torch.arange(-5, 0, 0.1)
momentums_log10 = 10**momentums_log10
momentums = [fastema(train_bn_means.transpose(0, 1), m) for m in momentums]
emas
= [((ema - bn_mean)**2).mean()**0.5 for ema in emas]
emas_rms_error
; plt.plot(momentums_log10, emas_rms_error)
= [fastema(train_bn_vars.transpose(0, 1), m) for m in momentums]
emas
= [((ema - bn_var)**2).mean()**0.5 for ema in emas]
emas_rms_error
; plt.plot(momentums_log10, emas_rms_error)
Running statistics during training
We can now wrap this in our MLP:
- at training time collect the running statistics
- at inference time use the running statistics
Note that this only requires 2 extra variables per hidden dimension (1 for mean and 1 for variance).
class MLPBatchNorm(MLP):
def __init__(self, m=default_m, h=default_h, V=V, block_size=block_size,
=1e-8, bn_momentum = 0.001):
bn_epssuper().__init__(m=m, h=h, V=V, block_size=block_size)
self.bn_scale = torch.ones((1,h))
self.bn_shift = torch.zeros((1,h))
self.bn_eps = bn_eps
# New stuff
self.training = True
self.bn_runvar = torch.ones((1,h))
self.bn_runmean = torch.zeros((1,h))
self.bn_momentum = bn_momentum
def parameters(self):
return super().parameters() + [self.bn_scale, self.bn_shift]
def forward(self, X):
self.embeddings = self.C[X]
self.hidden_layer = self.embeddings.view(X.shape[0], self.block_size * self.m) @ self.H + self.d
if self.training:
# Estimate batch mean and variance
= self.hidden_layer.mean(dim=0, keepdim=True)
μ 2 = self.hidden_layer.var(dim=0, keepdim=True)
σ
# Update running totals
with torch.no_grad():
self.bn_runmean = (1 - self.bn_momentum) * self.bn_runmean + \
self.bn_momentum * μ
self.bn_runvar = (1 - self.bn_momentum) * self.bn_runvar + \
self.bn_momentum * σ2
else:
= self.bn_runmean
μ 2 = self.bn_runvar
σ
self.hidden_bn = self.bn_scale * (self.hidden_layer - μ) / (σ2 + self.bn_eps) ** 0.5 + self.bn_shift
self.hidden_activations = torch.tanh(self.hidden_bn)
self.output_logits = self.hidden_activations @ self.U + self.b
return self.output_logits
This gives a similar loss as before
= fix_init(MLPBatchNorm()).requires_grad_()
mlp = train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1)
losses, val_losses
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
bn_run_loss
f'{bn_run_loss / bn_loss:0.2%} of batch norm loss' bn_run_loss,
(2.427150011062622, '99.85% of batch norm loss')
But now we can evaluate it on single examples
1]) mlp(X_val[:
tensor([[-1.8919, -4.8599, -1.5227, -1.5331, -1.6830, -2.4410, -3.0484, -2.9386,
-3.2406, -3.1063, -2.3455, -3.1913, 2.2928, 1.8896, 2.5483, 2.3309,
1.4567, 2.0108, 0.8814, 1.8238, 0.9417, 0.6373, 0.9153, 1.5311,
2.3831, 1.3458, 1.2359, 2.2754, -0.9118, 1.6478, 2.0318, 2.5369,
0.3390, 0.4008, 1.6127, -1.4685, -0.6556, -0.8556]],
grad_fn=<AddBackward0>)
And in evaluation mode the resutls are independent of the batch size:
1]), mlp(X_val[:100])[:1]) torch.allclose(mlp(X_val[:
True
Let’s check our running statistics are similar to calculating them after the fact.
with torch.no_grad():
= mlp(X)
preds = mlp.hidden_layer
preact = preact.mean(dim=0, keepdim=True)
bn_mean = preact.var(dim=0, keepdim=True) bn_var
They’re mostly similar, though some are substantially different.
- mlp.bn_runmean, bins=20);
plt.hist(bn_mean
- mlp.bn_runmean)**2).mean() ** 0.5 ((bn_mean
tensor(0.0261)
- bn_var, bins=20)
plt.hist(mlp.bn_runvar
- mlp.bn_runvar)**2).mean() ** 0.5 ((bn_var
tensor(0.1085)
Pytorchifying
With Batchnorm it’s getting hard to maintain all this spaghetti code, so let’s make it more modular like PyTorch.
We’ll start off with a simple Module class that’s a simple version of PyTorch’s nn.Module
class MyModule:
def __init__(self):
self.training = True
self._parameters = []
def train(self, mode=True):
self.training = mode
return self
def parameters(self):
return self._parameters
def requires_grad_(self, requires_grad=True):
for p in self.parameters():
p.requires_grad_(requires_grad)return self
def zero_grad(self):
for p in self.parameters():
= None
p.grad return self
def forward(self, X):
raise NotImplemented()
def __call__(self, X):
return self.forward(X)
def __repr__(self):
return f"{type(self).__name__}"
Linear Layer
Then for our MLP we’ll need a Linear layer, and we’ll copy their initialisation:
\[ w \sim U\left(-1/\sqrt{\rm in\_features}, 1/\sqrt{\rm in\_features}\right) \]
Note this is missing the \(\sqrt{3}\) from being uniform, and any activation specific gain.
Torch doesn’t have a handy way of building a uniform distribution, so we will roll our own:
def rand_unif(shape, min_val, max_val):
return torch.rand(shape) * (max_val - min_val) + min_val
10_000,), -1, 2), bins=20); plt.hist(rand_unif((
class MyLinear(MyModule):
def __init__(self, in_features: int, out_features: int,
bool = True):
bias: super().__init__()
= 1/(in_features)**0.5
scale
self.weight = rand_unif((in_features, out_features), -scale, scale)
self._parameters = [self.weight]
if bias:
self.bias = rand_unif(out_features, -scale, scale)
self._parameters.append(self.bias)
def __repr__(self):
return f"{type(self).__name__}({self.weight.shape[0]}, {self.weight.shape[1]}, bias={hasattr(self, 'bias')})"
def forward(self, X):
self.out = X @ self.weight
if hasattr(self, "bias"):
self.out += self.bias
return self.out
We can create a linear layer, and it has the appropriate mean and standard deviation
= MyLinear(100, 200)
linear linear
MyLinear(100, 200, bias=True)
10 * (3)**0.5 * linear.weight.std() linear.weight.mean(),
(tensor(0.0002), tensor(0.9963))
As do the biases
10 * (3)**0.5 * linear.bias.std() linear.bias.mean(),
(tensor(-0.0005), tensor(1.0303))
And it has the required parameters
for p in linear.parameters()] [p.shape
[torch.Size([100, 200]), torch.Size([200])]
And it converts a batch of 100 dimensional tensor into a batch of 200 dimensional tensors
32, 100)).shape linear(torch.randn(
torch.Size([32, 200])
= MyLinear(200, 100, bias=False)
linear_nobias for p in linear_nobias.parameters()] [p.shape
[torch.Size([200, 100])]
Embedding Layer
We can similarly abstract the embedding layer, ala torch.nn.embedding, which is much simpler:
class MyEmbedding(MyModule):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.weight = torch.randn(size=(num_embeddings, embedding_dim))
self._parameters = [self.weight]
def __repr__(self):
return f"{type(self).__name__}{tuple(self.weight.shape)}"
def forward(self, X):
self.out = self.weight[X]
return self.out
2,3) MyEmbedding(
MyEmbedding(2, 3)
Batch Norm
We can similarly implement BatchNorm1d
, which is more complex, but hides all the state inside the object which makes for a cleaner abstraction.
class MyBatchNorm1d(MyModule):
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.weight = torch.ones(num_features)
self.bias = torch.zeros(num_features)
if affine:
self._parameters = [self.weight, self.bias]
else:
self._parameters = []
self.track_running_stats = track_running_stats
if track_running_stats:
self.running_mean = torch.zeros(1, num_features)
self.running_var = torch.ones(1, num_features)
else:
self.running_mean = None
self.running_var = None
def __repr__(self):
return f"{type(self).__name__}({self.num_features}, eps={self.eps}, affine={self.affine})"
def forward(self, X):
if self.training:
= X.mean(dim=0, keepdim=True)
batch_mean = X.var(dim=0, keepdim=True, correction=0)
batch_var
if self.track_running_stats:
with torch.no_grad():
self.running_mean *= 1 - self.momentum
self.running_mean += self.momentum * batch_mean.view(-1)
self.running_var *= 1 - self.momentum
# Following documentation in Pytorch BatchNorm1D
self.running_var += self.momentum * X.var(dim=0, keepdim=True, correction=1)
else:
= self.running_mean
batch_mean = self.running_var
batch_var
self.out = self.weight * (X - batch_mean) / batch_var + self.bias
return self.out
5) MyBatchNorm1d(
MyBatchNorm1d(5, eps=1e-05, affine=True)
MLP
We now have most of the pieces, we just need to add a few more to create our MLP.
Firstly we will need our activations such as nn.Tanh
:
class MyTanh(MyModule):
def forward(self, X):
self.out = torch.tanh(X)
return self.out
And a way to flatten our embeddings from each token (up to block_size
) into a single tensor:
class MyFlatten(MyModule):
def __init__(self, start_dim=1, end_dim=-1):
super().__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, X):
return X.flatten(self.start_dim, self.end_dim)
= MyEmbedding(V, default_m)(X[:32])
embedding embedding.shape
torch.Size([32, 3, 30])
MyFlatten()(embedding).shape
torch.Size([32, 90])
And then we just need to stack them together in a Sequential sequence of layers:
class MySequential(MyModule):
def __init__(self, *args):
super().__init__()
self.layers = args
self._parameters = [params for layer in self.layers for params in layer.parameters()]
def __repr__(self):
return f"{type(self).__name__}({self.layers})"
def __getitem__(self, idx):
return self.layers[idx]
def forward(self, X):
= X
result for layer in self.layers:
= layer(result)
result return result
We can then build an MLP for a given embedding dimension, and set of hidden dimensions:
def get_mlp(m=default_m, hs=(default_h,), batch_norm=False, bias=False, V=V, block_size=block_size, activation_factory=lambda: MyTanh()):
# First we embed the vectors and then flatten them
= [MyEmbedding(V, m), MyFlatten()]
layers
# Then add the hidden layers
= [block_size * m] + list(hs)
in_sizes = list(hs) + [V]
out_sizes for h_in, h_out in zip(in_sizes, out_sizes):
=bias))
layers.append(MyLinear(h_in, h_out, biasif batch_norm:
=h_out))
layers.append(MyBatchNorm1d(num_features
layers.append(activation_factory())# Drop the last activation, since this is passed to Softmax
layers.pop()
return MySequential(*layers)
= get_mlp(bias=True).requires_grad_()
mlp mlp
MySequential((MyEmbedding(38, 30), MyFlatten, MyLinear(90, 200, bias=True), MyTanh, MyLinear(200, 38, bias=True)))
We can train this to get a similar results as before with fixed initialisation
= train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1)
losses, val_losses
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
mymlp_loss
f'{mymlp_loss / fix_init_loss:0.2%} of fixed init loss' mymlp_loss,
(2.4358863830566406, '99.50% of fixed init loss')
Going deeper
Now we have the framework for building deeper MLPs let’s try to train and analyse some
Without correction
Let’s start with a plain 5 layer MLP without gain correction:
= get_mlp(bias=True, m=10, hs=[100]*5).requires_grad_()
mlp mlp
MySequential((MyEmbedding(38, 10), MyFlatten, MyLinear(30, 100, bias=True), MyTanh, MyLinear(100, 100, bias=True), MyTanh, MyLinear(100, 100, bias=True), MyTanh, MyLinear(100, 100, bias=True), MyTanh, MyLinear(100, 100, bias=True), MyTanh, MyLinear(100, 38, bias=True)))
with torch.no_grad():
print('Initial Loss: %0.2f' % F.cross_entropy(mlp(X[:1000]), y[:1000]).item())
Initial Loss: 3.63
Let’s look at the initial activations
def show_layer(i, t):
print(f'layer {i} ({layer}): mean {t.mean():0.2f}, std {t.std():0.2f}, saturated: {((t.abs() > 0.97) * 1.0).mean():0.2%}')
= torch.histogram(t, density=True)
hy, hx -1].detach(), hy.detach())
plt.plot(hx[:f'layer {i} ({layer})') legends.append(
def show_layers(mlp, backward=False, classes=(MyTanh,), saturation_threshold = 0.97, figsize=(20,4), X=X_val[:32], y=y_val[:32]):
= mlp(X)
preds for layer in mlp:
if hasattr(layer, 'out'):
layer.out.retain_grad()= F.cross_entropy(input=preds, target=y)
loss
loss.backward()
=figsize) # width and height of the plot
plt.figure(figsize= []
legends with torch.no_grad():
for i, layer in enumerate(mlp):
if isinstance(layer, classes):
= layer.out
t if backward:
= t.grad
t print(f'layer {i} ({layer}): mean {t.mean():0.2f}, std {t.std():0.2f}, saturated: {((t.abs() > saturation_threshold) * 1.0).mean():0.2%}')
= torch.histogram(t, density=True)
hy, hx -1].detach(), hy.detach())
plt.plot(hx[:f'layer {i} ({layer})')
legends.append(
;
plt.legend(legends)'gradient' if backward else 'activation') + ' distribution')
plt.title((
mlp.zero_grad()
show_layers(mlp)
layer 3 (MyTanh): mean 0.02, std 0.46, saturated: 0.00%
layer 5 (MyTanh): mean -0.00, std 0.27, saturated: 0.00%
layer 7 (MyTanh): mean 0.00, std 0.15, saturated: 0.00%
layer 9 (MyTanh): mean 0.01, std 0.10, saturated: 0.00%
layer 11 (MyTanh): mean 0.01, std 0.08, saturated: 0.00%
And gradients
=True) show_layers(mlp, backward
layer 3 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 5 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 7 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 9 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 11 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
And the weights
def show_weights(mlp, figsize=(20, 4), skip_embedding_layer=True, X=X_val[:32], y=y_val[:32]):
= mlp(X)
preds for layer in mlp:
if hasattr(layer, 'out'):
layer.out.retain_grad()= F.cross_entropy(input=preds, target=y)
loss
loss.backward()
=figsize)
plt.figure(figsize= []
legends for i, p in enumerate(mlp.parameters()):
if skip_embedding_layer and i == 0:
continue
= p.grad
t if p.ndim == 2:
print('weight %10s | mean %+f | std %e | grad:data ratio %e' % (tuple(p.shape), t.mean(), t.std(), t.std() / p.std()))
= torch.histogram(t, density=True)
hy, hx -1].detach(), hy.detach())
plt.plot(hx[:f'{i} {tuple(p.shape)}')
legends.append(
plt.legend(legends)'weights gradient distribution')
plt.title(
mlp.zero_grad()
show_weights(mlp)
weight (30, 100) | mean -0.000015 | std 8.145859e-04 | grad:data ratio 7.634363e-03
weight (100, 100) | mean -0.000002 | std 8.332252e-04 | grad:data ratio 1.447422e-02
weight (100, 100) | mean -0.000002 | std 9.669473e-04 | grad:data ratio 1.674323e-02
weight (100, 100) | mean +0.000016 | std 1.014291e-03 | grad:data ratio 1.749655e-02
weight (100, 100) | mean -0.000012 | std 1.328168e-03 | grad:data ratio 2.303185e-02
weight (100, 38) | mean +0.000000 | std 3.048425e-03 | grad:data ratio 5.262388e-02
Let’s try training it for 1000 steps and see how it changes; we will track the update-to-data ratio during training:
class UDCallback():
def __init__(self):
self.ud = []
def __call__(self, model, lr, step, n_step, **kwargs):
= lr(step, n_step)
lr_value self.ud.append([(lr_value * p.grad.std() / p.data.std()).log10().item() for p in model.parameters()])
def plot(self, figsize=(20, 4)):
= self.ud
ud =figsize)
plt.figure(figsize= []
legends for i,p in enumerate(mlp.parameters()):
if p.ndim == 2:
for j in range(len(ud))])
plt.plot([ud[j][i] 'param %d' % i)
legends.append(0, len(ud)], [-3, -3], 'k') # these ratios should be ~1e-3, indicate on plot
plt.plot([ plt.legend(legends)
= UDCallback()
ud
= train(mlp, n_step=1_000, lr=lambda step, n_step: 0.1, callback=ud)
losses, val_losses
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
The weights initially train at very different rates, but all but the output and embedding layer converge (to a slightly too high rate)
ud.plot()
The activations and gradients have largely sorted themselves out
show_layers(mlp)
layer 3 (MyTanh): mean 0.02, std 0.47, saturated: 0.00%
layer 5 (MyTanh): mean 0.00, std 0.33, saturated: 0.00%
layer 7 (MyTanh): mean -0.01, std 0.30, saturated: 0.00%
layer 9 (MyTanh): mean 0.02, std 0.33, saturated: 0.00%
layer 11 (MyTanh): mean 0.07, std 0.42, saturated: 0.00%
=True) show_layers(mlp, backward
layer 3 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 5 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 7 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 9 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 11 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
And the weight gradients are more uniform
show_weights(mlp)
weight (30, 100) | mean +0.000075 | std 7.344535e-03 | grad:data ratio 6.576340e-02
weight (100, 100) | mean -0.000005 | std 3.532875e-03 | grad:data ratio 5.843860e-02
weight (100, 100) | mean +0.000001 | std 2.672232e-03 | grad:data ratio 4.419371e-02
weight (100, 100) | mean -0.000003 | std 2.625167e-03 | grad:data ratio 4.310565e-02
weight (100, 100) | mean +0.000025 | std 3.583072e-03 | grad:data ratio 5.814888e-02
weight (100, 38) | mean -0.000000 | std 1.289676e-02 | grad:data ratio 1.660140e-01
Even so it successfully trains
= get_mlp(bias=True, m=10, hs=[100]*5).requires_grad_()
mlp
= UDCallback()
ud
= train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud)
losses, val_losses
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
l5_loss
f'{l5_loss / fix_init_loss:0.2%} of fixed init loss' l5_loss,
(2.4107613563537598, '98.48% of fixed init loss')
ud.plot()
show_weights(mlp)
weight (30, 100) | mean -0.000254 | std 1.778142e-02 | grad:data ratio 5.134834e-02
weight (100, 100) | mean +0.000179 | std 1.043383e-02 | grad:data ratio 5.559371e-02
weight (100, 100) | mean -0.000025 | std 1.019966e-02 | grad:data ratio 5.705851e-02
weight (100, 100) | mean +0.000066 | std 1.106223e-02 | grad:data ratio 6.375358e-02
weight (100, 100) | mean +0.000054 | std 1.232356e-02 | grad:data ratio 8.169784e-02
weight (100, 38) | mean -0.000000 | std 2.402255e-02 | grad:data ratio 1.571375e-01
However when we get to 20 layers it fails to train
= get_mlp(bias=True, m=10, hs=[100]*20).requires_grad_()
mlp
= UDCallback()
ud
= train(mlp, n_step=5_000, lr=lambda step, n_step: 0.1, callback=ud, val_step=500)
losses, val_losses = zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
l20_loss
f'{l20_loss / fix_init_loss:0.2%} of fixed init loss' l20_loss,
(3.050510883331299, '124.61% of fixed init loss')
The gradients are spread all over the place
ud.plot()
Fixing Gain
We can fix the gain within layers and reduce the gain in the final layer to get a better initialisation:
def add_gain(mlp, gain=5/3, output_gain=0.1, update_layers=(MyLinear,)):
with torch.no_grad():
for layer in mlp[:-1]:
if isinstance(layer, update_layers):
*= gain
layer.weight -1].weight *= output_gain
mlp[
= get_mlp(bias=True, m=10, hs=[100]*5).requires_grad_()
mlp =0.1) add_gain(mlp, output_gain
with torch.no_grad():
print('Initial Loss: %0.2f' % F.cross_entropy(mlp(X[:1000]), y[:1000]).item())
Initial Loss: 3.65
The layers are more uniform now
show_layers(mlp)
layer 3 (MyTanh): mean -0.01, std 0.62, saturated: 2.72%
layer 5 (MyTanh): mean -0.02, std 0.47, saturated: 0.03%
layer 7 (MyTanh): mean 0.01, std 0.39, saturated: 0.00%
layer 9 (MyTanh): mean 0.01, std 0.34, saturated: 0.00%
layer 11 (MyTanh): mean -0.01, std 0.30, saturated: 0.00%
=True) show_layers(mlp, backward
layer 3 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 5 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 7 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 9 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 11 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
The weights are more uniform except the output layer (because of the multiplication)
show_weights(mlp)
weight (30, 100) | mean +0.000014 | std 3.260477e-04 | grad:data ratio 1.849884e-03
weight (100, 100) | mean +0.000001 | std 3.304202e-04 | grad:data ratio 3.432189e-03
weight (100, 100) | mean -0.000002 | std 3.086994e-04 | grad:data ratio 3.204230e-03
weight (100, 100) | mean +0.000001 | std 3.112402e-04 | grad:data ratio 3.216269e-03
weight (100, 100) | mean -0.000001 | std 3.279222e-04 | grad:data ratio 3.374243e-03
weight (100, 38) | mean -0.000000 | std 9.346377e-03 | grad:data ratio 1.604743e+00
Let’s train it for a little while:
= UDCallback()
ud
= train(mlp, n_step=1_000, lr=lambda step, n_step: 0.1, callback=ud)
losses, val_losses
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
The weight updates are similar to before but the weights move more in lockstep during the initial optimisation period
ud.plot()
The output layer gradients comes down across these iterations
=True) show_weights(mlp, skip_embedding_layer
weight (30, 100) | mean -0.000254 | std 1.086286e-02 | grad:data ratio 6.070493e-02
weight (100, 100) | mean +0.000004 | std 7.106286e-03 | grad:data ratio 7.206615e-02
weight (100, 100) | mean +0.000033 | std 5.773973e-03 | grad:data ratio 5.865905e-02
weight (100, 100) | mean -0.000028 | std 5.451859e-03 | grad:data ratio 5.512737e-02
weight (100, 100) | mean +0.000017 | std 5.011663e-03 | grad:data ratio 5.035378e-02
weight (100, 38) | mean -0.000000 | std 2.294768e-02 | grad:data ratio 3.940181e-01
And the activations and gradients still look good
show_layers(mlp)
layer 3 (MyTanh): mean -0.00, std 0.62, saturated: 2.94%
layer 5 (MyTanh): mean -0.04, std 0.51, saturated: 0.25%
layer 7 (MyTanh): mean -0.00, std 0.49, saturated: 0.34%
layer 9 (MyTanh): mean 0.01, std 0.51, saturated: 0.19%
layer 11 (MyTanh): mean 0.02, std 0.55, saturated: 0.66%
=True) show_layers(mlp, backward
layer 3 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
layer 5 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 7 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 9 (MyTanh): mean -0.00, std 0.00, saturated: 0.00%
layer 11 (MyTanh): mean 0.00, std 0.00, saturated: 0.00%
For 5 layers we get a similar result as before
= get_mlp(bias=True, m=10, hs=[100]*5).requires_grad_()
mlp =0.1)
add_gain(mlp, output_gain
= UDCallback()
ud
= train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud)
losses, val_losses
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
l5_fix_loss
f'{l5_fix_loss / fix_init_loss:0.2%} of fixed init loss' l5_fix_loss,
(2.4011518955230713, '98.08% of fixed init loss')
However at 20 layers we get the loss decreasing to a similar level as our 2 layer MLP
= get_mlp(bias=True, m=10, hs=[100]*20).requires_grad_()
mlp
add_gain(mlp)
= UDCallback()
ud
= train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud, val_step=500)
losses, val_losses = zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
l20_fix_loss
f'{l20_fix_loss / fix_init_loss:0.2%} of fixed init loss' l20_fix_loss,
(2.4759957790374756, '101.14% of fixed init loss')
There is still a bit more variance with the updates across layers in this deeper net, but with the better initialisation it is stable
ud.plot()
Batchnorm
Rather than being careful with initialisation we can use batch norm
= get_mlp(bias=False, m=10, hs=[100]*5, batch_norm=True).requires_grad_()
mlp #add_gain(mlp)
mlp
MySequential((MyEmbedding(38, 10), MyFlatten, MyLinear(30, 100, bias=False), MyBatchNorm1d(100, eps=1e-05, affine=True), MyTanh, MyLinear(100, 100, bias=False), MyBatchNorm1d(100, eps=1e-05, affine=True), MyTanh, MyLinear(100, 100, bias=False), MyBatchNorm1d(100, eps=1e-05, affine=True), MyTanh, MyLinear(100, 100, bias=False), MyBatchNorm1d(100, eps=1e-05, affine=True), MyTanh, MyLinear(100, 100, bias=False), MyBatchNorm1d(100, eps=1e-05, affine=True), MyTanh, MyLinear(100, 38, bias=False), MyBatchNorm1d(38, eps=1e-05, affine=True)))
show_layers(mlp)
layer 4 (MyTanh): mean -0.01, std 0.80, saturated: 32.00%
layer 7 (MyTanh): mean -0.00, std 0.83, saturated: 38.72%
layer 10 (MyTanh): mean -0.00, std 0.82, saturated: 34.44%
layer 13 (MyTanh): mean -0.01, std 0.82, saturated: 34.31%
layer 16 (MyTanh): mean -0.01, std 0.82, saturated: 36.94%
=True) show_layers(mlp, backward
layer 4 (MyTanh): mean 0.00, std 0.04, saturated: 0.00%
layer 7 (MyTanh): mean 0.00, std 0.03, saturated: 0.00%
layer 10 (MyTanh): mean 0.00, std 0.02, saturated: 0.00%
layer 13 (MyTanh): mean -0.00, std 0.01, saturated: 0.00%
layer 16 (MyTanh): mean 0.00, std 0.01, saturated: 0.00%
show_weights(mlp)
weight (30, 100) | mean -0.006068 | std 3.976782e-01 | grad:data ratio 3.796145e+00
weight (100, 100) | mean +0.002900 | std 3.224502e-01 | grad:data ratio 5.585790e+00
weight (100, 100) | mean +0.002294 | std 2.273206e-01 | grad:data ratio 3.929580e+00
weight (100, 100) | mean +0.000491 | std 1.619342e-01 | grad:data ratio 2.798622e+00
weight (100, 100) | mean -0.000863 | std 1.270476e-01 | grad:data ratio 2.212344e+00
weight (100, 38) | mean +0.002701 | std 1.532580e-01 | grad:data ratio 2.657254e+00
= UDCallback()
ud
= train(mlp, n_step=1_000, lr=lambda step, n_step: 0.1, callback=ud)
losses, val_losses
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
The training looks more stable
ud.plot()
The saturation quickly reduces
show_layers(mlp)
layer 4 (MyTanh): mean -0.00, std 0.76, saturated: 20.53%
layer 7 (MyTanh): mean -0.01, std 0.68, saturated: 8.31%
layer 10 (MyTanh): mean 0.01, std 0.74, saturated: 16.50%
layer 13 (MyTanh): mean 0.00, std 0.73, saturated: 12.87%
layer 16 (MyTanh): mean -0.01, std 0.74, saturated: 12.53%
And we get a similar result when we train for a long time
= get_mlp(batch_norm=True, bias=False, m=10, hs=[100]*5).requires_grad_()
mlp =0.1)
add_gain(mlp, output_gain
= UDCallback()
ud
= train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud)
losses, val_losses
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
l5_bn_loss
f'{l5_bn_loss / fix_init_loss:0.2%} of fixed init loss' l5_bn_loss,
(2.4429738521575928, '99.79% of fixed init loss')
ud.plot()
And it successfully trains a 20 layer network too
= get_mlp(batch_norm=True, bias=False, m=10, hs=[100]*20).requires_grad_()
mlp =0.1)
add_gain(mlp, output_gain
= UDCallback()
ud
= train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud)
losses, val_losses
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
l20_bn_loss
f'{l20_bn_loss / fix_init_loss:0.2%} of fixed init loss' l20_bn_loss,
(2.5574026107788086, '104.47% of fixed init loss')
Using PyTorch nn Library
We can easily reuse Torch’s implementations which are likely more efficient
from torch.nn import Module, Linear, Embedding, Tanh, BatchNorm1d, Sequential, Flatten
def get_mlp(m=default_m, hs=(default_h,), batch_norm=False, bias=False, V=V, block_size=block_size, activation_factory=lambda: Tanh()):
# First we embed the vectors and then flatten them
= [Embedding(V, m), Flatten()]
layers
# Then add the hidden layers
= [block_size * m] + list(hs)
in_sizes = list(hs) + [V]
out_sizes for h_in, h_out in zip(in_sizes, out_sizes):
=bias))
layers.append(Linear(h_in, h_out, biasif batch_norm:
=h_out))
layers.append(BatchNorm1d(num_features
layers.append(activation_factory())# Drop the last activation, since this is passed to Softmax
layers.pop()
return Sequential(*layers)
= get_mlp(bias=True, m=10, hs=[100]*5).requires_grad_()
mlp mlp
Sequential(
(0): Embedding(38, 10)
(1): Flatten(start_dim=1, end_dim=-1)
(2): Linear(in_features=30, out_features=100, bias=True)
(3): Tanh()
(4): Linear(in_features=100, out_features=100, bias=True)
(5): Tanh()
(6): Linear(in_features=100, out_features=100, bias=True)
(7): Tanh()
(8): Linear(in_features=100, out_features=100, bias=True)
(9): Tanh()
(10): Linear(in_features=100, out_features=100, bias=True)
(11): Tanh()
(12): Linear(in_features=100, out_features=38, bias=True)
)
However we can’t show the layers the same way as before because PyTorch doesn’t store results in .out
try:
=(Tanh,))
show_layers(mlp, classesexcept AttributeError as e:
print(e)
'Tanh' object has no attribute 'out'
<Figure size 2000x400 with 0 Axes>
However PyTorch provides a mechanism to modify the behaviour of a module without editing the source code: hooks.
We can register a forward hook to capture the output:
def log_output(module, args, output):
= output
module.out
for layer in mlp:
if isinstance(layer, Tanh):
layer.register_forward_hook(log_output)
And then use our functions as before:
=(Tanh,)) show_layers(mlp, classes
layer 3 (Tanh()): mean -0.03, std 0.49, saturated: 0.37%
layer 5 (Tanh()): mean -0.01, std 0.26, saturated: 0.00%
layer 7 (Tanh()): mean -0.01, std 0.16, saturated: 0.00%
layer 9 (Tanh()): mean -0.01, std 0.11, saturated: 0.00%
layer 11 (Tanh()): mean -0.02, std 0.08, saturated: 0.00%
=True, classes=(Tanh,)) show_layers(mlp, backward
layer 3 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 5 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 7 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 9 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 11 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
show_weights(mlp)
weight (100, 30) | mean -0.000007 | std 1.011561e-03 | grad:data ratio 9.519964e-03
weight (100, 100) | mean -0.000002 | std 1.124272e-03 | grad:data ratio 1.947545e-02
weight (100, 100) | mean -0.000002 | std 1.041845e-03 | grad:data ratio 1.801275e-02
weight (100, 100) | mean -0.000002 | std 1.207655e-03 | grad:data ratio 2.074841e-02
weight (100, 100) | mean -0.000012 | std 1.380317e-03 | grad:data ratio 2.384296e-02
weight (38, 100) | mean -0.000000 | std 3.135142e-03 | grad:data ratio 5.340258e-02
We can fix the weights as before, updating Linear (rather than MyLinear) layers:
=0.1, update_layers=(Linear,)) add_gain(mlp, output_gain
=(Tanh,)) show_layers(mlp, classes
layer 3 (Tanh()): mean -0.03, std 0.64, saturated: 5.59%
layer 5 (Tanh()): mean -0.00, std 0.47, saturated: 0.12%
layer 7 (Tanh()): mean -0.01, std 0.39, saturated: 0.00%
layer 9 (Tanh()): mean -0.01, std 0.33, saturated: 0.00%
layer 11 (Tanh()): mean -0.01, std 0.30, saturated: 0.00%
=True, classes=(Tanh,)) show_layers(mlp, backward
layer 3 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 5 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 7 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 9 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 11 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
show_weights(mlp)
weight (100, 30) | mean -0.000002 | std 4.227582e-04 | grad:data ratio 2.387187e-03
weight (100, 100) | mean -0.000001 | std 4.307418e-04 | grad:data ratio 4.476971e-03
weight (100, 100) | mean +0.000000 | std 3.819067e-04 | grad:data ratio 3.961735e-03
weight (100, 100) | mean +0.000001 | std 3.799068e-04 | grad:data ratio 3.916249e-03
weight (100, 100) | mean -0.000004 | std 3.589212e-04 | grad:data ratio 3.719902e-03
weight (38, 100) | mean +0.000000 | std 9.675195e-03 | grad:data ratio 1.648029e+00
And training is exactly the same:
= UDCallback()
ud
= train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud)
losses, val_losses
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
l5_torch_loss
f'{l5_torch_loss / l5_fix_loss:0.2%} of l5 fixed loss' l5_torch_loss,
(2.4006731510162354, '99.98% of l5 fixed loss')
We can plot the update dynamics and it looks similar to before
ud.plot()
Batchnorm
Similarly we can train with batchnorm
= get_mlp(batch_norm=True, bias=False, m=10, hs=[100]*5).requires_grad_()
mlp
= UDCallback()
ud
= train(mlp, n_step=50_000, lr=lambda step, n_step: 0.1, callback=ud)
losses, val_losses
= zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
l5_torch_bn_loss
f'{l5_torch_bn_loss / l5_bn_loss:0.2%} of l5 batchnorm loss' l5_torch_bn_loss,
(2.3883166313171387, '97.76% of l5 batchnorm loss')
Training a 5 layer network
Put it all together let’s see how low we can make the loss
= get_mlp(bias=True, m=default_m, hs=[default_h]*5).requires_grad_()
mlp =(Linear,)) add_gain(mlp, update_layers
Let’s check the initialisation:
= []
hooks
for layer in mlp:
if isinstance(layer, Tanh):
hooks.append(layer.register_forward_hook(log_output))
=(Tanh,)) show_layers(mlp, classes
layer 3 (Tanh()): mean 0.00, std 0.61, saturated: 2.58%
layer 5 (Tanh()): mean -0.00, std 0.46, saturated: 0.02%
layer 7 (Tanh()): mean -0.01, std 0.37, saturated: 0.00%
layer 9 (Tanh()): mean 0.01, std 0.32, saturated: 0.00%
layer 11 (Tanh()): mean -0.00, std 0.28, saturated: 0.00%
=(Tanh,), backward=True) show_layers(mlp, classes
layer 3 (Tanh()): mean -0.00, std 0.00, saturated: 0.00%
layer 5 (Tanh()): mean -0.00, std 0.00, saturated: 0.00%
layer 7 (Tanh()): mean -0.00, std 0.00, saturated: 0.00%
layer 9 (Tanh()): mean 0.00, std 0.00, saturated: 0.00%
layer 11 (Tanh()): mean -0.00, std 0.00, saturated: 0.00%
Remove the hooks so we don’t need to store any unnecessary outputs
for hook in hooks:
hook.remove()
We’ll train it for a lot longer and get a slightly lower loss
= UDCallback()
ud
= train(mlp, n_step=200_000, lr=lambda step, n_step: 0.1, callback=ud)
losses, val_losses = zip(*val_losses)
val_loss_step, val_loss_value
plt.plot(val_loss_step, val_loss_value)
= val_loss_value[-1]
aloss
f'{aloss / fix_init_loss:0.2%} of fixed init loss' aloss,
(2.3479623794555664, '95.91% of fixed init loss')
At the end the training has converged
= [val_loss_value[0]]
val_loss_ema = 0.01
momentum
for v in val_loss_value[1:]:
-1] * (1 - momentum) + v * momentum)
val_loss_ema.append(val_loss_ema[
-200:], val_loss_value[-200:])
plt.plot(val_loss_step[-200:], val_loss_ema[-200:]) plt.plot(val_loss_step[
These are looking slightly more human readable, and some are good, but in general don’t have long range coherence, suggesting that context length is the bottleneck.
for _ in range(20):
print(sample(mlp))
star
worldonteslangendownpolittersing
wegamepiand_ete
bey_babybugramberdenvanceanimedynews
wresiranugue
disabamazmongonpraxgifsonavioloningwael
windian_botlauker
digifs
kawallorestages
mildlanenoumusicess
slustsurveeupjobbyhbosfronthavirthoppromartyle
eding
blianquely
lone
djerkfries
blacebakinds
knerdups
mints
thedomniada
tecons
What next?
We have looked closer at how the weights change during training, and how initialisation and batch norm can keep the weights in a better range during training. We’ve also changed everything into pure PyTorch and added some instrumentation for checking how weights and gradients change over time.
What we haven’t done is substantially reduced the loss; for this we’re likely going to need more than 3 characters of context