Created
April 2, 2026 14:57
-
-
Save yaroslavvb/9c5a55471c5f3cf635f8a92cc569a9fe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # from https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd | |
| """ Example Transformer code with shape suffixes. | |
| This code is incomplete and possibly has bugs. Don't try to run it. | |
| Its purpose is to illustrate shape suffixes. | |
| Dimension key: | |
| B: batch size | |
| L: sequence length | |
| M: memory length (length of sequence being attended to) | |
| D: model dimension (sometimes called d_model or embedding_dim) | |
| V: vocabulary size | |
| F: feed-forward subnetwork hidden size | |
| H: number of attention heads in a layer | |
| K: size of each attention key or value (sometimes called d_kv) | |
| """ | |
| def transformer(input_token_id_BL, params): | |
| hidden_BLD = params.embedding_VD[input_token_id_BL] | |
| for layer_num in range(params.num_layers): | |
| hidden_BLD += attention(hiddden_BLD, params.attention_params[i]) | |
| hidden_BLD += ffn(hiddden_BLD, params.ffn_params[i]) | |
| hidden_BLD = layer_norm(hidden_BLD, params.final_layernorm_params) | |
| logits_BLV = torch.matmul(hidden_BLD, params.embedding_VD.T) | |
| return logits_BLV | |
| def ffn(input_BLD, params): | |
| input_BLD = layer_norm(input_BLD, params.layernorm_params) | |
| hidden_BLF = torch.gelu(torch.matmul(input_BLD, params.w_in_DF)) | |
| output_BLD = torch.matmul(hidden_BLF, params.w_out_FD) | |
| return output_BLD | |
| def attention(input_BLD, params): | |
| input_BLD = layer_norm(input_BLD, params.layernorm_params) | |
| query_BLHK = torch.einsum('BLD,DHK->BLHK', input_BLD, params.w_q_DHK) | |
| key_BMHK = torch.einsum('BLD,DHK->BLHK', input_BLD, params.w_k_DHK) | |
| value_BMHK = torch.einsum('BLD,DHK->BLHK', input_BLD, params.w_k_DHK) | |
| logits_BHLM = torch.einsum('BLHK,BMHK->BHLM', query_BLHK, key_BMHK) | |
| B, L, H, K = query_BLHK.shape() | |
| logits_BHLM /= K ** 0.5 | |
| masked_out_LM = torch.arange(L).unsqueeze(1) < torch.arange(L).unsqueeze(0) | |
| logits_BHLM += torch.where(masked_out_LM, -inf, 0) | |
| weights_BHLM = torch.softmax(logits_BHLM) | |
| wtd_values_BLHK = torch.einsum('BMHK,BHLM->BLHK', value_BMHK, logits_BHLM) | |
| out_BLD = torch.einsum('BLHK,HKD->BLD', wtd_values_BLHK, params.w_o_HKD) | |
| return out_BLD |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment