Created
October 1, 2019 23:40
-
-
Save leungi/d927a9bb22c0cb5747f55734b129d4fb to your computer and use it in GitHub Desktop.
RBERT demo application
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# install from https://github.com/jonathanbratt/RBERT | |
library(RBERT) | |
# |- Python ---- | |
reticulate::use_condaenv("r-tensorflow-1.11") | |
reticulate::py_config() | |
# |- model ---- | |
# path to downloaded BERT checkpoint | |
BERT_PRETRAINED_DIR <- file.path( | |
"output_data/", | |
"BERT_checkpoints", | |
"uncased_L-12_H-768_A-12" | |
) | |
if (!dir.exists(BERT_PRETRAINED_DIR)) { | |
# Download pre-trained BERT model. | |
RBERT::download_BERT_checkpoint( | |
model = "bert_base_uncased", | |
destination = "output_data/" | |
) | |
} | |
vocab_file <- file.path(BERT_PRETRAINED_DIR, "vocab.txt") | |
init_checkpoint <- file.path(BERT_PRETRAINED_DIR, "bert_model.ckpt") | |
bert_config_file <- file.path(BERT_PRETRAINED_DIR, "bert_config.json") | |
# |- analyze ---- | |
text_to_process <- "After stealing money from the bank vault, the bank robber was seen fishing on the Mississippi river bank." | |
BERT_feats <- RBERT::extract_features( | |
examples = RBERT::make_examples_simple(text_to_process), | |
vocab_file = vocab_file, | |
bert_config_file = bert_config_file, | |
init_checkpoint = init_checkpoint, | |
layer_indexes = as.list(1:12), | |
batch_size = 2L | |
) | |
# |- analyze ---- | |
library(tidyverse) | |
# num of sentences | |
num_batch <- length(BERT_feats$layer_outputs) | |
num_token <- length(BERT_feats$layer_outputs$example_1$features) | |
# -1L since first entry is bare token embedding | |
# https://github.com/jonathanbratt/RBERT/issues/6 | |
num_layer <- length(BERT_feats$layer_outputs$example_1$features$token_1$layers) - 1L | |
num_hidden_unit <- length(BERT_feats$layer_outputs$example_1$features$token_1$layers$layer_output_0$values) | |
berts <- tibble(layer = BERT_feats$layer_outputs) | |
my_layers <- berts %>% | |
hoist( | |
layer, | |
feat = "features" | |
) %>% | |
select(-layer) %>% | |
unnest_longer(feat) %>% | |
unnest_wider(feat) %>% | |
unnest_longer(layers) %>% | |
hoist( | |
layers, | |
values = "values" | |
) %>% | |
mutate_at(vars(feat_id, layers_id), ~ str_extract(., "\\d+") %>% as.integer()) | |
# http://mccormickml.com/2019/05/14/BERT-word-embeddings-tutorial/ | |
word_embed <- my_layers %>% | |
filter(between(layers_id, 9, 12)) %>% | |
group_by(feat_id, token) %>% | |
summarise( | |
concatenated_last_4_layers = list(flatten_dbl(values)), | |
summed_last_4_layers = list(reduce(values, `+`)) | |
) | |
# use second to last layer | |
sent_embed_2nd_last_layer <- my_layers %>% | |
filter(layers_id == num_layer - 1L) %>% | |
summarise(emb_mean = list(reduce(values, `+`) %>% | |
{.} / num_hidden_unit)) | |
print("First fifteen values of 'bank' as in 'bank robber':") | |
word_embed %>% | |
filter(feat_id == 11) %>% | |
.$summed_last_4_layers %>% | |
.[[1]] %>% | |
.[1:15] | |
print("First fifteen values of 'bank' as in 'bank vault':") | |
word_embed %>% | |
filter(feat_id == 7) %>% | |
.$summed_last_4_layers %>% | |
.[[1]] %>% | |
.[1:15] | |
print("First fifteen values of 'bank' as in 'river bank':") | |
word_embed %>% | |
filter(feat_id == 20) %>% | |
.$summed_last_4_layers %>% | |
.[[1]] %>% | |
.[1:15] | |
cosine_sim <- function(a, b) { | |
result <- crossprod(a, b) / sqrt(crossprod(a) * crossprod(b)) | |
} | |
same_bank <- cosine_sim( | |
word_embed$summed_last_4_layers[[11]], | |
word_embed$summed_last_4_layers[[7]] | |
) | |
different_bank <- cosine_sim( | |
word_embed$summed_last_4_layers[[11]], | |
word_embed$summed_last_4_layers[[20]] | |
) | |
glue::glue("Similarity of 'bank' as in 'bank robber' to 'bank' as in 'bank vault': {same_bank}") | |
# Similarity of 'bank' as in 'bank robber' to 'bank' as in 'bank vault': 0.938482861219369 | |
glue::glue("Similarity of 'bank' as in 'bank robber' to 'bank' as in 'river bank': {different_bank}") | |
# Similarity of 'bank' as in 'bank robber' to 'bank' as in 'river bank': 0.69301255768969 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment