Neurochirurgie sur les Transformers¶

Introduction¶

Aujourd'hui, nous allons explorer les SAE (Auto-Encodeurs Parcimonieux). Cette technique a récemment fait parler d'elle grâce à un article d'Anthropic qui a révélé que Claude, leur modèle d'IA, avait identifié le concept de "Golden Bridge" grâce à cette méthode, et s'était même mis à prétendre qu'il était lui-même le fameux pont de San Francisco...

Notre objectif est de comprendre comment les SAE peuvent nous aider à décoder le fonctionnement interne des réseaux de neurones, en particulier des Transformers, qui sont à la base des LLMs comme ChatGPT, mais également de modèles de vision.

Notre approche¶

  1. Entraînement d'un Transformer : Nous allons d'abord entraîner un modèle Transformer sur le classique dataset MNIST (reconnaissance de chiffres manuscrits).

  2. Extraction des activations : Une fois le modèle entraîné, nous allons extraire les activations du MLP (Multi-Layer Perceptron) lorsqu'on lui présente une image issue du MNIST.

  3. Entraînement d'un SAE : Nous utiliserons ces activations comme données d'entrée pour entraîner un Auto-Encodeur Parcimonieux. Le SAE va apprendre de lui-même quels sont les concepts importants encodés dans cette couche du MLP.

  4. Intervention (neuro-chirurgie) : Nous utiliserons cette séparation faite par le SAE pour intervenir sur les neurones par le biais du décodeur du SAE.

Cette approche nous permettra de "ouvrir la boîte noire" du Transformer et de comprendre quels types de caractéristiques ou de concepts il a appris à reconnaître dans les images de chiffres.

Pourquoi est-ce important ?¶

L'interprétabilité des modèles d'apprentissage profond est un enjeu majeur en IA. Les SAE offrent une méthode prometteuse pour décomposer les représentations complexes apprises par ces modèles en concepts plus simples et interprétables. Cela pourrait nous aider à mieux comprendre comment ces modèles prennent leurs décisions, et potentiellement à les améliorer ou à détecter des biais.

Dans les sections suivantes, nous allons détailler chaque étape de ce processus, en commençant par la préparation des données MNIST et l'architecture de notre Transformer.

SAE.png

Entraînement, probing, intervention avec SAE¶

Mettre les choses en place¶

In [ ]:
%%capture
!pip install datasets
In [ ]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import random
from datasets import load_dataset
import numpy as np

Le MNIST¶

MNIST (Modified National Institute of Standards and Technology) est un jeu de données emblématique en apprentissage automatique, introduit en 1998 par Yann LeCun, Corinna Cortes et Christopher Burges. Il consiste en 70 000 images en niveaux de gris de chiffres manuscrits (0 à 9), chacune de taille 28x28 pixels. Le jeu est divisé en 60 000 images d'entraînement et 10 000 images de test.

MNIST a joué un rôle crucial dans le développement de techniques de reconnaissance de formes et reste un "Hello World" de l'apprentissage profond.

In [ ]:
# Chargement et préparation du dataset
dataset = load_dataset("mnist")
train_dataset = dataset["train"]
test_dataset = dataset["test"]

def preprocess(examples):
    examples["pixel_values"] = [torch.tensor(np.array(img), dtype=torch.float32).view(1, 28, 28) / 255.0
                                for img in examples["image"]]
    return examples

train_dataset = train_dataset.map(preprocess, batched=True, remove_columns=["image"])
test_dataset = test_dataset.map(preprocess, batched=True, remove_columns=["image"])
train_dataset.set_format(type="torch", columns=["pixel_values", "label"])
test_dataset.set_format(type="torch", columns=["pixel_values", "label"])

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Fonction pour afficher des échantillons
def show_samples(dataset, num_samples=5):
    fig, axes = plt.subplots(1, num_samples, figsize=(15, 3))
    for i in range(num_samples):
        idx = np.random.randint(len(dataset))
        img = dataset[idx]["pixel_values"].squeeze().numpy()
        label = dataset[idx]["label"]
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f"Label: {label}")
        axes[i].axis('off')
    plt.show()

# Afficher des échantillons du jeu d'entraînement
show_samples(train_dataset)

# Afficher des échantillons du jeu de test
show_samples(test_dataset)
/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: 
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  warnings.warn(
README.md: 0.00B [00:00, ?B/s]
mnist/train-00000-of-00001.parquet:   0%|          | 0.00/15.6M [00:00<?, ?B/s]
mnist/test-00000-of-00001.parquet:   0%|          | 0.00/2.60M [00:00<?, ?B/s]
Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]
Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]
Map:   0%|          | 0/60000 [00:00<?, ? examples/s]
Map:   0%|          | 0/10000 [00:00<?, ? examples/s]
No description has been provided for this image
No description has been provided for this image

Le Transformer¶

Qu'est-ce qu'un Transformer ? C'est une architecture assez complexe de modèle, qui comprend trois étapes principales :

Le plongement (ou embedding) des mots dans des espaces en grande dimension. Concrètement, chaque mot se voit associer un vecteur qui vit dans un espace de grande dimension $\mathbb{R}^n$. Ces vecteurs sont appris durant l'entraînement. Cela permet en quelque sorte d'encoder la sémantique de chaque mot. Soit $w$ un mot, son plongement $e_w$ est défini par :

$$e_w = E[w]$$

où $E$ est la matrice de plongement apprise.

Le mécanisme d'attention, qui permet de calculer les dépendances entre les différents mots d'une phrase. Ce sont trois matrices $Q$, $K$ et $V$ qui sont apprises durant l'entraînement, et qui permettent de projeter l'importance relative des mots précédents sur le suivant. L'attention est calculée comme suit :

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

où $d_k$ est la dimension de la matrice $K$.

Le perceptron multi-couche, qui permet de manipuler les concepts abstraits. C'est un réseau de neurones on ne peut plus classique, entièrement connecté, traditionnellement à deux couches et faisant quatre fois la taille du plongement. Sa fonction peut être représentée comme :

$$\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2$$

où $W_1$, $W_2$, $b_1$, et $b_2$ sont les paramètres appris.

In [ ]:
# Paramètres
batch_size = 1024
num_epochs = 5
learning_rate = 0.001
d_model = 64
nhead = 2
num_layers = 2

Quelques mots sur ces paramètres, appelés hyper-paramètres :

  • La taille du lot (batch_size) désigne combien d'exemples notre Transformer va voir en parallèle à chaque étape d'entraînement. Ici 1024 exemples.
  • Le nombre d'époques (num_epochs) désigne le nombre de fois où le jeu de données en entier est vu par le modèle. Ici il va voir 5 fois le jeu de données au complet.
  • Le taux d'apprentissage (learning_rate) désigne quelle est la force de l'apprentissage, c'est un paramètre multiplicatif devant le gradient.
  • La dimension du modèle transformer (d_model) désigne la dimension des vecteurs d'embedding, mais donc aussi lors des différentes normalisations.
  • Le nombre de têtes d'attention (nhead) désigne le nombre de mécanismes d'attention appris en parallèle. Deux têtes signifie qu'il y deux matrices $Q$, deux matrices $K$, et deux matrices $V$.
  • Le nombre de couches (num_layers) désigne le nombre de fois où l'on répète un transformer, l'un à la suite de l'autre, ici deux fois.
In [ ]:
# Modèle Transformer
class MNISTTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers, num_classes=10):
        super(MNISTTransformer, self).__init__()
        self.embedding = nn.Linear(28*28, d_model)
        self.pos_encoder = nn.Parameter(torch.randn(1, d_model))
        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
            for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x, return_activations=False):
        x = x.view(x.size(0), -1)  # Aplatir l'image (batch_size, 784)
        x = self.embedding(x)  # (batch_size, d_model)
        x = x + self.pos_encoder  # Diffusion du pos_encoder
        x = x.unsqueeze(1)  # Ajouter une dimension de séquence (batch_size, 1, d_model)

        activations = None
        for i, layer in enumerate(self.transformer_layers):
            x = layer(x)
            if i == 1 and return_activations:
                activations = x.clone()

        x = x.squeeze(1)  # Retirer la dimension de séquence (batch_size, d_model)
        x = self.fc(x)

        if return_activations:
            return x, activations
        return x

Le SAE¶

Qu'est-ce qu'un SAE ? En français Auto-Encodeur Parcimonieux, le SAE est un modèle qui prend en entrée une donnée $x$, et doit rendre la même donnée en sortie $\hat{x}$ (d'où le "Auto"). L'idée derrière ce modèle est qu'il encode l'information contenue dans la donnée, en tout cas sa substantifique moelle, dans une représentation cachée $h$.

Mathématiquement, cela peut s'exprimer comme :

$$h = f(x) = \sigma(Wx + b)$$ $$\hat{x} = g(h) = \sigma(W'h + b')$$

où $f$ est la fonction d'encodage et $g$ la fonction de décodage.

La parcimonie garantit que la solution trouvée sera très simple, appliquant ainsi le rasoir d'Occam, qui nous dit de privilégier les solutions simples lorsque plusieurs solutions s'offrent à nous. Cette parcimonie est obtenue en ajoutant un terme de régularisation $\Omega(h)$ à la fonction de perte (loss) :

$$\min_{W, W', b, b'} L(x, \hat{x}) + \lambda \Omega(h)$$

où $L$ est une fonction de perte mesurant la différence (distance) entre l'entrée et la sortie reconstruite, et $\lambda$ contrôle le degré de parcimonie.

Nous avons donc une couche linéaire pour encoder l'information, et une couche linéaire pour la décoder.

In [ ]:
sae_hidden_dim = 100 # Une dimension qui doit impérativement être un carré, pour la visualisation

Pour résumer le SAE, nous avons donc :

  • Une couche d'entrée de la dimension d'une couche du Transformer qu'on a défini plus haut, c'est-à-dire 64
  • Une transformation linéaire qui prend cette entrée et qui rend :
  • Une couche cachée, celle qui nous intéresse pour le code, car elle encode l'information de l'entrée, de dimension 100
  • Une transformation linéaire qui prend cet encodage et qui rend :
  • Un décodage, qui est de taille 64 comme l'entrée, et qui doit en principe lui ressembler le plus possible
In [ ]:
# Modèle SAE
class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SparseAutoencoder, self).__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        encoded = torch.relu(self.encoder(x))
        decoded = self.decoder(encoded)
        return decoded, encoded

L'entraînement¶

Cette fonction a pour but d'évaluer à quel point le Transformer prédit bien les étiquettes associées aux images du MNIST. C'est un fonction de précision, et ce qu'elle rend est un pourcentage d'exemples correctement classifiés.

In [ ]:
# Fonction d'évaluation pour le Transformer
def evaluate_transformer(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in data_loader:
            inputs, labels = batch['pixel_values'].to(device), batch['label'].to(device)
            outputs, _ = model(inputs, return_activations=True)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(data_loader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy

Le critère de la fonction de perte pour le Transformer est l'entropie croisée (nn.CrossEntropyLoss()), qui permet de calculer la distance entre deux distributions. Dans ce contexte, nous comparons :

  1. La distribution attendue, qui est une distribution certaine puisqu'une seule étiquette est prévue pour chaque image. Elle est représentée par un vecteur unitaire $y$ (one-hot encoding) :

    $$y = (0, 0, ..., 1, ..., 0)$$

  2. La distribution à la sortie du Transformer, que nous appellerons $\hat{y}$ :

    $$\hat{y} = (\hat{y}_1, \hat{y}_2, ..., \hat{y}_n)$$

    où $n$ est le nombre de classes possibles.

L'entropie croisée entre ces deux distributions est définie comme suit :

$$H(y, \hat{y}) = -\sum_{i=1}^n y_i \log(\hat{y}_i)$$

Dans le cas d'une distribution certaine (one-hot encoding), ce qui est le cas ici, cette formule se simplifie à :

$$H(y, \hat{y}) = -\log(\hat{y}_c)$$

où $c$ est l'indice de la classe correcte.

L'objectif de l'entraînement est de minimiser cette entropie croisée, ce qui revient à maximiser la probabilité que le modèle assigne à la classe correcte.

In [ ]:
# Initialisation des modèles
transformer = MNISTTransformer(d_model, nhead, num_layers)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(transformer.parameters(), lr=learning_rate)

# Entraînement du Transformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer.to(device)
Out[ ]:
MNISTTransformer(
  (embedding): Linear(in_features=784, out_features=64, bias=True)
  (transformer_layers): ModuleList(
    (0-1): 2 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (linear1): Linear(in_features=64, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=64, bias=True)
      (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Linear(in_features=64, out_features=10, bias=True)
)
In [ ]:
for epoch in range(num_epochs):
    transformer.train()
    total_loss = 0
    for batch in train_loader:
        inputs, labels = batch['pixel_values'].to(device), batch['label'].to(device)

        optimizer.zero_grad()
        outputs, _ = transformer(inputs, return_activations=True)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    train_loss = total_loss / len(train_loader)

    # Évaluation sur l'ensemble de test
    test_loss, test_accuracy = evaluate_transformer(transformer, test_loader, criterion, device)

    print(f"Transformer Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

print("Entraînement du Transformer terminé!")
Transformer Epoch [1/5]
  Train Loss: 0.3148
  Test Loss: 0.1609, Test Accuracy: 95.13%
Transformer Epoch [2/5]
  Train Loss: 0.1451
  Test Loss: 0.1290, Test Accuracy: 96.24%
Transformer Epoch [3/5]
  Train Loss: 0.1107
  Test Loss: 0.0871, Test Accuracy: 97.48%
Transformer Epoch [4/5]
  Train Loss: 0.0875
  Test Loss: 0.0965, Test Accuracy: 97.14%
Transformer Epoch [5/5]
  Train Loss: 0.0750
  Test Loss: 0.0844, Test Accuracy: 97.47%
Entraînement du Transformer terminé!

Une fois le transformer entraîné, on collecte les activations de la couche de MLP, c'est à dire les états des vecteurs lorsqu'on donne une image du MNIST à notre Transformer. C'est comme si on prenait un instantané de la pensée de notre Transformer.

In [ ]:
# Collecter les activations
transformer.eval()
activations = []
with torch.no_grad():
    for batch in train_loader:
        inputs = batch['pixel_values'].to(device)
        _, batch_activations = transformer(inputs, return_activations=True)
        activations.append(batch_activations.squeeze(1).cpu())

activations = torch.cat(activations, dim=0)

Avec ces activations, on peut entraîner notre SAE, qui va s'efforcer de rendre la même activation, en l'encodant avec le moins de neurones activés possible (la plupart vaut zéro comme on le verra).

In [ ]:
# Fonction d'entraînement du SAE
def train_sae(sae, activation_loader, num_epochs=300, learning_rate=0.001, sparsity_weight=0.1):
    optimizer = optim.Adam(sae.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    for epoch in range(num_epochs):
        total_loss = 0
        for batch in activation_loader:
            batch = batch[0].to(device)
            optimizer.zero_grad()
            decoded, encoded = sae(batch)
            reconstruction_loss = criterion(decoded, batch)
            sparsity_loss = torch.mean(torch.abs(encoded))
            loss = reconstruction_loss + sparsity_weight * sparsity_loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"SAE Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(activation_loader):.4f}")
In [ ]:
# Initialiser et entraîner le SAE
sae = SparseAutoencoder(d_model, sae_hidden_dim)
sae.to(device)
activation_dataset = TensorDataset(activations)
activation_loader = DataLoader(activation_dataset, batch_size=batch_size, shuffle=True)
train_sae(sae, activation_loader)

print("Entraînement du SAE terminé!")
SAE Epoch [1/300], Loss: 0.5661
SAE Epoch [2/300], Loss: 0.0721
SAE Epoch [3/300], Loss: 0.0597
SAE Epoch [4/300], Loss: 0.0524
SAE Epoch [5/300], Loss: 0.0467
SAE Epoch [6/300], Loss: 0.0421
SAE Epoch [7/300], Loss: 0.0384
SAE Epoch [8/300], Loss: 0.0353
SAE Epoch [9/300], Loss: 0.0326
SAE Epoch [10/300], Loss: 0.0304
SAE Epoch [11/300], Loss: 0.0284
SAE Epoch [12/300], Loss: 0.0267
SAE Epoch [13/300], Loss: 0.0252
SAE Epoch [14/300], Loss: 0.0239
SAE Epoch [15/300], Loss: 0.0227
SAE Epoch [16/300], Loss: 0.0216
SAE Epoch [17/300], Loss: 0.0206
SAE Epoch [18/300], Loss: 0.0196
SAE Epoch [19/300], Loss: 0.0188
SAE Epoch [20/300], Loss: 0.0180
SAE Epoch [21/300], Loss: 0.0173
SAE Epoch [22/300], Loss: 0.0166
SAE Epoch [23/300], Loss: 0.0160
SAE Epoch [24/300], Loss: 0.0154
SAE Epoch [25/300], Loss: 0.0148
SAE Epoch [26/300], Loss: 0.0143
SAE Epoch [27/300], Loss: 0.0138
SAE Epoch [28/300], Loss: 0.0133
SAE Epoch [29/300], Loss: 0.0129
SAE Epoch [30/300], Loss: 0.0125
SAE Epoch [31/300], Loss: 0.0121
SAE Epoch [32/300], Loss: 0.0117
SAE Epoch [33/300], Loss: 0.0113
SAE Epoch [34/300], Loss: 0.0110
SAE Epoch [35/300], Loss: 0.0107
SAE Epoch [36/300], Loss: 0.0104
SAE Epoch [37/300], Loss: 0.0102
SAE Epoch [38/300], Loss: 0.0099
SAE Epoch [39/300], Loss: 0.0097
SAE Epoch [40/300], Loss: 0.0095
SAE Epoch [41/300], Loss: 0.0092
SAE Epoch [42/300], Loss: 0.0090
SAE Epoch [43/300], Loss: 0.0088
SAE Epoch [44/300], Loss: 0.0087
SAE Epoch [45/300], Loss: 0.0085
SAE Epoch [46/300], Loss: 0.0083
SAE Epoch [47/300], Loss: 0.0082
SAE Epoch [48/300], Loss: 0.0080
SAE Epoch [49/300], Loss: 0.0079
SAE Epoch [50/300], Loss: 0.0077
SAE Epoch [51/300], Loss: 0.0076
SAE Epoch [52/300], Loss: 0.0075
SAE Epoch [53/300], Loss: 0.0074
SAE Epoch [54/300], Loss: 0.0073
SAE Epoch [55/300], Loss: 0.0072
SAE Epoch [56/300], Loss: 0.0070
SAE Epoch [57/300], Loss: 0.0069
SAE Epoch [58/300], Loss: 0.0068
SAE Epoch [59/300], Loss: 0.0067
SAE Epoch [60/300], Loss: 0.0067
SAE Epoch [61/300], Loss: 0.0066
SAE Epoch [62/300], Loss: 0.0065
SAE Epoch [63/300], Loss: 0.0064
SAE Epoch [64/300], Loss: 0.0063
SAE Epoch [65/300], Loss: 0.0062
SAE Epoch [66/300], Loss: 0.0061
SAE Epoch [67/300], Loss: 0.0061
SAE Epoch [68/300], Loss: 0.0060
SAE Epoch [69/300], Loss: 0.0059
SAE Epoch [70/300], Loss: 0.0058
SAE Epoch [71/300], Loss: 0.0058
SAE Epoch [72/300], Loss: 0.0057
SAE Epoch [73/300], Loss: 0.0056
SAE Epoch [74/300], Loss: 0.0056
SAE Epoch [75/300], Loss: 0.0055
SAE Epoch [76/300], Loss: 0.0054
SAE Epoch [77/300], Loss: 0.0054
SAE Epoch [78/300], Loss: 0.0053
SAE Epoch [79/300], Loss: 0.0052
SAE Epoch [80/300], Loss: 0.0052
SAE Epoch [81/300], Loss: 0.0051
SAE Epoch [82/300], Loss: 0.0051
SAE Epoch [83/300], Loss: 0.0050
SAE Epoch [84/300], Loss: 0.0050
SAE Epoch [85/300], Loss: 0.0049
SAE Epoch [86/300], Loss: 0.0049
SAE Epoch [87/300], Loss: 0.0048
SAE Epoch [88/300], Loss: 0.0048
SAE Epoch [89/300], Loss: 0.0047
SAE Epoch [90/300], Loss: 0.0047
SAE Epoch [91/300], Loss: 0.0046
SAE Epoch [92/300], Loss: 0.0046
SAE Epoch [93/300], Loss: 0.0045
SAE Epoch [94/300], Loss: 0.0045
SAE Epoch [95/300], Loss: 0.0044
SAE Epoch [96/300], Loss: 0.0044
SAE Epoch [97/300], Loss: 0.0044
SAE Epoch [98/300], Loss: 0.0043
SAE Epoch [99/300], Loss: 0.0043
SAE Epoch [100/300], Loss: 0.0042
SAE Epoch [101/300], Loss: 0.0042
SAE Epoch [102/300], Loss: 0.0042
SAE Epoch [103/300], Loss: 0.0041
SAE Epoch [104/300], Loss: 0.0041
SAE Epoch [105/300], Loss: 0.0040
SAE Epoch [106/300], Loss: 0.0040
SAE Epoch [107/300], Loss: 0.0040
SAE Epoch [108/300], Loss: 0.0039
SAE Epoch [109/300], Loss: 0.0039
SAE Epoch [110/300], Loss: 0.0039
SAE Epoch [111/300], Loss: 0.0038
SAE Epoch [112/300], Loss: 0.0038
SAE Epoch [113/300], Loss: 0.0038
SAE Epoch [114/300], Loss: 0.0037
SAE Epoch [115/300], Loss: 0.0037
SAE Epoch [116/300], Loss: 0.0037
SAE Epoch [117/300], Loss: 0.0036
SAE Epoch [118/300], Loss: 0.0036
SAE Epoch [119/300], Loss: 0.0036
SAE Epoch [120/300], Loss: 0.0036
SAE Epoch [121/300], Loss: 0.0035
SAE Epoch [122/300], Loss: 0.0035
SAE Epoch [123/300], Loss: 0.0035
SAE Epoch [124/300], Loss: 0.0034
SAE Epoch [125/300], Loss: 0.0034
SAE Epoch [126/300], Loss: 0.0034
SAE Epoch [127/300], Loss: 0.0034
SAE Epoch [128/300], Loss: 0.0033
SAE Epoch [129/300], Loss: 0.0033
SAE Epoch [130/300], Loss: 0.0033
SAE Epoch [131/300], Loss: 0.0033
SAE Epoch [132/300], Loss: 0.0032
SAE Epoch [133/300], Loss: 0.0032
SAE Epoch [134/300], Loss: 0.0032
SAE Epoch [135/300], Loss: 0.0032
SAE Epoch [136/300], Loss: 0.0031
SAE Epoch [137/300], Loss: 0.0031
SAE Epoch [138/300], Loss: 0.0031
SAE Epoch [139/300], Loss: 0.0031
SAE Epoch [140/300], Loss: 0.0031
SAE Epoch [141/300], Loss: 0.0030
SAE Epoch [142/300], Loss: 0.0030
SAE Epoch [143/300], Loss: 0.0030
SAE Epoch [144/300], Loss: 0.0030
SAE Epoch [145/300], Loss: 0.0030
SAE Epoch [146/300], Loss: 0.0029
SAE Epoch [147/300], Loss: 0.0029
SAE Epoch [148/300], Loss: 0.0029
SAE Epoch [149/300], Loss: 0.0029
SAE Epoch [150/300], Loss: 0.0028
SAE Epoch [151/300], Loss: 0.0028
SAE Epoch [152/300], Loss: 0.0028
SAE Epoch [153/300], Loss: 0.0028
SAE Epoch [154/300], Loss: 0.0028
SAE Epoch [155/300], Loss: 0.0028
SAE Epoch [156/300], Loss: 0.0027
SAE Epoch [157/300], Loss: 0.0027
SAE Epoch [158/300], Loss: 0.0027
SAE Epoch [159/300], Loss: 0.0027
SAE Epoch [160/300], Loss: 0.0027
SAE Epoch [161/300], Loss: 0.0027
SAE Epoch [162/300], Loss: 0.0026
SAE Epoch [163/300], Loss: 0.0026
SAE Epoch [164/300], Loss: 0.0026
SAE Epoch [165/300], Loss: 0.0026
SAE Epoch [166/300], Loss: 0.0026
SAE Epoch [167/300], Loss: 0.0026
SAE Epoch [168/300], Loss: 0.0025
SAE Epoch [169/300], Loss: 0.0025
SAE Epoch [170/300], Loss: 0.0025
SAE Epoch [171/300], Loss: 0.0025
SAE Epoch [172/300], Loss: 0.0025
SAE Epoch [173/300], Loss: 0.0025
SAE Epoch [174/300], Loss: 0.0025
SAE Epoch [175/300], Loss: 0.0025
SAE Epoch [176/300], Loss: 0.0024
SAE Epoch [177/300], Loss: 0.0024
SAE Epoch [178/300], Loss: 0.0024
SAE Epoch [179/300], Loss: 0.0024
SAE Epoch [180/300], Loss: 0.0024
SAE Epoch [181/300], Loss: 0.0024
SAE Epoch [182/300], Loss: 0.0024
SAE Epoch [183/300], Loss: 0.0023
SAE Epoch [184/300], Loss: 0.0023
SAE Epoch [185/300], Loss: 0.0023
SAE Epoch [186/300], Loss: 0.0023
SAE Epoch [187/300], Loss: 0.0023
SAE Epoch [188/300], Loss: 0.0023
SAE Epoch [189/300], Loss: 0.0023
SAE Epoch [190/300], Loss: 0.0023
SAE Epoch [191/300], Loss: 0.0023
SAE Epoch [192/300], Loss: 0.0022
SAE Epoch [193/300], Loss: 0.0022
SAE Epoch [194/300], Loss: 0.0022
SAE Epoch [195/300], Loss: 0.0022
SAE Epoch [196/300], Loss: 0.0022
SAE Epoch [197/300], Loss: 0.0022
SAE Epoch [198/300], Loss: 0.0022
SAE Epoch [199/300], Loss: 0.0022
SAE Epoch [200/300], Loss: 0.0022
SAE Epoch [201/300], Loss: 0.0022
SAE Epoch [202/300], Loss: 0.0021
SAE Epoch [203/300], Loss: 0.0021
SAE Epoch [204/300], Loss: 0.0021
SAE Epoch [205/300], Loss: 0.0021
SAE Epoch [206/300], Loss: 0.0021
SAE Epoch [207/300], Loss: 0.0021
SAE Epoch [208/300], Loss: 0.0021
SAE Epoch [209/300], Loss: 0.0021
SAE Epoch [210/300], Loss: 0.0021
SAE Epoch [211/300], Loss: 0.0021
SAE Epoch [212/300], Loss: 0.0021
SAE Epoch [213/300], Loss: 0.0021
SAE Epoch [214/300], Loss: 0.0021
SAE Epoch [215/300], Loss: 0.0020
SAE Epoch [216/300], Loss: 0.0020
SAE Epoch [217/300], Loss: 0.0020
SAE Epoch [218/300], Loss: 0.0020
SAE Epoch [219/300], Loss: 0.0020
SAE Epoch [220/300], Loss: 0.0020
SAE Epoch [221/300], Loss: 0.0020
SAE Epoch [222/300], Loss: 0.0020
SAE Epoch [223/300], Loss: 0.0020
SAE Epoch [224/300], Loss: 0.0020
SAE Epoch [225/300], Loss: 0.0020
SAE Epoch [226/300], Loss: 0.0020
SAE Epoch [227/300], Loss: 0.0019
SAE Epoch [228/300], Loss: 0.0019
SAE Epoch [229/300], Loss: 0.0019
SAE Epoch [230/300], Loss: 0.0019
SAE Epoch [231/300], Loss: 0.0019
SAE Epoch [232/300], Loss: 0.0019
SAE Epoch [233/300], Loss: 0.0019
SAE Epoch [234/300], Loss: 0.0019
SAE Epoch [235/300], Loss: 0.0019
SAE Epoch [236/300], Loss: 0.0019
SAE Epoch [237/300], Loss: 0.0019
SAE Epoch [238/300], Loss: 0.0019
SAE Epoch [239/300], Loss: 0.0019
SAE Epoch [240/300], Loss: 0.0019
SAE Epoch [241/300], Loss: 0.0019
SAE Epoch [242/300], Loss: 0.0019
SAE Epoch [243/300], Loss: 0.0018
SAE Epoch [244/300], Loss: 0.0019
SAE Epoch [245/300], Loss: 0.0018
SAE Epoch [246/300], Loss: 0.0018
SAE Epoch [247/300], Loss: 0.0018
SAE Epoch [248/300], Loss: 0.0018
SAE Epoch [249/300], Loss: 0.0018
SAE Epoch [250/300], Loss: 0.0018
SAE Epoch [251/300], Loss: 0.0019
SAE Epoch [252/300], Loss: 0.0018
SAE Epoch [253/300], Loss: 0.0018
SAE Epoch [254/300], Loss: 0.0018
SAE Epoch [255/300], Loss: 0.0018
SAE Epoch [256/300], Loss: 0.0018
SAE Epoch [257/300], Loss: 0.0018
SAE Epoch [258/300], Loss: 0.0018
SAE Epoch [259/300], Loss: 0.0018
SAE Epoch [260/300], Loss: 0.0018
SAE Epoch [261/300], Loss: 0.0018
SAE Epoch [262/300], Loss: 0.0018
SAE Epoch [263/300], Loss: 0.0017
SAE Epoch [264/300], Loss: 0.0017
SAE Epoch [265/300], Loss: 0.0018
SAE Epoch [266/300], Loss: 0.0017
SAE Epoch [267/300], Loss: 0.0017
SAE Epoch [268/300], Loss: 0.0018
SAE Epoch [269/300], Loss: 0.0017
SAE Epoch [270/300], Loss: 0.0017
SAE Epoch [271/300], Loss: 0.0017
SAE Epoch [272/300], Loss: 0.0017
SAE Epoch [273/300], Loss: 0.0017
SAE Epoch [274/300], Loss: 0.0017
SAE Epoch [275/300], Loss: 0.0017
SAE Epoch [276/300], Loss: 0.0017
SAE Epoch [277/300], Loss: 0.0017
SAE Epoch [278/300], Loss: 0.0017
SAE Epoch [279/300], Loss: 0.0017
SAE Epoch [280/300], Loss: 0.0017
SAE Epoch [281/300], Loss: 0.0017
SAE Epoch [282/300], Loss: 0.0017
SAE Epoch [283/300], Loss: 0.0017
SAE Epoch [284/300], Loss: 0.0017
SAE Epoch [285/300], Loss: 0.0017
SAE Epoch [286/300], Loss: 0.0017
SAE Epoch [287/300], Loss: 0.0016
SAE Epoch [288/300], Loss: 0.0017
SAE Epoch [289/300], Loss: 0.0016
SAE Epoch [290/300], Loss: 0.0016
SAE Epoch [291/300], Loss: 0.0017
SAE Epoch [292/300], Loss: 0.0016
SAE Epoch [293/300], Loss: 0.0016
SAE Epoch [294/300], Loss: 0.0016
SAE Epoch [295/300], Loss: 0.0016
SAE Epoch [296/300], Loss: 0.0016
SAE Epoch [297/300], Loss: 0.0016
SAE Epoch [298/300], Loss: 0.0017
SAE Epoch [299/300], Loss: 0.0016
SAE Epoch [300/300], Loss: 0.0016
Entraînement du SAE terminé!

Visualisation¶

Ce code permet de visualiser toutes les étapes intéressantes.

In [ ]:
def visualize_process(transformer, sae, test_loader, num_samples=10):
    device = next(transformer.parameters()).device
    transformer.eval()
    sae.eval()

    # Sélectionner des échantillons aléatoires
    all_samples = []
    for batch in test_loader:
        all_samples.extend(list(zip(batch['pixel_values'], batch['label'])))
    random_samples = random.sample(all_samples, num_samples)

    fig, axs = plt.subplots(num_samples, 5, figsize=(20, 4*num_samples))

    for i, (image, label) in enumerate(random_samples):
        image = image.to(device)

        # Image originale
        axs[i, 0].imshow(image.cpu().squeeze().numpy(), cmap='gray')
        axs[i, 0].set_title(f"Input (Label: {label})")

        # Transformer activation
        with torch.no_grad():
            transformer_output = transformer(image.unsqueeze(0), return_activations=True)
            transformer_activation = transformer_output[-1].squeeze().cpu()

        # SAE encoding et decoding
        with torch.no_grad():
            sae_output, sae_encoding = sae(transformer_activation.to(device))

        # Affichage de l'activation du Transformer
        axs[i, 1].imshow(transformer_activation.numpy().reshape(int(np.sqrt(transformer_activation.shape[0])), -1), cmap='viridis')
        axs[i, 1].set_title("Transformer Activation")

        # Affichage de l'encodage SAE
        axs[i, 2].imshow(sae_encoding.cpu().numpy().reshape(int(np.sqrt(sae_encoding.shape[0])), -1), cmap='viridis')
        axs[i, 2].set_title("SAE Encoding")

        # Affichage de la sortie SAE
        axs[i, 3].imshow(sae_output.cpu().numpy().reshape(int(np.sqrt(sae_output.shape[0])), -1), cmap='viridis')
        axs[i, 3].set_title("SAE Output")

        # Affichage de la sortie du Transformer
        with torch.no_grad():
            transformer_output = transformer(image.unsqueeze(0))
        predicted_class = torch.argmax(transformer_output).item()
        axs[i, 4].bar(range(10), nn.functional.softmax(transformer_output.squeeze(), dim=0).cpu().numpy())
        axs[i, 4].set_title(f"Transformer Output\nPredicted: {predicted_class}")
        axs[i, 4].set_xticks(range(10))

    for ax in axs.flat:
        ax.axis('off')

    plt.tight_layout()
    plt.show()

# Utilisation de la fonction
visualize_process(transformer, sae, test_loader)
No description has been provided for this image

Description :

  • première colonne : input du transformer
  • deuxième : état de la couche de sortie du MLP dans le transformer
  • troisième : état de la couche de code du SAE, qui a été entraîné sur ladite couche de sortie du MLP
  • quatrième : couche prédite par le SAE après décodage (identique à l'entrée donc)
  • cinquième : output du transformer, c'est-à dire une distribution sur les labels des chiffres. Tellement bien entraîné qu'il est souvent certain...

Chirurgie du Transformer ! (à vous de jouer)¶

Voici le code qui permet d'intervenir directement sur les neurones, et qui permet de modifier le "flux de pensée" du Transformer : on peut lui faire croire qu'il voit un 0 alors qu'il a un autre chiffre en entrée !

In [ ]:
def visualize_process_with_intervention(transformer, sae, test_loader, num_samples=20, neuron_position=None, intensity=1.0):
    device = next(transformer.parameters()).device
    transformer.eval()
    sae.eval()

    all_samples = []
    for batch in test_loader:
        all_samples.extend(list(zip(batch['pixel_values'], batch['label'])))
    random_samples = random.sample(all_samples, num_samples)

    fig, axs = plt.subplots(num_samples, 7, figsize=(28, 4*num_samples))
    fig.suptitle(f"Visualisation du processus Transformer-SAE avec intervention (Neurone: {neuron_position}, Intensité: {intensity})", fontsize=16)

    for i, (image, label) in enumerate(random_samples):
        image = image.to(device)

        # Image originale
        axs[i, 0].imshow(image.cpu().squeeze().numpy(), cmap='gray')
        axs[i, 0].set_title(f"Input (Label: {label})")

        # Transformer activation
        with torch.no_grad():
            original_output, transformer_activation = transformer(image.unsqueeze(0), return_activations=True)
            transformer_activation = transformer_activation.squeeze().cpu()

        # SAE encoding et decoding
        with torch.no_grad():
            sae_output, sae_encoding = sae(transformer_activation.to(device))

        # Affichage de l'activation du Transformer
        axs[i, 1].imshow(transformer_activation.numpy().reshape(int(np.sqrt(transformer_activation.shape[0])), -1), cmap='viridis')
        axs[i, 1].set_title("Transformer Activation")

        # Affichage de l'encodage SAE original
        axs[i, 2].imshow(sae_encoding.cpu().numpy().reshape(int(np.sqrt(sae_encoding.shape[0])), -1), cmap='viridis')
        axs[i, 2].set_title("SAE Encoding (Original)")

        # Intervention manuelle sur l'encodage SAE
        modified_encoding = sae_encoding.clone()
        if neuron_position is not None:
            modified_encoding[neuron_position] += intensity
        else:
            # Si aucun neurone spécifique n'est choisi, on modifie le premier quart comme avant
            region_size = sae_encoding.shape[0] // 4
            modified_encoding[:region_size] *= intensity

        # Affichage de l'encodage SAE modifié
        axs[i, 3].imshow(modified_encoding.cpu().numpy().reshape(int(np.sqrt(modified_encoding.shape[0])), -1), cmap='viridis')
        axs[i, 3].set_title("SAE Encoding (Modified)")

        # Décodage de l'encodage modifié
        with torch.no_grad():
            modified_output = sae.decoder(modified_encoding.to(device))

        # Affichage de la sortie SAE modifiée
        axs[i, 4].imshow(modified_output.cpu().numpy().reshape(int(np.sqrt(modified_output.shape[0])), -1), cmap='viridis')
        axs[i, 4].set_title("Modified SAE Output")

        # Affichage de la sortie du Transformer original
        original_predicted_class = torch.argmax(original_output).item()
        axs[i, 5].bar(range(10), nn.functional.softmax(original_output.squeeze(), dim=0).cpu().numpy())
        axs[i, 5].set_title(f"Original Output\nPredicted: {original_predicted_class}")
        axs[i, 5].set_xticks(range(10))

        # Calcul de la sortie du Transformer avec intervention
        with torch.no_grad():
            # Préparation des données pour le passage à travers le transformer
            x = image.view(1, -1)  # Aplatir l'image (1, 784)
            x = transformer.embedding(x)  # (1, d_model)
            x = x + transformer.pos_encoder  # Diffusion du pos_encoder
            x = x.unsqueeze(1)  # Ajouter une dimension de séquence (1, 1, d_model)

            # Passage à travers les couches du transformer
            for j, layer in enumerate(transformer.transformer_layers):
                if j == 1:
                    # Remplacer l'activation de la deuxième couche par l'activation modifiée
                    x = modified_output.unsqueeze(0).unsqueeze(0)
                else:
                    x = layer(x)

            x = x.squeeze(1)  # Retirer la dimension de séquence (1, d_model)
            modified_transformer_output = transformer.fc(x)

        modified_predicted_class = torch.argmax(modified_transformer_output).item()
        axs[i, 6].bar(range(10), nn.functional.softmax(modified_transformer_output.squeeze(), dim=0).cpu().numpy())
        axs[i, 6].set_title(f"Modified Output\nPredicted: {modified_predicted_class}")
        axs[i, 6].set_xticks(range(10))

    for ax in axs.flat:
        ax.axis('off')

    plt.tight_layout()
    plt.show()

# Utilisation de la fonction
# Pour modifier un neurone spécifique, utilisez :
# visualize_process_with_intervention(transformer, sae, test_loader, neuron_position=42, intensity=2.0)
# Pour modifier une région, utilisez :
# visualize_process_with_intervention(transformer, sae, test_loader, intensity=2.0)
# visualize_process_with_intervention(transformer, sae, test_loader)

Comment persuader un transformer qu'il voit un zéro, alors qu'il a un autre chiffre devant ses yeux ?¶

Le neurone lié au 0 dans le SAE (lorsque je l'ai entraîné) est le neurone 28. L'intensité est fixée à 2, ce qui permet de compenser les autres neurones déjà activer, pour forcer les choses.

Usage : Pour modifier un neurone, mettez neuron_position = index du neurone que vous voulez triturer, et changez son intensité !

Note : à chaque entraînement du modèle, la place du neurone change, il faut donc observer attentivement la position du neurone en question dans la sortie au dessus.

In [ ]:
visualize_process_with_intervention(transformer, sae, test_loader, neuron_position=28, intensity=2.0)
No description has been provided for this image

Et voilà comment on finit par croire qu'on est le Golden Bridge...

neurones1-408773ccd7729726.jpeg

Meta-SAE (pour aller encore plus loin)¶

In [ ]:
import torch
import torch.nn as nn
import torch.optim as optim

class MetaSAE(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(MetaSAE, self).__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        encoded = self.batch_top_k(self.encoder(x), k=4)  # BatchTopK avec k=4
        decoded = self.decoder(encoded)
        return decoded, encoded

    def batch_top_k(self, x, k):
        top_k_values, _ = torch.topk(x, k, dim=1)
        kth_values = top_k_values[:, -1].unsqueeze(1)
        return torch.where(x >= kth_values, x, torch.zeros_like(x))
In [ ]:
def train_meta_sae(meta_sae, sae, num_epochs=2000, learning_rate=0.001, sparsity_weight=0.1):
    optimizer = optim.Adam(meta_sae.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    # Obtenir les poids du décodeur du SAE original
    decoder_weights = sae.decoder.weight.t().detach()

    for epoch in range(num_epochs):
        optimizer.zero_grad()

        # Forward pass
        decoded, encoded = meta_sae(decoder_weights)

        # Calcul des pertes
        reconstruction_loss = criterion(decoded, decoder_weights)
        sparsity_loss = torch.mean(torch.abs(encoded))
        loss = reconstruction_loss + sparsity_weight * sparsity_loss

        # Backward pass et optimisation
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 100 == 0:
            print(f"Meta-SAE Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
In [ ]:
# Créer et entraîner le meta-SAE
input_dim = sae.decoder.weight.shape[0]  # Taille des latents du SAE original
hidden_dim = 100  # Comme dans l'article original
meta_sae = MetaSAE(input_dim, hidden_dim).to(device)
train_meta_sae(meta_sae, sae)
Meta-SAE Epoch [100/2000], Loss: 0.1126
Meta-SAE Epoch [200/2000], Loss: 0.0408
Meta-SAE Epoch [300/2000], Loss: 0.0290
Meta-SAE Epoch [400/2000], Loss: 0.0235
Meta-SAE Epoch [500/2000], Loss: 0.0190
Meta-SAE Epoch [600/2000], Loss: 0.0161
Meta-SAE Epoch [700/2000], Loss: 0.0152
Meta-SAE Epoch [800/2000], Loss: 0.0131
Meta-SAE Epoch [900/2000], Loss: 0.0128
Meta-SAE Epoch [1000/2000], Loss: 0.0120
Meta-SAE Epoch [1100/2000], Loss: 0.0110
Meta-SAE Epoch [1200/2000], Loss: 0.0102
Meta-SAE Epoch [1300/2000], Loss: 0.0100
Meta-SAE Epoch [1400/2000], Loss: 0.0095
Meta-SAE Epoch [1500/2000], Loss: 0.0093
Meta-SAE Epoch [1600/2000], Loss: 0.0092
Meta-SAE Epoch [1700/2000], Loss: 0.0085
Meta-SAE Epoch [1800/2000], Loss: 0.0082
Meta-SAE Epoch [1900/2000], Loss: 0.0083
Meta-SAE Epoch [2000/2000], Loss: 0.0076
In [ ]:
import torch
import matplotlib.pyplot as plt
import numpy as np

def visualize_meta_latents(meta_sae, num_latents=10):
    meta_sae.eval()

    # Obtenir les poids du décodeur du meta-SAE
    with torch.no_grad():
        decoder_weights = meta_sae.decoder.weight.t().cpu().numpy()

    # Calculer le nombre de lignes et de colonnes pour le subplot
    num_rows = (num_latents + 4) // 5  # 5 images par ligne
    num_cols = min(num_latents, 5)

    fig, axs = plt.subplots(num_rows, num_cols, figsize=(15, 3*num_rows))
    axs = axs.flatten() if num_rows > 1 else [axs]

    for i in range(num_latents):
        # Redimensionner le poids en image carrée
        img_size = int(np.sqrt(decoder_weights.shape[0]))
        img = decoder_weights[:, i].reshape(img_size, img_size)

        # Afficher l'image
        axs[i].imshow(img, cmap='viridis')
        axs[i].axis('off')
        axs[i].set_title(f'Meta-latent {i}')

    # Cacher les subplots vides
    for i in range(num_latents, len(axs)):
        axs[i].axis('off')

    plt.tight_layout()
    plt.show()

# Appeler la fonction
visualize_meta_latents(meta_sae, num_latents=20)
No description has been provided for this image
In [ ]:
def analyze_meta_latents_activation(meta_sae, sae, transformer, test_loader, num_samples=1000):
    meta_sae.eval()
    sae.eval()
    transformer.eval()

    all_activations = []
    all_labels = []

    with torch.no_grad():
        for batch in test_loader:
            inputs, labels = batch['pixel_values'].to(device), batch['label']

            if len(all_labels) >= num_samples:
                break

            # Obtenir les activations du transformer
            _, transformer_activations = transformer(inputs, return_activations=True)

            # Passer à travers le SAE
            sae_encoded, _ = sae(transformer_activations.squeeze(1))

            # Passer à travers le meta-SAE
            _, meta_encoded = meta_sae(sae_encoded)

            all_activations.append(meta_encoded.cpu())
            all_labels.extend(labels.numpy())

    all_activations = torch.cat(all_activations, dim=0).numpy()
    all_labels = np.array(all_labels)

    # Calculer l'activation moyenne pour chaque chiffre
    mean_activations = np.array([all_activations[all_labels == i].mean(axis=0) for i in range(10)])

    # Afficher un heatmap des activations moyennes
    plt.figure(figsize=(15, 8))
    plt.imshow(mean_activations, aspect='auto', cmap='viridis')
    plt.colorbar(label='Activation moyenne')
    plt.xlabel('Meta-latents')
    plt.ylabel('Chiffres')
    plt.title('Activation moyenne des meta-latents pour chaque chiffre')
    plt.yticks(range(10))
    plt.show()

# Appeler la fonction
analyze_meta_latents_activation(meta_sae, sae, transformer, test_loader)
No description has been provided for this image
In [ ]:
# Fonction pour visualiser les meta-latents
def visualize_meta_latents(sae, meta_sae, num_latents=5):
    sae.eval()
    meta_sae.eval()

    with torch.no_grad():
        decoder_weights = sae.decoder.weight.t().detach()
        _, meta_encoded = meta_sae(decoder_weights)

    fig, axs = plt.subplots(num_latents, 2, figsize=(10, 4*num_latents))

    for i in range(num_latents):
        # SAE latent
        axs[i, 0].imshow(decoder_weights[:, i].cpu().numpy().reshape(int(np.sqrt(decoder_weights.shape[0])), -1), cmap='viridis')
        axs[i, 0].set_title(f"SAE Latent {i}")

        # Meta-SAE decomposition
        axs[i, 1].imshow(meta_encoded[i].cpu().numpy().reshape(int(np.sqrt(meta_encoded.shape[1])), -1), cmap='viridis')
        axs[i, 1].set_title(f"Meta-SAE Decomposition {i}")

    for ax in axs.flat:
        ax.axis('off')

    plt.tight_layout()
    plt.show()

# Visualiser les meta-latents
visualize_meta_latents(sae, meta_sae)
No description has been provided for this image
In [ ]:
def visualize_max_activation(meta_sae, num_iterations=100):
    meta_sae.eval()  # Assurez-vous que le meta-SAE est en mode évaluation

    # Créer une entrée aléatoire qui nécessite un gradient
    input = torch.randn(1, meta_sae.encoder.in_features, requires_grad=True, device=device)
    optimizer = optim.Adam([input], lr=0.1)

    for _ in range(num_iterations):
        optimizer.zero_grad()

        # Forward pass
        with torch.set_grad_enabled(True):
            _, encoded = meta_sae(input)
            loss = -encoded.abs().mean()  # Maximiser l'activation moyenne absolue

        # Backward pass
        loss.backward()
        optimizer.step()

    # Visualiser l'entrée optimisée
    plt.figure(figsize=(6, 6))
    plt.imshow(input.detach().cpu().numpy().reshape(int(np.sqrt(input.shape[1])), -1), cmap='viridis')
    plt.title("Input that maximizes overall meta-SAE activation")
    plt.axis('off')
    plt.show()

# Visualiser l'entrée qui maximise l'activation globale du meta-SAE
visualize_max_activation(meta_sae)
No description has been provided for this image
In [ ]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

def intervene_and_observe(meta_sae, sae, transformer, test_loader, latent_index, scale=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    meta_sae.to(device)
    sae.to(device)
    transformer.to(device)

    meta_sae.eval()
    sae.eval()
    transformer.eval()

    for batch in test_loader:
        image, label = batch['pixel_values'].to(device), batch['label']

        with torch.no_grad():
            print("1. Original prediction")
            original_output = transformer(image)
            original_pred = torch.argmax(original_output, dim=1)
            print(f"Original output shape: {original_output.shape}")
            print(f"Original prediction: {original_pred[0].item()}")

            print("\n2. Transformer activations")
            _, activations = transformer(image, return_activations=True)
            print(f"Transformer activations shape: {activations.shape}")

            print("\n3. SAE encoding")
            sae_encoded, _ = sae(activations.squeeze(1))
            print(f"SAE encoded shape: {sae_encoded.shape}")

            print("\n4. Meta-SAE encoding")
            _, meta_encoded = meta_sae(sae_encoded)
            print(f"Meta-SAE encoded shape: {meta_encoded.shape}")

            print("\n5. Intervention on meta-latent")
            modified_meta_encoded = meta_encoded.clone()
            modified_meta_encoded[0, latent_index] += scale
            print(f"Modified meta-encoded shape: {modified_meta_encoded.shape}")

            print("\n6. Meta-SAE decoding")
            reconstructed_sae = meta_sae.decoder(modified_meta_encoded)
            print(f"Reconstructed SAE shape: {reconstructed_sae.shape}")

            print("\n7. Adjusting reconstructed SAE shape")
            adjusted_reconstructed_sae = F.pad(reconstructed_sae, (0, 36))  # Add 36 zeros to make it 100
            print(f"Adjusted reconstructed SAE shape: {adjusted_reconstructed_sae.shape}")

            print("\n8. SAE decoding")
            reconstructed_activation = sae.decoder(adjusted_reconstructed_sae)
            print(f"Reconstructed activation shape: {reconstructed_activation.shape}")

            print("\n9. Reshaping for transformer")
            reshaped_activation = reconstructed_activation.unsqueeze(1)  # Add sequence dimension
            print(f"Reshaped activation shape: {reshaped_activation.shape}")

            print("\n10. Passing through transformer layers")
            for layer in transformer.transformer_layers:
                reshaped_activation = layer(reshaped_activation)
            print(f"Transformed activation shape: {reshaped_activation.shape}")

            print("\n11. Passing through final linear layer")
            modified_output = transformer.fc(reshaped_activation.squeeze(1))
            print(f"Modified output shape: {modified_output.shape}")

            modified_pred = torch.argmax(modified_output, dim=1)
            print(f"Modified prediction: {modified_pred[0].item()}")

        # Visualisation
        plt.figure(figsize=(15, 5))

        # Image originale
        plt.subplot(1, 3, 1)
        plt.imshow(image[0].cpu().squeeze().numpy(), cmap='gray')
        plt.title(f"Original: {original_pred[0].item()}")
        plt.axis('off')

        # Image reconstruite
        plt.subplot(1, 3, 2)
        # Modification ici pour gérer la dimensionnalité
        reconstructed_image = reconstructed_activation[0].cpu().numpy()
        if reconstructed_image.shape[0] == 64:
            reconstructed_image = reconstructed_image.reshape(8, 8)
        elif reconstructed_image.shape[0] == 784:
            reconstructed_image = reconstructed_image.reshape(28, 28)
        else:
            print(f"Unexpected shape: {reconstructed_image.shape}")
            reconstructed_image = reconstructed_image.reshape(-1, int(reconstructed_image.shape[0]**0.5))
        plt.imshow(reconstructed_image, cmap='viridis')
        plt.title(f"Reconstructed")
        plt.axis('off')

        # Sortie modifiée
        plt.subplot(1, 3, 3)
        plt.imshow(modified_output.cpu().numpy().T, cmap='viridis', aspect='auto')
        plt.title(f"Modified Output: {modified_pred[0].item()}")
        plt.colorbar(label='Logit value')
        plt.yticks(range(10))

        plt.tight_layout()
        plt.show()

        break  # Juste un exemple

# Utilisation de la fonction
latent_index = 0  # Choisissez l'index du meta-latent à modifier
scale = 5  # Choisissez l'ampleur de la modification

intervene_and_observe(meta_sae, sae, transformer, test_loader, latent_index, scale)
Using device: cuda
1. Original prediction
Original output shape: torch.Size([64, 10])
Original prediction: 7

2. Transformer activations
Transformer activations shape: torch.Size([64, 1, 64])

3. SAE encoding
SAE encoded shape: torch.Size([64, 64])

4. Meta-SAE encoding
Meta-SAE encoded shape: torch.Size([64, 100])

5. Intervention on meta-latent
Modified meta-encoded shape: torch.Size([64, 100])

6. Meta-SAE decoding
Reconstructed SAE shape: torch.Size([64, 64])

7. Adjusting reconstructed SAE shape
Adjusted reconstructed SAE shape: torch.Size([64, 100])

8. SAE decoding
Reconstructed activation shape: torch.Size([64, 64])

9. Reshaping for transformer
Reshaped activation shape: torch.Size([64, 1, 64])

10. Passing through transformer layers
Transformed activation shape: torch.Size([64, 1, 64])

11. Passing through final linear layer
Modified output shape: torch.Size([64, 10])
Modified prediction: 2
No description has been provided for this image
In [ ]:
from sklearn.cluster import KMeans
import numpy as np
import matplotlib.pyplot as plt

def cluster_meta_latents(meta_sae, n_clusters=10):
    meta_sae.eval()

    with torch.no_grad():
        # Obtenir les poids du décodeur du meta-SAE
        meta_latents = meta_sae.decoder.weight.t().cpu().numpy()

    # Appliquer K-means
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = kmeans.fit_predict(meta_latents)

    # Visualiser les centroïdes des clusters
    plt.figure(figsize=(15, 3 * ((n_clusters + 4) // 5)))
    for i in range(n_clusters):
        cluster_latents = meta_latents[clusters == i]
        centroid = cluster_latents.mean(axis=0)

        plt.subplot(((n_clusters + 4) // 5), 5, i+1)
        img_size = int(np.sqrt(centroid.shape[0]))
        plt.imshow(centroid.reshape(img_size, img_size), cmap='viridis')
        plt.title(f"Cluster {i}")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

    # Afficher la distribution des clusters
    unique, counts = np.unique(clusters, return_counts=True)
    plt.figure(figsize=(10, 5))
    plt.bar(unique, counts)
    plt.xlabel("Cluster")
    plt.ylabel("Number of meta-latents")
    plt.title("Distribution of meta-latents across clusters")
    plt.show()

    return clusters

# Appeler la fonction
clusters = cluster_meta_latents(meta_sae)
No description has been provided for this image
No description has been provided for this image

Questions qu'on se pose¶

Deux modèles entraînés avec les mêmes paramètres ont-ils les mêmes activations ?¶

In [ ]:
# Initialisation des modèles
transformer1 = MNISTTransformer(d_model, nhead, num_layers)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(transformer1.parameters(), lr=learning_rate)

# Entraînement du Transformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer1.to(device)
Out[ ]:
MNISTTransformer(
  (embedding): Linear(in_features=784, out_features=64, bias=True)
  (transformer_layers): ModuleList(
    (0-1): 2 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (linear1): Linear(in_features=64, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=64, bias=True)
      (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Linear(in_features=64, out_features=10, bias=True)
)
In [ ]:
for epoch in range(num_epochs):
    transformer1.train()
    total_loss = 0
    for batch in train_loader:
        inputs, labels = batch['pixel_values'].to(device), batch['label'].to(device)

        optimizer.zero_grad()
        outputs, _ = transformer1(inputs, return_activations=True)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    train_loss = total_loss / len(train_loader)

    # Évaluation sur l'ensemble de test
    test_loss, test_accuracy = evaluate_transformer(transformer1, test_loader, criterion, device)

    print(f"Transformer1 Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

print("Entraînement du Transformer1 terminé!")
Transformer1 Epoch [1/5]
  Train Loss: 0.3262
  Test Loss: 0.1777, Test Accuracy: 94.68%
Transformer1 Epoch [2/5]
  Train Loss: 0.1494
  Test Loss: 0.1367, Test Accuracy: 95.88%
Transformer1 Epoch [3/5]
  Train Loss: 0.1122
  Test Loss: 0.1273, Test Accuracy: 96.38%
Transformer1 Epoch [4/5]
  Train Loss: 0.0918
  Test Loss: 0.1059, Test Accuracy: 96.85%
Transformer1 Epoch [5/5]
  Train Loss: 0.0778
  Test Loss: 0.0863, Test Accuracy: 97.47%
Entraînement du Transformer1 terminé!
In [ ]:
# Initialisation des modèles
transformer2 = MNISTTransformer(d_model, nhead, num_layers)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(transformer2.parameters(), lr=learning_rate)

# Entraînement du Transformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer2.to(device)
Out[ ]:
MNISTTransformer(
  (embedding): Linear(in_features=784, out_features=64, bias=True)
  (transformer_layers): ModuleList(
    (0-1): 2 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (linear1): Linear(in_features=64, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=64, bias=True)
      (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Linear(in_features=64, out_features=10, bias=True)
)
In [ ]:
for epoch in range(num_epochs):
    transformer2.train()
    total_loss = 0
    for batch in train_loader:
        inputs, labels = batch['pixel_values'].to(device), batch['label'].to(device)

        optimizer.zero_grad()
        outputs, _ = transformer2(inputs, return_activations=True)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    train_loss = total_loss / len(train_loader)

    # Évaluation sur l'ensemble de test
    test_loss, test_accuracy = evaluate_transformer(transformer2, test_loader, criterion, device)

    print(f"Transformer2 Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

print("Entraînement du Transformer2 terminé!")
Transformer2 Epoch [1/5]
  Train Loss: 0.3153
  Test Loss: 0.1776, Test Accuracy: 94.45%
Transformer2 Epoch [2/5]
  Train Loss: 0.1472
  Test Loss: 0.1358, Test Accuracy: 96.06%
Transformer2 Epoch [3/5]
  Train Loss: 0.1124
  Test Loss: 0.1173, Test Accuracy: 96.32%
Transformer2 Epoch [4/5]
  Train Loss: 0.0913
  Test Loss: 0.0972, Test Accuracy: 97.20%
Transformer2 Epoch [5/5]
  Train Loss: 0.0787
  Test Loss: 0.0977, Test Accuracy: 96.98%
Entraînement du Transformer2 terminé!
In [ ]:
def visualize_process(transformer1, transformer2, test_loader, num_samples=10):
    device = next(transformer1.parameters()).device
    transformer1.eval()
    transformer2.eval()

    # Sélectionner des échantillons aléatoires
    all_samples = []
    for batch in test_loader:
        all_samples.extend(list(zip(batch['pixel_values'], batch['label'])))
    random_samples = random.sample(all_samples, num_samples)

    fig, axs = plt.subplots(num_samples, 5, figsize=(20, 4*num_samples))

    for i, (image, label) in enumerate(random_samples):
        image = image.to(device)

        # Image originale
        axs[i, 0].imshow(image.cpu().squeeze().numpy(), cmap='gray')
        axs[i, 0].set_title(f"Input (Label: {label})")

        # Transformer1 activation
        with torch.no_grad():
            transformer1_output = transformer1(image.unsqueeze(0), return_activations=True)
            transformer1_activation = transformer1_output[-1].squeeze().cpu()

        # Transformer2 activation
        with torch.no_grad():
            transformer2_output = transformer2(image.unsqueeze(0), return_activations=True)
            transformer2_activation = transformer2_output[-1].squeeze().cpu()

        # Affichage de l'activation du Transformer1
        axs[i, 1].imshow(transformer1_activation.numpy().reshape(int(np.sqrt(transformer1_activation.shape[0])), -1), cmap='viridis')
        axs[i, 1].set_title("Transformer1 Activation")

        # Affichage de l'activation du Transformer2
        axs[i, 2].imshow(transformer2_activation.numpy().reshape(int(np.sqrt(transformer2_activation.shape[0])), -1), cmap='viridis')
        axs[i, 2].set_title("Transformer2 Activation")

        # Affichage de la sortie du Transformer1
        with torch.no_grad():
            transformer1_output = transformer1(image.unsqueeze(0))
            predicted_class1 = torch.argmax(transformer1_output).item()
        axs[i, 3].bar(range(10), nn.functional.softmax(transformer1_output.squeeze(), dim=0).cpu().numpy())
        axs[i, 3].set_title(f"Transformer1 Output\nPredicted: {predicted_class1}")
        axs[i, 3].set_xticks(range(10))

        # Affichage de la sortie du Transformer2
        with torch.no_grad():
            transformer2_output = transformer2(image.unsqueeze(0))
            predicted_class2 = torch.argmax(transformer2_output).item()
        axs[i, 4].bar(range(10), nn.functional.softmax(transformer2_output.squeeze(), dim=0).cpu().numpy())
        axs[i, 4].set_title(f"Transformer2 Output\nPredicted: {predicted_class2}")
        axs[i, 4].set_xticks(range(10))

    for ax in axs.flat:
        ax.axis('off')

    plt.tight_layout()
    plt.show()

# Utilisation de la fonction
visualize_process(transformer1, transformer2, test_loader)
No description has been provided for this image

La réponse à la question est donc NON, deux transformers de même architecture ne traitent pas les données de la même manière.

Mais c'est un peu court ! On sent bien que d'une certaine manière, ces deux transformers réalisent la même fonction, mais l'implémentent d'une manière différente en quelque sorte.

En fait, nous allons voir que ce sont les représentations qui convergent vers la même chose, et non, dans l'absolu, les activations des neurones.

Deux modèles entraînés avec les mêmes paramètres ont-ils les mêmes représentations ?¶

Reprenons nos deux modèle. Il existe de nombreuses métriques pour savoir si deux modèles partagent une même représentation. Une méthode simple pour cela est de calculer les plus proches voisins.

Qu'est-ce que Cycle KNN ?¶

Cycle KNN (K-Nearest Neighbors) est une métrique utilisée pour comparer la similarité entre deux ensembles de représentations vectorielles, généralement issues de deux modèles différents ou de deux couches différentes d'un même modèle.

Comment fonctionne Cycle KNN ?¶

Le processus se déroule en plusieurs étapes :

  1. Pour chaque point dans l'ensemble A, on trouve ses K plus proches voisins dans l'ensemble B.
  2. Pour chacun de ces K voisins dans B, on trouve leurs K plus proches voisins dans A.
  3. On vérifie si le point original de A est parmi ces K voisins.

Formulation mathématique¶

Soit $A = \{a_1, ..., a_n\}$ et $B = \{b_1, ..., b_n\}$ deux ensembles de vecteurs.

Définissons $NN_k(x, Y)$ comme l'ensemble des k plus proches voisins de x dans Y.

La précision Cycle KNN est alors définie comme :

$$\text{CycleKNN}(A, B, k) = \frac{1}{n} \sum_{i=1}^n \mathbb{1}[a_i \in NN_k(NN_k(a_i, B), A)]$$

où $\mathbb{1}[\cdot]$ est la fonction indicatrice.

Interprétation¶

  • Une valeur élevée (proche de 1) indique une forte correspondance entre les structures de voisinage dans A et B.
  • Une valeur faible (proche de 0) suggère des structures de voisinage très différentes.

Pourquoi est-ce utile dans ce contexte ?¶

  1. Comparaison de modèles : Cycle KNN permet de comparer les représentations internes de différents modèles (ici, transformer1 et transformer2) sur les mêmes données d'entrée.

  2. Analyse de la préservation de structure : Il aide à comprendre si les modèles préservent des structures similaires dans leurs espaces de représentation respectifs.

  3. Évaluation de la cohérence : Une forte similarité Cycle KNN suggère que les deux modèles ont appris des représentations cohérentes, malgré d'éventuelles différences dans leurs architectures ou leurs processus d'entraînement.

  4. Détection de divergences : Une faible similarité peut indiquer des divergences significatives dans la façon dont les modèles traitent et représentent les informations.

  5. Analyse de transfert d'apprentissage : Dans le contexte du transfert d'apprentissage, Cycle KNN peut aider à évaluer à quel point les représentations apprises sont transférables entre les modèles.

En résumé, Cycle KNN fournit une mesure quantitative de la similarité structurelle entre les espaces de représentation de deux modèles, offrant ainsi des insights précieux sur leurs comportements internes et leurs capacités de généralisation.

In [ ]:
import torch
import torch.nn.functional as F
import random

def compute_nearest_neighbors(feats, topk=100):
    feats = F.normalize(feats, p=2, dim=1)
    sim_matrix = feats @ feats.T
    sim_matrix.fill_diagonal_(-float('inf'))
    return sim_matrix.argsort(dim=1, descending=True)[:, :topk]

def compute_knn_accuracy(knn):
    n = knn.shape[0]
    acc = knn == torch.arange(n, device=knn.device).view(-1, 1, 1)
    return acc.float().view(n, -1).max(dim=1).values.mean()

def cycle_knn(feats_A, feats_B, topk):
    print(f"Shape of feats_A: {feats_A.shape}, feats_B: {feats_B.shape}")
    print(f"Range of values in feats_A: [{feats_A.min():.4f}, {feats_A.max():.4f}]")
    print(f"Range of values in feats_B: [{feats_B.min():.4f}, {feats_B.max():.4f}]")

    knn_A = compute_nearest_neighbors(feats_A, topk)
    knn_B = compute_nearest_neighbors(feats_B, topk)

    print(f"Shape of knn_A: {knn_A.shape}, knn_B: {knn_B.shape}")

    accuracy = compute_knn_accuracy(knn_A[knn_B])
    print(f"Raw accuracy: {accuracy.item()}")

    return accuracy.item()

def collect_activations(model, data_loader, device, num_samples=500):
    model.eval()
    activations = []
    sample_count = 0
    with torch.no_grad():
        for batch in data_loader:
            inputs = batch['pixel_values'].to(device)
            _, batch_activations = model(inputs, return_activations=True)
            activations.append(batch_activations.squeeze(1).cpu())
            sample_count += batch_activations.shape[0]
            if sample_count >= num_samples:
                break
    activations = torch.cat(activations, dim=0)[:num_samples]
    return activations

# Utilisation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Collecte des activations pour un sous-ensemble
activations1 = collect_activations(transformer1, train_loader, device, num_samples=500)
activations2 = collect_activations(transformer2, train_loader, device, num_samples=500)

print("Activations collected.")
print(f"activations1 shape: {activations1.shape}")
print(f"activations2 shape: {activations2.shape}")

# Calcul de la similarité
similarity = cycle_knn(activations1, activations2, topk=40)
print(f"La similarité cycle_knn entre transformer1 et transformer2 est : {similarity}")
Activations collected.
activations1 shape: torch.Size([500, 64])
activations2 shape: torch.Size([500, 64])
Shape of feats_A: torch.Size([500, 64]), feats_B: torch.Size([500, 64])
Range of values in feats_A: [-3.3848, 3.1561]
Range of values in feats_B: [-4.2418, 3.4259]
Shape of knn_A: torch.Size([500, 40]), knn_B: torch.Size([500, 40])
Raw accuracy: 0.9100000262260437
La similarité cycle_knn entre transformer1 et transformer2 est : 0.9100000262260437
In [ ]:
# Initialisation des modèles
transformer_alea = MNISTTransformer(d_model, nhead, num_layers)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(transformer_alea.parameters(), lr=learning_rate)

# Entraînement du Transformer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer_alea.to(device)
Out[ ]:
MNISTTransformer(
  (embedding): Linear(in_features=784, out_features=64, bias=True)
  (transformer_layers): ModuleList(
    (0-1): 2 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (linear1): Linear(in_features=64, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=64, bias=True)
      (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Linear(in_features=64, out_features=10, bias=True)
)
In [ ]:
# Collecte des activations pour un sous-ensemble
activations1 = collect_activations(transformer_alea, train_loader, device, num_samples=500)
activations2 = collect_activations(transformer2, train_loader, device, num_samples=500)

print("Activations collected.")
print(f"activations1 shape: {activations1.shape}")
print(f"activations2 shape: {activations2.shape}")

# Calcul de la similarité
similarity = cycle_knn(activations1, activations2, topk=40)
print(f"La similarité cycle_knn entre transformer1 et transformer2 est : {similarity}")
Activations collected.
activations1 shape: torch.Size([500, 64])
activations2 shape: torch.Size([500, 64])
Shape of feats_A: torch.Size([500, 64]), feats_B: torch.Size([500, 64])
Range of values in feats_A: [-3.4254, 2.5900]
Range of values in feats_B: [-3.9892, 3.2260]
Shape of knn_A: torch.Size([500, 40]), knn_B: torch.Size([500, 40])
Raw accuracy: 0.8519999980926514
La similarité cycle_knn entre transformer1 et transformer2 est : 0.8519999980926514