Last active
November 20, 2024 08:03
-
-
Save minjang/25715aa9d618c6040a570c7188f03197 to your computer and use it in GitHub Desktop.
TTMIR for matmul_kernel (03-matrix-multiplication-cpu.py)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 1], warpsPerCTA = [1, 1], order = [1, 0]}> | |
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 1], warpsPerCTA = [1, 1], order = [1, 0]}> | |
#loc = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0) | |
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cpu", "triton_gpu.threads-per-warp" = 1 : i32} { | |
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg1: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg2: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg3: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg4: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg5: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg6: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg7: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg8: i32 {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0)) attributes {noinline = false} { | |
%c8_i32 = arith.constant 8 : i32 loc(#loc1) | |
%c16_i32 = arith.constant 16 : i32 loc(#loc1) | |
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc1) | |
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked1> loc(#loc1) | |
%c15_i32 = arith.constant 15 : i32 loc(#loc1) | |
%c1_i32 = arith.constant 1 : i32 loc(#loc1) | |
%c0_i32 = arith.constant 0 : i32 loc(#loc1) | |
%cst_1 = arith.constant dense<16> : tensor<16x16xi32, #blocked1> loc(#loc1) | |
%0 = tt.get_program_id x : i32 loc(#loc2) | |
%1 = arith.addi %arg3, %c15_i32 : i32 loc(#loc56) | |
%2 = arith.divsi %1, %c16_i32 : i32 loc(#loc57) | |
%3 = arith.addi %arg4, %c15_i32 : i32 loc(#loc58) | |
%4 = arith.divsi %3, %c16_i32 : i32 loc(#loc59) | |
%5 = arith.muli %4, %c8_i32 : i32 loc(#loc7) | |
%6 = arith.divsi %0, %5 : i32 loc(#loc8) | |
%7 = arith.muli %6, %c8_i32 : i32 loc(#loc9) | |
%8 = arith.subi %2, %7 : i32 loc(#loc10) | |
%9 = arith.minsi %8, %c8_i32 : i32 loc(#loc11) | |
%10 = arith.remsi %0, %9 : i32 loc(#loc12) | |
%11 = arith.addi %7, %10 : i32 loc(#loc13) | |
%12 = arith.remsi %0, %5 : i32 loc(#loc14) | |
%13 = arith.divsi %12, %9 : i32 loc(#loc15) | |
%14 = arith.muli %11, %c16_i32 : i32 loc(#loc16) | |
%15 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc17) | |
%16 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc17) | |
%17 = tt.splat %14 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc18) | |
%18 = arith.addi %17, %15 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc18) | |
%19 = tt.splat %arg3 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc19) | |
%20 = arith.remsi %18, %19 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> loc(#loc19) | |
%21 = arith.muli %13, %c16_i32 : i32 loc(#loc20) | |
%22 = tt.splat %21 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc21) | |
%23 = arith.addi %22, %16 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc21) | |
%24 = tt.splat %arg4 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc22) | |
%25 = arith.remsi %23, %24 : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> loc(#loc22) | |
%26 = tt.expand_dims %20 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> loc(#loc23) | |
%27 = tt.splat %arg6 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc24) | |
%28 = arith.muli %26, %27 : tensor<16x1xi32, #blocked1> loc(#loc24) | |
%29 = tt.expand_dims %16 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1> loc(#loc25) | |
%30 = tt.broadcast %28 : tensor<16x1xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc26) | |
%31 = tt.broadcast %29 : tensor<1x16xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc26) | |
%32 = arith.addi %30, %31 : tensor<16x16xi32, #blocked1> loc(#loc26) | |
%33 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc27) | |
%34 = tt.addptr %33, %32 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc27) | |
%35 = tt.expand_dims %15 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> loc(#loc28) | |
%36 = tt.splat %arg7 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc29) | |
%37 = arith.muli %35, %36 : tensor<16x1xi32, #blocked1> loc(#loc29) | |
%38 = tt.expand_dims %25 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1> loc(#loc30) | |
%39 = tt.broadcast %37 : tensor<16x1xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc31) | |
%40 = tt.broadcast %38 : tensor<1x16xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc31) | |
%41 = arith.addi %39, %40 : tensor<16x16xi32, #blocked1> loc(#loc31) | |
%42 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc32) | |
%43 = tt.addptr %42, %41 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc32) | |
%44 = arith.addi %arg5, %c15_i32 : i32 loc(#loc60) | |
%45 = arith.divsi %44, %c16_i32 : i32 loc(#loc61) | |
%46 = arith.muli %arg7, %c16_i32 : i32 loc(#loc34) | |
%47 = tt.splat %46 : i32 -> tensor<16x16xi32, #blocked1> loc(#loc35) | |
%48:3 = scf.for %arg9 = %c0_i32 to %45 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %34, %arg12 = %43) -> (tensor<16x16xf32, #blocked>, tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16x!tt.ptr<f32>, #blocked1>) : i32 { | |
%66 = arith.muli %arg9, %c16_i32 : i32 loc(#loc37) | |
%67 = arith.subi %arg5, %66 : i32 loc(#loc38) | |
%68 = tt.splat %67 : i32 -> tensor<1x16xi32, #blocked1> loc(#loc39) | |
%69 = arith.cmpi slt, %29, %68 : tensor<1x16xi32, #blocked1> loc(#loc39) | |
%70 = tt.broadcast %69 : tensor<1x16xi1, #blocked1> -> tensor<16x16xi1, #blocked1> loc(#loc40) | |
%71 = tt.load %arg11, %70, %cst_0 : tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc40) | |
%72 = tt.splat %67 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc41) | |
%73 = arith.cmpi slt, %35, %72 : tensor<16x1xi32, #blocked1> loc(#loc41) | |
%74 = tt.broadcast %73 : tensor<16x1xi1, #blocked1> -> tensor<16x16xi1, #blocked1> loc(#loc42) | |
%75 = tt.load %arg12, %74, %cst_0 : tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc42) | |
%76 = triton_gpu.convert_layout %71 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> loc(#loc40) | |
%77 = triton_gpu.convert_layout %75 : tensor<16x16xf32, #blocked1> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> loc(#loc42) | |
%78 = tt.dot %76, %77, %arg10, inputPrecision = tf32 : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x16xf32, #blocked> loc(#loc43) | |
%79 = tt.addptr %arg11, %cst_1 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc44) | |
%80 = tt.addptr %arg12, %47 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc35) | |
scf.yield %78, %79, %80 : tensor<16x16xf32, #blocked>, tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc45) | |
} loc(#loc36) | |
%49 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi32, #blocked1> loc(#loc46) | |
%50 = tt.splat %arg8 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc47) | |
%51 = arith.muli %50, %49 : tensor<16x1xi32, #blocked1> loc(#loc47) | |
%52 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<16x1x!tt.ptr<f32>, #blocked1> loc(#loc48) | |
%53 = tt.addptr %52, %51 : tensor<16x1x!tt.ptr<f32>, #blocked1>, tensor<16x1xi32, #blocked1> loc(#loc48) | |
%54 = tt.expand_dims %23 {axis = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x16xi32, #blocked1> loc(#loc49) | |
%55 = tt.broadcast %53 : tensor<16x1x!tt.ptr<f32>, #blocked1> -> tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc50) | |
%56 = tt.broadcast %54 : tensor<1x16xi32, #blocked1> -> tensor<16x16xi32, #blocked1> loc(#loc50) | |
%57 = tt.addptr %55, %56 : tensor<16x16x!tt.ptr<f32>, #blocked1>, tensor<16x16xi32, #blocked1> loc(#loc50) | |
%58 = tt.splat %arg3 : i32 -> tensor<16x1xi32, #blocked1> loc(#loc51) | |
%59 = arith.cmpi slt, %49, %58 : tensor<16x1xi32, #blocked1> loc(#loc51) | |
%60 = tt.splat %arg4 : i32 -> tensor<1x16xi32, #blocked1> loc(#loc52) | |
%61 = arith.cmpi slt, %54, %60 : tensor<1x16xi32, #blocked1> loc(#loc52) | |
%62 = tt.broadcast %59 : tensor<16x1xi1, #blocked1> -> tensor<16x16xi1, #blocked1> loc(#loc53) | |
%63 = tt.broadcast %61 : tensor<1x16xi1, #blocked1> -> tensor<16x16xi1, #blocked1> loc(#loc53) | |
%64 = arith.andi %62, %63 : tensor<16x16xi1, #blocked1> loc(#loc53) | |
%65 = triton_gpu.convert_layout %48#0 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #blocked1> loc(#loc54) | |
tt.store %57, %65, %64 : tensor<16x16x!tt.ptr<f32>, #blocked1> loc(#loc54) | |
tt.return loc(#loc55) | |
} loc(#loc) | |
} loc(#loc) | |
#loc1 = loc(unknown) | |
#loc2 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":188:24) | |
#loc3 = loc("/data/users/minjang/triton-oss/triton-cpu/python/triton/language/standard.py":40:22) | |
#loc4 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":189:27) | |
#loc5 = loc("/data/users/minjang/triton-oss/triton-cpu/python/triton/language/standard.py":40:28) | |
#loc6 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":190:27) | |
#loc7 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":191:38) | |
#loc8 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":192:22) | |
#loc9 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":193:29) | |
#loc10 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":194:35) | |
#loc11 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":194:48) | |
#loc12 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":195:33) | |
#loc13 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":195:27) | |
#loc14 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":196:19) | |
#loc15 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":196:40) | |
#loc16 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":205:23) | |
#loc17 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":205:51) | |
#loc18 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":205:38) | |
#loc19 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":205:68) | |
#loc20 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":206:23) | |
#loc21 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":206:38) | |
#loc22 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":206:68) | |
#loc23 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:30) | |
#loc24 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:41) | |
#loc25 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:60) | |
#loc26 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:53) | |
#loc27 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":208:22) | |
#loc28 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:29) | |
#loc29 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:40) | |
#loc30 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:60) | |
#loc31 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:52) | |
#loc32 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":209:22) | |
#loc33 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":217:33) | |
#loc34 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":227:33) | |
#loc35 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":227:18) | |
#loc36 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":217:22) | |
#loc37 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":221:59) | |
#loc38 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":221:55) | |
#loc39 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":221:51) | |
#loc40 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":221:20) | |
#loc41 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":222:51) | |
#loc42 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":222:20) | |
#loc43 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":224:35) | |
#loc44 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":226:18) | |
#loc45 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":227:8) | |
#loc46 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:41) | |
#loc47 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:33) | |
#loc48 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:21) | |
#loc49 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:72) | |
#loc50 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":236:52) | |
#loc51 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":238:33) | |
#loc52 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":238:58) | |
#loc53 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":238:39) | |
#loc54 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":239:21) | |
#loc55 = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":239:4) | |
#loc56 = loc(callsite(#loc3 at #loc4)) | |
#loc57 = loc(callsite(#loc5 at #loc4)) | |
#loc58 = loc(callsite(#loc3 at #loc6)) | |
#loc59 = loc(callsite(#loc5 at #loc6)) | |
#loc60 = loc(callsite(#loc3 at #loc33)) | |
#loc61 = loc(callsite(#loc5 at #loc33)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment