Created
April 2, 2024 16:16
-
-
Save JosiahParry/7886dc2d57c70a52cc76bd7d9d77ab64 to your computer and use it in GitHub Desktop.
Recreating the vacc function with simd from advanced R. This isn't very well done. It only returns in multiples of 4. This is probably because of `array_chunks::<4>()`. Its surprisingly slower than anticipated?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#![feature(portable_simd)] | |
#![feature(array_chunks)] | |
#[extendr] | |
fn vacc(age: &[f64], female: &[u8], ily: &[f64]) -> Vec<f64> { | |
age.array_chunks::<4>() | |
.map(|&a| f64x4::from_array(a)) | |
.zip(female.array_chunks::<4>().map(|&f| u8x4::from_array(f))) | |
.zip(ily.array_chunks::<4>().map(|&i| f64x4::from_array(i))) | |
.map(|((a, f), i)| { | |
// 0.25 + 0.3 * 1.0 | |
let num = f64x4::splat(0.24) + f64x4::splat(0.3) * f64x4::splat(1.0); | |
let denom = | |
(f64x4::splat(1.0) - (f64x4::splat(0.04) * a).exp()) + f64x4::splat(0.1) * i; | |
let coef = if f == u8x4::splat(0) { | |
f64x4::splat(1.25) | |
} else { | |
f64x4::splat(0.75) | |
}; | |
let p = num / denom; | |
(p * coef) | |
.simd_max(f64x4::splat(0.0)) | |
.simd_min(f64x4::splat(1.0)) | |
}) | |
.flat_map(|x| x.to_array()) | |
.collect::<Vec<_>>() | |
} | |
#[extendr] | |
fn vacc_(age: &[f64], female: &[f64], ily: &[f64]) -> Vec<f64> { | |
let num = f64x4::splat(0.55); | |
let one_quarter = f64x4::splat(0.25); | |
let three_tenths = f64x4::splat(0.3); | |
let one = f64x4::splat(1.0); | |
let a_coef = f64x4::splat(0.04); | |
let i_coef = f64x4::splat(0.1); | |
let male_coef = f64x4::splat(1.25); | |
let female_coef = f64x4::splat(0.75); | |
let min_val = f64x4::splat(0.0); | |
let zero = u8x4::splat(0); | |
let max_val = f64x4::splat(1.0); | |
age.array_chunks::<4>() | |
.map(|&a| f64x4::from_array(a)) | |
.zip(female.array_chunks::<4>().map(|&f| f64x4::from_array(f))) | |
.zip(ily.array_chunks::<4>().map(|&i| f64x4::from_array(i))) | |
.map(|((a, f), i)| { | |
let p = one_quarter + three_tenths * one / (one - (a_coef * a).exp()) + (i_coef * i); | |
let coef = f * male_coef + (f - one).abs() * female_coef; | |
(p * coef).simd_max(min_val).simd_min(max_val) | |
}) | |
.flat_map(|x| x.to_array()) | |
.collect::<Vec<_>>() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment