Last active
October 26, 2024 02:20
-
-
Save ToluClassics/b6fcb1c7c375cce9bafc9f0dddfb86ab to your computer and use it in GitHub Desktop.
Bert in Rust
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::borrow::Borrow; | |
use tch::nn::ModuleT; | |
use tch::nn::{self}; | |
use tch::{Kind, Tensor}; | |
#[derive(Debug)] | |
pub struct Dropout { | |
dropout_prob: f64, | |
} | |
impl Dropout { | |
pub fn new(p: f64) -> Dropout { | |
Dropout { dropout_prob: p } | |
} | |
} | |
impl ModuleT for Dropout { | |
fn forward_t(&self, input: &Tensor, train: bool) -> Tensor { | |
input.dropout(self.dropout_prob, train) | |
} | |
} | |
pub struct Config { | |
pub vocab_size: i64, | |
pub hidden_size: i64, | |
pub num_hidden_layers: i64, | |
pub num_attention_heads: i64, | |
pub intermediate_size: i64, | |
pub hidden_act: String, | |
pub hidden_dropout_prob: f64, | |
pub attention_probs_dropout_prob: f64, | |
pub max_position_embeddings: i64, | |
pub type_vocab_size: i64, | |
pub initializer_range: f64, | |
pub layer_norm_eps: f64, | |
} | |
pub struct BertEmbeddings { | |
word_embeddings: nn::Embedding, | |
position_embeddings: nn::Embedding, | |
token_type_embeddings: nn::Embedding, | |
layer_norm: nn::LayerNorm, | |
dropout: Dropout, | |
} | |
impl Default for Config { | |
fn default() -> Self { | |
Config { | |
vocab_size: 30522, | |
hidden_size: 768, | |
num_hidden_layers: 12, | |
num_attention_heads: 12, | |
intermediate_size: 3072, | |
hidden_act: "gelu".to_string(), | |
hidden_dropout_prob: 0.1, | |
attention_probs_dropout_prob: 0.1, | |
max_position_embeddings: 512, | |
type_vocab_size: 2, | |
initializer_range: 0.02, | |
layer_norm_eps: 1e-12, | |
} | |
} | |
} | |
pub fn new( | |
vocab_size: i64, | |
hidden_size: i64, | |
num_hidden_layers: i64, | |
num_attention_heads: i64, | |
intermediate_size: i64, | |
hidden_act: String, | |
hidden_dropout_prob: f64, | |
attention_probs_dropout_prob: f64, | |
max_position_embeddings: i64, | |
type_vocab_size: i64, | |
initializer_range: f64, | |
layer_norm_eps: f64, | |
) -> Self { | |
Config { | |
vocab_size, | |
hidden_size, | |
num_hidden_layers, | |
num_attention_heads, | |
intermediate_size, | |
hidden_act, | |
hidden_dropout_prob, | |
attention_probs_dropout_prob, | |
max_position_embeddings, | |
type_vocab_size, | |
initializer_range, | |
layer_norm_eps, | |
} | |
} | |
} | |
impl BertEmbeddings { | |
pub fn new<'p, P>(p: P, config: &Config) -> BertEmbeddings | |
where | |
P: Borrow<nn::Path<'p>>, | |
{ | |
let p = p.borrow(); | |
let word_embeddings = nn::embedding( | |
p / "word_embeddings", | |
config.vocab_size, | |
config.hidden_size, | |
Default::default(), | |
); | |
let position_embeddings = nn::embedding( | |
p / "position_embeddings", | |
config.max_position_embeddings, | |
config.hidden_size, | |
Default::default(), | |
); | |
let token_type_embeddings = nn::embedding( | |
p / "token_type_embeddings", | |
config.type_vocab_size, | |
config.hidden_size, | |
Default::default(), | |
); | |
let layer_norm_config = nn::LayerNormConfig { | |
eps: config.layer_norm_eps, | |
..Default::default() | |
}; | |
let layer_norm = | |
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config); | |
let dropout = Dropout::new(config.hidden_dropout_prob); | |
Self { | |
word_embeddings, | |
position_embeddings, | |
token_type_embeddings, | |
layer_norm, | |
dropout, | |
} | |
} | |
pub fn forward_t( | |
&self, | |
input_ids: &Tensor, | |
token_type_ids: Option<&Tensor>, | |
position_ids: Option<&Tensor>, | |
train: bool, | |
) -> Result<Tensor, &'static str> { | |
let input_shape = input_ids.size(); | |
let seq_length = input_shape[1]; | |
let device = input_ids.device(); | |
let input_ids = input_ids.view((-1, seq_length)); | |
let position_ids = match position_ids { | |
Some(position_ids) => position_ids.view((-1, seq_length)), | |
None => Tensor::arange(seq_length, (Kind::Int64, device)) | |
.unsqueeze(0) | |
.expand(&input_shape, true), | |
}; | |
let token_type_ids = match token_type_ids { | |
Some(token_type_ids) => token_type_ids.view((-1, seq_length)), | |
None => Tensor::zeros(&input_shape, (Kind::Int64, device)), | |
}; | |
let word_embeddings = input_ids.apply(&self.word_embeddings); | |
let position_embeddings = position_ids.apply(&self.position_embeddings); | |
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings); | |
let mut embeddings = word_embeddings + position_embeddings + token_type_embeddings; | |
Ok(embeddings | |
.apply(&self.layer_norm) | |
.apply_t(&self.dropout, train)) | |
} | |
} | |
fn test_bert_embeddings() { | |
let config = Config::new( | |
30522, | |
768, | |
12, | |
3072, | |
12, | |
"gelu".to_string(), | |
0.1, | |
0.1, | |
512, | |
2, | |
0.02, | |
1e-12, | |
); | |
let vs = VarStore::new(Device::Cpu); | |
let root = vs.root(); | |
let embeddings = BertEmbeddings::new(&root, &config); | |
let input_ids = Tensor::of_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); | |
let input_ids = input_ids.unsqueeze(0); | |
let token_type_ids = Tensor::of_slice(&vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); | |
let token_type_ids = token_type_ids.unsqueeze(0); | |
let position_ids = Tensor::of_slice(&vec![0, 1, 2, 3, 4, 0, 1, 2, 3, 4]); | |
let position_ids = position_ids.unsqueeze(0); | |
let output = embeddings.forward_t( | |
&input_ids, | |
Some(&token_type_ids), | |
Some(&position_ids), | |
false, | |
); | |
let expected_shape = vec![1, 10, 768]; | |
assert_eq!(output.unwrap().size(), expected_shape.as_slice()); | |
} | |
fn main(){ | |
test_bert_embeddings() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment