Encoder-decoder and attention mechanisms

Nous nous intéressons maintenant à un problème de transduction d'une séquence de symbole en une autre (autrement dit de la traduction). Traduire de séquences de mots d'une langue à une autre n'est pas raisonnable sans GPU. Nous allons donc nous intéresser au problème de la phonétisation automatique. Étant donné une séquence de graphèmes, nous devons générer la séquence de phonèmes correspondant.

Ce problème peut être vu comme la composition de deux problèmes vus dans les notebooks précédents :

  • le problème de prédiction de la polarité d'un tweet : lire une séquence en entrée et produire une représentation à partir de cette séquence (ici les graphèmes)
  • le problème de modèlisation du langage : partir d'une représentation cachée puis générer une séquence de symboles (ici les phonèmes)

On appelle souvent ce cadre "encodeur-décodeur" ou "seq2seq.

Comme les modèles de langages conditionnés (par leur état caché initial) ne fonctionnent pas très bien car ils doivent emagasiner toute l'information sur la séquence vue en entrée dans une représentation de taille fixe, nous dans une deuxième étape augmenter le modèle d'un macanisme d'attention.

Commençons par télécharger un dictionnaire phonétisé de petite taille créé à partir d'un sous-ensemble du dictionnaire de CMU (utilisé dans l'ASR sphinx). Ce dictionnaire, regénérable avec les commandes en commentaire, contient sur chaque ligne un mot, suivi d'un séparateur "|||" suivi d'une sèquence de phonèmes.

In [ ]:
%matplotlib inline
In [ ]:
%%bash
# wget -q http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b
# iconv -f latin1 cmudict-0.7b | grep "^[A-Z]" | awk 'NF < 16 {print}' | sed 's/([0-9]*)//;s/  / ||| /' | shuf | head -5500 > cmudict-0.7b.filtered
[ -f cmudict-0.7b.filtered ] || wget -q https://raw.githubusercontent.com/benob/dl4nlp-tutorials-data/master/cmudict-0.7b.filtered
head cmudict-0.7b.filtered

Le chargement des données nécessite de convertir les mots en listes de caractères, et la séquence de phonème en liste de chaînes de caractères.

In [ ]:
words = []
phonemes = []

with open("cmudict-0.7b.filtered") as fp:
    for line in fp:
        word, phones = line.strip().split(' ||| ')
        words.append(list(word))
        phonemes.append(phones.split())

print(words[42], phonemes[42])
len(words)

Si on regarde la distribution des tailles de mots et de phonétisations, on peut voir que l'on couvre la plupart des cas avec une longueur maximale de 16.

In [ ]:
from matplotlib import pyplot as plt

plt.hist([len(x) for x in words])
plt.show()
plt.hist([len(x) for x in phonemes])
plt.show()

La conversion des entrées et sorties du système en séquences d'entiers se fait comme pour l'analyse de sentiment et le modèle de langage. Notez la présence du symbole <start> pour la partie modèle de langage.

In [ ]:
import collections

letter_vocab = collections.defaultdict(lambda: len(letter_vocab))
letter_vocab['<eos>'] = 0

phoneme_vocab = collections.defaultdict(lambda: len(phoneme_vocab))
phoneme_vocab['<eos>'] = 0
phoneme_vocab['<start>'] = 1

int_words = []
int_phonemes = []

for word, phones in zip(words, phonemes):
    int_words.append([letter_vocab[x] for x in word])
    int_phonemes.append([phoneme_vocab[x] for x in phones])

print(len(letter_vocab), len(phoneme_vocab))
print(int_words[42], int_phonemes[42])

rev_letter_vocab = {y: x for x, y in letter_vocab.items()}
rev_phoneme_vocab = {y: x for x, y in phoneme_vocab.items()}

print([rev_letter_vocab[x] for x in int_words[42]], [rev_phoneme_vocab[x] for x in int_phonemes[42]])

Nous allons utiliser des hyperparamètres de magnitude réduite pour pouvoir entraîner le système sur CPU. Sur GPU, on pourrait prendre de bien plus grands états cachés. De plus, rien ne nous empêche d'avoir des tailles d'embedding et d'état caché différentes selon que l'on est dans la partie "encodeur" ou "décodeur".

In [ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

max_len = 16
batch_size = 16
embed_size = 16
hidden_size = 32

Une fois que l'on a des listes de listes d'entiers, il est relativement simple de les mettre dans des tenseurs avec le padding habituel. Le choix de coller les séquences à gauche est complètement arbitraire.

In [ ]:
X = torch.zeros((len(int_words), max_len)).long()
Y = torch.zeros((len(int_phonemes), max_len)).long()

for i, (word, phones) in enumerate(zip(int_words, int_phonemes)):
    word_length = min(max_len, len(word))
    X[i,0:word_length] = torch.LongTensor(word[:word_length])
    phones_length = min(max_len, len(phones))
    Y[i,0:phones_length] = torch.LongTensor(phones[:phones_length])

print(X[42].tolist())
print(Y[42].tolist())

Le corpus est divisé en un ensemble d'entraînement et de validation, et nous utilisons les facilités proposées par pytorch pour la génération des batches.

In [ ]:
X_train = X[:5000]
Y_train = Y[:5000]
X_valid = X[5000:]
Y_valid = Y[5000:]

from torch.utils.data import TensorDataset, DataLoader
train_set = TensorDataset(X_train, Y_train)
valid_set = TensorDataset(X_valid, Y_valid)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size)

Voici notre premier modèle. C'est un encodeur-decodeur classique qui utilise une couche d'embedding pour projeter les caractères du mot vers un espace d'embedding, puis une couche recurrente bidirectionnelle pour créer une représentation de l'intégralité du mot. La représentation issue de cette couche sera de taille (num_layers * num_directions, batch_size, hidden_size), donc il faut que la seconde couche récurrente qui va générer les phonèmes ait une couche cachée de taille 2 * hidden_size. Cette dernière est construite comme un modèle de langage : elle commence par le symbole <start>, le projette dans un espace d'embedding, le passe dans la couche récurrente, puis la sortie de cette dernière dans une couche de décision qui génère un vecteur de scores de la taille du vocabulaire des phonèmes.

L'inférence est divisée en deux, la partie encodage qui renvoie l'état caché à l'issue de la lecture du mot, et la partie décodage qui renvoie la décision à chaque position pour la séquence de phonème ainsi que l'état caché à la fin (pour pouvoir faire un décodage phonème par phonème comme on l'a fait dans le modèle de langage).

In [ ]:
class Seq2SeqModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.letter_embed = nn.Embedding(len(letter_vocab), embed_size, padding_idx=letter_vocab['<eos>'])
        self.phoneme_embed = nn.Embedding(len(phoneme_vocab), embed_size)
        self.letter_rnn = nn.GRU(embed_size, hidden_size, num_layers=1, bias=False, bidirectional=True, batch_first=True)
        self.phoneme_rnn = nn.GRU(embed_size, 2 * hidden_size, num_layers=1, batch_first=True)
        self.dropout = nn.Dropout(0.3)
        # size of hidden state: (num_layers * num_directions, batch_size, hidden_size)
        self.decision = nn.Linear(hidden_size * 2 * 1, len(phoneme_vocab))
    
    def encode(self, word):
        embed = self.letter_embed(word)
        output, h_n = self.letter_rnn(embed)
        return self.dropout(h_n.transpose(0, 1).contiguous().view(1, word.size(0), -1))
    
    def decode(self, phones, h_0):
        embed = self.phoneme_embed(phones)
        output, h_n = self.phoneme_rnn(embed, h_0)
        return self.dropout(self.decision(output)), h_n
    
    def forward(self, word, phones):
        output, h_n = self.decode(phones, self.encode(word))
        return output

seq2seq_model = Seq2SeqModel()
seq2seq_model

On peut vérifier que le modèle renvoie bien un tenseur de taille (batch_size, sequence_length, num_phonemes). Pour cela nous passons $Y$ mais ce dernier représente les phonèmes à générer, pas les phonèmes précédents.

In [ ]:
with torch.no_grad():
  print(seq2seq_model(X[:3], Y[:3]).size())

L'évaluation sur les données de validation peut renvoyer la perplexité (ou un taux d'erreur). Par contre, le loader renvoie des paires $(x, y)$ contenant des batches de mots et phonétisations correspondantes. Donc il est nécessaire de créer une nouvelle variable qui contient les phonèmes décalés vers la gauche (phonème précédent) précédés du symbole <start>. Pour conserver la taille de séquence, on utilise le sous-tenseur y[:,:-1] qui représente tous les éléments de y sauf le dernier (sur la dimension 1), pour le batch en intégralité (dimension 0).

In [ ]:
import math

def perf(model, loader):
    criterion = nn.CrossEntropyLoss()
    model.eval()
    total_loss = num = 0
    for x, y in loader:
      with torch.no_grad():
        x2 = torch.cat([phoneme_vocab['<start>'] * torch.ones(y.size(0), 1).long(), y[:,:-1]], 1)
        y_scores = model(x, x2)
        loss = criterion(y_scores.view(y.size(0) * y.size(1), -1), y.view(y.size(0) * y.size(1)))
        total_loss += loss.item()
        num += len(y)
    return total_loss / num, math.exp(total_loss / num)

perf(seq2seq_model, valid_loader)

Il est alors nécessaire de modifier la fonction d'apprentissage du modèle de langage de la même manière.

In [ ]:
def fit(model, epochs):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    for epoch in range(epochs):
        model.train()
        total_loss = num = 0
        for x, y in train_loader:
            x2 = torch.cat([phoneme_vocab['<start>'] * torch.ones(y.size(0), 1).long(), y[:,:-1]], 1)
            optimizer.zero_grad()
            y_scores = model(x, x2)
            loss = criterion(y_scores.view(y.size(0) * y.size(1), -1), y.view(y.size(0) * y.size(1)))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            num += len(y)
        print(epoch, total_loss / num, *perf(model, valid_loader))

fit(seq2seq_model, 10)

Nous pouvons faire une fonction de génération pour ce modèle. La différence avec le modèle de langage est que nous commençons avec un état caché généré par l'encodeur, et que la fonction decode() permet de générer la séquence phonème par phonème.

Une fois le modèle entraîné, on s'aperçoit que le générateur n'est pas si bon. En général, il commence bien les phonétisations mais n'arrive pas à les terminer. Il ajoute souvent des sons qui n'apparaissent pas. Ce phénomène est dû à deux problèmes :

  • l'état caché qui sert à encoder le mot entier est limité et partagé avec celui utilisé pour le décodage (on pourrait concaténer les embeddings des phonèmes passés en entrée avec une copie de cet état caché à chaque étape, pour ne pas perdre la mémoire). Entraîner le modèle plus longtemps avec plus de données et un plus grand état caché pourrait améliorer la situation.
  • il y a une différence entre les conditions d'apprentissage et de prédiction car en apprentissage on utilise le symbole précédent de référence (méthode "teacher forcing") alors qu'en test, on utilise le symbole prédit, potentiellement faux. Des méthodes ont été proposées pour passer de la distribution forcée à la distribution réelle au court de l'apprentissage, mais c'est compliqué à mettre en oeuvre.

Heureusement, il y a les mécanismes d'attention.

In [ ]:
def generate_seq2seq(model, word):
    int_word = [letter_vocab[letter] for letter in word]
    x = torch.LongTensor(int_word).view(1, -1)
    hidden = model.encode(x)
    
    x2 = torch.zeros((1, 1)).long()
    x2[0, 0] = phoneme_vocab['<start>']

    with torch.no_grad():
      for i in range(200):
        y_scores, hidden = model.decode(x2, hidden)
        y_pred = torch.max(y_scores, 2)[1]
        selected = y_pred.data[0, 0].item()
        if selected == phoneme_vocab['<eos>']:
            break
        print(rev_phoneme_vocab[selected], end=' ')
        x2[0, 0] = selected
    print()

generate_seq2seq(seq2seq_model, 'TALL')

Les mécanismes d'attention reprennent l'idée provenant de la cognition humaine que nous pouvons porter notre attention sur un sous-ensemble des entrées plutôt que devoir en appréhender l'intégralité en permanence. Pour un problème seq2seq, ceci va se traduire en l'utilisation sélective des états cachés des caractères du mot en entrée en fonction du phonème que l'on est en train de générer, plutôt que de prendre le dernier état caché de la séquence.

L'encodeur va cette fois renvoyer les sorties du RNN (son état caché à chaque indice) plutôt que le dernier état caché. Le décodeur les sorties du RNN sur les phonèmes, puis transforme ces sorties avec une couche linéaire appelée attn. On réalise la multiplication de matrice entre cette sortie transformée et chacun des état cachés sur l'entrée (les caractères du mot) et l'on prend le softmax du résultat (ce type d'attention est appelé attention multiplicative). Ceci donne une distribution sur les entrée que l'on appelle poids d'attention. On peut alors calculer la somme pondérée des états cachés en entrée par ces poids d'attention pour obtenir une représentation de l'entrée contextuelle pour le décodage du phonème courant. C'est la concaténation de cette représentation contextuelle et de la sortie du RNN sur les phonèmes qui est utilisée pour prendre la décision finale.

Les choses se compliquent un peu car nous faisons des traitements par batch, et donc les séquences de caractères contiennent du padding. Même si l'on applique la technique vue précédemment pour que l'état caché correspondant aux symboles de padding soit nul, le mécanisme d'attention risque d'utiliser les sorties du RNN à cet endroit pour apprendre des régularités sur la distribution a priori ou la longueur des entrées. Il faut donc s'assurer que le softmax donnera un poids de zéro aux états cachés issus du padding. On utilise un masque calculé sur les entrées par l'encodeur qui est vrai pour chaque symbole de padding, faux sinon. La fonction masked_fill_() permet alors de fixer les poids d'attention à $-\infty$ avant de faire le softmax. Comme le numérateur de ce dernier prend l'exponentielle de ses entrées, on a bien un poids à zéro. Ceci permet aussi de couper la propagation du gradient pour ces composantes.

Pour ce qui est de la taille des différentes couches, le RNN sur les caractères est bidirectionnel donc sa sortie est 2 fois la taille de la couche cachée. La couche de transformation s'occupe de projeter l'état caché du RNN sur les phonèmes qui est unidirectionnel (donc elle passe de hidden_size à 2 * hidden_size). La multiplication de matrice traite des matrices de la taille (batch_size, sequence_length, 2 * hidden_size) et (batch_size, 2 * hidden_size, sequence_length) après transposition. Il en résulte donc une matrice de poids de taille (batch_size, sequence_length, sequence_length). Comme nous avons la même taille de séquences pour les mots et les phonétisations, ce n'est pas facile à interpréter, mais celà correspond en fait à (batch_size, phoneme_size, word_size), donc c'est bien sur la dimensions correspondant aux mots que l'on veut faire le softmax. Finalement, le contexte créé est la somme pondérée des états cachés sur les entrées, de taille 2 * hidden_size, et l'état caché du RNN sur les phonèmes est de taille hidden_size, donc la couche de décision a une entrée de 3 * hidden_size.

Pour rappel sur les RNN :

  • taille de l'état caché : (num_layers * num_directions, batch_size, hidden_size)
  • taille de la sortie : (batch_size, sequence_length, num_directions * hidden_size)
In [ ]:
class AttnModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.letter_embed = nn.Embedding(len(letter_vocab), embed_size, padding_idx=letter_vocab['<eos>'])
        self.phoneme_embed = nn.Embedding(len(phoneme_vocab), embed_size)
        self.letter_rnn = nn.GRU(embed_size, hidden_size, num_layers=1, bias=False, bidirectional=True, batch_first=True)
        self.phoneme_rnn = nn.GRU(embed_size, hidden_size, num_layers=1, batch_first=True)
        self.dropout = nn.Dropout(0.3)
        self.attn = nn.Linear(hidden_size, hidden_size * 2)
        self.decision = nn.Linear(hidden_size * 3, len(phoneme_vocab))
    
    def encode(self, word):
        mask = (word == 0)
        embed = self.letter_embed(word)
        output, h_n = self.letter_rnn(embed)
        return self.dropout(output), mask
    
    def decode(self, phones, states, mask, h_0=None):
        embed = self.phoneme_embed(phones)
        output, h_n = self.phoneme_rnn(embed, h_0)
        output = self.dropout(output)
        
        a1 = self.attn(output)
        a2 = a1.bmm(states.transpose(1, 2))
        a2.data.masked_fill_(mask.unsqueeze(1).data, -float('inf'))
        attn_weights = F.softmax(a2, 2)

        context = attn_weights.bmm(states)
        scores = self.decision(torch.cat([context, output], 2))
        return scores, h_n, attn_weights
    
    def forward(self, word, phones):
        states, mask = self.encode(word)
        output, h_n, attn_weights = self.decode(phones, states, mask)
        return output

attn_model = AttnModel()
print(attn_model)
with torch.no_grad():
  print(attn_model(X[:3], Y[:3]).size())

On peut entraîner ce modèle et l'on doit normalement obtenir des loss plus faibles sur l'ensemble de validation que pour l'encodeur-décodeur sans attention.

In [ ]:
fit(attn_model, 10)

Exercice 1

Faire une fonction de génération sur le modèle de generate_seq2seq() pour le modèle avec attention. Il n'y a qu'à changer l'appel aux fonction encode et decode pour passer les bons paramètres, et produire un état caché de départ à zéro pour le décodeur comme dans le modèle de langage.

Plutôt que de prendre le phonème le plus probable en décodage, nous pourrions le tirer aléatoirement dans la distribution de scores. Modifier la fonction en ce sens, et collectez des statistiques sur les résultats de 100 tirages pour le mot "BONJOUR". Quelle est la phonéistation la plus couramment générée ?

In [ ]:
 

On peut aussi demander au modèle de nous renvoyer la matrice d'attention pour pouvoir analyser les états cachés utilisés par le modèle pour faire ses prédictions. Il est intéressant de voir que le modèle apprend à ignorer les muettes. L'attention multiplicatie calcule une similarité entre les états cachés en entrée et en sortie (transformés) et a donc tendance à être forte lorsque les symboles sont systématiquement la traduction l'un de l'autre car le modèle peut apprendre une représentation similaire pour un caractère et un phonème.

In [ ]:
def show_attn(attn_model, word):
    int_word = [letter_vocab[letter] for letter in word]
    x = torch.LongTensor(int_word).view(1, -1)
    states, mask = attn_model.encode(x)
    
    x2 = torch.zeros((1, 1)).long()
    x2[0, 0] = phoneme_vocab['<start>']
    hidden = torch.zeros(1, 1, hidden_size)
    result = []
    attn_matrix = []
    with torch.no_grad():
      for i in range(200):
        y_scores, hidden, attn = attn_model.decode(x2, states, mask, hidden)
        attn_matrix.append(attn.squeeze().data.tolist())
        y_pred = torch.max(y_scores, 2)[1]
        selected = y_pred.data[0, 0].item()
        result.append(rev_phoneme_vocab[selected])
        if selected == phoneme_vocab['<eos>']:
            break
        x2[0, 0] = selected
    plt.matshow(attn_matrix)
    plt.xticks(range(len(word)), word)
    plt.yticks(range(len(result)), result)
    plt.show()

show_attn(attn_model, 'THOROUGH')

Exercice 2

Le modèle seq2seq est bon pour traiter des problèmes de transduction pour lesquels il y a un alignement entre les entrées et les sorties. Un tel modèle est-il capable d'apprendre des concepts plus abstraits comme calculer le résutlat d'une expression mathématique à partir des caractères qui la constituent (par exemple $2\times(3+4) \to 14$) ? Génrérez des données de ce type et vérifiez les propriétés de généralisation du modèle (par exemple en nombre de termes, taille des termes, opérations)

In [ ]:
import random
for i in range(100):
    a = random.randint(0,9)
    b = random.randint(0,9)
    c = a + b
    x = "%d+%d" % (a, b)
    y = str(c)
    print(x, "=>", y)
    

Exercice 3

Apprendre un système qui régénère les mots à partir de la séquence de phonèmes. Lexique3 (http://www.lexique.org/, prétraité ici https://pageperso.lis-lab.fr/benoit.favre/files/lexique-phonetise.txt) contient par exemple des phonétisations pour le français qui est beaucoup plus ambigu que l'anglais.

In [ ]:
#Note: this cell shows how to load words and phonemes encoded in utf8 
!curl -O https://pageperso.lis-lab.fr/benoit.favre/files/lexique-phonetise.txt
#https://stackoverflow.com/questions/51294483/python-splitting-a-string-with-accented-letters
import re
pattern = re.compile(r'(\w[\u02F3\u1D53\u0300\u2013\u032E\u208D\u203F\u0311\u0323\u035E\u031C\u02FC\u030C\u02F9\u0328\u032D:\u02F4\u032F\u0330\u035C\u0302\u0327\u03572\u0308\u0351\u0304\u02F2\u0352\u0355\u00B7\u032C\u030B\u2019\u0339\u00B4\u0301\u02F1\u0303\u0306\u030A7\u0325\u0307\u0354`\u02F0]+|\w|\W)', re.UNICODE | re.IGNORECASE)
import unicodedata
import random

words = []
phonemes = []

with open("lexique-phonetise.txt") as fp:
  lines = fp.readlines()
  random.shuffle(lines)
  for line in lines:
      line = unicodedata.normalize('NFC', line)
      word, phones = line.strip().split(' = ')
      words.append(list(word))
      phonemes.append(list(pattern.findall(phones)))

for i in range(20):
  print(words[i], phonemes[i])
len(words)
In [ ]:
 

Pour aller plus loin

  • Ce modèle est la base d'un système de traduction comme openNMT (http://opennmt.net/). On peut entraîner un petit système de traduction sur les mots et obtenir des bons résultats. Il est recommandé d'appliquer la méthode du "byte pair encoding" pour créer des symboles pour les facteurs de mots fréquents et limiter ainsi le nombre de symboles différents à prédire.
  • On peut implémenter plusieurs têtes d'attention avec des paramètres différents qui agissent en parallèle et peuvent se focaliser sur des états cachés localisés à différents endroits des entrées.
In [ ]: