Skip to content

Instantly share code, notes, and snippets.

@YouJiacheng
Last active January 3, 2025 19:42
Show Gist options
  • Save YouJiacheng/44540bc155248252283e967a894f5f4a to your computer and use it in GitHub Desktop.
Save YouJiacheng/44540bc155248252283e967a894f5f4a to your computer and use it in GitHub Desktop.
def abs_cdf(t: Tensor, thresholds: list[float]):
t = t.abs()
level = torch.bucketize(t, t.new_tensor(thresholds), out_int32=True) # sum(x > v for v in thresholds)
return level.flatten().bincount(minlength=len(thresholds) + 1).cumsum(0) / t.numel()
# reference: https://github.com/pytorch/pytorch/issues/69519#issuecomment-2500366519
def histogram(input: Tensor, bins: Tensor, *, weight: Optional[Tensor] = None, density: bool = False):
bucket_indices = torch.bucketize(input, bins)
counts = torch.bincount(bucket_indices, weights=weight, minlength=bins.size(0)+1)
counts = counts[1:-1]
# Processing the last bin right border
if input.dtype == torch.int:
counts[-1] += torch.sum(input == bins[-1].item())
else:
counts[-1] += torch.sum(torch.isclose(input, bins[-1].item()))
width = bins[1:] - bins[:-1]
if not density:
out = (counts, bins)
else:
density = counts / width / counts.sum()
out = (density, bins)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment