Last active
September 22, 2021 18:52
-
-
Save wassname/fe1d8940344a9fe8cd1a85f6660e7b1d to your computer and use it in GitHub Desktop.
Transformer in ~80 lines of code from Thomas Wolf's tweet https://twitter.com/Thom_Wolf/status/1129658539142766592
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Transformer in ~80 lines of code. | |
From Thomas Wolf's tweet https://twitter.com/Thom_Wolf/status/1129658539142766592. | |
""" | |
import torch | |
from torch import nn | |
class Transformer(nn.Module): | |
""" | |
Transformer (GPT-2 architecture). | |
Args: | |
embed_dim: Dimensionality of the embeddings. | |
hidden_dim: Dimensionality of the hidden states. | |
num_embed: Vocabulary size of `x`. | |
num_pos: Number of positional embeddings. | |
num_heads: Number of attention heads for each attention layer in the Transformer encoder. | |
num_layers: Number of hidden layers in the Transformer encoder. | |
dropout: The dropout probabilitiy for all layers | |
""" | |
def __init__(self, embed_dim=768, hidden_dim=768, num_embed=4992, num_pos=768, num_heads=6, num_layers=6, dropout=0.1): | |
super().__init__() | |
self.token_embeddings = nn.Embedding(num_embed, embed_dim) | |
self.poition_embeddings = nn.Embedding(num_pos, embed_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.attentions, self.feed_forwards = nn.ModuleList(), nn.ModuleList() | |
self.ln_1, self.ln_2 = nn.ModuleList(), nn.ModuleList() | |
for _ in range(num_layers): | |
self.attentions.append(nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)) | |
self.feed_forwards.append(nn.Sequential(nn.Linear(embed_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, embed_dim))) | |
self.ln_1.append(nn.LayerNorm(embed_dim, eps=1e-12)) | |
self.ln_2.append(nn.LayerNorm(embed_dim, eps=1e-12)) | |
self.head = nn.Linear(hidden_dim, num_embed) | |
def forward(self, x): | |
positions = torch.arange(len(x), device=x.device).unsqueeze(-1) | |
h = self.token_embeddings(x) | |
h = h + self.poition_embeddings(positions).expand_as(h) | |
h = self.dropout(h) | |
attn_mask = torch.full((len(x), len(x)), -float('Inf'), device=h.device, dtype=h.dtype) | |
attn_mask = torch.triu(attn_mask, diagonal=1) | |
for layer_norm_1, attention, layer_norm_2, feed_forward in zip(self.ln_1, self.attentions, | |
self.ln_2, self.feed_forwards): | |
h = layer_norm_1(h) | |
x, _ = attention(h, h, h, attn_mask=attn_mask, need_weights=False) | |
x = self.dropout(x) | |
h = x + h | |
h = layer_norm_2(h) | |
x = feed_forward(h) | |
x = self.dropout(x) | |
h = x + h | |
return self.head(h) | |
# test | |
transformer = Transformer() | |
batch_size = 4 | |
seq_len = 32 | |
x = torch.randint(low=0, high=transformer.token_embeddings.num_embeddings, size=(batch_size, seq_len)) | |
y=transformer(x) | |
y.shape # torch.Size([4, 32, 4992]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment