Skip to content

Instantly share code, notes, and snippets.

@minjang
Last active November 20, 2024 08:03
Show Gist options
  • Save minjang/25715aa9d618c6040a570c7188f03197 to your computer and use it in GitHub Desktop.
Save minjang/25715aa9d618c6040a570c7188f03197 to your computer and use it in GitHub Desktop.
TTMIR for matmul_kernel (03-matrix-multiplication-cpu.py)
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 1], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 1], warpsPerCTA = [1, 1], order = [1, 0]}>
#loc = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0)
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cpu", "triton_gpu.threads-per-warp" = 1 : i32} {
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg1: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg2: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg3: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg4: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg5: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg6: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg7: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg8: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0)) attributes {noinline = false} {
%c8_i32 = arith.constant 8 : i32 loc(#loc1)
%c16_i32 = arith.constant 16 : i32 loc(#loc1)
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc1)
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked1> loc(#loc1)
%c15_i32 = arith.constant 15 : i32 loc(#loc1)
%c1_i32 = arith.constant 1 : i32 loc(#loc1)
%c0_i32 = arith.constant 0 : i32 loc(#loc1)
%cst_1 = arith.constant dense<16> : tensor<16x16xi32, #blocked1> loc(#loc1)
%0 = tt.get_program_id x : i32 loc(#loc2)
%1 = arith.addi %arg3, %c15_i32 : i32 loc(#loc56)
%2 = arith.divsi %1, %c16_i32 : i32 loc(#loc57)
%3 = arith.addi %arg4, %c15_i32 : i32 loc(#loc58)
%4 = arith.divsi %3, %c16_i32 : i32 loc(#loc59)
%5 = arith.muli %4, %c8_i32 : i32 loc(#loc7)
%6 = arith.divsi %0, %5 : i32 loc(#loc8)
%7 = arith.muli %6, %c8_i32 : i32 loc(#loc9)
%8 = arith.subi %2, %7 : i32 loc(#loc10)
%9 = arith.minsi %8, %c8_i32 : i32 loc(#loc11)
%10 = arith.remsi %0, %9 : i32 loc(#loc12)
%11 = arith.addi %7, %10 : i32 loc(#loc13)
%12 = arith.remsi %0, %5 : i32 loc(#loc14)
%13 = arith.divsi %12, %9 : i32 loc(#loc15)
%14 = arith.muli %11, %c16_i32 : i32 loc(#loc16)
%15 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc17)
%16 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc17)
%17 = tt.splat %14 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc18)
%18 = arith.addi %17, %15 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc18)
%19 = tt.splat %arg3 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc19)
%20 = arith.remsi %18, %19 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc19)
%21 = arith.muli %13, %c16_i32 : i32 loc(#loc20)
%22 = tt.splat %21 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc21)
%23 = arith.addi %22, %16 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc21)
%24 = tt.splat %arg4 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc22)
%25 = arith.remsi %23, %24 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc22)
%26 = tt.expand_dims %20 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> loc(#loc23)
%27 = tt.splat %arg6 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc24)
%28 = arith.muli %26, %27 : tensor<16x1xi32, #blocked1> loc(#loc24)
%29 = tt.expand_dims %16 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1> loc(#loc25)
%30 = tt.broadcast %28 : tensor<16x1xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc26)
%31 = tt.broadcast %29 : tensor<1x16xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc26)
%32 = arith.addi %30, %31 : tensor<16x16xi32, #blocked1> loc(#loc26)
%33 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc27)
%34 = tt.addptr %33, %32 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc27)
%35 = tt.expand_dims %15 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> loc(#loc28)
%36 = tt.splat %arg7 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc29)
%37 = arith.muli %35, %36 : tensor<16x1xi32, #blocked1> loc(#loc29)
%38 = tt.expand_dims %25 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1> loc(#loc30)
%39 = tt.broadcast %37 : tensor<16x1xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc31)
%40 = tt.broadcast %38 : tensor<1x16xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc31)
%41 = arith.addi %39, %40 : tensor<16x16xi32, #blocked1> loc(#loc31)
%42 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc32)
%43 = tt.addptr %42, %41 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc32)
%44 = arith.addi %arg5, %c15_i32 : i32 loc(#loc60)
%45 = arith.divsi %44, %c16_i32 : i32 loc(#loc61)
%46 = arith.muli %arg7, %c16_i32 : i32 loc(#loc34)
%47 = tt.splat %46 : i32 -> tensor<16x16xi32, #blocked1> loc(#loc35)
%48:3 = scf.for %arg9 = %c0_i32 to %45 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %34, %arg12 = %43) -> (tensor<16x16xf32, #blocked>, tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16x!tt.ptr<f32>, #blocked1>) : i32 {
%66 = arith.muli %arg9, %c16_i32 : i32 loc(#loc37)
%67 = arith.subi %arg5, %66 : i32 loc(#loc38)
%68 = tt.splat %67 : i32 -> tensor<1x16xi32, #blocked1> loc(#loc39)
%69 = arith.cmpi slt, %29, %68 : tensor<1x16xi32, #blocked1> loc(#loc39)
%70 = tt.broadcast %69 : tensor<1x16xi1, #blocked1> -> tensor<16x16xi1, #blocked1> loc(#loc40)
%71 = tt.load %arg11, %70, %cst_0 : tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc40)
%72 = tt.splat %67 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc41)
%73 = arith.cmpi slt, %35, %72 : tensor<16x1xi32, #blocked1> loc(#loc41)
%74 = tt.broadcast %73 : tensor<16x1xi1, #blocked1> -> tensor<16x16xi1, #blocked1> loc(#loc42)
%75 = tt.load %arg12, %74, %cst_0 : tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc42)
%76 = triton_gpu.convert_layout %71 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> loc(#loc40)
%77 = triton_gpu.convert_layout %75 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> loc(#loc42)
%78 = tt.dot %76, %77, %arg10, inputPrecision = tf32 : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf32, #blocked> loc(#loc43)
%79 = tt.addptr %arg11, %cst_1 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc44)
%80 = tt.addptr %arg12, %47 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc35)
scf.yield %78, %79, %80 : tensor<16x16xf32, #blocked>, tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc45)
} loc(#loc36)
%49 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> loc(#loc46)
%50 = tt.splat %arg8 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc47)
%51 = arith.muli %50, %49 : tensor<16x1xi32, #blocked1> loc(#loc47)
%52 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<16x1x!tt.ptr<f32>, #blocked1> loc(#loc48)
%53 = tt.addptr %52, %51 : tensor<16x1x!tt.ptr<f32>, #blocked1>, tensor<16x1xi32, #blocked1> loc(#loc48)
%54 = tt.expand_dims %23 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1> loc(#loc49)
%55 = tt.broadcast %53 : tensor<16x1x!tt.ptr<f32>, #blocked1> -> tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc50)
%56 = tt.broadcast %54 : tensor<1x16xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc50)
%57 = tt.addptr %55, %56 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc50)
%58 = tt.splat %arg3 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc51)
%59 = arith.cmpi slt, %49, %58 : tensor<16x1xi32, #blocked1> loc(#loc51)
%60 = tt.splat %arg4 : i32 -> tensor<1x16xi32, #blocked1> loc(#loc52)
%61 = arith.cmpi slt, %54, %60 : tensor<1x16xi32, #blocked1> loc(#loc52)
%62 = tt.broadcast %59 : tensor<16x1xi1, #blocked1> -> tensor<16x16xi1, #blocked1> loc(#loc53)
%63 = tt.broadcast %61 : tensor<1x16xi1, #blocked1> -> tensor<16x16xi1, #blocked1> loc(#loc53)
%64 = arith.andi %62, %63 : tensor<16x16xi1, #blocked1> loc(#loc53)
%65 = triton_gpu.convert_layout %48#0 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #blocked1> loc(#loc54)
tt.store %57, %65, %64 : tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc54)
tt.return loc(#loc55)
} loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":188:24)
#loc3 = loc("/data/users/minjang/triton-oss/triton-cpu/python/triton/language/standard.py":40:22)
#loc4 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":189:27)
#loc5 = loc("/data/users/minjang/triton-oss/triton-cpu/python/triton/language/standard.py":40:28)
#loc6 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":190:27)
#loc7 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":191:38)
#loc8 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":192:22)
#loc9 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":193:29)
#loc10 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":194:35)
#loc11 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":194:48)
#loc12 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":195:33)
#loc13 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":195:27)
#loc14 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":196:19)
#loc15 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":196:40)
#loc16 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":205:23)
#loc17 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":205:51)
#loc18 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":205:38)
#loc19 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":205:68)
#loc20 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":206:23)
#loc21 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":206:38)
#loc22 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":206:68)
#loc23 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:30)
#loc24 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:41)
#loc25 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:60)
#loc26 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:53)
#loc27 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:22)
#loc28 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:29)
#loc29 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:40)
#loc30 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:60)
#loc31 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:52)
#loc32 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:22)
#loc33 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":217:33)
#loc34 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":227:33)
#loc35 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":227:18)
#loc36 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":217:22)
#loc37 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":221:59)
#loc38 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":221:55)
#loc39 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":221:51)
#loc40 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":221:20)
#loc41 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":222:51)
#loc42 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":222:20)
#loc43 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":224:35)
#loc44 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":226:18)
#loc45 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":227:8)
#loc46 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:41)
#loc47 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:33)
#loc48 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:21)
#loc49 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:72)
#loc50 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:52)
#loc51 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":238:33)
#loc52 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":238:58)
#loc53 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":238:39)
#loc54 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":239:21)
#loc55 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":239:4)
#loc56 = loc(callsite(#loc3 at #loc4))
#loc57 = loc(callsite(#loc5 at #loc4))
#loc58 = loc(callsite(#loc3 at #loc6))
#loc59 = loc(callsite(#loc5 at #loc6))
#loc60 = loc(callsite(#loc3 at #loc33))
#loc61 = loc(callsite(#loc5 at #loc33))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment