Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shunting314/c82917a1849b698bf4d2be2fde2fd2ba to your computer and use it in GitHub Desktop.
Save shunting314/c82917a1849b698bf4d2be2fde2fd2ba to your computer and use it in GitHub Desktop.
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /tmp/torchinductor_shunting/4d/c4dzszdatmg24ritob4bg3vi2uvosxrussq2cxzji4i62z3xs4sw.py
# Topologically Sorted Source Nodes: [sum_1], Original ATen: [aten.sum]
# Source node to ATen node mapping:
# sum_1 => sum_1
# Graph fragment:
# %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%arg2_1, [0]), kwargs = {})
triton_red_fused_sum_0 = async_compile.triton('triton_red_fused_sum_0', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.reduction(
size_hints={'x': 131072, 'r0_': 16},
reduction_hint=ReductionHint.OUTER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '43284EFBB2F4A605713E05507A3DDC925900EFFAEA9FEAA732C9A879F43B2640', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.008912896}
)
@triton.jit
def triton_red_fused_sum_0(in_ptr0, out_ptr0, ks0, ks1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x1 = xindex // ks0
x0 = (xindex % ks0)
_tmp5 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
x3 = xindex
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = r0_2 + x1*((255 + ks1) // 256)
tmp1 = ks1
tmp2 = tmp0 < tmp1
tmp3 = tl.load(in_ptr0 + (x0 + ks0*r0_2 + ks0*x1*((255 + ks1) // 256)), r0_mask & xmask & tmp2, eviction_policy='evict_last', other=0.0)
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, R0_BLOCK])
tmp6 = _tmp5 + tmp4
_tmp5 = tl.where(r0_mask & xmask, tmp6, _tmp5)
tmp5 = tl.sum(_tmp5, 1)[:, None]
tl.store(out_ptr0 + (x3), tmp5, xmask)
def get_args():
arg_0 = rand_strided((4096, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((512, 256), (1, 512), device='cuda:0', dtype=torch.float32)
arg_2 = 512
arg_3 = 4096
return arg_0, arg_1, arg_2, arg_3, 131072, 16,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_sum_0.run(*args, stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_sum_0.benchmark_all_configs(*args)
if __name__ == '__main__':
from torch._inductor.runtime.benchmarking import benchmarker
args = get_args()
ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)
num_gb = 0.008912896
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
# kernel path: /tmp/torchinductor_shunting/rn/crnbrcdkm32ldaran3767a5emftwbaebj6r6ztonyudjq3iycx3x.py
# Topologically Sorted Source Nodes: [sum_1], Original ATen: [aten.sum]
# Source node to ATen node mapping:
# sum_1 => sum_1
# Graph fragment:
# %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%arg2_1, [0]), kwargs = {})
triton_red_fused_sum_1 = async_compile.triton('triton_red_fused_sum_1', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch
@triton_heuristics.reduction(
size_hints={'x': 512, 'r0_': 256},
reduction_hint=ReductionHint.OUTER_TINY,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_sum_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': '43284EFBB2F4A605713E05507A3DDC925900EFFAEA9FEAA732C9A879F43B2640', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'kernel_num_gb': 0.000526336}
)
@triton.jit
def triton_red_fused_sum_1(in_ptr0, out_ptr0, ks0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
r0_numel = 256
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (x0 + ks0*r0_1), r0_mask & xmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp2, xmask)
def get_args():
arg_0 = rand_strided((512, 256), (1, 512), device='cuda:0', dtype=torch.float32)
arg_1 = rand_strided((512,), (1,), device='cuda:0', dtype=torch.float32)
arg_2 = 512
return arg_0, arg_1, arg_2, 512, 256,
def call(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
stream0 = get_raw_stream(0)
triton_red_fused_sum_1.run(*args, stream=stream0)
def benchmark_all_configs(args):
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
return triton_red_fused_sum_1.benchmark_all_configs(*args)
if __name__ == '__main__':
from torch._inductor.runtime.benchmarking import benchmarker
args = get_args()
ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)
num_gb = 0.000526336
gb_per_s = num_gb / (ms / 1e3)
print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1 = args
args.clear()
s77 = arg0_1
s27 = arg1_1
assert_size_stride(arg2_1, (s77, s27), (s27, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
_xnumel = 256*s27
_r0_numel = (255 + s77) // 256
buf0 = empty_strided_cuda((s27, 256), (1, s27), torch.float32)
# Topologically Sorted Source Nodes: [sum_1], Original ATen: [aten.sum]
triton_red_fused_sum_0_xnumel = 256*s27
triton_red_fused_sum_0_r0_numel = (255 + s77) // 256
stream0 = get_raw_stream(0)
triton_red_fused_sum_0.run(arg2_1, buf0, s27, s77, triton_red_fused_sum_0_xnumel, triton_red_fused_sum_0_r0_numel, stream=stream0)
del arg2_1
buf1 = empty_strided_cuda((s27, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [sum_1], Original ATen: [aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused_sum_1.run(buf0, buf1, s27, s27, 256, stream=stream0)
del buf0
return (buf1, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = 4096
arg1_1 = 512
arg2_1 = rand_strided((4096, 512), (512, 1), device='cuda:0', dtype=torch.float32)
fn = lambda: call([arg0_1, arg1_1, arg2_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment