Implementing Google's famous BERT transformer architecture. The most popular NLP model over the past few years that uses transformers to train in a bidirectional manner by taking advantage of masking tokens at train time.
class TransformerLayer(nn.Module):
    def __init__(self, num_heads, emb_size, key_size, Activation=nn.ReLU, drop_p=0.5):
        "Transformer Layer: (Multiheaded) attention followed my linear layers"
        super().__init__()
        self.emb_size = emb_size
        self.key_size = key_size
        self.out_size = num_heads * key_size

        self.query = nn.Linear(emb_size, self.out_size)
        self.key = nn.Linear(emb_size, self.out_size)
        self.value = nn.Linear(emb_size, self.out_size)

        self.softmax = nn.Softmax(dim=1)

        self.combine_heads = nn.Linear(self.out_size, emb_size)

        self.layer_norm = nn.LayerNorm(emb_size)
        self.layer_norm_final = nn.LayerNorm(emb_size)

        self.fcn = nn.Sequential(nn.Linear(emb_size, emb_size), Activation(), nn.Linear(emb_size, emb_size))

        self.drop = nn.Dropout(drop_p)

    def forward(self, xb):
        query = self.query(xb)
        key = self.key(xb)
        value = self.value(xb)

        attention = query @ key.transpose(-1, -2)
        scaled_attention = attention / math.sqrt(self.key_size)
        normalized_attention = self.softmax(scaled_attention)
        values = normalized_attention @ value

        final_attention = self.layer_norm(self.drop(self.combine_heads(values)))

        fcn_output = self.fcn(final_attention)

        skip = self.drop(fcn_output) + xb

        return self.layer_norm_final(skip)
class TransformerLayer(nn.Module):
    def __init__(self, num_heads, emb_size, key_size, Activation=nn.ReLU, drop_p=0.5):
        "Transformer Layer: (Multiheaded) attention followed my linear layers"
        super().__init__()
        self.emb_size = emb_size
        self.key_size = key_size
        self.out_size = num_heads * key_size

        self.query = nn.Linear(emb_size, self.out_size)
        self.key = nn.Linear(emb_size, self.out_size)
        self.value = nn.Linear(emb_size, self.out_size)

        self.softmax = nn.Softmax(dim=1)

        self.combine_heads = nn.Linear(self.out_size, emb_size)

        self.layer_norm = nn.LayerNorm(emb_size)
        self.layer_norm_final = nn.LayerNorm(emb_size)

        self.fcn = nn.Sequential(nn.Linear(emb_size, emb_size), Activation(), nn.Linear(emb_size, emb_size))

        self.drop = nn.Dropout(drop_p)

    def forward(self, xb):
        query = self.query(xb)
        key = self.key(xb)
        value = self.value(xb)

        attention = query @ key.transpose(-1, -2)
        scaled_attention = attention / math.sqrt(self.key_size)
        normalized_attention = self.softmax(scaled_attention)
        values = normalized_attention @ value
class BERTEmbeddings(nn.Module):
    def __init__(self, depth, max_seq_size, vocab_size):
        "Custom BERT embeddings, combination of positional, normal token embeddings and segment embeddings for decoding tasks"
        super().__init__()
        self.pos_embedding = nn.Embedding(max_seq_size, depth)
        self.token_embedding = nn.Embedding(vocab_size, depth)
        self.segment_embedding = nn.Embedding(2, depth)

    
    def forward(self, xb, token_types=None):
        token_types = torch.zeros(xb.shape).long() if token_types is None else token_types
        return self.pos_embedding(token_types) + self.token_embedding(xb) + self.segment_embedding(xb)
class BERTEmbeddings(nn.Module):
    def __init__(self, depth, max_seq_size, vocab_size):
        "Custom BERT embeddings, combination of positional, normal token embeddings and segment embeddings for decoding tasks"
        super().__init__()
        self.pos_embedding = nn.Embedding(max_seq_size, depth)
        self.token_embedding = nn.Embedding(vocab_size, depth)
        self.segment_embedding = nn.Embedding(2, depth)


    def forward(self, xb, token_types=None):
        token_types = torch.zeros(xb.shape).long() if token_types is None else token_types
        return self.pos_embedding(token_types) + self.token_embedding(xb) + self.segment_embedding(xb)
class BERT(nn.Module):
    def __init__(self, num_layers, num_heads, emb_size, key_size, max_seq_size, vocab_size, **kwargs):
        "BERT in all it's glory"
        super().__init__()
        self.embedding = BERTEmbeddings(emb_size, max_seq_size, vocab_size) 
        transformer_layers = [TransformerLayer(num_heads, emb_size, key_size, **kwargs) for _ in range(num_layers)]
        self.encoder = nn.Sequential(*transformer_layers)
    
    def forward(self, xb, token_types=None):
        embeddings = self.embedding(xb, token_types)
        output = self.encoder(embeddings)
        return output
class BERT(nn.Module):
    def __init__(self, num_layers, num_heads, emb_size, key_size, max_seq_size, vocab_size, **kwargs):
        "BERT in all it's glory"
        super().__init__()
        self.embedding = BERTEmbeddings(emb_size, max_seq_size, vocab_size) 
        transformer_layers = [TransformerLayer(num_heads, emb_size, key_size, **kwargs) for _ in range(num_layers)]
        self.encoder = nn.Sequential(*transformer_layers)

    def forward(self, xb, token_types=None):
        embeddings = self.embedding(xb, token_types)
        output = self.encoder(embeddings)
        return output
bert = BERT(3, 3, 10, 5, 10, 10)
bert(torch.zeros(16, 10).long()).shape
torch.Size([16, 10, 10])
!python notebook2script.py BERT.ipynb
python: can't open file 'notebook2script.py': [Errno 2] No such file or directory