Created
June 1, 2020 19:05
-
-
Save dlibenzi/a01647cd204ebfb9e3bc0d1a8cb3eb51 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import torch | |
import torch_xla | |
import torch_xla.core.functions as xf | |
import torch_xla.core.xla_model as xm | |
import torch_xla.distributed.xla_multiprocessing as xmp | |
def big_mm(w, x, split=1): | |
ordinal = xm.get_ordinal() | |
# w = N x Ko | |
# WG = Ko * WORLD_SIZE | |
# x = WG x M | |
assert x.size(0) // xm.xrt_world_size() == w.size(1) | |
splits = [] | |
if split != 1: | |
size = x.size(1) | |
assert size % split == 0 | |
split_size = size // split | |
splits = torch.split(x, split_size, dim=1) | |
else: | |
splits.append(x) | |
results = [] | |
for xs in splits: | |
# xg = WG x (M * WORLD_SIZE) | |
xg = xf.all_gather(xs, dim=1) | |
# xgn = Ko x (M * WORLD_SIZE) | |
xgn = torch.narrow(xg, 0, ordinal * w.size(1), w.size(1)) | |
# wxg = N x (M * WORLD_SIZE) | |
wxg = w @ xgn | |
# rwxg = N x (M * WORLD_SIZE) | |
rwxg = xf.all_reduce(xm.REDUCE_SUM, wxg) | |
# wx = N x M | |
wx = torch.narrow(rwxg, 1, ordinal * xs.size(1), xs.size(1)) | |
results.append(wx) | |
return torch.cat(results, dim=1) if len(results) > 1 else results[0] | |
def _mp_fn(index): | |
device = xm.xla_device() | |
if xm.xla_device_hw(device) != 'CPU': | |
torch_xla._XLAC._xla_set_use_full_mat_mul_precision( | |
use_full_mat_mul_precision=True) | |
torch.manual_seed(11) | |
xm.set_rng_state(11) | |
KO = 2 | |
wsize = KO * xm.xrt_world_size() | |
wg = torch.randn(3, wsize, device=device, requires_grad=True) | |
w = torch.narrow(wg, 1, index * KO, KO) | |
x = torch.randn(wsize, 4, device=device) | |
mm = wg @ x | |
bmm = big_mm(w, x, split=1) | |
mm_cpu = mm.cpu() | |
bmm_cpu = bmm.cpu() | |
if not mm_cpu.allclose(bmm_cpu, rtol=1e-04, atol=1e-04): | |
print('big_mm() produced wrong result', file=sys.stderr) | |
print('[{}]\n{}\n{}'.format(index, mm_cpu, bmm_cpu), file=sys.stderr) | |
sys.exit(1) | |
else: | |
print( | |
'Default device {} does not support replication'.format(device), | |
file=sys.stderr) | |
if __name__ == '__main__': | |
xmp.spawn(_mp_fn, nprocs=None) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment