Last active
May 13, 2024 07:07
-
-
Save deepankarsharma/7955e64b423bf39a8bc32304d3be9fe3 to your computer and use it in GitHub Desktop.
avx2 running sum
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub unsafe fn avx_hsum(a: __m256i) -> i32 { | |
let zero = _mm256_setzero_si256(); | |
let sad = _mm256_sad_epu8(a, zero); | |
let sum = _mm256_extract_epi16::<0>(sad) as i32 + _mm256_extract_epi16::<4>(sad) + _mm256_extract_epi16::<8>(sad) + _mm256_extract_epi16::<12>(sad); | |
sum | |
} | |
#[inline(always)] | |
unsafe fn negate_8bit_ints(v: __m256i) -> __m256i { | |
let zero = _mm256_setzero_si256(); | |
_mm256_sub_epi8(zero, v) | |
} | |
pub unsafe fn count_newlines_memmap_avx2_running_sum(filename: &str) -> Result<usize, Error> { | |
let file = File::open(filename)?; | |
let mmap = unsafe { Mmap::map(&file)? }; | |
mmap.advise(Advice::Sequential)?; | |
let newline_byte = b'\n'; | |
let newline_vector = _mm256_set1_epi8(newline_byte as i8); | |
let mut newline_count = 0; | |
let mut running_sum = _mm256_setzero_si256(); | |
let mut ptr = mmap.as_ptr(); | |
let end_ptr = unsafe { ptr.add(mmap.len()) }; | |
let mut iteration_count = 0; | |
while ptr <= end_ptr.sub(32) { | |
let data = unsafe { _mm256_loadu_si256(ptr as *const __m256i) }; | |
// cmp_result will have -1's for newlines, 0's otherwise | |
let cmp_result = _mm256_cmpeq_epi8(data, newline_vector); | |
// since cmp_result has -1's we accumulate negative values here | |
// we fix those during the call to avx_hsum | |
running_sum = _mm256_add_epi8(running_sum, cmp_result); | |
ptr = unsafe { ptr.add(32) }; | |
iteration_count += 1; | |
if iteration_count % 128 == 0 { | |
let fixed_running_sum = negate_8bit_ints(running_sum); | |
newline_count += avx_hsum(fixed_running_sum) as usize; | |
running_sum = _mm256_setzero_si256(); | |
} | |
} | |
// Process remaining iterations | |
if iteration_count % 128 != 0 { | |
let fixed_running_sum = negate_8bit_ints(running_sum); | |
newline_count += avx_hsum(fixed_running_sum) as usize; | |
} | |
// Count remaining bytes | |
let remaining_bytes = end_ptr as usize - ptr as usize; | |
newline_count += mmap[mmap.len() - remaining_bytes..].iter().filter(|&&b| b == newline_byte).count(); | |
reset_file_caches(); | |
Ok(newline_count) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment