Skip to content

Instantly share code, notes, and snippets.

@t-kalinowski
Last active October 15, 2024 23:53
Show Gist options
  • Save t-kalinowski/62e9a1bbf8d670b712082c1765be4df4 to your computer and use it in GitHub Desktop.
Save t-kalinowski/62e9a1bbf8d670b712082c1765be4df4 to your computer and use it in GitHub Desktop.
LLaMA implemented in R Tensorflow and Keras
## Setup
Sys.setenv(CUDA_VISIBLE_DEVICES='')
options(tensorflow.extract.warn_tensors_passed_asis = FALSE)
library(dplyr, warn.conflicts = FALSE)
library(purrr)
library(glue)
library(envir)
library(tensorflow)
library(tfautograph)
library(keras)
reticulate::use_virtualenv("./.venv", required = TRUE)
attach_eval({
np <- reticulate::import("numpy", convert = FALSE)
import_from(withr, with_options, local_options)
import_from(keras$layers, Dense)
import_from(tf$compiler$tf2xla$python$xla, dynamic_update_slice)
nlist <- \(...) rlang::dots_list(..., .named = TRUE)
seq_len0 <- \(x) seq.int(from = 0L, length.out = x)
})
precompute_rotarty_freqs <- function(seqlen, feature_dim, theta = 10000) {
repeat_each_twice <- function(x)
tf$`repeat`(x, 2L, axis = -1L)
t <- tf$range(seqlen, dtype = tf$float32)
freqs <- tf$range(start = 0, limit = 1,
delta = 1 / (feature_dim %/% 2),
dtype = tf$float32)
tf_assert(tf$size(freqs) == feature_dim %/% 2)
freqs <- 1 / (theta ^ freqs)
# outer product; (seqlen, head_size/2)
freqs <- tf$einsum('a,b->ab', t, freqs)
# prep to recycle across head_size axis and
# broadcast across batch_size and n_heads axes
list(cos = tf$cos(freqs),
sin = tf$sin(freqs)) |>
lapply(repeat_each_twice) |>
lapply(\(m) m[tf$newaxis, , tf$newaxis, ]) # (1, seqlen, 1, head_size)
}
apply_rotary_embedding <- function(x, freqs) {
rotate_every_two <- function(x) {
x1 <- x[all_dims(), `::2`]
x2 <- x[all_dims(), `2::2`]
x_ <- tf$stack(list(-x2, x1), axis = -1L)
tf$reshape(x_, tf$shape(x))
}
(x * freqs$cos) + (rotate_every_two(x) * freqs$sin)
}
make_mask <- function(seqlen, position_index = 0L, dtype = k_floatx()) {
x <- tf$range(seqlen)
i <- x[, tf$newaxis] + position_index
j <- x[tf$newaxis, ]
mask <- tf$where(i < j,
tf$constant(-Inf, dtype = dtype),
tf$constant(0, dtype = dtype))
mask[tf$newaxis, tf$newaxis, , ] # (1, 1, seqlen_q, seqlen_q)
}
RMSNorm(keras$layers$Layer) %py_class% {
initialize <-
function(eps = 1e-6, ..., block_id = NULL, feeds_into = NULL) {
super$initialize(...)
self$eps <- eps
self$block_id <- block_id
self$feeds_into <- feeds_into
}
build <- function(input_shape) {
# input_shape == (batch_size, seqlen, params$dim)
# self$w will broadcast over batch_size and seqlen dims.
# w_shape == (1, 1, params$dim)
w_shape <- rep(1L, length(input_shape))
w_shape[length(input_shape)] <- as.integer(input_shape) |> tail(1L)
# helper that will load
# the pretrained-weights if we supplied `block_id` and `feeds_into`
import_from({self}, block_id, feeds_into)
initializer <- if (is.null(self$block_id))
"ones"
else if (block_id >=0) {
\(...) weights_path("7B/layers.{block_id}.{feeds_into}_norm.weight.npy") |>
np$load() |> np$expand_dims(0:1)
} else if(block_id == -1)
# load weights for the final output norm, which is not part of a TransformerBlock
\(...) weights_path("7B/norm.weight.npy") |>
np$load() |> np$expand_dims(0:1)
self$w <- self$add_weight(shape = w_shape,
initializer = initializer,
trainable = TRUE)
}
rrms <- function(x) {
# reciprocal root mean square along the last axis
x %>%
tf$math$square() %>%
tf$reduce_mean(axis = -1L, keepdims = TRUE) %>%
tf$math$add(self$eps) %>% # for numerical stability
tf$math$rsqrt()
}
call <- function(x) {
x * self$rrms(x) * self$w
}
}
FeedForward(keras$layers$Layer) %py_class% {
initialize <- function(hidden_dim, multiple_of = 256L, ..., block_id = NULL) {
super$initialize()
if(!is.null(multiple_of)) {
hidden_dim <- hidden_dim %>%
{ as.integer( . * (2/3)) } %>%
{ (. + multiple_of - 1) %/% multiple_of } %>%
{ . * multiple_of }
}
self$hidden_dim <- hidden_dim
self$block_id <- block_id
}
build <- function(input_shape) {
output_dim <- input_shape |> as.integer() |> tail(1)
load_weight <- NULL
if(!is.null(self$block_id))
load_weight <- \(name) \(...) np$load(weights_path(
"7B/layers.{self$block_id}.feed_forward.{name}.weight.npy"))$`T`
self$w1 <- Dense(self$hidden_dim, use_bias = FALSE,
kernel_initializer = load_weight("w1"))
self$w2 <- Dense(output_dim, use_bias = FALSE,
kernel_initializer = load_weight("w2"))
self$w3 <- Dense(self$hidden_dim, use_bias = FALSE,
kernel_initializer = load_weight("w3"))
super$build(input_shape)
}
call <- function(x) {
import_from({self}, w1, w2, w3)
import_from(tf$nn, silu)
x %>%
{ silu(w1(.)) * w3(.) } %>% # SwiGLU
w2()
}
}
Attention(keras$layers$Layer) %py_class% {
initialize <- function(head_size, n_heads, ..., block_id = NULL) {
super$initialize(...)
self$head_size <- head_size
self$n_heads <- n_heads
if (is.null(block_id))
load_weight <- function(name) NULL
else
load_weight <- \(name) \(...) np$load(weights_path(
"7B/layers.{block_id}.attention.{name}.weight.npy"))$`T`
Dense <- function(name) keras$layers$Dense(
units = n_heads * head_size,
use_bias = FALSE,
kernel_initializer = load_weight(name)
)
self$wq <- Dense("wq")
self$wk <- Dense("wk")
self$wv <- Dense("wv")
self$wo <- Dense("wo")
}
call <- function(x, ...,
freqs = NULL,
cache = NULL,
cache_index = NULL,
mask = NULL) {
c(batch_size, seqlen_q, n_features) %<-% tf$unstack(tf$shape(x))
seqlen_k <- seqlen_v <- cache_index + seqlen_q
split_heads_shape <- c(batch_size, seqlen_q, self$n_heads, self$head_size)
q <- x |> self$wq() |> tf$reshape(split_heads_shape)
k <- x |> self$wk() |> tf$reshape(split_heads_shape)
v <- x |> self$wv() |> tf$reshape(split_heads_shape)
q %<>% apply_rotary_embedding(freqs) # (bsz, seqlen_q, n_heads, head_size)
k %<>% apply_rotary_embedding(freqs) # (bsz, seqlen_q, n_heads, head_size)
if(!is.null(cache)) {
# append k,v to respective caches; fetch full k,v from cache
cache$k %<>% dynamic_update_slice(k, c(0L, cache_index, 0L, 0L))
cache$v %<>% dynamic_update_slice(v, c(0L, cache_index, 0L, 0L))
k <- cache$k[, NA:seqlen_k, , ]
v <- cache$v[, NA:seqlen_v, , ]
}
v <- tf$transpose(v, c(0L, 2L, 1L, 3L)) # (bsz, n_heads, seqlen_v, head_size)
q <- tf$transpose(q, c(0L, 2L, 1L, 3L)) # (bsz, n_heads, seqlen_q, head_size)
k <- tf$transpose(k, c(0L, 2L, 3L, 1L)) # (bsz, n_heads, head_size, seqlen_k)
scores <- (q %*% k) / sqrt(self$head_size) # (bsz, n_heads, seqlen_q, seqlen_k)
# apply causal mask, so the model can't "look ahead" during training
if (!is.null(mask))
scores %<>% { . + mask }
scores <- tf$nn$softmax(scores, axis = -1L)
# adjust values tensor with attention scores
# scores (bsz, n_heads, seqlen_q, seqlen_k)
# v (bsz, n_heads, seqlen_v, head_size)
output <- scores %*% v # (bsz, n_heads, seqlen_q, head_size)
# combine heads back into a single features dim,
# so Attention output_shape==input_shape
# (needed so that you can add residuals in TransformerBlock)
output <- output |>
tf$transpose(c(0L, 2L, 1L, 3L)) |> # (bsz, seqlen_q, n_heads, head_size)
tf$reshape(c(batch_size, seqlen_q, # (bsz, seqlen_q, n_heads * head_size)
self$n_heads * self$head_size))
# one more trainable linear projection for good luck
output <- self$wo(output) # (bsz, seqlen_q, n_heads * head_size)
if(is.null(cache))
output
else
list(output, cache)
}
}
TransformerBlock(keras$layers$Layer) %py_class% {
initialize <- function(attn_head_size, attn_n_heads,
norm_eps = k_epsilon(), ...,
block_id = NULL) {
super$initialize(...)
self$attention <- Attention(attn_head_size, attn_n_heads,
block_id = block_id)
self$feed_forward <- FeedForward(
hidden_dim = 4 * attn_head_size * attn_n_heads,
block_id = block_id)
self$attention_norm <- RMSNorm(eps = norm_eps, block_id = block_id,
feeds_into = "attention")
self$feed_forward_norm <- RMSNorm(eps = norm_eps, block_id = block_id,
feeds_into = "ffn")
}
call <- function(x, ..., cache = NULL) {
# norm and attention
x2 <- x |>
self$attention_norm() |>
self$attention(..., cache = cache)
# maybe unpack cache returned by Attention
if(!is.null(cache))
c(x2, cache) %<-% x2
x <- x + x2 # add residual
# norm and swiglu projection
x2 <- x %>%
self$feed_forward_norm() %>%
self$feed_forward()
x <- x + x2 # residual again
if(is.null(cache)) x else list(x, cache)
}
}
TransformerDecoder(keras$Model) %py_class% {
initialize <- function(vocab_size, n_blocks, n_heads, head_size, norm_eps) {
super$initialize()
self$head_size <- head_size
self$n_heads <- n_heads
self$tok_embeddings <- keras$layers$Embedding(
input_dim = vocab_size,
output_dim = n_heads*head_size,
embeddings_initializer =
\(...) np$load(weights_path("7B/tok_embeddings.weight.npy")))
self$blocks <- lapply(seq_len0(n_blocks), function(block_id) {
TransformerBlock(attn_head_size = head_size,
attn_n_heads = n_heads,
norm_eps = norm_eps,
block_id = block_id)
})
self$norm <- RMSNorm(block_id = -1, eps = norm_eps)
self$output_proj <- Dense(
vocab_size, use_bias = FALSE,
kernel_initializer = \(...)
np$load(weights_path("7B/output.weight.npy"))$`T`)
self$freqs <- precompute_rotarty_freqs(feature_dim = head_size,
seqlen = 2048L)
}
call <- function(tokens) {
c(bsz, seqlen) %<-% tf$unstack(tf$shape(tokens))
freqs <- self$freqs |> lapply(\(f) f[, NA:seqlen, , ])
mask <- make_mask(seqlen)
x <- tokens |>
self$tok_embeddings()
for (block in self$blocks)
x <- block(x, freqs = freqs, mask = mask)
local_options(c(tensorflow.extract.warn_negatives_pythonic = FALSE))
x |>
self$norm() |>
_[, -1, ] |>
self$output_proj()
}
call_with_cache <- function(tokens, cache, position) {
c(batch_size, seqlen) %<-% tf$unstack(tf$shape(tokens))
# Sanity check: after the initial seeding of cache with the prompt, we
# should only be running inference on one token at a time.
tf_assert(position == 0 | seqlen == 1)
if(is.numeric(position) && position == 0L) {
# initial cache seeding
mask <- make_mask(seqlen)
freqs <- self$freqs |> lapply(\(f) f[, NA:seqlen, , ])
} else {
# inference with one token
position %<>% as_tensor(dtype = "int32")
freqs <- self$freqs |> lapply(\(f) f[, position, , ])
mask <- NULL
}
blocks <- self$blocks
stopifnot(is.list(cache), length(cache) == length(blocks))
x <- tokens |>
self$tok_embeddings()
for (i in seq_along(blocks)) {
c(x, cache[[i]]) %<-% blocks[[i]](x, cache = cache[[i]],
cache_index = position,
freqs = freqs,
mask = mask)
}
local_options(c(tensorflow.extract.warn_negatives_pythonic = FALSE))
output <- x |>
self$norm() |>
_[,-1,] |>
self$output_proj()
list(output, cache)
}
.make_cache <- function(prompt_tokens, max_seqlen = 2048L) {
c(batch_size, seqlen) %<-% tf$unstack(tf$shape(prompt_tokens))
import_from({self}, head_size, n_heads)
max_seqlen <- min(max_seqlen + seqlen, 2048L)
cache_shape <- c(batch_size, max_seqlen, n_heads, head_size)
cache <- lapply(seq_along(self$blocks), \(.) {
list(k = tf$zeros(cache_shape), v = tf$zeros(cache_shape))
})
tokens_with_preallocated_space <-
tf$zeros(c(batch_size, max_seqlen), dtype = "int32") |>
dynamic_update_slice(update = prompt_tokens, indices = c(0L, 0L))
# run first forward pass to seed cache with initial prompt
# return (propmt_tokens, next_token_probs, cache)
c(tokens_with_preallocated_space,
self$call_with_cache(prompt_tokens, cache = cache, position = 0L))
}
private$sampler_fn <- \(logits) logits |>
tf$argmax(axis = -1L, output_type = "int32") |>
tf$expand_dims(-1L)
sampler %<-active% function(fn) {
if(missing(fn))
private$sampler_fn
else
private$sampler_fn <- fn
}
generate <- function(prompt, max_len = 20L) {
max_len %<>% as_tensor("int32")
prompt %<>% as_tensor()
# accept either tokens or a string
if (prompt$dtype$name == "string") {
if(length(dim(prompt)) == 0) # ensure a batch dim
prompt %<>% .[tf$newaxis]
tokens <- tokenizer$tokenize(prompt)$to_tensor()
} else {
tokens <- prompt
if(length(dim(prompt)) == 1) # ensure a batch dim
tokens %<>% .[tf$newaxis, ]
}
c(batch_size, initial_prompt_len) %<-% tf$unstack(tf$shape(tokens))
max_seqlen <- min(max_len + initial_prompt_len, 2048L)
c(tokens, next_token_probs, cache) %<-% self$.make_cache(tokens, max_len)
i <- initial_prompt_len
autograph({
# enable `if` and `for` to accept tensors
for (i in tf$range(initial_prompt_len, max_seqlen, dtype = "int32")) {
next_token <- self$sampler(next_token_probs)
tokens %<>% dynamic_update_slice(next_token, c(0L, i))
if (any(next_token == 2L))
break # end-of-sequence token
c(next_token_probs, cache) %<-%
self$call_with_cache(next_token, cache, i)
}
})
tokens %<>% .[, NA:(i+1)] # drop unused preallocated space
if(prompt$dtype$name == "string")
# return string if supplied a string
tokenizer$detokenize(tokens)
else
tokens
}
}
# ---- load
weights_path <- function(rel_path) {
normalizePath(
file.path(
"~/github/facebookresearch/llama/weights/LLaMA/",
glue::glue(rel_path, .envir = parent.frame())
),
mustWork = TRUE
)
}
params <- jsonlite::read_json(weights_path("7B/params.json"))
tf_text <- reticulate::import("tensorflow_text")
tokenizer_path <- weights_path("tokenizer.model")
tokenizer <- tf_text$SentencepieceTokenizer(
tf$io$gfile$GFile(tokenizer_path, "rb")$read(),
add_bos = TRUE, add_eos = FALSE,
)
llama <- TransformerDecoder(vocab_size = tokenizer$vocab_size(),
n_blocks = params$n_layers,
n_heads = params$n_heads,
head_size = params$dim %/% params$n_heads,
norm_eps = params$norm_eps)
prompt <- "The best way to attract bees"
test_generate <- function() {
prompt |>
tokenizer$tokenize() |>
llama$generate(as_tensor(17L)) |>
tokenizer$detokenize() |>
as.character() |>
strwrap(60) |> writeLines()
}
test_generate()
## expected output with the argmax() sampler:
# The best way to attract bees to your garden is to plant a
# variety of flowers that bloom at different times.
# Timings on CPU:
print(system.time(test_generate()))
# user system elapsed
# 99.562 0.149 89.057
# Compile to XLA
llama$generate %<>% tf_function(jit_compile = TRUE)
# First call includes tracing time
print(system.time(test_generate()))
# user system elapsed
# 64.944 0.809 55.314
# Second call is pure graph mode
print(system.time(generate()))
# user system elapsed
# 28.754 0.120 18.453
@philippmuench
Copy link

Hi Tomasz,

Thank you for sharing your work! I’ve tried both the version from your blog post and the code in this Gist. I believe I have installed the correct Python and TensorFlow versions as outlined in the blog, but I’m running into an issue during the weight loading step.

Would you happen to have any insights on what might be going wrong?

Thanks so much for your help!

Philipp

torch <- reticulate::import("torch", convert = FALSE)
with_dir("/home/pmuench/github/meta-llama/llama/llama-2-7b/", {
   pretrained_weights <- torch$load("consolidated.00.pth",
                                    map_location = "cpu")
   for (name in names(pretrained_weights)) {
     filename <- sprintf("%s.npy", name)
     array <- pretrained_weights[[name]]$numpy()
     np$save(filename, array)
     message(glue(
       "wrote: '{basename(filename)}' with shape: {array$shape}"))
   }
 })

with the output

sys:1: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
Error in py_call_impl(callable, call_args$unnamed, call_args$named) : 
  TypeError: Got unsupported ScalarType BFloat16
Run `reticulate::py_last_error()` for details.

with

> reticulate::py_last_error()

── Python Exception Message ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
TypeError: Got unsupported ScalarType BFloat16

── R Traceback ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
    ?
 1. ├─withr::with_dir(...)
 2. │ └─base::force(code)
 3. └─pretrained_weights[[name]]$numpy()
 4.   └─reticulate:::py_call_impl(callable, call_args$unnamed, call_args$named)
See `reticulate::py_last_error()$r_trace$full_call` for more details.

When following the Gist, I get an error on the test_generate()step

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1729036102.836315 3861519 service.cc:146] XLA service 0x56001110ba90 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1729036102.836382 3861519 service.cc:154]   StreamExecutor device (0): Host, Default Version

I0000 00:00:1729036103.430079 3861519 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
Error in py_call_impl(callable, call_args$unnamed, call_args$named) : 
  ValueError: Invalid dtype: <property object at 0x1534f04b8950>
Run `reticulate::py_last_error()` for details.
> sessionInfo()
R version 4.4.1 (2024-06-14)
Platform: x86_64-conda-linux-gnu
Running under: AlmaLinux 8.9 (Midnight Oncilla)

Matrix products: default
BLAS/LAPACK: /home/pmuench/miniconda3/envs/tf_py310/lib/libopenblasp-r0.3.27.so;  LAPACK version 3.12.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

time zone: Europe/Berlin
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] purrr_1.0.2            keras_2.15.0           tfautograph_0.3.2     
[4] tensorflow_2.16.0.9000 envir_0.3.0            glue_1.8.0            
[7] dplyr_1.1.4           

loaded via a namespace (and not attached):
 [1] vctrs_0.6.5            cli_3.6.3              zeallot_0.1.0         
 [4] rlang_1.1.4            png_0.1-8              generics_0.1.3        
 [7] jsonlite_1.8.9         backports_1.5.0        fansi_1.0.6           
[10] grid_4.4.1             tfruns_1.5.3           tibble_3.2.1          
[13] base64enc_0.1-3        lifecycle_1.0.4        whisker_0.4.1         
[16] compiler_4.4.1         Rcpp_1.0.13            pkgconfig_2.0.3       
[19] lattice_0.22-6         R6_2.5.1               reticulate_1.39.0.9000
[22] tidyselect_1.2.1       utf8_1.2.4             pillar_1.9.0          
[25] magrittr_2.0.3         Matrix_1.7-0           withr_3.0.1           
[28] tools_4.4.1
tf$version$VERSION
[1] "2.17.0"
 py_config()$version
[1] "3.10"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment