import torch
from transformers import AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import matplotlib.pyplot as plt
import torch.nn.functional as F
from einops import rearrange, repeat
Mamba Lanugage Model Inference
Mamba is an alternative kind of neural network model to Transformers inspired by State Space Models. It can train as efficiently as a transformer (thanks to some hardware tricks), scales as well with compute as transformers in language modelling, but can handle very long context because computation is linear in the length of the input, not quadratic like in transformers. These models are related to Recurrent Neural Networks, and at inference time just need to store a relatively small state at each step which makes them much faster to generate text than transformers which need to reprocess the whole input.
The ideas behind Mamba are quite novel and I found working through the code illuminating. The source code covers both efficient training and inference, but is a little opaque to be high performance. The Mamba minimal repository makes it easier to understand efficient parallel training, but doesn’t have the linear time inference code.
This notebook breaks the Mamba model down and builds it up again using elementary PyTorch. The first half of this notebook starts with the official implementation and works through each of the pieces in detail, slowly expanding out the code into PyTorch. The second half then puts the pieces back together to reproduce predictions for inference. The final model is much less efficient and less numerically stable than the original model, but is useful for understanding conceptually how Mamba works.
This is a Jupyter notebook you can download the notebook.
Running Mamba
We start by importing the MambaLMHeadModel from the source repository
For testing we will load the smallest model released (130 million parameter) in CUDA for an NVIDIA GPU
= "state-spaces/mamba-130m"
model_name = 'cuda' device
We can load the model, which uses the GPT tokenizer, from Huggingface:
= AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer = MambaLMHeadModel.from_pretrained(model_name,
model =device,
device=torch.float16) dtype
Given a prompt:
= "So I was made to" prompt
We can convert it into tokens:
= tokenizer(prompt, return_tensors="pt").to(device)
input_tokens = input_tokens['input_ids']
input_ids input_ids
tensor([[2598, 309, 369, 1160, 281]], device='cuda:0')
Which we then use to generate text:
with torch.inference_mode():
= model.generate(input_ids, len(input_tokens[0]) + 11)
preds
for result in tokenizer.batch_decode(preds):
print(result)
So I was made to feel like I was a part of the team. I
This works by passing the input into the model to get the logits
with torch.no_grad():
= model(input_ids).logits.cpu().float()
logits logits.shape
torch.Size([1, 5, 50280])
For each token the softmax of the logit gives the probability of the next token. We can find the probabilities of the next token to occur in the sequence:
= logits[:,-1].softmax(dim=-1)
probs
print(f'Prob.\t ',
"{:10s}".format("Token"),
'\t', "Id")
for idx in torch.argsort(probs[0], descending=True)[:10]:
print(f'{probs[0, idx].item():0.2%}\t',
"{:10s}".format(tokenizer.decode([idx])),
'\t', idx.item())
Prob. Token Id
17.99% feel 1928
5.24% think 1158
5.16% believe 2868
3.33% do 513
2.80% go 564
2.51% sit 1790
2.36% understand 2096
2.08% see 923
2.05% work 789
1.90% take 1379
To expand the sequence this is done in a loop; for Mamba by keeping track of the state we can do this efficiently with minimal recomputation.
Deconstructing the model
The Mamba model consists of:
- A backbone (MixerModel) with
- An embedding model mapping the vocabulary to dense embeddings (768 here)
- Multiple mixer (Mamba) layers, that we will dig more into later
- A final norm_f RMSNorm
- A lanugage modelling head (lm_head) mapping the embeddings back to the vocabulary
model
MambaLMHeadModel(
(backbone): MixerModel(
(embedding): Embedding(50280, 768)
(layers): ModuleList(
(0-23): 24 x Block(
(mixer): Mamba(
(in_proj): Linear(in_features=768, out_features=3072, bias=False)
(conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
(act): SiLU()
(x_proj): Linear(in_features=1536, out_features=80, bias=False)
(dt_proj): Linear(in_features=48, out_features=1536, bias=True)
(out_proj): Linear(in_features=1536, out_features=768, bias=False)
)
(norm): RMSNorm()
)
)
(norm_f): RMSNorm()
)
(lm_head): Linear(in_features=768, out_features=50280, bias=False)
)
We can re-implement the model by:
- Passing the input into the embedding
- Iteratively passing the hidden state and residuals through the mixer layers
- Applying layer norm to the hidden_state + residual
- Using the lm_head to get back to the original state
= model.backbone
backbone
= backbone.embedding
embedding = backbone.layers
mixers = backbone.norm_f
norm_f = model.lm_head
lm_head
from mamba_ssm.ops.triton.layernorm import rms_norm_fn
with torch.inference_mode():
= embedding(input_ids)
hs = None
rs for mixer in mixers:
= mixer(hs, rs)
hs, rs = rms_norm_fn(hs, norm_f.weight, norm_f.bias,
prelogits =norm_f.eps, residual=rs,
eps=False, residual_in_fp32=True)
prenorm= lm_head(prelogits)
logits2
= logits2.cpu().float()
logits2
assert (logits2 == logits).all()
RMS Norm
The rms_norm_fn
is a CUDA fused version of add and norm for efficiency. We can calculate it directly using PyTorch, which gives similar results, although there is a substantial difference of up to 10% for extreme values.
with torch.inference_mode():
= norm_f(hs+rs)
prelogits2
= prelogits[0,-1].cpu()
actual = prelogits2[0,-1].cpu()
target = actual - target
diff
"Difference in fused RMS Norm with PyTorch");
plt.title("Input value")
plt.xlabel("Difference in output")
plt.ylabel(; plt.scatter(actual, diff)
This kind of plot, seeing how much two varaiables differ as a function of the first variable, is useful for getting a picture of the absolute and relative errors.
We’ll use this a lot so let’s wrap it in a function:
def plot_diff(actual, pred,
="Input value",
xlabel="Difference in output",
ylabel=None):
axsif axs is None:
= plt.subplot()
axs
= actual.cpu().detach().float().flatten()
actual = pred.cpu().detach().float().flatten()
pred
if actual.shape != pred.shape:
raise ValueError("Incompatible shapes")
= actual - pred
diff
"Input value")
axs.set_xlabel("Difference in output")
axs.set_ylabel(;
axs.scatter(actual, diff)return axs
def max_diff(actual, pred):
= actual.cpu().detach().float().flatten()
actual = pred.cpu().detach().float().flatten()
pred = actual - pred
diff return diff.abs().max().item()
plot_diff(actual, target) max_diff(actual, target)
0.00704193115234375
Even though the logits were different there is only a very small difference in probabilities:
with torch.inference_mode():
= lm_head(prelogits2.half())[:,-1].softmax(dim=-1).cpu()
probs2
plot_diff(probs, probs2) max_diff(probs, probs2)
1.1391937732696533e-05
RMSNorm is essentially just LayerNorm with no shift:
\[ \operatorname{RMSNorm}(x_i) = \frac{x_i}{\sqrt{\frac{1}{N}\sum_{j=1}^{N} x_j^2 + \epsilon}} g_i \]
This is straightforward to implement directly
def rms_norm(x, weight, eps=1e-5):
= (x.pow(2).mean(axis=-1, keepdim=True) + eps) ** -0.5
inv_rms return x * inv_rms * weight
with torch.inference_mode():
= rms_norm(hs + rs, weight=norm_f.weight)
prelogits3
assert torch.allclose(prelogits3, prelogits2)
Mixer Blocks
The Mixer Layer, along with its component S6 block, is the fundamental contribution of the Mamba paper.
From Section 3.4:
This architecture involves expanding the model dimension \(D\) by a controllable expansion factor \(E\). For each block, most of the parameters (\(3ED^2\)) are in the linear projections (\(2ED^2\) for input projections, \(ED^2\) for output projection) while the inner SSM contributes less. The number of SSM parameters (projections for \(\Delta, \mathbf{B}, \mathbf{C}\), and the matrix \(\mathbf{A}\)) are much smaller in comparison. We repeat this block, interleaved with standard normalization and residual connections, to form the Mamba architecture.
In our model we have model dimension \(D=768\), expansion factor \(E=2\) and SSM state dimension \(N=16\). Each mixer block then consists of :
- in_proj (the bottom trapezoids): The projections from \(D\) to \(2ED\) for the combined hidden and residual state
- conv1d: One dimensional convolutions (with bias) for each of the \(2ED\) layers, here with a kernel size of 4
- act: The activation for both the hidden state and the residual
- x_proj: Mappings for x onto \(B\) (dimension \(N\)), \(C\) (dimension \(N\)) and \(\Delta\) (here dimension 48)
- out_proj (the top trapezoid): Projection from \(ED\) back to \(D\)
- norm: Pre-Norm applied at the start of the layer
0] mixers[
Block(
(mixer): Mamba(
(in_proj): Linear(in_features=768, out_features=3072, bias=False)
(conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
(act): SiLU()
(x_proj): Linear(in_features=1536, out_features=80, bias=False)
(dt_proj): Linear(in_features=48, out_features=1536, bias=True)
(out_proj): Linear(in_features=1536, out_features=768, bias=False)
)
(norm): RMSNorm()
)
We can step through a couple of layers to deconstruct what is going on:
with torch.inference_mode():
= embedding(input_ids), None
hs0, rs0 = mixers[0](hs0, rs0)
hs1, rs1 = mixers[1](hs1, rs1)
hs2, rs2 #...
Pre-norm
Before going into the mixer block the input, made up of the hidden states and residual, are normalised with RMSNorm
with torch.no_grad():
= rms_norm_fn(
hidden_states, residual
hs1,1].norm.weight,
mixers[=None,
bias=rs1,
residual=True,
prenorm=True,
residual_in_fp32=1e-5
eps
)
assert (hs2 == mixers[1].mixer(hidden_states)).all()
assert (rs2 == residual).all()
As before this is a fused operation and can get a similar result with the function rms_norm
implemented previously:
with torch.inference_mode():
= (hs1 + rs1).float()
residual2 = rms_norm(residual2, mixers[1].norm.weight)
hidden_states2
assert (residual2 == residual).all()
plot_diff(hidden_states, hidden_states2) max_diff(hidden_states, hidden_states2)
0.00048744678497314453
When passed through the mixer this gives a very similar result:
with torch.no_grad():
= mixers[1].mixer(hidden_states2.half())
hs2_
assert torch.allclose(hs2, hs2_)
Next we can step into the actual mixer
= mixers[1].mixer
mixer mixer
Mamba(
(in_proj): Linear(in_features=768, out_features=3072, bias=False)
(conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
(act): SiLU()
(x_proj): Linear(in_features=1536, out_features=80, bias=False)
(dt_proj): Linear(in_features=48, out_features=1536, bias=True)
(out_proj): Linear(in_features=1536, out_features=768, bias=False)
)
Input projection and block
The normed hidden states are then projected into two expanded inputs. The in_proj
takes hidden states of size \((B, L, D)\) and to produce outputs of size \((B, L, 2ED)\):
with torch.inference_mode():
= mixer.in_proj(hidden_states)
xz
hidden_states.shape, xz.shape
(torch.Size([1, 5, 768]), torch.Size([1, 5, 3072]))
This is then passed through a fast CUDA implementation of the Mamba layer:
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn
with torch.no_grad():
= mamba_inner_fn(rearrange(xz, "b l d -> b d l"),
hs2_
mixer.conv1d.weight,
mixer.conv1d.bias,
mixer.x_proj.weight,
mixer.dt_proj.weight,
mixer.out_proj.weight,
mixer.out_proj.bias,-torch.exp(mixer.A_log.float()) ,
None, # input-dependent B
None, # input-dependent C
float(),
mixer.D.=mixer.dt_proj.bias.float(),
delta_bias=True,
delta_softplus
)
assert (hs2_ == hs2).all()
But this isn’t very illuminating; let’s break this mamba_inner_fn
down further.
Mamba inner function
First we can separate the pre_x
and z
vectors (each of the two trapezoids at the bottom of the Mamba block diagram)
with torch.inference_mode():
= xz.chunk(2, dim=-1)
pre_x, z pre_x.shape, z.shape
(torch.Size([1, 5, 1536]), torch.Size([1, 5, 1536]))
Next we want to perform the convolution on the left input (discarding the extra elements due to padding at the end), and apply the activation
with torch.inference_mode():
= rearrange(
x
mixer.act(
mixer.conv1d('b l d -> b d l')
rearrange(pre_x, 1)]),
)[...,:pre_x.size('b d l -> b l d')
x.shape
torch.Size([1, 5, 1536])
SiLU Activation
Note that the activation function is SiLU which is like a smoothed ReLU
mixer.act
SiLU()
= torch.arange(-10,10,0.01)
grid ; plt.plot(grid, mixer.act(grid))
This is defined as
\[ \operatorname{SiLU}(x) = \frac{x}{1+\exp(-x)} \]
And we can get similar results with a direct implementation:
def mysilu(x):
return x/(1+torch.exp(-x))
plot_diff(mixer.act(grid), mysilu(grid))
<Axes: xlabel='Input value', ylabel='Difference in output'>
Applying SSM
Next the SSM is applied to x. First the data-dependent transition matrices are calculated:
\[\begin{align} B &= s_B(x) \\ C &= s_C(x) \\ \Delta &= \tau_\Delta({\rm Parameter} + s_\Delta(x)) \\ \end{align}\]
Where here all the \(s\) are linear and combined in the x_proj
layer:
with torch.inference_mode():
= torch.split(mixer.x_proj(x), [mixer.dt_rank, mixer.d_state, mixer.d_state], dim=-1)
dt, B, C
dt.shape, B.shape, C.shape
(torch.Size([1, 5, 48]), torch.Size([1, 5, 16]), torch.Size([1, 5, 16]))
These inputs can be put into an efficient CUDA selective scan function to run the SSM (after reshaping everything to (batch, channel, length):
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
with torch.inference_mode():
= selective_scan_fn(
y, ssm_state "b l d -> b d l").contiguous(),
rearrange(x, @ mixer.dt_proj.weight.t(),
rearrange(dt "b l d -> b d l").contiguous(),
-torch.exp(mixer.A_log.float()),
"b l d -> b d l").contiguous(),
rearrange(B, "b l d -> b d l").contiguous(),
rearrange(C, float(),
mixer.D.=rearrange(z, "b l d -> b d l").contiguous(),
z=mixer.dt_proj.bias.float(),
delta_bias=True,
delta_softplus=True,
return_last_state
)
y.shape, ssm_state.shape
(torch.Size([1, 1536, 5]), torch.Size([1, 1536, 16]))
This gives a very similar result to passing it through the model directly:
with torch.inference_mode():
= mixer.out_proj(rearrange(y, "b d l -> b l d"))
out
plot_diff(hs2, out) max_diff(hs2, out)
0.0009765625
Running the State Space Model
The selective scan function runs the SSM, let’s step through how that works.
Softplus
Firstly we need to calcuate \(\Delta = \operatorname{softplus}(s_\Delta(dt) + b)\)
Softplus it’s a smooth approximation to the ReLU:
\[\operatorname{softplus}(x) = \log\left(1+\exp(x)\right) \]
= torch.arange(-10, 10, step=0.01)
grid = F.softplus(grid)
sp = F.relu(grid)
sp2
plt.plot(grid, sp)
plt.plot(grid, sp2)
assert torch.allclose(sp, (1 + grid.exp()).log(), atol=1e-7)
We apply this to dt:
with torch.inference_mode():
= F.softplus(mixer.dt_proj(dt))
delta
delta.shape
torch.Size([1, 5, 1536])
Discretization
The next step is to discretize the SSM by equation 4
\[\begin{align} \overline{A} &= \exp(\Delta A) \\ \overline{B} &= (\Delta A)^{-1} \left(\exp(\Delta A) - I\right) \Delta B\\ \end{align} \]
The paper assumes
Structure and Dimensions. Finally, we note that structured SSMs are so named because computing them efficiently also requires imposing structure on the A matrix. The most popular form of structure is diagonal, which we also use. In this case, the \(A \in \mathbb{R}^{N\times N}\) \(B \in \mathbb{R}^{N \times 1}\), \(B \in \mathbb{R}^{1 \times N}\) matrices can all be represented by 𝑁 numbers. To operate over an input sequence \(x\) of batch size \(B\) and length \(L\) with \(D\) channels, the SSM is applied independently to each channel. Note that in this case, the total hidden state has dimension \(DN\) per input, and computing it over the sequence length requires \(O(BLDN)\) time and memory; this is the root of the fundamental efficiency bottleneck addressed in Section 3.3.
Then the matrix exponential of the diagonal matrix \(A\) reduces to the elementwise exponential, and
\[\begin{align} \overline{B}_i &= \frac{1}{\Delta A_i} \left(\exp(\Delta A_i) - 1\right) \Delta B_i \\ &\approx \frac{1}{\Delta A_i} \left(1 + \Delta A_i + O(\Delta A_i)^2 - 1\right) \Delta B_i \\ &\approx \Delta B_i + O(A_i^2) \end{align}\]
Note that the model stores the negative log of A; this means the values are all less than 0 and the exponential won’t overflow (but many values will underflow to zero), and we expect the error \(\Delta A_i^2\) to be small.
with torch.inference_mode():
= -torch.exp(mixer.A_log.float())
A
A.shape, B.shape
(torch.Size([1536, 16]), torch.Size([1, 5, 16]))
with torch.inference_mode():
= torch.exp(torch.einsum("b l d, d n -> b l d n", delta, A))
A_bar = torch.einsum("b l d, b l n -> b l d n", delta, B) B_bar
With the discrete (data dependent) transition matrices, the State Model discrete transition can be applied:
\[\begin{align} h_{t} &= \overline{A} h_{t-1} + \overline{B} x_t \\ y_{t} &= C h_t + D x_t \\ \end{align}\]
We need to be a bit careful about implicit shapes: this is applied independently across each element of the batch (here size 1) and each channel (here 1536).
print(f'''
{A_bar.shape=}
{B_bar.shape=}
{x.shape=}
{C.shape=}
{mixer.D.shape=}
''')
A_bar.shape=torch.Size([1, 5, 1536, 16])
B_bar.shape=torch.Size([1, 5, 1536, 16])
x.shape=torch.Size([1, 5, 1536])
C.shape=torch.Size([1, 5, 16])
mixer.D.shape=torch.Size([1536])
We can calculate \(h_t\) iteratively for each time step \(t\) and use this to calculate \(y_t\):
= A_bar.shape
b, l, d, n
= torch.zeros(b, d, n, device=A_bar.device)
h = []
states
with torch.inference_mode():
for i in range(l):
# At t=i At t =i At t=i, expand along n (hidden dimension)
= A_bar[:, i] * h + B_bar[:, i] * x[:, i].unsqueeze(-1)
h
states.append(h)
# Combine all states into a single vector
= torch.stack(states, dim=1)
h_t
= torch.einsum("b l n, b l d n -> b l d", C.float(), h_t) + \
y_t 'd, b l d -> b l d', mixer.D, x) torch.einsum(
Then we can combine the output of the SSM by multiplying with the gated MLP as in Figure 3:
with torch.inference_mode():
= rearrange(y_t * mixer.act(z), 'b l d -> b d l') mixer_out
The state we get is similar to what we got from selective_scan_fn
:
plot_diff(ssm_state, h) max_diff(ssm_state, h)
9.191036224365234e-05
Similarly our output is similar, with deviations at very small values of the output (that are relatively small)
plot_diff(y, mixer_out) max_diff(y, mixer_out)
0.0007761716842651367
Interlude: State selection
The Mamba paper argues that Mamba performs much better than S4 because it can selectively ignore or focus on inputs depending on the data. We can look at the transition matrices to understand how this works.
The \(\Delta\) function at each timestep impacts how much the state gets updated in each dimension.
For \(\Delta \approx 0\) then \(\overline{A} = \exp{\Delta A} \approx I\) and \(\overline{B} = \Delta B \approx 0\) and so \(h_t \approx h_{t-1}\) so we ignore the input in that dimension.
On the other hand since for large \(\Delta\), since \(A_i < 0\) then \(\overline{A} \approx 0\) and \(\overline{B} = \Delta B\), and so \(h_t \approx \Delta B x_t\) and so the state is reset.
We can see here that the distribution of \(\Delta\) across hidden dimensions changes over time (here the tokens).
for i in range(delta.size(1)):
0,i].cpu(), bins=50, alpha=0.7,
plt.hist(delta[=f"t={i}, {tokenizer.decode(input_ids[0,i])}")
label"delta")
plt.xlabel("freq")
plt.ylabel(="upper right"); plt.legend(loc
The values of \(A\) indicate how much state to keep for each embedding dimension: very negative values correspond to forgetting state, values close to 0 correspond to keeping state.
Note that the majority of values are very close to 0 (keeping state) but it varies by channel.
= plt.subplots(4, 4, figsize=(15,10))
fig, ax
for i in range(4):
for j in range(4):
*4+j].cpu().exp(), bins=50)
ax[i,j].hist(A[:, i
"Distribution of exp(A), each subplot is different dimension"); fig.suptitle(
B is applied uniformly across channel (and so its impact by channel is controlled by \(\Delta\)), but varies over time by dimension.
In particular when \(B \approx 0\) the input is ignored for that hidden dimension.
= plt.subplots(4, 4, figsize=(15,10))
fig, ax
for i in range(4):
for j in range(4):
range(B.size(1)), B[0, :, 4*i+j].cpu())
ax[i,j].bar(
ax[i,j]
'B over time; each subplot is a different hidden dimension'); fig.suptitle(
Reconstructing the model
Now we’ve taken the model apart top-to-bottom let’s put it back together again bottom-to-top using elementary PyTorch. In particular we will focus on an iterative inference scenario for generating text, rather than the parallel scan for training.
Implementing Discrete SSM
The Discrete State Space model is straightforward to implement, being careful about dimensions:
\[\begin{align} h_0 &= 0 \\ h_{t} &= \overline{A} h_{t-1} + \overline{B} x_{t} \\ y_{t} &= C h_t + D x_t \end{align}\]
def ssm(A_bar, # (B, D, N)
# (B, D, N)
B_bar, # (B, N)
C, # (D)
D, # (B, D)
x, =None, # (B,D,N)
h#-> y=(B, D), h=(B,D,N)
): if h is None:
= torch.zeros_like(A_bar)
h # (B,D,N) * (B,D,N) (B,D,N) * (B,D,1)
= A_bar * h + B_bar * x.unsqueeze(-1)
h # (D) * (B, D)
= torch.einsum("b n, b d n -> b d", C, h) + D * x
y
return y, h
def ssm(A_bar, # (B, D, N)
# (B, D, N)
B_bar, # (B, N)
C, # (D)
D, # (B, D)
x, =None, # (B,D,N)
h#-> y=(B, D), h=(B,D,N)
): if h is None:
= torch.zeros_like(A_bar)
h # (B,D,N) * (B,D,N) (B,D,N) * (B,D,1)
= A_bar * h + B_bar * x.unsqueeze(-1)
h # (B, 1, N) (B,D,N) => (B,D) (D) * (B, D)
= (C.unsqueeze(1) * h).sum(axis=-1) + D * x
y
return y, h
We can then get the same output as before by stepping through the model token by token:
with torch.inference_mode():
= None
h = A_bar.size(1)
l = [None] * l
ys for i in range(l):
= ssm(A_bar[:, i], B_bar[:, i],
ys[i], h float(), mixer.D, x[:, i], h)
C[:, i].
assert torch.allclose(y_t, torch.stack(ys, axis=1))
S6 (S4 with Selection)
We run a modified version of algorithm 2 from the paper, where the time step L=1 is omitted since we advance 1 step at a time:
Input: 𝑥 ∶ (𝙱, 𝙳)
Output: 𝑦 ∶ (𝙱, 𝙳)
1: A ∶ (𝙳, 𝙽) ← 𝖯𝖺𝗋𝖺𝗆𝖾𝗍𝖾𝗋 (Diagonal NxN matrix)
2: B ∶ (𝙱, 𝙽) ← 𝑠_𝐵 (𝑥)
3: C ∶ (𝙱, 𝙽) ← 𝑠_𝐶 (𝑥)
4: D : (1,N) ← 𝖯𝖺𝗋𝖺𝗆𝖾𝗍𝖾𝗋
4: ∆ ∶ (𝙱, 𝙳) ← 𝜏∆(𝖯𝖺𝗋𝖺𝗆𝖾𝗍𝖾𝗋+𝑠∆(𝑥))
5: A_bar, B_bar ∶ (𝙱, 𝙳, 𝙽) ← 𝖽𝗂𝗌𝖼𝗋𝖾𝗍𝗂𝗓𝖾(∆, A, B)
6: 𝑦 ← 𝖲𝖲𝖬(A, B, C, D)(𝑥)
7: return 𝑦
One particular detail is \(s_\Delta\) is produced via a lower rank factorization \(s_\Delta = U_\Delta V_\Delta\), where the two matrices have a lower rank, here 32 times smaller.
We can apply this in a straightforward manner (again being careful with dimensions).
def s6(x, # (B,D)
# (D,N)
A, # (D, N)
s_B, # (D, N)
s_C, # (D, R)
U_Δ, # (R, D)
V_Δ, # (1, D)
bias_Δ, # (D)
D, # Softplus
= lambda x: torch.log(1+torch.exp(x)),
t_Δ = None, # (B,D,N)
h
):
= x @ s_B # (B, N)
B = x @ s_C # (B, N)
C
= t_Δ((x @ U_Δ) @ V_Δ + bias_Δ).unsqueeze(-1) # (B, D, 1)
Δ
# Discretize
#(B,D,1) (D,N)
= torch.exp(Δ * A) # (B,D,N)
A_bar #(B,D,1) (B, 1, N)
= Δ * B.unsqueeze(1) # (B,D,N)
B_bar
# Apply SSM
return ssm(A_bar, B_bar, C.float(), D, x, h)
To run this we need to be able to extract the parameter matrices out from a mixer. The \(U_\Delta, s_B, s_C\) matrices are stored together for computational efficiency, so we separate out the weights for clarity based on dimensions.
def get_s6_params(mixer):
= -torch.exp(mixer.A_log)
A = A.size(-1)
n
return dict(
= A,
A = mixer.x_proj.weight[:-2*n].T,
U_Δ = mixer.x_proj.weight[-2*n:-n].T,
s_B = mixer.x_proj.weight[-n:].T,
s_C = mixer.dt_proj.weight.T,
V_Δ = mixer.dt_proj.bias,
bias_Δ = mixer.D
D )
Then the S6 model can be run token by token, keeping track of the state hs
, to get a similar output y_t
as before.
= None
hs = []
outs with torch.inference_mode():
for x_i in [x[:, i] for i in range(pre_x.size(1))]:
= s6(x_i, **get_s6_params(mixer), h=hs)
out, hs
outs.append(out)
=1)); plot_diff(y_t, torch.stack(outs, axis
Iterative convolution
The convolution looks back the input kernel_length steps, and as we progress iteratively through the sequence we need to store the state of that convolution. If we have those inputs we can pass it through the convolution layer:
with torch.inference_mode():
= rearrange(
conv_output 'b l d -> b d l')),
mixer.conv1d(rearrange(pre_x, 'b d l -> b l d')
The convolution consists of weights (for each dimension, across the kernel) and a bias (for each dimension)
= mixer.conv1d.weight # (D, 1, kernel_size)
conv_weight = mixer.conv1d.bias # (D)
conv_bias
conv_weight.shape, conv_bias.shape
(torch.Size([1536, 1, 4]), torch.Size([1536]))
The convolution works by sliding the weights over the embeddings and performing a multiply and add. It can be calculated by:
- left padding the input with zeros
- for each time step in the convolution
- get the next convolution weight
- shift the input 1 place left
- multiply together and add to the result
- add the convolution bias
= conv_weight.size(-1)
kernel_size
= pre_x.size()
b, l, d = mixer.conv1d.padding[0]
pad_sz
with torch.inference_mode():
# Left pad the input with zeros
= torch.zeros(b, l+pad_sz, d, device=conv_weight.device)
pad_x -l:,:] = pre_x
pad_x[:,
= torch.zeros_like(pad_x)
result # for each time step
for i in range(kernel_size):
= torch.roll(pad_x, shifts=-i, dims=1)
shifted_input # add to result: convolution weight * input shifted left
+= conv_weight[:, 0, i] * shifted_input
result # add the bias
+= conv_bias
result
plot_diff(conv_output, result) max_diff(conv_output, result)
0.0014514923095703125
Note that we drop the padding on the right (since those items aren’t fully convolved)
= result[:, :-pad_sz]
result_unpad result_unpad.shape
torch.Size([1, 5, 1536])
This is how we obtained x from x_pre before, by applying the activation which gives relatively similar results:
; plot_diff(x, mixer.act(result_unpad))
Note that in particular the last value is just obtained by a dot product of the convolution weights with the last kernel_size
(here 4) elements of the input:
with torch.no_grad():
= torch.einsum('b k d, d k -> b d',
convolution_end -kernel_size:,:],
pre_x[:,1)
conv_weight.squeeze(+ conv_bias
)
-1], convolution_end)
plot_diff(result_unpad[:,-1], convolution_end) max_diff(result_unpad[:,
0.0013649463653564453
So if we store the last kernel_size
elements of x, for each dimension and batch item, we can perform the convolution at each step with a simple multiplication:
= torch.zeros(b, kernel_size, d, device=pre_x.device).half()
conv_state
with torch.inference_mode():
= []
conv_output for pre_x_i in [pre_x[:, i] for i in range(pre_x.size(1))]:
# Move the convolution state back
= conv_state.roll(shifts=-1, dims=1)
conv_state -1] = pre_x_i
conv_state[:,
conv_output.append('b k d, d k -> b d',
torch.einsum(
conv_state,1))
conv_weight.squeeze(+ conv_bias)
=1)); plot_diff(result_unpad, torch.stack(conv_output, dim
Mixer
With the iterative convolution we can now go back and implement the Mixer block step by step
def mixer_step(
# (B,D)
x, # (D, E*D)
left_proj, # (D, E*D)
right_proj, # (E*D, K)
conv_weight, # (E*D)
conv_bias, # (E*D,D)
out_proj,
s6_params,=lambda x: x/(1+torch.exp(-x)), # SiLU
σ=None, # (B,E*D,N)
h=None, # (B, E*D, K)
conv_state# -> y, h, conv_state
):
# Left branch
= x @ left_proj #(B, E*D)
left_input
# Initialise convolution state to 0
if conv_state is None:
= x.size(0)
b = conv_weight.shape
ed, k = torch.zeros((b,ed,k),
conv_state =x.dtype,
dtype=x.device,
device=x.requires_grad)
requires_grad
# Conv, σ
= conv_state.roll(shifts=-1, dims=-1)
conv_state -1] = left_input
conv_state[..., = σ( (conv_state * conv_weight).sum(axis=-1) +
conv_output # (B, E*D)
conv_bias)
# SSM
= s6(conv_output, **s6_params, h=h)
ssm_output, h
# Multiplication
= ssm_output * σ(x @ right_proj) # (B, E*D)
block_act
# Output
return block_act.to(out_proj.dtype) @ out_proj, h, conv_state
As before we need to extract the matrices from the Mixer, in particular breaking in_proj
apart for clarity
def get_mixer_params(mixer):
= mixer.in_proj.out_features // 2
ed
return dict(
=mixer.in_proj.weight[:ed].T,
left_proj=mixer.in_proj.weight[ed:].T,
right_proj=mixer.conv1d.weight.squeeze(1),
conv_weight=mixer.conv1d.bias,
conv_bias=mixer.out_proj.weight.T,
out_proj=get_s6_params(mixer),
s6_params )
Running our version gets similar results to running the Mixer directly
= None
h = None
conv_state = []
outputs
with torch.no_grad():
for hs_i in [hidden_states[:, i] for i in range(hidden_states.size(1))]:
= mixer_step(hs_i,
output, h, conv_state **get_mixer_params(mixer),
=h, conv_state=conv_state)
h
outputs.append(output)
=1)); plot_diff(mixer(hidden_states), torch.stack(outputs, axis
For simplicity we can pass the mixer into the function to extract all the parameters for mixer_step:
def my_mixer(mixer, x, h=None, conv_state=None):
return mixer_step(x, **get_mixer_params(mixer), h=h, conv_state=conv_state)
We can then call it as before:
= None
h = None
conv_state = []
outputs
with torch.no_grad():
for hs_i in [hidden_states[:, i] for i in range(hidden_states.size(1))]:
= my_mixer(mixer, hs_i, h=h, conv_state=conv_state)
output, h, conv_state
outputs.append(output)
=1)); plot_diff(mixer(hidden_states), torch.stack(outputs, axis
Connecting the blocks
Recall that before we walked through the mixers, keeping the neural network hidden state (hs) and the unnormed residual (rs):
with torch.inference_mode():
= embedding(input_ids), None
hs0, rs0 = mixers[0](hs0, rs0)
hs1, rs1 = mixers[1](hs1, rs1)
hs2, rs2 #...
We can reproduce this using MyMixer by keeping track of the hidden states and convolutions in each layer:
= []
hs1_ = []
hs2_
with torch.inference_mode():
= None, None
hs1_h, hs1_conv = None, None
hs2_h, hs2_conv # ...
for token in [input_ids[:,i] for i in range(input_ids.size(1))]:
= embedding(token)
hs0_i = torch.zeros_like(hs0_i)
rs0_i
= (hs0_i + rs0_i).float()
rs1_i = rms_norm(rs1_i.half(), mixers[0].norm.weight)
rs1_i_norm = my_mixer(mixers[0].mixer, rs1_i_norm,
hs1_i, hs1_h, hs1_conv =hs1_h, conv_state=hs1_conv)
h
hs1_.append(hs1_i)
= (hs1_i + rs1_i).float()
rs2_i = rms_norm(rs2_i.half(), mixers[1].norm.weight)
rs2_i_norm = my_mixer(mixers[1].mixer, rs2_i_norm,
hs2_i, hs2_h, hs2_conv =hs2_h, conv_state=hs2_conv)
h
hs2_.append(hs2_i)
#...
This gives a relatively similar result to running it directly:
=1))
plot_diff(hs1, torch.stack(hs1_, axis=1)); plot_diff(hs2, torch.stack(hs2_, axis
Going through all the layers
We can write this programatically using the backbone layers:
backbone.layers
ModuleList(
(0-23): 24 x Block(
(mixer): Mamba(
(in_proj): Linear(in_features=768, out_features=3072, bias=False)
(conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
(act): SiLU()
(x_proj): Linear(in_features=1536, out_features=80, bias=False)
(dt_proj): Linear(in_features=48, out_features=1536, bias=True)
(out_proj): Linear(in_features=1536, out_features=768, bias=False)
)
(norm): RMSNorm()
)
)
Our implementation isn’t as numerically stable so let’s convert everything to float (removing the fp16):
= backbone.float() batckbone
We need to keep an SSM state and convolution state for each block separately to proceed token by token.
= [None] * len(backbone.layers)
hs = [None] * len(backbone.layers)
conv_state
with torch.inference_mode():
for token in [input_ids[:,i] for i in range(input_ids.size(1))]:
= embedding(token)
hsi = torch.zeros_like(hsi)
rsi
for i, layer in enumerate(backbone.layers):
= (hsi + rsi)
rsi = rms_norm(rsi, layer.norm.weight)
rsi_norm = my_mixer(layer.mixer, rsi_norm,
hsi, hs[i], conv_state[i] =hs[i], conv_state=conv_state[i]) h
As before we can RMS normalise the last layers output and pass it through the language modelling head to get the logits (which have a small relative difference to the actual values):
= rms_norm(hsi + rsi, weight=norm_f.weight)
my_prelogits = my_prelogits @ lm_head.weight.T
my_logits
-1], my_logits); plot_diff(logits[:,
This gives reasoanbly similar probabilities for common tokens:
= my_logits.softmax(dim=-1)
my_probs
plot_diff(probs, my_probs)
<Axes: xlabel='Input value', ylabel='Difference in output'>
Aside: Size of the state
Note that the each convolution state has size (B, E*D, K)
(where K
is the kernel dimension), and each SSM hidden state has dimension (B, E*D, N)
. So the total number of parameters to track per block is \(BED(N+K)\) which is typically a small fraction of the roughly \(3ED^2\) parameters per block
= conv_state[0].shape
b, ed, k = hs[0].shape
b, ed, k
# E=2 here
= ed/2
d
f'The state is {ed*(n+k) / (3*ed*d):0.1%} of the parameters per item in batch'
'The state is 1.4% of the parameters per item in batch'
We can check the parameters per model really is roughly \(3ED^2\)
sum([len(x.flatten()) for x in backbone.layers[0].parameters()]) / (3*ed*d)
1.0657552083333333
Greedy decoding
Now that we have the state we can iteratively continue token by token to generate the text:
= [my_logits.argmax(dim=-1)]
output_tokens = 11-1
num_tokens
with torch.inference_mode():
for _ in range(num_tokens):
# Embedding
= embedding(output_tokens[-1]).float()
hsi = torch.zeros_like(hsi)
rsi
# Backbone
for i, layer in enumerate(backbone.layers):
= (hsi + rsi)
rsi = rms_norm(rsi, layer.norm.weight)
rsi_norm = my_mixer(layer.mixer, rsi_norm,
hsi, hs[i], conv_state[i] =hs[i], conv_state=conv_state[i])
h
# LM Head
= rms_norm(hsi + rsi, weight=norm_f.weight)
my_prelogits = my_prelogits @ lm_head.weight.T
my_logits
=-1)) output_tokens.append(my_logits.argmax(dim
for result in tokenizer.batch_decode(torch.stack(output_tokens, axis=-1)):
print(result)
feel like I was a part of the team. I
Which generates exactly the same result as the model.
with torch.inference_mode():
= model.generate(input_ids, len(input_tokens[0]) + 11)
preds
for result in tokenizer.batch_decode(preds):
print(result)
So I was made to feel like I was a part of the team. I
Note there is nothing special about greedy decoding here; we could equally use sampling or beam search in a similar way.
What next?
Mamba is a really promising alternative direction to Transformers, especially for very long texts (as an alternative of techniques to increase the size of the context window such as LongRoPE). Given the recent history of Large Language Models there are some obvious things to try with Mamba (such as scaling up the model size and training, instruction tuning, RHLF or DPO), which are straightforward but require compute and work on making the training more stable and adapting parameter efficient techniques such as LoRA. There are also potentially straightforward architecture optimisations; there have been many detailed studies on the transformer achitecture (such as in the T5 paper) that have helped make it better, these remain to be explored for Mamba. Because the model is straightforward and works on many long inputs (such as audio and DNA), there’s also a lot of potential in multimodal models and other applications such as code generation. There are also some more speculative questions I find interesting overcoming the issues I have had using encoder transformers in long documents:
- can this produce good representations for retrieval in large documents (such as dense vector embeddings, or in techniques like SPLADE or ColBERT)?
- how can this be used for span classification for long documents (e.g. extracting relevant text from HTML pages)?
- for document understanding does it make sense to use bidirectional Mamba models (in analogy to bidirectional LSTMs)? How does it compare with BERT?
Mamba also opens new research for understanding and interpreting high performing language models, which have been solely concentrated on transformers. State selection gives an insight into understanding the hidden states, it would be interesting to understand how this interacts with representations of underlying features of the text (such as grammar) through the layers of the model. I look forward to seeing the new research to come on Mamba.