Skip to content

Instantly share code, notes, and snippets.

@dlibenzi
Created May 27, 2020 16:00
Show Gist options
  • Save dlibenzi/cbe12dd7a2db5a403f63c057366d2954 to your computer and use it in GitHub Desktop.
Save dlibenzi/cbe12dd7a2db5a403f63c057366d2954 to your computer and use it in GitHub Desktop.
import torch
import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_op_registry as xor
import torch_xla.core.xla_model as xm
device = xm.xla_device()
tt = torch.randn(2, 2)
print(xb.tensor_shape(tt))
b = xb.create_builder('BuilderTest')
p0 = xb.mkparam(b, 0, xb.mkshape('f32', (2, 2)))
p1 = xb.mkparam(b, 1, xb.mkshape('f32', (2, 2)))
sh = p1.shape()
print(sh)
a = p0 + p1
m = a * p1
r = m.reshape((4, ), dimensions=(0, 1))
cc = xb.Op.constant(b, tt)
xcc = xb.Op.constant(b, torch.tensor(1.2))
cv = cc - a + xcc.broadcast(sizes=(2, 2))
def addme(a, b, **kwargs):
return a + b
cxz = xb.create_computation('addme', addme, (xb.mkshape('f32', ()), xb.mkshape('f32', ())))
fzero = xb.Op.constant(b, torch.tensor(0.0, dtype=torch.float))
rr = cv.reduce(fzero, cxz, (0, 1))
xxz = xb.create_computation('xaddme', addme, (cv.shape(), cv.shape()))
za = xb.Op.call(xxz, (cv, m))
pc = cv.pad(xcc, config=[[1, 1, 0], [1, 1, 0]])
q = cv.slice_in_dim(start_index=0, limit_index=1, dimno=0)
zero = xb.Op.constant(b, torch.tensor(0, dtype=torch.int32))
qq = cv.dynamic_slice(start_indices=(zero, zero), slice_sizes=(1, 1))
t = xb.Op.tuple((za, qq, rr, p0, p1, a, q, m, pc, r, cv, r.sin()), builder=b)
print(t)
print(a, m)
ob = m.builder()
print(ob)
c = t.build('TestComputation')
print(c)
print(xb.get_computation_hlo(c))
def add_op(a, b, **kwargs):
return a + b
cx = xb.create_computation('addo',
add_op,
(xb.mkshape('f32',
(2, 2)), xb.mkshape('f32', (2, 2))),
foo=11)
print(cx)
print(xb.get_computation_hlo(cx))
t1 = torch.randn(2, 2).to(device)
t2 = torch.randn(2, 2).to(device)
x = torch_xla._XLAC._xla_user_computation('xla::moose', (t1, t2), cx)
print(x)
print(t1)
print(t2)
def mul_op(a, b, **kwargs):
return a * b
zadd = xor.register('add', add_op)
zmul = xor.register('mul', mul_op)
x = zadd(t1, t2)
y = zmul(x, t2)
print(y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment