Created
August 20, 2024 10:55
-
-
Save PhilipGAQ/054a24cde3811c696b742747d7772857 to your computer and use it in GitHub Desktop.
model structure
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import math | |
# 定义 Positional Encoding 层 | |
class PositionalEncoding(nn.Module): | |
def __init__(self, embedding_dim, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
position = torch.arange(0, max_len).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, embedding_dim, 2) * -(math.log(10000.0) / embedding_dim)) | |
pe = torch.zeros(max_len, 1, embedding_dim) | |
pe[:, 0, 0::2] = torch.sin(position * div_term) | |
pe[:, 0, 1::2] = torch.cos(position * div_term) | |
self.pe = pe | |
def forward(self, x): | |
x = x + self.pe[:x.size(0), :] | |
return x | |
# 定义 TransformerEncoderLayer 层 | |
class TransformerEncoderLayer(nn.Module): | |
def __init__(self, d_model, nhead): | |
super(TransformerEncoderLayer, self).__init__() | |
self.self_attn = nn.MultiheadAttention(d_model, nhead) | |
self.feed_forward = nn.Sequential( | |
nn.Linear(d_model, d_model * 4), | |
nn.ReLU(), | |
nn.Linear(d_model * 4, d_model) | |
) | |
self.layer_norm = nn.LayerNorm(d_model) | |
self.dropout = nn.Dropout(0.1) | |
def forward(self, src, src_mask=None, src_key_padding_mask=None): | |
src2 = self.self_attn(src, src, src, attn_mask=src_mask, | |
key_padding_mask=src_key_padding_mask)[0] | |
src = src + src2 | |
src = self.layer_norm(src) | |
src2 = self.feed_forward(src) | |
src = src + src2 | |
src = self.layer_norm(src) | |
return src | |
# 定义 ContextIntegrationModel | |
class ContextIntegrationModel(nn.Module): | |
def __init__(self, embedding_dim, num_heads, num_layers, dropout=0.1): | |
super(ContextIntegrationModel, self).__init__() | |
self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) | |
self.pos_encoder = PositionalEncoding(embedding_dim) | |
self.encoder_layer = TransformerEncoderLayer(embedding_dim, num_heads) | |
self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers) | |
self.new_embedding_generator = nn.Linear(embedding_dim, embedding_dim) | |
self.layer_norm = nn.LayerNorm(embedding_dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, chunks_embeddings): | |
batch_size = chunks_embeddings.size(0) | |
cls_tokens = self.cls_token.expand(batch_size, -1, -1) | |
chunks_embeddings = torch.cat((cls_tokens, chunks_embeddings), dim=1) | |
chunks_embeddings = self.pos_encoder(chunks_embeddings.permute(1, 0, 2)) | |
transformed_embeddings = self.transformer_encoder(chunks_embeddings) | |
transformed_embeddings = transformed_embeddings.permute(1, 0, 2) | |
transformed_embeddings = self.layer_norm(transformed_embeddings + chunks_embeddings.permute(1, 0, 2)) | |
transformed_embeddings = self.dropout(transformed_embeddings) | |
doc_embedding = transformed_embeddings[:, 0] | |
new_embeddings = self.new_embedding_generator(transformed_embeddings[:, 1:]) | |
doc_embedding = torch.nn.functional.normalize(doc_embedding, p=2, dim=-1) | |
new_embeddings = torch.nn.functional.normalize(new_embeddings, p=2, dim=-1) | |
return doc_embedding, new_embeddings | |
# 定义模型 | |
class TransformerModel(nn.Module): | |
def __init__(self, embedding_dim, num_heads=8, num_layers=1): | |
super(TransformerModel, self).__init__() | |
self.context_integration = ContextIntegrationModel(embedding_dim, num_heads, num_layers) | |
def forward(self, chunks_embedding): | |
transformed_embeddings = self.context_integration(chunks_embedding) | |
return transformed_embeddings |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment