Last active
May 30, 2025 12:47
-
-
Save Swoorup/6030a622346ef465cffad57440b52eb8 to your computer and use it in GitHub Desktop.
Gpu Buffer Tracking
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::ops::Range; | |
use derive_more::{Deref, DerefMut, Into}; | |
#[derive(Debug, Deref, DerefMut, Hash, PartialEq, Eq, Clone, Into)] | |
pub struct DirtyIndexRange(pub(super) Range<usize>); | |
impl DirtyIndexRange { | |
pub fn one(index: usize) -> Self { | |
Self(index..index + 1) | |
} | |
pub fn range(start: usize, until: usize) -> Self { | |
Self(start..until) | |
} | |
} | |
#[derive(Default)] | |
pub struct DirtyIndexSet { | |
inner: Vec<DirtyIndexRange>, | |
index_merge_distance: usize, | |
} | |
impl DirtyIndexSet { | |
pub fn new() -> Self { | |
const MERGE_DISTANCE: usize = 1; // TODO: Maybe increase this | |
Self { | |
inner: vec![], | |
index_merge_distance: MERGE_DISTANCE, | |
} | |
} | |
pub fn clear(&mut self) { | |
self.inner.clear(); | |
} | |
pub fn is_empty(&self) -> bool { | |
self.inner.is_empty() | |
} | |
pub fn iter(&self) -> impl Iterator<Item = &Range<usize>> { | |
self.inner.iter().map(|range| &range.0) | |
} | |
pub fn to_vec(&self) -> Vec<Range<usize>> { | |
self.iter().cloned().collect() | |
} | |
pub fn insert(&mut self, range: impl Into<Range<usize>>) { | |
let mut range = range.into(); | |
// Find the insertion point | |
let mut insert_idx = self | |
.inner | |
.partition_point(|existing_range| existing_range.end < range.start); | |
// Merge with the previous range if they overlap | |
if insert_idx > 0 | |
&& self.inner[insert_idx - 1].end + self.index_merge_distance >= range.start | |
{ | |
insert_idx -= 1; | |
range.start = self.inner[insert_idx].start.min(range.start); | |
range.end = self.inner[insert_idx].end.max(range.end); | |
} | |
// Merge with the next ranges if they overlap | |
while insert_idx < self.inner.len() | |
&& self.inner[insert_idx].start <= range.end + self.index_merge_distance | |
{ | |
range.start = range.start.min(self.inner[insert_idx].start); | |
range.end = range.end.max(self.inner[insert_idx].end); | |
self.inner.remove(insert_idx); | |
} | |
// Insert the new range | |
self.inner.insert(insert_idx, DirtyIndexRange(range)); | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
#[test] | |
fn test_merged_insert_no_overlap() { | |
let mut set = DirtyIndexSet::new(); | |
set.insert(DirtyIndexRange::range(0, 5)); | |
set.insert(DirtyIndexRange::range(10, 15)); | |
assert_eq!( | |
set.inner, | |
vec![DirtyIndexRange::range(0, 5), DirtyIndexRange::range(10, 15)] | |
); | |
} | |
#[test] | |
fn test_merged_insert_with_overlap() { | |
let mut set = DirtyIndexSet::new(); | |
set.insert(DirtyIndexRange::range(0, 5)); | |
set.insert(DirtyIndexRange::range(4, 10)); | |
assert_eq!(set.inner, vec![DirtyIndexRange::range(0, 10)]); | |
} | |
#[test] | |
fn test_merged_insert_with_multiple_overlaps() { | |
let mut set = DirtyIndexSet::new(); | |
set.insert(DirtyIndexRange::range(0, 5)); | |
set.insert(DirtyIndexRange::range(10, 15)); | |
set.insert(DirtyIndexRange::range(4, 12)); | |
assert_eq!(set.inner, vec![DirtyIndexRange::range(0, 15)]); | |
} | |
#[test] | |
fn test_merged_insert_inside_existing() { | |
let mut set = DirtyIndexSet::new(); | |
set.insert(DirtyIndexRange::range(0, 10)); | |
set.insert(DirtyIndexRange::range(3, 7)); | |
assert_eq!(set.inner, vec![DirtyIndexRange::range(0, 10)]); | |
} | |
#[test] | |
fn test_merged_insert_same_as_existing() { | |
let mut set = DirtyIndexSet::new(); | |
set.insert(DirtyIndexRange::range(0, 10)); | |
set.insert(DirtyIndexRange::range(0, 10)); | |
assert_eq!(set.inner, vec![DirtyIndexRange::range(0, 10)]); | |
} | |
#[test] | |
fn test_merged_insert_single_inside_existing() { | |
let mut set = DirtyIndexSet::new(); | |
set.insert(DirtyIndexRange::range(0, 10)); | |
set.insert(DirtyIndexRange::one(5)); | |
assert_eq!(set.inner, vec![DirtyIndexRange::range(0, 10)]); | |
} | |
#[test] | |
fn test_merged_insert_connects_ranges() { | |
let mut set = DirtyIndexSet::new(); | |
set.insert(DirtyIndexRange::range(0, 5)); | |
set.insert(DirtyIndexRange::range(10, 15)); | |
set.insert(DirtyIndexRange::range(5, 10)); | |
set.insert(DirtyIndexRange::range(16, 16)); | |
set.insert(DirtyIndexRange::range(16, 20)); | |
assert_eq!(set.inner, vec![DirtyIndexRange::range(0, 20)]); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::mem::size_of; | |
use wgpu::util::{BufferInitDescriptor, DeviceExt}; | |
use wgpu::{Buffer, BufferAddress, BufferDescriptor, BufferUsages, Device, Queue}; | |
pub struct GpuBuffer<T> | |
where | |
T: Sized + Clone + bytemuck::Pod + bytemuck::Zeroable, | |
{ | |
buffer: Buffer, | |
label: String, | |
phantom: std::marker::PhantomData<T>, | |
} | |
impl<T> GpuBuffer<T> | |
where | |
T: Sized + Clone + bytemuck::Pod + bytemuck::Zeroable, | |
{ | |
pub fn new(device: &Device, capacity: usize, usage: BufferUsages, label: &str) -> Self { | |
let buffer = device.create_buffer(&BufferDescriptor { | |
label: Some(label), | |
size: (size_of::<T>() * capacity).next_multiple_of(4) as BufferAddress, | |
usage, | |
mapped_at_creation: false, | |
}); | |
Self { | |
buffer, | |
label: label.into(), | |
phantom: Default::default(), | |
} | |
} | |
pub fn new_with_data( | |
device: &Device, | |
usage: BufferUsages, | |
label: &str, | |
data: &[T], | |
) -> Self { | |
let buffer = device.create_buffer_init(&BufferInitDescriptor { | |
label: Some(label), | |
contents: bytemuck::cast_slice(data), | |
usage, | |
}); | |
Self { | |
buffer, | |
label: label.into(), | |
phantom: Default::default(), | |
} | |
} | |
pub fn total_data_capacity(&self) -> usize { | |
self.buffer.size() as usize / size_of::<T>() | |
} | |
pub fn inner(&self) -> &Buffer { | |
&self.buffer | |
} | |
pub fn resized(&self, device: &Device, requested_length: usize) -> Option<Self> { | |
if requested_length > self.total_data_capacity() { | |
let new_buffer = | |
GpuBuffer::new(device, requested_length, self.buffer.usage(), &self.label); | |
Some(new_buffer) | |
} else { | |
None | |
} | |
} | |
pub fn write(&self, queue: &Queue, index: usize, data: &[T]) { | |
let start_byte_offset = (index * size_of::<T>()) as BufferAddress; | |
let data_slice = bytemuck::cast_slice(data); | |
queue.write_buffer(&self.buffer, start_byte_offset, data_slice); | |
} | |
pub fn read(&self, device: &Device, queue: &Queue) -> Vec<T> { | |
// Create a buffer that will receive the data from the GPU | |
let staging_buffer = device.create_buffer(&BufferDescriptor { | |
label: Some("staging buffer"), | |
size: self.buffer.size(), | |
usage: BufferUsages::COPY_DST | BufferUsages::MAP_READ, | |
mapped_at_creation: false, | |
}); | |
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { | |
label: Some("extract buffer encoder"), | |
}); | |
// Copy data from the GPU buffer to the buffer_copy | |
encoder.copy_buffer_to_buffer( | |
&self.buffer, | |
0, | |
&staging_buffer, | |
0, | |
self.buffer.size(), | |
); | |
// Submit the command encoder | |
queue.submit(Some(encoder.finish())); | |
let buffer_slice = staging_buffer.slice(..); | |
// Map the buffer_copy into memory | |
buffer_slice.map_async(wgpu::MapMode::Read, |_| ()); | |
// Make sure the GPU executes the command before we try to read the data | |
device.poll(wgpu::Maintain::Wait).panic_on_timeout(); | |
let data_view = buffer_slice.get_mapped_range().to_vec(); | |
let data = bytemuck::cast_slice(&data_view).to_vec(); | |
// Unmap and drop the buffer_copy | |
staging_buffer.unmap(); | |
data | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::hash::Hash; | |
use std::ops::RangeBounds; | |
use derive_more::{Index, IndexMut}; | |
use wgpu::naga::FastIndexMap; | |
use crate::wgpu_resources::buffer::DirtyIndexRange; | |
#[derive(Debug, Clone, PartialEq)] | |
pub(super) struct DataEntry<T: Clone + bytemuck::Pod + bytemuck::Zeroable> { | |
pub data: Vec<T>, | |
pub cumulative_data_len: usize, | |
} | |
impl<T: Clone + bytemuck::Pod + bytemuck::Zeroable> DataEntry<T> { | |
#[inline] | |
pub fn index(&self) -> usize { | |
self.cumulative_data_len - self.data.len() | |
} | |
#[inline] | |
pub fn byte_offset(&self) -> usize { | |
self.index() * std::mem::size_of::<T>() | |
} | |
} | |
#[derive(Index, IndexMut, Debug)] | |
pub(super) struct GpuDataMap<K: Hash + Eq, V: Clone + bytemuck::Pod + bytemuck::Zeroable>( | |
FastIndexMap<K, DataEntry<V>>, | |
); | |
impl<K: Hash + Eq, V: Clone + bytemuck::Pod + bytemuck::Zeroable> GpuDataMap<K, V> { | |
pub fn new() -> Self { | |
Self(FastIndexMap::default()) | |
} | |
pub fn new_with_data(data: Vec<(K, Vec<V>)>) -> Self { | |
let mut map = Self::new(); | |
for (key, data) in data { | |
map.create_or_update(key, data); | |
} | |
map | |
} | |
fn cumulative_data_len_at_index(&self, index: usize) -> usize { | |
self | |
.0 | |
.get_index(index) | |
.map(|(_, entry)| entry.cumulative_data_len) | |
.unwrap_or(0) | |
} | |
fn recalculate_cumulative_data_len_from_index(&mut self, start_index: usize) { | |
let mut cumulative_data_len = match start_index.checked_sub(1) { | |
Some(index) => self.cumulative_data_len_at_index(index), | |
None => 0, | |
}; | |
for (_, entry) in self.0.iter_mut() { | |
cumulative_data_len += entry.data.len(); | |
entry.cumulative_data_len = cumulative_data_len; | |
} | |
} | |
fn recompute_cumulative_data_and_get_dirty_index_range( | |
&mut self, | |
index: usize, | |
old_data_len: usize, | |
new_data_len: usize, | |
) -> DirtyIndexRange { | |
if old_data_len != new_data_len { | |
// update the dirty indices to entire range since the current. | |
self.recalculate_cumulative_data_len_from_index(index); | |
DirtyIndexRange::range(index, self.0.len()) | |
} else { | |
// update the dirty indices to the single index. | |
DirtyIndexRange::one(index) | |
} | |
} | |
pub fn create_or_update(&mut self, key: K, data: Vec<V>) -> DirtyIndexRange { | |
if let Some((index, _, entry)) = self.0.get_full_mut(&key) { | |
let old_data_len = entry.data.len(); | |
let new_data_len = data.len(); | |
entry.data = data; | |
self.recompute_cumulative_data_and_get_dirty_index_range( | |
index, | |
old_data_len, | |
new_data_len, | |
) | |
} else { | |
let cumulative_data_len = self.total_data_len() + data.len(); | |
self.0.insert( | |
key, | |
DataEntry { | |
data, | |
cumulative_data_len, | |
}, | |
); | |
DirtyIndexRange::one(self.len() - 1) | |
} | |
} | |
pub fn delete(&mut self, key: &K) -> Option<DirtyIndexRange> { | |
if let Some((index, _, entry)) = self.0.swap_remove_full(key) { | |
let (_, replaced_at_element) = self.0.get_index(index)?; | |
Some(self.recompute_cumulative_data_and_get_dirty_index_range( | |
index, | |
entry.data.len(), | |
replaced_at_element.data.len(), | |
)) | |
} else { | |
None | |
} | |
} | |
#[inline] | |
pub fn len(&self) -> usize { | |
self.0.len() | |
} | |
pub fn total_data_len(&self) -> usize { | |
self | |
.0 | |
.last() | |
.map(|(_, entry)| entry.cumulative_data_len) | |
.unwrap_or(0) | |
} | |
pub fn total_byte_len(&self) -> usize { | |
self.total_data_len() * std::mem::size_of::<V>() | |
} | |
pub fn write_to_contiguous_buffer<R: RangeBounds<usize>>( | |
&self, | |
range: R, | |
buffer: &mut Vec<V>, | |
) { | |
let Some(slice) = self.0.get_range(range) else { | |
return; | |
}; | |
for (_, entry) in slice { | |
buffer.extend(&entry.data); | |
} | |
} | |
pub fn collect<B: FromIterator<(K, DataEntry<V>)>>(&mut self) -> B | |
where | |
K: Clone, | |
{ | |
self.0.iter().map(|(k, v)| (k.clone(), v.clone())).collect() | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::mem::size_of; | |
use wgpu::util::{BufferInitDescriptor, DeviceExt}; | |
use wgpu::{Buffer, BufferAddress, BufferUsages, Device, Queue}; | |
use super::DirtyIndexSet; | |
use crate::v_assert; | |
// TODO: Possibly be able to do granular field updates, than the entire uniform | |
// using https://docs.rs/field_types/latest/field_types/ | |
pub struct GpuUniform<T: Sized + Clone + bytemuck::Pod + bytemuck::Zeroable> { | |
buffer: Buffer, | |
data: T, | |
label: String, | |
dirty_offset_set: DirtyIndexSet, | |
} | |
impl<T: Sized + Clone + bytemuck::Pod + bytemuck::Zeroable> GpuUniform<T> { | |
pub fn new(device: &Device, label: &str) -> Self | |
where | |
T: Default, | |
{ | |
Self::new_with_data(device, label, Default::default()) | |
} | |
pub fn new_with_data(device: &Device, label: &str, data: T) -> Self { | |
let buffer = device.create_buffer_init(&BufferInitDescriptor { | |
label: Some(label), | |
contents: bytemuck::bytes_of(&data), | |
usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST, | |
}); | |
Self { | |
buffer, | |
data, | |
label: label.into(), | |
dirty_offset_set: DirtyIndexSet::default(), | |
} | |
} | |
pub fn buffer(&self) -> &Buffer { | |
v_assert!(self.dirty_offset_set.is_empty(), "Buffer is dirty, call sync() first"); | |
&self.buffer | |
} | |
// update the buffer entirely | |
pub fn set(&mut self, data: T) { | |
self.data = data; | |
self.dirty_offset_set.insert(0..size_of::<T>()); | |
} | |
pub fn sync(&mut self, queue: &Queue) { | |
if !self.dirty_offset_set.is_empty() { | |
for range in self.dirty_offset_set.iter() { | |
let start = range.start; | |
let until = range.end; | |
let start_byte_offset = start as BufferAddress; | |
let data_bytes = bytemuck::bytes_of(&self.data); | |
queue.write_buffer(&self.buffer, start_byte_offset, &data_bytes[start..until]); | |
} | |
self.dirty_offset_set.clear(); | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::fmt::Debug; | |
use std::hash::Hash; | |
use std::mem::size_of; | |
use wgpu::{ | |
BindGroupEntry, BindGroupLayoutEntry, BindingType, Buffer, BufferBindingType, | |
BufferUsages, Device, Queue, ShaderStages, | |
}; | |
use crate::v_assert; | |
use crate::wgpu_resources::buffer::gpu_buffer::GpuBuffer; | |
use crate::wgpu_resources::buffer::{DirtyIndexSet, GpuDataMap}; | |
/// `GpuVariableVec` is a structure that manages a single GPU buffer with variable-length data blocks. | |
/// It uses a `GpuDataMap` to map keys of type `K` to values of type `V`. | |
pub struct GpuVariableVec< | |
K: Debug + Eq + Hash, | |
V: Clone + bytemuck::Pod + bytemuck::Zeroable, | |
> { | |
data: GpuDataMap<K, V>, | |
buffer: GpuBuffer<V>, | |
dirty_index_set: DirtyIndexSet, | |
} | |
impl<K: Debug + Eq + Hash, V: Clone + bytemuck::Pod + bytemuck::Zeroable> | |
GpuVariableVec<K, V> | |
{ | |
pub fn new(device: &Device, capacity: usize, usage: BufferUsages, label: &str) -> Self { | |
Self { | |
data: GpuDataMap::new(), | |
buffer: GpuBuffer::new(device, capacity, usage, label), | |
dirty_index_set: DirtyIndexSet::new(), | |
} | |
} | |
pub fn new_with_data( | |
device: &Device, | |
usage: BufferUsages, | |
label: &str, | |
data: Vec<(K, Vec<V>)>, | |
) -> Self { | |
let data = GpuDataMap::new_with_data(data); | |
let mut staging_buffer_data = vec![]; | |
data.write_to_contiguous_buffer(0..data.len(), &mut staging_buffer_data); | |
Self { | |
data, | |
buffer: GpuBuffer::new_with_data(device, usage, label, &staging_buffer_data), | |
dirty_index_set: DirtyIndexSet::new(), | |
} | |
} | |
pub fn total_data_len(&self) -> usize { | |
self.data.total_data_len() | |
} | |
pub fn buffer(&self) -> &Buffer { | |
v_assert!(self.dirty_index_set.is_empty(), "Buffer is dirty, call sync() first"); | |
&self.buffer.inner() | |
} | |
pub fn create_or_update(&mut self, key: K, data: Vec<V>) { | |
let dirty_index_range = self.data.create_or_update(key, data); | |
self.dirty_index_set.insert(dirty_index_range); | |
} | |
fn delete(&mut self, key: &K) { | |
if let Some(dirty_index_range) = self.data.delete(key) { | |
self.dirty_index_set.insert(dirty_index_range); | |
} else { | |
tracing::warn!("Key {:?} not found in buffer data map", key); | |
}; | |
} | |
fn realloc_if_necessary(&mut self, device: &Device) -> bool { | |
if self.data.total_data_len() > self.buffer.total_data_capacity() { | |
self.buffer = self | |
.buffer | |
.resized(device, self.data.total_data_len()) | |
.unwrap(); | |
true | |
} else { | |
false | |
} | |
} | |
pub fn bind_group_layout_entry(binding: u32) -> BindGroupLayoutEntry { | |
BindGroupLayoutEntry { | |
binding, | |
visibility: ShaderStages::VERTEX | ShaderStages::FRAGMENT, | |
ty: BindingType::Buffer { | |
ty: BufferBindingType::Storage { read_only: true }, | |
has_dynamic_offset: false, | |
min_binding_size: Some(std::num::NonZeroU64::new(size_of::<V>() as u64).unwrap()), | |
}, | |
count: None, | |
} | |
} | |
pub fn bind_group_layout(&self, binding: u32) -> BindGroupEntry { | |
BindGroupEntry { | |
binding, | |
resource: self.buffer.inner().as_entire_binding(), | |
} | |
} | |
pub fn sync(&mut self, device: &Device, queue: &Queue) -> bool { | |
let mut realloc = false; | |
if !self.dirty_index_set.is_empty() { | |
if self.realloc_if_necessary(device) { | |
// If we had to reallocate, we need to re-sync the entire buffer | |
self.dirty_index_set.insert(0..self.data.len()); | |
realloc = true; | |
} | |
let mut staging_buffer_data: Vec<V> = vec![]; | |
// Populate staging buffer with updated data and write to GPU: | |
for range in self.dirty_index_set.iter() { | |
let start = range.start; | |
let until = range.end.min(self.data.len()); // cap until to the length of the data_map | |
self | |
.data | |
.write_to_contiguous_buffer(start..until, &mut staging_buffer_data); | |
let entry = &self.data[start]; | |
self | |
.buffer | |
.write(queue, entry.index(), &staging_buffer_data); | |
staging_buffer_data.clear(); | |
} | |
self.dirty_index_set.clear(); // Clear dirty flags after syncing | |
} | |
realloc | |
} | |
// only useful when shader writes to this buffer and we need to read back | |
// or in the case of test_util | |
pub fn extract_buffer_to_host(&self, device: &Device, queue: &Queue) -> Vec<V> { | |
self.buffer.read(device, queue) | |
} | |
} | |
#[cfg(all(test, feature = "test_util"))] | |
mod tests { | |
use test_context::test_context; | |
use wgpu::BufferUsages; | |
use crate::test_util::GpuTestContext; | |
use crate::wgpu_resources::buffer::gpu_variable_vec::GpuVariableVec; | |
use crate::wgpu_resources::buffer::DataEntry; | |
#[test_context(GpuTestContext)] | |
#[test] | |
fn test_insert(ctx: &mut GpuTestContext) { | |
let mut manager: GpuVariableVec<i32, u16> = GpuVariableVec::new( | |
&ctx.device, | |
4, | |
BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC, | |
"test", | |
); | |
manager.create_or_update(0, vec![1, 1]); | |
manager.create_or_update(1, vec![2, 2]); | |
manager.sync(&ctx.device, &ctx.queue); | |
let buffer = manager.extract_buffer_to_host(&ctx.device, &ctx.queue); | |
assert_eq!(buffer, vec![1, 1, 2, 2]); | |
} | |
#[test_context(GpuTestContext)] | |
#[test] | |
fn test_insert_realloc(ctx: &mut GpuTestContext) { | |
let mut manager: GpuVariableVec<i32, u16> = GpuVariableVec::new( | |
&ctx.device, | |
2, | |
BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC, | |
"test", | |
); | |
manager.create_or_update(0, vec![1, 1]); | |
manager.create_or_update(1, vec![2, 2]); | |
manager.create_or_update(2, vec![3, 3]); | |
manager.sync(&ctx.device, &ctx.queue); | |
let buffer = manager.extract_buffer_to_host(&ctx.device, &ctx.queue); | |
assert_eq!(buffer, vec![1, 1, 2, 2, 3, 3]); | |
} | |
#[test_context(GpuTestContext)] | |
#[test] | |
fn test_update(ctx: &mut GpuTestContext) { | |
let mut manager: GpuVariableVec<i32, u8> = GpuVariableVec::new( | |
&ctx.device, | |
8, | |
BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC, | |
"test", | |
); | |
manager.create_or_update(0, vec![1, 1]); | |
manager.create_or_update(1, vec![2, 2]); | |
manager.create_or_update(0, vec![3, 3]); | |
manager.sync(&ctx.device, &ctx.queue); | |
let buffer = manager.extract_buffer_to_host(&ctx.device, &ctx.queue); | |
assert_eq!(buffer, vec![3, 3, 2, 2, 0, 0, 0, 0]); | |
} | |
#[test_context(GpuTestContext)] | |
#[test] | |
fn test_delete(ctx: &mut GpuTestContext) { | |
let mut manager: GpuVariableVec<i32, u8> = GpuVariableVec::new( | |
&ctx.device, | |
8, | |
BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC, | |
"test", | |
); | |
manager.create_or_update(0, vec![1, 1]); | |
manager.create_or_update(1, vec![2, 2, 2, 2]); | |
manager.delete(&0); | |
assert_eq!( | |
manager.data.collect::<Vec<_>>(), | |
vec![( | |
1, | |
DataEntry { | |
data: vec![2, 2, 2, 2], | |
cumulative_data_len: 4, | |
} | |
)] | |
); | |
assert_eq!(manager.dirty_index_set.to_vec(), vec![0..2]); | |
manager.sync(&ctx.device, &ctx.queue); | |
let buffer = manager.extract_buffer_to_host(&ctx.device, &ctx.queue); | |
assert_eq!(buffer, vec![2, 2, 2, 2, 0, 0, 0, 0]); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::ops::Range; | |
use wgpu::{Buffer, BufferUsages, Device, Queue}; | |
use super::DirtyIndexSet; | |
use crate::v_assert; | |
use crate::wgpu_resources::buffer::gpu_buffer::GpuBuffer; | |
// TODO: Replace with https://github.com/teoxoy/encase | |
// Once it has better performance: https://github.com/teoxoy/encase/issues/53 | |
pub struct GpuVec<T: Sized + Clone + bytemuck::Pod + bytemuck::Zeroable> { | |
buffer: GpuBuffer<T>, | |
data: Vec<T>, | |
dirty_index_set: DirtyIndexSet, | |
} | |
impl<T: Sized + Clone + bytemuck::Pod + bytemuck::Zeroable> GpuVec<T> { | |
pub fn new(device: &Device, capacity: usize, usage: BufferUsages, label: &str) -> Self { | |
let buffer = GpuBuffer::new(device, capacity, usage, label); | |
Self { | |
buffer, | |
data: vec![], | |
dirty_index_set: Default::default(), | |
} | |
} | |
pub fn new_with_data( | |
device: &Device, | |
usage: BufferUsages, | |
label: &str, | |
data: Vec<T>, | |
) -> Self { | |
let buffer = GpuBuffer::new_with_data(device, usage, label, data.as_slice()); | |
Self { | |
buffer, | |
data, | |
dirty_index_set: Default::default(), | |
} | |
} | |
pub fn total_data_len(&self) -> usize { | |
self.data.len() | |
} | |
pub fn buffer(&self) -> &Buffer { | |
v_assert!(self.dirty_index_set.is_empty(), "Buffer is dirty, call sync() first"); | |
self.buffer.inner() | |
} | |
fn realloc_if_necessary(&mut self, device: &Device) -> bool { | |
if self.data.len() > self.buffer.total_data_capacity() { | |
self.buffer = self.buffer.resized(device, self.data.len()).unwrap(); | |
true | |
} else { | |
false | |
} | |
} | |
pub fn insert(&mut self, data: impl IntoIterator<Item = T> + ExactSizeIterator) { | |
if !data.is_empty() { | |
let offset = self.data.len(); | |
self.data.extend(data); | |
self.dirty_index_set.insert(offset..self.data.len()); | |
} | |
} | |
pub fn update_at( | |
&mut self, | |
index: usize, | |
data: impl IntoIterator<Item = T> + ExactSizeIterator, | |
) { | |
if !data.is_empty() { | |
let length = data.len(); | |
self.data.splice(index..index + length, data); | |
self.dirty_index_set.insert(index..self.data.len()); | |
} | |
} | |
pub fn delete(&mut self, range: Range<usize>) { | |
assert!(range.end <= self.data.len()); | |
let offset = range.start; | |
let length = range.len(); | |
self.data.splice(offset..offset + length, vec![]); | |
self.dirty_index_set.insert(offset..self.data.len()); | |
} | |
pub fn sync(&mut self, device: &Device, queue: &Queue) -> bool { | |
let mut realloc = false; | |
if !self.dirty_index_set.is_empty() { | |
if self.realloc_if_necessary(device) { | |
self.dirty_index_set.insert(0..self.data.len()); | |
realloc = true; | |
} | |
for range in self.dirty_index_set.iter() { | |
let start = range.start; | |
let until = range.end.min(self.data.len()); // cap until to the length of the data | |
self.buffer.write(queue, start, &self.data[start..until]); | |
} | |
self.dirty_index_set.clear(); | |
} | |
realloc | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment