Skip to content

State Space Model

Introduced in Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers.

The State Space Model combines the continuous memory of [[Recurrent Neural Network]]s and efficient parallelization of [[Convolutional Neural Network]]s. It can be seen as a general learnable LTI-filter.

Linear State-Space Layer

The State Space Model is based on the Linear State-Space Layer (LSSL). This layer is a generalization of [[Convolutional Neural Network]]s and [[Recurrent Neural Network]]s, combining them in a single framework.

lssl.svg

They are defined as by four matrices:

\[ \begin{align} \dot{x}(t) &= Ax(t)+Bu(t) \\ y(t)&=Cx(t)+Du(t) \end{align} \]

The matrix \(A\) is a weight of the recurrence of the model. Matrix \(B\) projects the input to the hidden dimension, Matrix \(C\) into the output dimension. The matrix \(D\) is the weight of the residual connection, combining the projected output with information of the input token.

Continuous Memory

The matrix \(A\) defines the evolution of the hidden context \(x(t)\), the so called continuous memory of the model. As in RNNs, the recurrent weighting \(A\) suffers from vanishing gradients when initialized randomly. To circumvent this issue, the authors employ the High-Order Polynomial Projection Operator [[HiPPO]] to initialize \(A\).

The HiPPO provides a method to transform the entire history of \(u(t)\) into a fixed-sized vector \(x(t)\), where \(x(t)\) is the optimal approximation of the signal's history projected onto a set of context-vectors for each input feature. So we for an input of size \(L\times H\) (length x features), we get a matrix \(A\) that transforms the input into a context-vector of size \(H\times N\) (features x context size) which optimally encodes the history of the input.

Discretization

This continuous-time formulation is discretized in practice through a discretization rule that depends on the step-size \(\Delta t\). Various rules can be applied, such as the [[Zero-Order Hold]]:

\[ \begin{align} \bar{A} &= \exp{(\Delta t A)}\\ \bar{B} &= (\Delta t A)^{-1}(\exp{(\Delta t A)-\mathbf{I}})\cdot\Delta t B \end{align} \]

In descretized for, we can unfold the model:

lssl_discrete.svg

The timescale \(\Delta t\) can be a matrix of the input size, providing a different step-size for each feature of the input. We can then represent the prior LSSL formulation for discrete sequences as:

\[ \begin{align} x_{t} &= \bar{A}x_{t-1}+\bar{B}u_{t} \tag{1} \\ y_{t} &= Cx_{t}+Du_{t} \tag{2} \end{align} \]

Since the discretization has a variable step-size, the model is resolution independent.

Equivalence with CNNs and RNNs

The LSSL is equivalent to CNNs and RNNs. Applying the LSSL formulation with discretized matrices is just the same as a recurrent model, where \(D\) is the gating mechanism.

The equivalence to convolution is seen when unfolding the discrete formulation (1) and (2) over time. Assuming \(x_{-1} = 0\) and ignoring the residual connection \(D\), we get:

\[ \begin{align} x_{0} &= \bar{A}x_{-1}+\bar{B}u_{0} \\ y_{0} &= Cx_{0}=C\bar{B}u_{0} \\ x_{1} &= \bar{A}x_{0}+\bar{B}u_{1} = \bar{A}\bar{B}u_{0}+\bar{B}u_{1} \\ y_{1} &= Cx_{1}=C\bar{B}\bar{A}u_{1} + C\bar{B}u_{1} \\ &\ldots\\ x_{t} &= \bar{A}x_{t-1}+\bar{B}u_{t} = \bar{A}^{t}\bar{B}u_{0}+\bar{A}^{t-1}\bar{B}u_{1}+\ldots+\bar{A}_{t}\bar{B}u_{t-1} + \bar{B}u_{t}\\ y_{t} &= C\bar{A}x_{t}+C\bar{B}u_{t} = C\bar{A}^{t}\bar{B}u_{0}+C\bar{A}^{t-1}\bar{B}u_{1}+\ldots+C\bar{A}_{t}\bar{B}u_{t-1} + C\bar{B}u_{t} \\ \end{align} \]

For fixed \(C,\bar{A},\bar{B}\), we can compute the factors \(C\bar{A}^{k}\bar{b}\), giving us a convolutional kernel:

\[ K=(C\bar{B}, C\bar{A}\bar{B},\ldots,C\bar{A}^{T-1}\bar{B}) \]

for an input sequence of size \(T\). Instead of computing the LSSL in an auto-regressive manner, we can directly get the output of the LSSL by convolving the kernel over the input: \(y=K * x\).

Efficiency

The LSSL combines the strengths of RNNs (unbounded context) and CNNs (parallelizable training). However, the requirement for this transformation is that all matrices A, B, C and \(\Delta t\) are time-invariant (see LTI System), otherwise the kernel parameters couldn't be pre-computed.

The convolutional kernel effectively bypasses the state computation, so that instead of having to recurrently expand the state, materialising a size of BxLxDxN (N-dimensional memory for each of the D input channels at each time-step L), we have a kernel only of size BxLxD. The memory state is implicit in the convolution. The effectiveness of the memory state depends on its size (reason why transformers are so effective, they have huge memory by caching previous key-value pairs), and since the dual-representation allows us to bypass the state materialisation, much larger memory sizes can be used (N ~ 10 to 100)

Strengths and Weaknesses

The LSSLs are efficient to train and work on very large context lengths (>10 tokens). They often still work when transformers already run out of memory. This is because, during training, the whole sequence can be processed efficiently in parallel and during inference, the next token can be predicted efficiently in the autoregressive recurrent mode.

However, they also inherit some of the weaknesses of RNNs, namely forgetting long-term history as the context evolves and not being able to selectively remember and forget. A solution is to make \(A\) input-dependent, deciding what context to keep and what to forget. This is an idea reminiscent to [[LSTM]]s is not straight-forward to implement, as the convolutional formulation is based on a fixed \(\bar{A}\). A solution to this problem is presented in the Mamba model.