Last active
August 1, 2024 19:38
-
-
Save Ryu1845/09d51411f78252f5f98f03ae5527abae to your computer and use it in GitHub Desktop.
ZerO Initialization copied from the original repo (https://github.com/jiaweizzhao/ZerO-initialization/)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
def hadamard(n: int, dtype=torch.int8): | |
"""This function is a port of the one in scipy.linalg""" | |
if n < 1: | |
lg2 = 0 | |
else: | |
lg2 = int(math.log(n, 2)) | |
if 2 ** lg2 != n: | |
raise ValueError("n must be an positive integer, and n must be " | |
"a power of 2") | |
H = torch.tensor([[1]], dtype=dtype) | |
# Sylvester's construction | |
for i in range(0, lg2): | |
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H)))) | |
return H | |
@torch.compile() | |
@torch.no_grad() | |
def linear_ZerO_init_(tensor: torch.Tensor): | |
# Algorithm 1 in the paper. | |
assert len(tensor.shape) == 2, "linear_ZerO_init_ only works on 2D tensors" | |
m, n = tensor.shape | |
if m <= n: | |
tensor[:] = torch.nn.init.eye_(torch.empty(m, n)) | |
else: # m > n | |
clog_m = math.ceil(math.log2(m)) | |
p = 2**(clog_m) | |
tensor[:] = torch.nn.init.eye_(torch.empty(m, p)) @ (hadamard(p, dtype=tensor.dtype)/(2**(clog_m/2))) @ torch.nn.init.eye_(torch.empty(p, n)) | |
@torch.compile() | |
@torch.no_grad() | |
def conv2d_ZerO_init_(tensor: torch.Tensor): | |
"""Source: https://github.com/jiaweizzhao/ZerO-initialization/issues/1#issuecomment-1405598940""" | |
assert len(tensor.shape) == 4, "conv2d_ZerO_init_ only works on 4D tensors" | |
m, n, k, l = tensor.shape | |
index = int(math.floor(k/2)) | |
if m <= n: | |
tensor[:, :, index, index] = torch.nn.init.eye_(torch.empty(m, n)) | |
else: # m > n | |
clog_m = math.ceil(math.log2(m)) | |
p = 2**(clog_m) | |
tensor[:, :, index, index] = torch.nn.init.eye_(torch.empty(m, p)) @ (hadamard(p, dtype=tensor.dtype)/(2**(clog_m/2))) @ torch.nn.init.eye_(torch.empty(p, n)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment