Last active
August 31, 2024 19:53
-
-
Save alexshtf/a434024013c07c4100bbe84e7bc4d580 to your computer and use it in GitHub Desktop.
Batch iterators
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class BatchIter: | |
""" | |
tensors: feature tensors (each with shape: num_instances x *) | |
""" | |
def __init__(self, *tensors, batch_size, shuffle=True): | |
self.tensors = tensors | |
device = tensors[0].device | |
n = tensors[0].size(0) | |
if shuffle: | |
idxs = torch.randperm(n, device=device) | |
else: | |
idxs = torch.arange(n, device=device) | |
self.idxs = idxs.split(batch_size) | |
def __len__(self): | |
return len(self.idxs) | |
def __iter__(self): | |
tensors = self.tensors | |
for batch_idxs in self.idxs: | |
yield tuple((x[batch_idxs, ...] for x in tensors)) | |
def lexsort(*keys, dim=-1): | |
''' Computes the lexicographical sorting order of the given tensors, starting from the least significant to the | |
most significant ones | |
''' | |
if len(keys) == 0: | |
raise ValueError(f"Must have at least 1 key, but {len(keys)=}.") | |
idx = keys[0].argsort(dim=dim, stable=True) | |
for k in keys[1:]: | |
idx = idx.gather(dim, k.gather(dim, idx).argsort(dim=dim, stable=True)) | |
return idx | |
def view_as_bytes(x): | |
''' | |
Decomposes the given tensor to its constituent bytes. The byte order depends on the computer architecture, | |
and is not consistent. So the result isn't useful for persistence, just for in-memory computations. The bytes | |
are added as the last dimension. | |
''' | |
element_bytes = x.dtype.itemsize | |
bytes_tensor = x.view(torch.uint8).view(x.shape + (element_bytes,)) | |
return bytes_tensor.unbind(dim=-1) | |
def fnv_hash(tensor): | |
""" | |
Computes the FNV hash for each component of a PyTorch tensor of integers. | |
Args: | |
tensor: torch.tensor the tensor for which we compute element-wise FNV hash | |
Returns: | |
A PyTorch tensor of the same size and dtype as the input tensor, containing the FNV hash for each element. | |
""" | |
# Define the FNV prime and offset basis | |
FNV_PRIME = torch.tensor(0x01000193, dtype=torch.uint32) | |
FNV_OFFSET = torch.tensor(0x811c9dc5, dtype=torch.uint32) | |
# Initialize the hash value with zeros (same size and dtype as tensor) | |
hash_value = torch.full_like(tensor, FNV_OFFSET) | |
for byte in view_as_bytes(tensor): | |
hash_value = torch.bitwise_xor(hash_value * FNV_PRIME, byte) | |
# No need to reshape, output already has the same size and dtype as input | |
return hash_value | |
def group_idx(group_id): | |
''' | |
Given a sequence of group ids, each group given in consecutive order, compute the index where each | |
group begins | |
''' | |
values, counts = group_id.unique_consecutive(return_counts=True) | |
idx = torch.cumsum(counts, dim=-1) | |
return torch.nn.functional.pad(idx, (1, 0)) | |
def batch_endpoint_indices(group_idx, batch_size): | |
''' | |
Given a tensor of indices where each group begins, and a batch size - compute | |
the start and end points of each mini-batch, each consisting of the specified | |
number of groups | |
''' | |
# pad group_idx to the smallest multiple of batch_size | |
padding_size = batch_size - (len(group_idx) - batch_size * (len(group_idx) // batch_size)) | |
if padding_size > 0: | |
padding = group_idx[-1].expand(padding_size) | |
group_idx = torch.cat((group_idx, padding), dim=-1) | |
# extract start and end points | |
start_points = group_idx[0:-1:batch_size] | |
end_points = group_idx[batch_size::batch_size] | |
# return them as a list, so we can iterate over them | |
return start_points.tolist(), end_points.tolist() | |
class GroupBatchIter: | |
def __init__(self, group_id, *tensors, batch_size=1, shuffle=True, shuffle_seed=42): | |
self.group_id = group_id | |
self.tensors = tensors | |
if shuffle: | |
self.idxs = lexsort(group_id, fnv_hash(group_id + seed)) | |
else: | |
self.idxs = torch.arange(len(group_id), device=group_id.device) | |
group_start_indices = group_idx(group_id[self.idxs]) | |
self.batch_start, self.batch_end = batch_endpoint_indices(group_start_indices, batch_size) | |
def __len__(self): | |
return len(self.batch_start) | |
def __iter__(self): | |
# we create mini-batches containing both group-id, and the additional | |
# tensors | |
tensors = (self.group_id,) + self.tensors | |
# iterate over batch endpoints, and yield tensors | |
for start, end in zip(self.batch_start, self.batch_end): | |
batch_idxs = self.idxs[start:end] | |
if len(batch_idxs) > 0: | |
yield tuple(x[batch_idxs, ...] for x in tensors) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment