Implementing Google's famous BERT transformer architecture. The most popular NLP model over the past few years that uses transformers to train in a bidirectional manner by taking advantage of masking tokens at train time.
class TransformerLayer(nn.Module):
def __init__(self, num_heads, emb_size, key_size, Activation=nn.ReLU, drop_p=0.5):
"Transformer Layer: (Multiheaded) attention followed my linear layers"
super().__init__()
self.emb_size = emb_size
self.key_size = key_size
self.out_size = num_heads * key_size
self.query = nn.Linear(emb_size, self.out_size)
self.key = nn.Linear(emb_size, self.out_size)
self.value = nn.Linear(emb_size, self.out_size)
self.softmax = nn.Softmax(dim=1)
self.combine_heads = nn.Linear(self.out_size, emb_size)
self.layer_norm = nn.LayerNorm(emb_size)
self.layer_norm_final = nn.LayerNorm(emb_size)
self.fcn = nn.Sequential(nn.Linear(emb_size, emb_size), Activation(), nn.Linear(emb_size, emb_size))
self.drop = nn.Dropout(drop_p)
def forward(self, xb):
query = self.query(xb)
key = self.key(xb)
value = self.value(xb)
attention = query @ key.transpose(-1, -2)
scaled_attention = attention / math.sqrt(self.key_size)
normalized_attention = self.softmax(scaled_attention)
values = normalized_attention @ value
final_attention = self.layer_norm(self.drop(self.combine_heads(values)))
fcn_output = self.fcn(final_attention)
skip = self.drop(fcn_output) + xb
return self.layer_norm_final(skip)
class TransformerLayer(nn.Module):
def __init__(self, num_heads, emb_size, key_size, Activation=nn.ReLU, drop_p=0.5):
"Transformer Layer: (Multiheaded) attention followed my linear layers"
super().__init__()
self.emb_size = emb_size
self.key_size = key_size
self.out_size = num_heads * key_size
self.query = nn.Linear(emb_size, self.out_size)
self.key = nn.Linear(emb_size, self.out_size)
self.value = nn.Linear(emb_size, self.out_size)
self.softmax = nn.Softmax(dim=1)
self.combine_heads = nn.Linear(self.out_size, emb_size)
self.layer_norm = nn.LayerNorm(emb_size)
self.layer_norm_final = nn.LayerNorm(emb_size)
self.fcn = nn.Sequential(nn.Linear(emb_size, emb_size), Activation(), nn.Linear(emb_size, emb_size))
self.drop = nn.Dropout(drop_p)
def forward(self, xb):
query = self.query(xb)
key = self.key(xb)
value = self.value(xb)
attention = query @ key.transpose(-1, -2)
scaled_attention = attention / math.sqrt(self.key_size)
normalized_attention = self.softmax(scaled_attention)
values = normalized_attention @ value
class BERTEmbeddings(nn.Module):
def __init__(self, depth, max_seq_size, vocab_size):
"Custom BERT embeddings, combination of positional, normal token embeddings and segment embeddings for decoding tasks"
super().__init__()
self.pos_embedding = nn.Embedding(max_seq_size, depth)
self.token_embedding = nn.Embedding(vocab_size, depth)
self.segment_embedding = nn.Embedding(2, depth)
def forward(self, xb, token_types=None):
token_types = torch.zeros(xb.shape).long() if token_types is None else token_types
return self.pos_embedding(token_types) + self.token_embedding(xb) + self.segment_embedding(xb)
class BERTEmbeddings(nn.Module):
def __init__(self, depth, max_seq_size, vocab_size):
"Custom BERT embeddings, combination of positional, normal token embeddings and segment embeddings for decoding tasks"
super().__init__()
self.pos_embedding = nn.Embedding(max_seq_size, depth)
self.token_embedding = nn.Embedding(vocab_size, depth)
self.segment_embedding = nn.Embedding(2, depth)
def forward(self, xb, token_types=None):
token_types = torch.zeros(xb.shape).long() if token_types is None else token_types
return self.pos_embedding(token_types) + self.token_embedding(xb) + self.segment_embedding(xb)
class BERT(nn.Module):
def __init__(self, num_layers, num_heads, emb_size, key_size, max_seq_size, vocab_size, **kwargs):
"BERT in all it's glory"
super().__init__()
self.embedding = BERTEmbeddings(emb_size, max_seq_size, vocab_size)
transformer_layers = [TransformerLayer(num_heads, emb_size, key_size, **kwargs) for _ in range(num_layers)]
self.encoder = nn.Sequential(*transformer_layers)
def forward(self, xb, token_types=None):
embeddings = self.embedding(xb, token_types)
output = self.encoder(embeddings)
return output
class BERT(nn.Module):
def __init__(self, num_layers, num_heads, emb_size, key_size, max_seq_size, vocab_size, **kwargs):
"BERT in all it's glory"
super().__init__()
self.embedding = BERTEmbeddings(emb_size, max_seq_size, vocab_size)
transformer_layers = [TransformerLayer(num_heads, emb_size, key_size, **kwargs) for _ in range(num_layers)]
self.encoder = nn.Sequential(*transformer_layers)
def forward(self, xb, token_types=None):
embeddings = self.embedding(xb, token_types)
output = self.encoder(embeddings)
return output
bert = BERT(3, 3, 10, 5, 10, 10)
bert(torch.zeros(16, 10).long()).shape
!python notebook2script.py BERT.ipynb