Last active
September 28, 2024 01:44
-
-
Save nihalpasham/133a935304e22054b0fe92efde43caec to your computer and use it in GitHub Desktop.
Expanding CubeCL's gelu example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// 🦀 Generated by Rust Macro Expand 🦀 | |
// 🦀 Timestamp: 16/09/2024, 12:50:30 🦀 | |
#![allow(warnings)] | |
#![feature(print_internals)] | |
#![feature(panic_internals)] | |
#![feature(prelude_import)] | |
#[prelude_import] | |
use std::prelude::rust_2021::*; | |
#[macro_use] | |
extern crate std; | |
use cubecl::prelude::*; | |
#[allow(dead_code, clippy::too_many_arguments)] | |
fn gelu_array<F: Float>(input: &Array<F>, output: &mut Array<F>) { | |
use cubecl::prelude::{CubeIndex as _, CubeIndexMut as _}; | |
if ABSOLUTE_POS < input.len() { | |
*output.cube_idx_mut(ABSOLUTE_POS) = gelu_scalar::<F>(*input.cube_idx(ABSOLUTE_POS)); | |
} | |
} | |
mod gelu_array { | |
use super::*; | |
#[allow(unused, clippy::all)] | |
pub fn expand<F: Float>( | |
context: &mut cubecl::prelude::CubeContext, | |
input: <Array<F> as cubecl::prelude::CubeType>::ExpandType, | |
output: <Array<F> as cubecl::prelude::CubeType>::ExpandType, | |
) -> <() as cubecl::prelude::CubeType>::ExpandType { | |
use cubecl::prelude::IntoRuntime as _; | |
{ | |
{ | |
let _cond = { | |
let _lhs = ABSOLUTE_POS::expand(context); | |
let _rhs = { input.clone().__expand_len_method(context) }; | |
cubecl::frontend::lt::expand(context, _lhs, _rhs) | |
}; | |
cubecl::frontend::branch::if_expand(context, _cond.into(), |context| { | |
{ | |
let _array = output; | |
let _index = ABSOLUTE_POS::expand(context); | |
let _value = { | |
let _arg_0 = { | |
let _array = input; | |
let _index = ABSOLUTE_POS::expand(context); | |
cubecl::frontend::index::expand(context, _array, _index) | |
}; | |
gelu_scalar::expand::<F>(context, _arg_0.into()) | |
}; | |
cubecl::frontend::index_assign::expand(context, _array, _index, _value) | |
}; | |
() | |
}); | |
}; | |
() | |
} | |
} | |
///gelu_array Kernel | |
pub struct GeluArray<F: Float, __R: cubecl::prelude::Runtime> { | |
settings: cubecl::prelude::KernelSettings, | |
__ty: ::core::marker::PhantomData<(__R, F)>, | |
} | |
impl<F: Float, __R: cubecl::prelude::Runtime> GeluArray<F, __R> { | |
pub fn new(settings: cubecl::prelude::KernelSettings) -> Self { | |
Self { | |
settings, | |
__ty: ::core::marker::PhantomData, | |
} | |
} | |
} | |
impl<F: Float, __R: cubecl::prelude::Runtime> cubecl::Kernel for GeluArray<F, __R> { | |
fn define(&self) -> cubecl::prelude::KernelDefinition { | |
let mut builder = cubecl::prelude::KernelBuilder::default(); | |
let mut inputs: ::std::collections::BTreeMap< | |
usize, | |
std::sync::Arc<dyn core::any::Any>, | |
> = std::collections::BTreeMap::new(); | |
let mut outputs: ::std::collections::BTreeMap< | |
usize, | |
std::sync::Arc<dyn core::any::Any>, | |
> = std::collections::BTreeMap::new(); | |
#[allow(unused)] | |
let register_input = |builder: &mut cubecl::prelude::KernelBuilder, | |
settings: &cubecl::prelude::KernelSettings, | |
position: usize| | |
-> ::std::sync::Arc<dyn ::core::any::Any> { | |
match position { | |
0usize => ::std::sync::Arc::new( | |
<Array<F> as cubecl::prelude::LaunchArgExpand>::expand( | |
builder, | |
settings.vectorization_input(0usize), | |
), | |
), | |
_ => { | |
{ | |
::core::panicking::panic_fmt(format_args!( | |
"Input {0} is invalid", | |
position | |
)); | |
}; | |
} | |
} | |
}; | |
#[allow(unused)] | |
let register_output = |builder: &mut cubecl::prelude::KernelBuilder, | |
settings: &cubecl::prelude::KernelSettings, | |
position: usize| | |
-> ::std::sync::Arc<dyn ::core::any::Any> { | |
match position { | |
0usize => ::std::sync::Arc::new( | |
<Array<F> as cubecl::prelude::LaunchArgExpand>::expand_output( | |
builder, | |
settings.vectorization_output(0usize), | |
), | |
), | |
_ => { | |
{ | |
::core::panicking::panic_fmt(format_args!( | |
"Input {0} is invalid", | |
position | |
)); | |
}; | |
} | |
} | |
}; | |
for i in 0..1usize { | |
inputs.insert(i, register_input(&mut builder, &self.settings, i)); | |
} | |
for mapping in &self.settings.mappings { | |
let input = inputs.get(&mapping.pos_input).unwrap(); | |
outputs.insert(mapping.pos_output, input.clone()); | |
} | |
for i in 0..1usize { | |
if !outputs.contains_key(&i) { | |
outputs.insert(i, register_output(&mut builder, &self.settings, i)); | |
} | |
} | |
let input: &<Array<F> as cubecl::prelude::CubeType>::ExpandType = inputs | |
.get(&0usize) | |
.unwrap() | |
.downcast_ref() | |
.expect( | |
"Output type should be correct. It could be caused by an invalid kernel input/output alias.", | |
); | |
let output: &<Array<F> as cubecl::prelude::CubeType>::ExpandType = outputs | |
.get(&0usize) | |
.unwrap() | |
.downcast_ref() | |
.expect( | |
"Output type should be correct. It could be caused by an invalid kernel input/output alias.", | |
); | |
expand::<F>(&mut builder.context, input.clone(), output.clone()); | |
builder.build(self.settings.clone()) | |
} | |
fn id(&self) -> cubecl::KernelId { | |
cubecl::KernelId::new::<Self>().info((self.settings.clone())) | |
} | |
} | |
#[allow(clippy::too_many_arguments)] | |
///Launch the kernel [gelu_array()] on the given runtime | |
pub unsafe fn launch_unchecked<'kernel, F: Float, __R: cubecl::prelude::Runtime>( | |
__client: &cubecl::prelude::ComputeClient<__R::Server, __R::Channel>, | |
__cube_count: cubecl::prelude::CubeCount<__R::Server>, | |
__cube_dim: cubecl::prelude::CubeDim, | |
input: cubecl::RuntimeArg<'kernel, Array<F>, __R>, | |
output: cubecl::RuntimeArg<'kernel, Array<F>, __R>, | |
) -> () { | |
use cubecl::frontend::ArgSettings as _; | |
let mut __settings = cubecl::prelude::KernelSettings::default().cube_dim(__cube_dim); | |
__settings = | |
cubecl::prelude::ArgSettings::<__R>::configure_input(&input, 0usize, __settings); | |
__settings = | |
cubecl::prelude::ArgSettings::<__R>::configure_output(&output, 0usize, __settings); | |
let kernel = GeluArray::<F, __R>::new(__settings); | |
let mut launcher = cubecl::prelude::KernelLauncher::<__R>::default(); | |
input.register(&mut launcher); | |
output.register(&mut launcher); | |
launcher.launch_unchecked(__cube_count, kernel, __client); | |
} | |
} | |
#[allow(dead_code, clippy::too_many_arguments)] | |
fn gelu_scalar<F: Float>(x: F) -> F { | |
use cubecl::prelude::{CubeIndex as _, CubeIndexMut as _}; | |
x * F::erf(x / F::new(2.0f32.sqrt()) + F::new(1.0)) / F::new(2.0) | |
} | |
mod gelu_scalar { | |
use super::*; | |
#[allow(unused, clippy::all)] | |
pub fn expand<F: Float>( | |
context: &mut cubecl::prelude::CubeContext, | |
x: <F as cubecl::prelude::CubeType>::ExpandType, | |
) -> <F as cubecl::prelude::CubeType>::ExpandType { | |
use cubecl::prelude::IntoRuntime as _; | |
{ | |
{ | |
let _lhs = { | |
let _lhs = x.clone(); | |
let _rhs = { | |
let _arg_0 = { | |
let _lhs = { | |
let _lhs = x; | |
let _rhs = { | |
let _arg_0 = 2.0f32.sqrt(); | |
F::__expand_new(context, _arg_0.into()) | |
}; | |
cubecl::frontend::div::expand(context, _lhs, _rhs) | |
}; | |
let _rhs = { | |
let _arg_0 = 1.0; | |
F::__expand_new(context, _arg_0.into()) | |
}; | |
cubecl::frontend::add::expand(context, _lhs, _rhs) | |
}; | |
F::__expand_erf(context, _arg_0.into()) | |
}; | |
cubecl::frontend::mul::expand(context, _lhs, _rhs) | |
}; | |
let _rhs = { | |
let _arg_0 = 2.0; | |
F::__expand_new(context, _arg_0.into()) | |
}; | |
cubecl::frontend::div::expand(context, _lhs, _rhs) | |
} | |
} | |
} | |
} | |
pub fn launch<R: Runtime>(device: &R::Device) { | |
let client = R::client(device); | |
let input = &[-1., 0., 1., 5.]; | |
let output_handle = client.empty(input.len() * core::mem::size_of::<f32>()); | |
let input_handle = client.create(f32::as_bytes(input)); | |
unsafe { | |
gelu_array::launch_unchecked::<f32, R>( | |
&client, | |
CubeCount::Static(1, 1, 1), | |
CubeDim::new(input.len() as u32, 1, 1), | |
ArrayArg::from_raw_parts(&input_handle, input.len(), 1), | |
ArrayArg::from_raw_parts(&output_handle, input.len(), 1), | |
) | |
}; | |
let bytes = client.read(output_handle.binding()); | |
let output = f32::from_bytes(&bytes); | |
{ | |
::std::io::_print(format_args!( | |
"Executed gelu with runtime {0:?} => {1:?}\n", | |
R::name(), | |
output, | |
)); | |
}; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment