Skip to content

Instantly share code, notes, and snippets.

@alexshtf
Last active August 31, 2024 19:53
Show Gist options
  • Save alexshtf/a434024013c07c4100bbe84e7bc4d580 to your computer and use it in GitHub Desktop.
Save alexshtf/a434024013c07c4100bbe84e7bc4d580 to your computer and use it in GitHub Desktop.
Batch iterators
import torch
class BatchIter:
"""
tensors: feature tensors (each with shape: num_instances x *)
"""
def __init__(self, *tensors, batch_size, shuffle=True):
self.tensors = tensors
device = tensors[0].device
n = tensors[0].size(0)
if shuffle:
idxs = torch.randperm(n, device=device)
else:
idxs = torch.arange(n, device=device)
self.idxs = idxs.split(batch_size)
def __len__(self):
return len(self.idxs)
def __iter__(self):
tensors = self.tensors
for batch_idxs in self.idxs:
yield tuple((x[batch_idxs, ...] for x in tensors))
def lexsort(*keys, dim=-1):
''' Computes the lexicographical sorting order of the given tensors, starting from the least significant to the
most significant ones
'''
if len(keys) == 0:
raise ValueError(f"Must have at least 1 key, but {len(keys)=}.")
idx = keys[0].argsort(dim=dim, stable=True)
for k in keys[1:]:
idx = idx.gather(dim, k.gather(dim, idx).argsort(dim=dim, stable=True))
return idx
def view_as_bytes(x):
'''
Decomposes the given tensor to its constituent bytes. The byte order depends on the computer architecture,
and is not consistent. So the result isn't useful for persistence, just for in-memory computations. The bytes
are added as the last dimension.
'''
element_bytes = x.dtype.itemsize
bytes_tensor = x.view(torch.uint8).view(x.shape + (element_bytes,))
return bytes_tensor.unbind(dim=-1)
def fnv_hash(tensor):
"""
Computes the FNV hash for each component of a PyTorch tensor of integers.
Args:
tensor: torch.tensor the tensor for which we compute element-wise FNV hash
Returns:
A PyTorch tensor of the same size and dtype as the input tensor, containing the FNV hash for each element.
"""
# Define the FNV prime and offset basis
FNV_PRIME = torch.tensor(0x01000193, dtype=torch.uint32)
FNV_OFFSET = torch.tensor(0x811c9dc5, dtype=torch.uint32)
# Initialize the hash value with zeros (same size and dtype as tensor)
hash_value = torch.full_like(tensor, FNV_OFFSET)
for byte in view_as_bytes(tensor):
hash_value = torch.bitwise_xor(hash_value * FNV_PRIME, byte)
# No need to reshape, output already has the same size and dtype as input
return hash_value
def group_idx(group_id):
'''
Given a sequence of group ids, each group given in consecutive order, compute the index where each
group begins
'''
values, counts = group_id.unique_consecutive(return_counts=True)
idx = torch.cumsum(counts, dim=-1)
return torch.nn.functional.pad(idx, (1, 0))
def batch_endpoint_indices(group_idx, batch_size):
'''
Given a tensor of indices where each group begins, and a batch size - compute
the start and end points of each mini-batch, each consisting of the specified
number of groups
'''
# pad group_idx to the smallest multiple of batch_size
padding_size = batch_size - (len(group_idx) - batch_size * (len(group_idx) // batch_size))
if padding_size > 0:
padding = group_idx[-1].expand(padding_size)
group_idx = torch.cat((group_idx, padding), dim=-1)
# extract start and end points
start_points = group_idx[0:-1:batch_size]
end_points = group_idx[batch_size::batch_size]
# return them as a list, so we can iterate over them
return start_points.tolist(), end_points.tolist()
class GroupBatchIter:
def __init__(self, group_id, *tensors, batch_size=1, shuffle=True, shuffle_seed=42):
self.group_id = group_id
self.tensors = tensors
if shuffle:
self.idxs = lexsort(group_id, fnv_hash(group_id + seed))
else:
self.idxs = torch.arange(len(group_id), device=group_id.device)
group_start_indices = group_idx(group_id[self.idxs])
self.batch_start, self.batch_end = batch_endpoint_indices(group_start_indices, batch_size)
def __len__(self):
return len(self.batch_start)
def __iter__(self):
# we create mini-batches containing both group-id, and the additional
# tensors
tensors = (self.group_id,) + self.tensors
# iterate over batch endpoints, and yield tensors
for start, end in zip(self.batch_start, self.batch_end):
batch_idxs = self.idxs[start:end]
if len(batch_idxs) > 0:
yield tuple(x[batch_idxs, ...] for x in tensors)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment