This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import math | |
import random | |
from infini_gram.engine import InfiniGramEngine | |
from transformers import AutoTokenizer | |
def compute_longest_prefix(query, doc): | |
"""helper function for computing longest prefix of query that exists | |
within a document""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
from collections import Counter | |
import tempfile | |
from transformers import AutoTokenizer | |
# load tokenizer / data | |
enc = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", add_bos_token=False, add_eos_token=False) | |
data_rows = [{'text': 'here is some training data'}, ...] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class CrossAttention(nn.Module): | |
def __init__(self, d): | |
""" | |
Arguments: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class SelfAttention(nn.Module): | |
def __init__(self, d): | |
""" | |
Arguments: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
class MoEBlock(nn.Module): | |
def __init__( | |
self, | |
d, | |
H, | |
C, | |
n_exp, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Based upon ColossalAI OpenMoE | |
""" | |
from torch import nn | |
class MOELayer(nn.Module): | |
def __init__( | |
self, | |
d, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Computes ST-MoE router z loss (https://arxiv.org/abs/2202.08906) | |
See equation (5) on page 7 | |
""" | |
import torch | |
# constants | |
B = 16 # batch size | |
C = 256 # sequence length |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Computes Switch Transformer auxiliary loss (https://arxiv.org/abs/2101.03961) | |
See equations (4)-(6) on page 7 | |
""" | |
import torch | |
import torch.nn.functional as F | |
# constants | |
B = 16 # batch size |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class Router(nn.Module): | |
def __init__( | |
self, | |
d, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class BasicSoftmaxRouter(nn.Module): | |
def __init__( | |
self, | |
d, | |
n_exp = 8, | |
top_k = 2, |
NewerOlder