Last active
March 8, 2024 01:32
-
-
Save JosephCatrambone/39fb1d7902ffdb73530119b0039855af to your computer and use it in GitHub Desktop.
Embedding GPT-2 in Godot via Rust
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mod ml_thread; | |
use gdnative::prelude::{godot_print, methods, Method, NativeClass, Node as GDNode, InitHandle, godot_init}; | |
use ml_thread::start_language_model_thread; | |
use std::sync::mpsc::{channel, Receiver, RecvError, Sender, SendError}; | |
const MAX_INPUT_LENGTH: usize = 512; | |
const BATCH_SIZE: usize = 1; | |
// Contains our processing job and work IDs. | |
#[derive(NativeClass)] | |
#[inherit(GDNode)] | |
pub struct ChatBot { | |
message_tx: Sender<String>, | |
response_rx: Receiver<String> | |
} | |
// Only one impl block can have [methods]. | |
#[methods] | |
impl ChatBot { | |
/// The "constructor" of the class. | |
pub fn new(_base: &GDNode) -> Self { | |
let (tx, rx) = start_language_model_thread(); | |
ChatBot { | |
message_tx: tx, | |
response_rx: rx, | |
} | |
} | |
pub fn make_reply(&self, text: &str, maxent: bool) -> String { | |
// If we use maxent, then we just pick the most likely word. | |
// Otherwise select probabalistically. | |
// No beam search yet. | |
self.message_tx.send(text.to_string()).expect("Child runner crashed."); | |
if let Some(msg) = self.response_rx.recv() { | |
return msg; | |
} | |
return "".into(); | |
} | |
#[method] | |
fn _ready(&self, #[base] base: &GDNode) { | |
// The `godot_print!` macro works like `println!` but prints to the Godot-editor output tab as well. | |
godot_print!("Hello world from node {}!", base.to_string()); | |
} | |
#[method] | |
fn process_user_query(&self, #[base] base: &GDNode, user_str: String) -> String { | |
// We could use GodotString, but there are different performance characteristics. Let's try this one! | |
//godot_print!("Got a call to the process_user_query endpoint with {}", &user_str); | |
self.make_reply(&user_str, false) | |
} | |
} | |
// Function that registers all exposed classes to Godot | |
fn init(handle: InitHandle) { | |
handle.add_class::<ChatBot>() | |
} | |
// Macro that creates the entry-points of the dynamic library. | |
godot_init!(init); | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
#[test] | |
fn it_works() { | |
let cb = ChatBot::load(); | |
let result = cb.make_reply("I give you one yike:", true); | |
println!("{result}"); | |
//assert_eq!(result, 0); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use ndarray::s; | |
use std::{f32, path::{Path, PathBuf}, str::FromStr}; | |
use std::error::Error; | |
use std::io::Cursor; | |
use std::sync::mpsc::{channel, Receiver, RecvError, Sender}; | |
use std::thread; | |
use tokenizers::tokenizer::{Result, Tokenizer}; | |
use tract_onnx::prelude::*; | |
use tract_onnx::tract_hir::infer::InferenceOp; | |
use rand::{Rng, thread_rng}; | |
use crate::{BATCH_SIZE, MAX_INPUT_LENGTH}; | |
fn run_language_model(input_channel: Receiver<String>, result_channel: Sender<String>) { | |
let tokenizer: Tokenizer = Tokenizer::from_str(include_str!("../model/tokenizer_gpt2.json")).expect("Failed to load packed tokenizer. Library may be corrupt."); | |
// A little info on GPT-2: | |
// input1 - type: int64[input1_dynamic_axes_1,input1_dynamic_axes_2,input1_dynamic_axes_3] | |
// output1 - type: float32[input1_dynamic_axes_1,input1_dynamic_axes_2,input1_dynamic_axes_3,50257] | |
// output dims are [1, 50257]. | |
let mut model_path = PathBuf::from_str("model").unwrap(); | |
model_path.push(Path::new("model.onnx")); | |
let mut model_buf = Cursor::new(include_bytes!("../model/gpt-neo-2.onnx")); | |
//model: RunnableModel<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>, | |
//model: SimplePlan<InferenceFact, Box<dyn InferenceOp>, Graph<InferenceFact, Box<dyn InferenceOp>>>, | |
let model = tract_onnx::onnx() | |
.model_for_read(&mut model_buf).expect("Unable to read from model built into binary. This indicates corruption.") | |
//.model_for_path(model_path).unwrap() | |
.with_input_fact(0, i64::fact(&[BATCH_SIZE, MAX_INPUT_LENGTH]).into()).expect("Defining input fact size on preloaded language model.") | |
.with_input_fact(1, i64::fact(&[BATCH_SIZE, MAX_INPUT_LENGTH]).into()).expect("Defining input mask size on preloaded language model.") | |
//.with_output_fact(0, f32::fact(&[axis_1_shape, axis_2_shape, axis_3_shape, axis_4_shape]).into()).unwrap() | |
//.into_optimized().expect("Converting packaged model to optimized build failed.") | |
.into_runnable().expect("Converting optimized model into runnable model failed."); | |
loop { | |
match input_channel.recv() { | |
Ok(msg) => { | |
let tokenizer_output = tokenizer.encode(text, true).expect("Unable to encode input string."); | |
let token_ids = tokenizer_output.get_ids(); | |
let mask: Tensor = tract_ndarray::Array2::from_shape_fn((1, MAX_INPUT_LENGTH), |idx|{ if idx.0 < MAX_INPUT_LENGTH && idx.1 < MAX_INPUT_LENGTH { 1i64 } else { 0i64 } }).into(); | |
let token_tensor: Tensor = tract_ndarray::Array2::from_shape_fn((1, MAX_INPUT_LENGTH),|idx| { if idx.0 < token_ids.len() { token_ids[idx.0] as i64 } else { 0 as i64 } }).into(); | |
//let token_tensor: Tensor = tract_ndarray::Array2::from_shape_vec((1, token_ids.len()),token_ids.iter().map(|&x| x as i64).collect()).unwrap().into(); | |
let outputs = model.run(tvec!(token_tensor, mask)).expect("Failed to run model on token tensor."); | |
let logits = outputs[0].to_array_view::<f32>().expect("Unable to convert tensor output to f32 array."); | |
let word_id = if maxent { | |
logits.iter().zip(0..).max_by(|a, b| a.0.partial_cmp(b.0).unwrap()).unwrap().1 | |
} else { | |
let mut rng = thread_rng(); | |
let mut energy = rng.gen::<f32>(); | |
let mut selected_token = 0; | |
for (idx, token_energy_hill) in logits.iter().enumerate() { | |
if *token_energy_hill > energy { | |
selected_token = idx; | |
energy = 0.0; | |
} else { | |
energy -= *token_energy_hill; | |
} | |
} | |
selected_token as u32 | |
}; | |
let word = tokenizer.id_to_token(word_id).unwrap_or(" ".into()); | |
result_channel.send(word).expect("Failed to send result."); | |
} | |
Err(_) => { | |
return; | |
} | |
} | |
} | |
} | |
pub fn start_language_model_thread() -> (Sender<String>, Receiver<String>) { | |
let (user_input_tx, user_input_rx) = channel(); | |
let (ai_completion_tx, ai_completion_rx) = channel(); | |
thread::spawn(move || run_language_model(user_input_rx, ai_completion_tx)); | |
(user_input_tx, ai_completion_rx) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment