In the great paper "The Truth is in There: Improving Reasoning in Language Models with Layer-Selective Rank Reduction", Sharma, Ash, and Misra (2023) present compelling evidence that their LAyer SElective Rank Reduction (LASER) approach can markedly improve the performance of language models. Essentially, LASER involves replacing certain weight matrices in language models with lower-rank approximations using singular value decompositon (SVD). Although these modified matrices contain less information than the original ones, surprisingly, the performance of the model increases, without any further pre-training or fine-tuning. The computation of these matrix approximations is straightforward—achievable with few lines of Python code—and swift, facilitating the ease of implementing this proposed method.
In this post, I delve into the world of Singular Value Decomposition (SVD) — explaining its theoretical underpinnings and bringing it to life with 3D visualizations. To top it off, I'll demonstrate the application of LASER to a BERT model. This practical example aims to showcase not just the simplicity of implementing LASER, but also its effectiveness in enhancing language models.
Intuition Behind Singular Value Decomposition (SVD)¶
1. Matrix as Transformation:¶
In the paper, the authors apply LASER on weight matrices within typical transformer-based language models, like the query, key, value or output matrices in the self-attention block, or the weight matrices of the MLP block of a transformer layer. Think of such a matrix $M$ as a linear transformation $ T: \mathbb{R}^n \rightarrow \mathbb{R}^m, x \longmapsto Mx$. $M$ can stretch, shrink, rotate or reflect vectors in a space.
I will visualize that in the 3D space, so let for our purposes $n$ and $m$ be equal to 3. All of the results we see in this post generalize to arbitrary values for $n$ and $m$ with only minor adjustments when $n\neq m$.
I start with the 3D unit cube, a cube starting at the origin (0,0,0) and whose sides are 1 unit long. As a first step, I define a function that plots a general rectangular cuboid given the coordinated of the vertices and the lower and upper limits for all axes:
%matplotlib notebook
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
def plot_rectangular(vertices, limits):
# Define the edges connecting the vertices
edges = [[vertices[:,i], vertices[:,j]] for i, j in [
(0, 1), (1, 2), (2, 3), (3, 0), # bottom face
(4, 5), (5, 6), (6, 7), (7, 4), # top face
(0, 4), (1, 5), (2, 6), (3, 7) # side faces
]]
# Create a 3D plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# Plot the edges
for edgeIdx, edge in enumerate(edges):
if edgeIdx < 4:
ax.plot(*zip(*edge), color='darkblue')
elif edgeIdx > 7:
ax.plot(*zip(*edge), color='gold')
else:
ax.plot(*zip(*edge), color='darkred')
# Set plot display parameters
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')
ax.set_xlim(limits)
ax.set_ylim(limits)
ax.set_zlim(limits)
plt.show()
Note that I use different colors for bottom, top and side edges. This makes it easier to recognize the object after transformation. Additionally, I allow to set the limits manually such that the different figures are comparable.
Plot of the 3D unit cube:
limits = (-2,3)
vertices_3d_unit_cube = np.array([[0, 0, 0],
[1, 0, 0],
[1, 1, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 1],
[1, 1, 1],
[0, 1, 1]])
# Transpose, as datapoints should be gicen as column vectors
vertices_3d_unit_cube = np.transpose(vertices_3d_unit_cube)
# Plot
plot_rectangular(vertices_3d_unit_cube,limits)
Next, we define some transformation matrix $M$ with which we transform the unit cube. $M$ is a combination of rotating the unit cube for 45° around both, the z and the y axis. Additionally, all axes are scaled with the factors 1.2, 1.4 and 0.8, respectively.
# Angle for rotation
angle_deg = 45
angle_rad = np.radians(angle_deg)
# Rotation matrix for 45° rotation around z-axis
rotation_matrix_z_45 = np.array([
[np.cos(angle_rad), -np.sin(angle_rad), 0],
[np.sin(angle_rad), np.cos(angle_rad), 0],
[0, 0, 1]
])
# Rotation matrix for 45° rotation around y-axis
rotation_matrix_y_45 = np.array([
[np.cos(angle_rad), 0, np.sin(angle_rad)],
[0, 1, 0],
[-np.sin(angle_rad), 0, np.cos(angle_rad)]
])
# Scaling all axes
scaling_matrix = np.diag([1.2,1.4,0.8])
# Create M
M = rotation_matrix_z_45.dot(rotation_matrix_y_45).dot(scaling_matrix)
# Transform unit cube
vertices_transformed = M.dot(vertices_3d_unit_cube)
# Plot
plot_rectangular(vertices_transformed,limits)
We see that a mixture of stretching and rotation of the 3D unit cube occured by the matrix multiplication.
2. Decomposing Transformation with SVD:¶
SVD allows to decompose $M$ into:
- Rotation/Reflection Matrices $U$ and $V^T$: These matrices are ortogonal (e.g. $U*U^T=Id$). One can interpret these matrices also basis changes of the vector space, with the columns of $U$ and $V$ being orthonormal bases. More on that later.
- Scaling Matrix $\Sigma$: Represented by a diagonal matrix ,with $M$'s singular values on the diagonal, in decreasing order. Singular values are non-negative real numbers.
Mathematical Representation: $M = U \Sigma V^T$
In python calculating the SVD of a matrix $M$ is simple:
# Do singular value decompositon
U, S, Vt = np.linalg.svd(M)
# S is given as an array of the singular values. Make it a diagonal matrix
S = np.diag(S)
# Print SVD decomposition
print("U:")
print(U)
print("\nSigma:")
print(S)
print("\nV^T")
print(Vt)
# Check whether M is indeed equal to U*S*Vt by calculating the sum of the absolute elementwise differences
print(f"\nSum of elementwise absolute differences between M and its SVD decompositon: {np.sum(np.abs(M-U.dot(S).dot(Vt)))}")
U: [[ 7.07106781e-01 -5.00000000e-01 -5.00000000e-01] [-7.07106781e-01 -5.00000000e-01 -5.00000000e-01] [-5.55111512e-17 7.07106781e-01 -7.07106781e-01]] Sigma: [[1.4 0. 0. ] [0. 1.2 0. ] [0. 0. 0.8]] V^T [[-0. -1. 0.] [-1. -0. 0.] [-0. 0. -1.]] Sum of elementwise absolute differences between M and its SVD decompositon: 7.993605777301127e-16
We see that $M$ is indeed equal to $U \Sigma V^T$.
In a next step, let's check visually that:
- $U$ and $V^T$ rotatione/reflect the object
- $\Sigma$ scales the object
Rotate unit cube with $V^T$:
vertices_transformed_Vt = Vt.dot(vertices_3d_unit_cube)
plot_rectangular(vertices_transformed_Vt,limits)
Scale with $\Sigma$:
vertices_transformed_S = S.dot(vertices_transformed_Vt)
plot_rectangular(vertices_transformed_S,limits)
Rotate with $U$:
vertices_transformed_U = U.dot(vertices_transformed_S)
plot_rectangular(vertices_transformed_U,limits)
3. Approximating $M$ by Removing Smaller Singular Values¶
Applying Singular Value Decomposition (SVD) to the weight matrices of language models won't have any effect, as SVD perfectly replicates the original matrix $M$. However, the real utility of SVD emerges in this case when seeking a lower-rank approximation of $M$. This is achieved by leveraging the fact that the larger singular values in the diagonal matrix $\Sigma$ have a more significant impact on the transformation represented by $M$. By strategically discarding components associated with smaller singular values, we retain only the most influential aspects of $M$'s transformation. This process effectively simplifies $M$, capturing its essential characteristics while reducing the complexity of the data.
To illustrate this in practical terms, consider again the example of transforming the 3D unit cube using $M$. If we choose to eliminate the smallest singular value, the transformation with the lower-rank approximation of $M$ would still yield a shape similar to the originally resulting rectangular cuboid (at least if the smallest singular value is close to 0). However, it would now resemble a plane, as one dimension of the cuboid is effectively 'lost' in this transformation.
I plan to visualize this concept using several different transformation matrices $M$. The examples will start with simpler matrices and progressively increase in complexity to demonstrate the effects of lower-rank approximations in various scenarios.
First, I define a function to apply the SVD approximation of $M$ by dropping the smallest singular value:
def approx_matrix_SVD(M):
# Computing the Singular Value Decomposition (SVD) of the matrix M
U, S, Vt = np.linalg.svd(M)
# Recreate M, but without using the last singular value
M_approx = U[:,:-1].dot(np.diag(S[:-1])).dot(Vt[:-1,:])
return M_approx
Note that this function can be easily extended to eliminate an arbitrary number of the smallest singular values. This extension will be necessary, for instance, when applied to language models. However, for the current demonstration, I will only remove the smallest singular value, as this most effectively visualizes the concept.
Example 1: No rotation, only scale z axis
In this example, I apply a scaling transformation to the unit cube, specifically scaling the z-axis to a factor of 0.3. This operation does not involve any rotation, so the singular values corresponding to the x and y axes remain at one. However, the singular value associated with the z-axis is reduced to 0.3. When employing SVD for rank reduction, we eliminate the smallest singular value. As a result, the z-dimension of the rectangular cuboid, having the smallest singular value (0.3), effectively vanishes:
M = np.array([[1, 0, 0],
[0, 1, 0],
[0, 0, 0.3]])
vertices_transformed_ex1 = M.dot(vertices_3d_unit_cube)
vertices_approx_transformed_ex1 = approx_matrix_SVD(M).dot(vertices_3d_unit_cube)
limits=(0,1)
print("Original transformation of the 3D unit cube:")
plot_rectangular(vertices_transformed_ex1,limits)
print("Approximated transformation of the 3D unit cube:")
plot_rectangular(vertices_approx_transformed_ex1,limits)
Original transformation of the 3D unit cube:
Approximated transformation of the 3D unit cube:
Example 2: No rotation, only scale y axis
This approach operates independently of the specific axis chosen. For instance, if the y-axis possesses the smallest singular value, then it is this axis that gets eliminated.
M = np.array([[1, 0, 0],
[0, 0.3, 0],
[0, 0, 1]])
vertices_transformed_ex2 = M.dot(vertices_3d_unit_cube)
vertices_approx_transformed_ex2 = approx_matrix_SVD(M).dot(vertices_3d_unit_cube)
limits=(0,1)
print("Original transformation of the 3D unit cube:")
plot_rectangular(vertices_transformed_ex2,limits)
print("Approximated transformation of the 3D unit cube:")
plot_rectangular(vertices_approx_transformed_ex2,limits)
Original transformation of the 3D unit cube:
Approximated transformation of the 3D unit cube:
Example 3: The information loss is driven by the size of the smallest singular value
Let's compare two cases: First, the smallest singular value is close to 0, say 0.05. We will see that the approximated transformation will be very close to the true transformation:
M = np.array([[1, 0, 0],
[0, 1, 0],
[0, 0, 0.05]])
vertices_transformed_ex3_1 = M.dot(vertices_3d_unit_cube)
vertices_approx_transformed_ex3_1 = approx_matrix_SVD(M).dot(vertices_3d_unit_cube)
limits=(0,1)
print("Original transformation of the 3D unit cube:")
plot_rectangular(vertices_transformed_ex3_1,limits)
print("Approximated transformation of the 3D unit cube:")
plot_rectangular(vertices_approx_transformed_ex3_1,limits)
Original transformation of the 3D unit cube:
Approximated transformation of the 3D unit cube:
On the other hand, if the smallest singular value is rather large, the approximation loss will also be larger:
M = np.array([[1, 0, 0],
[0, 1, 0],
[0, 0, 0.8]])
vertices_transformed_ex3_2 = M.dot(vertices_3d_unit_cube)
vertices_approx_transformed_ex3_2 = approx_matrix_SVD(M).dot(vertices_3d_unit_cube)
limits=(0,1)
print("Original transformation of the 3D unit cube:")
plot_rectangular(vertices_transformed_ex3_2,limits)
print("Approximated transformation of the 3D unit cube:")
plot_rectangular(vertices_approx_transformed_ex3_2,limits)
Original transformation of the 3D unit cube:
Approximated transformation of the 3D unit cube:
Example 4: Rotation/reflection and scaling
I conclude with some arbitrary matrix $M$.
M = np.array([[1, 2, 3],
[3, 1, 2],
[2, 3, 1]])
# Apply transformations
vertices_transformed_ex4= M.dot(vertices_3d_unit_cube)
vertices_approx_transformed_ex4 = approx_matrix_SVD(M).dot(vertices_3d_unit_cube)
# Plot
limits=(0,6)
print("Original transformation of the 3D unit cube:")
plot_rectangular(vertices_transformed_ex4,limits)
print("Approximated transformation of the 3D unit cube:")
plot_rectangular(vertices_approx_transformed_ex4,limits)
Original transformation of the 3D unit cube:
Approximated transformation of the 3D unit cube:
In this example, a similar outcome occurs as in the previous examples: the resultant rectangular cuboid is approximated by projecting it on a plane. However, in this case, the resulting plane does not align with any of the standard planes of the coordinate system (x,y), (x,z), or (y,z). Intuitively, this plane is chosen in a way to minimize the information loss by dropping one dimension, i.e. minimizing the distance between the plane and all edge points of the original rectangular cuboid. Thinking that way, the resulting plane looks intuitive, as it flattens the cuboid at its lowest points.
Another way to think about this is by considering every rotation/reflection as a basis change. So in the actual example, if we would change the basis of our 3D space in a way such that the resulting plane lies on the (x,y) plane, then again we have the same intuition of our previous example (flattening out the dimension with the smallest singular value).
SVD provides exactly that. Remember that the columns of $U = (U_1, U_2, U_3)$ and $V = (V_1, V_2, V_3)$ of the singular value decompomposition of some transformation $M$ $(M=U\Sigma V^T)$ provide orthonormal bases of, in our case, $\mathbb{R}^3$.
If we express matrix $M$ as some linear transformation $ T: \mathbb{R}^3 \rightarrow \mathbb{R}^3, x \longmapsto Mx$, then this map has a very simple description with respect to these orthonormal bases: $T(V_i) = \sigma_i U_i$, where $\sigma_i$ is the i-th diagonal entry of $\Sigma$.
So in our example, if we express the 3D unit cube first w.r.t. to $V$'s basis, apply $M$ and then express the result w.r.t. $U$'s basis, we are again in the case of our previous examples, where the resulting plane is on the (x,y) plane.
Therefore, SVD can be understood as a basis change with which the information loss by dropping the dimension with the lowest singular value is minimized.
Let's again visualize the same transformation of example 4, but this time after changing the basis:
# Apply SVD
U, S, Vt = np.linalg.svd(M)
# To change the basis, we need to multiply with the inverse a matrix whose columns are the new basis vectors
V_inv = np.linalg.inv(np.transpose(Vt))
U_inv = np.linalg.inv(U)
# Before transformation, change basis to V
vertices_unit_cube_base_V = V_inv.dot(vertices_3d_unit_cube)
# Transform with M and approximated M
vertices_base_V_transformed = M.dot(vertices_unit_cube_base_V)
vertices_base_V_approx_transformed = approx_matrix_SVD(M).dot(vertices_unit_cube_base_V)
# After transformation, change basis to U
vertices_base_U_transformed = U_inv.dot(vertices_base_V_transformed)
vertices_base_U_approx_transformed = U_inv.dot(vertices_base_V_approx_transformed)
# Plot
print("Original transformation of the 3D unit cube in base of U:")
plot_rectangular(vertices_base_U_transformed,limits)
print("Approximated transformation of the 3D unit cube in base of U:")
plot_rectangular(vertices_base_U_approx_transformed,limits)
Original transformation of the 3D unit cube in base of U:
Approximated transformation of the 3D unit cube in base of U:
Why is LASER working when applied to language models?¶
Now we understood what LASER does: It approximates transformation matrices $M$ by removing singular values that only explain few of the transformations variance. However, it is still not clear why this information removal is even beneficial for the performance of language models.
The authors explain the beneficial effect of removing certain information from language models through the concept of denoising. In the context of open-ended question answering, it has been noted that improvements often arise from questions where answers are supported by data that is less frequently occurring in the training set. When higher-order components (e.g. singular vectors with small singular values.) are eliminated, it seems to help the model recover this "hidden" or less frequent information.
The higher-order components of a model are suspected to often capture incorrect but semantically similar responses. When all components are used, these may lead to conflicting responses and the model hence produces generic, high-frequency tokens like "a", "the" or "of". By removing these higher-order components, the model's internal conflict is resolved. This allows the model to more accurately respond with the correct entity rather than defaulting to these generic responses.
For further information, have a look at the paper.
Simple Application: BERT-large Performance on SST-2 with and without LASER¶
To explore the effectiveness of the LASER approach, we'll apply it to an actual language model and observe the results.
We begin by defining the essential functions needed for our experiment. The first one is an adaptation of the SVD approximation, tailored to work with matrices represented as PyTorch tensors. This modification is crucial since the weight matrices in language models typically use this data format.
Another enhancement in our function is the ability to remove more than one singular value. This feature aligns with the methodology outlined in the paper. We introduce a parameter, denoted as $\rho$, to manage the fraction of the maximum rank to be eliminated during the low-rank approximation process. This addition offers more flexibility and control over the degree of approximation applied to the model.
import torch
import evaluate
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
from datasets import load_dataset, load_metric
import numpy as np
def approx_matrix_SVD(M , rho = 0.995):
r = round((1-rho) * min(M.shape))
# Computing the Singular Value Decomposition (SVD) of the matrix M
U, S, Vt = torch.linalg.svd(M)
# Recreate M, but only using first r singular values
M_approx = U[:,:r] @ torch.diag(S[:r]) @ Vt[:r,:]
return M_approx
Moving forward, I'll introduce a function designed to execute and assess the performance of a given model using the SST-2 benchmark. This step is fairly standard in the process.
def evaluate_model(model, tokenizer, dataset_name, batch_size=8, split='validation'):
# Load dataset and metric
dataset = load_dataset("glue", dataset_name)
metric = evaluate.load('glue', dataset_name)
# Tokenize the input texts
def tokenize(batch):
return tokenizer(batch['sentence'], padding=True, truncation=True)
dataset[split] = dataset[split].map(tokenize, batched=True)
dataset[split].set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
# DataLoader
dataloader = DataLoader(dataset[split], batch_size=batch_size)
# Evaluation
model.eval()
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['label'].to(device)
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
logits = outputs.logits
predictions = torch.argmax(logits, axis=-1)
metric.add_batch(predictions=predictions, references=labels)
final_score = metric.compute()
return final_score['accuracy']
I apply the LASER methodology to a BERT-large model, which has been fine-tuned specifically for the SST-2 benchmark.
# Setting up the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Adjust the batch size as needed
batch_size = 16 # Example batch size
# Model and Dataset
model_name = 'assemblyai/bert-large-uncased-sst2'
dataset_name = 'sst2'
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name).to(device)
Using device: cuda
Before we dive into the application of the LASER approach, it's crucial to establish a baseline for comparison. To do this, we calculate the validation accuracy of our BERT model in its original state, without the LASER modifications.
# Run without LASER
model = BertForSequenceClassification.from_pretrained(model_name).to(device)
val_score_without_laser = evaluate_model(model, tokenizer, dataset_name)
print(f"Validation score without laser: {val_score_without_laser:.2%}")
Validation score without laser: 92.89%
Now, we are set to apply the LASER technique to our BERT model. LASER can be applied to any linear transformation within a language model. This led the authors to employ a grid search strategy, exploring various layers ($l$), matrices ($\tau$) and fractions ($\rho$) to pinpoint the most effective LASER intervention. They didn't stop there; they also experimented with different combinations LASER interventions to further optimize their results.
However, it's important to note that conducting such a comprehensive grid search is an immensely time-consuming endeavor. Given that this post is more of an illustrative exercise, we adopt a more streamlined approach.
In our case, we apply LASER specifically to the first linear transformation in the MLP of a transformer layer. This focus is based on the authors' findings, which highlighted significant improvements when LASER was applied to this particular transformation. Our grid search will be limited to varying the layer number $l$ (ranging from 0 to 23) and the fraction $\rho$ (options being 0.1, 0.5, or 0.9).
# Run with LASER
layers = list(range(24)) # Transformer block index (0-based)
rhos = [0.1,0.5,0.9]
layer_num = []
rho_val = []
accuracy = []
for layer in layers:
for rho in rhos:
model_laser = BertForSequenceClassification.from_pretrained(model_name).to(device)
# Apply LASER
model_laser.bert.encoder.layer[layer].output.dense.weight.data = approx_matrix_SVD(model_laser.bert.encoder.layer[layer].output.dense.weight.data, rho = 0.5)
layer_num.append(layer) # Track current layer
rho_val.append(rho) # Track current rho
accuracy.append(evaluate_model(model_laser, tokenizer, dataset_name)) # Track acutal accuracy
# Print laser results
import pandas as pd
results_laser = pd.DataFrame({"Layer" : layer_num, "Rho" : rho_val, "Accuracy" : accuracy})
results_laser.sort_values(by=['Accuracy'], inplace = True, ignore_index = True, ascending=False)
print(results_laser.head(10))
Layer Rho Accuracy 0 1 0.1 0.932339 1 1 0.5 0.932339 2 1 0.9 0.932339 3 4 0.1 0.931193 4 21 0.9 0.931193 5 21 0.5 0.931193 6 21 0.1 0.931193 7 4 0.5 0.931193 8 4 0.9 0.931193 9 5 0.5 0.930046
It turns out that several LASER interventions can indeed improve BERT's performance, and notably, this is achieved without additional pretraining or finetuning. This finding suggests that targeted adjustments can effectively enhance a language model's capabilities, offering a practical approach to model optimization.
That's it! I hope you liked my first blog post. If you have any questions, comments, or suggestions for improvement, then it's best to contact me via email.