Skip to content

Instantly share code, notes, and snippets.

View minjang's full-sized avatar

Minjang Kim minjang

  • Facebook
  • Menlo Park, CA
View GitHub Profile
@minjang
minjang / 03-matrix-multiplication-cpu.py
Created November 20, 2024 08:00
matmul_kernel for 03-matrix-multiplication-cpu.py without leaky_relu
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak, #
@minjang
minjang / matmul_kernel.asm
Last active November 20, 2024 07:57
x86-64 (AVX512) for matmul_kernel (03-matrix-multiplication-cpu.py) from TTMIR
.text
.file "LLVMDialectModule"
.section .rodata,"a",@progbits
.p2align 6, 0x0 # -- Begin function matmul_kernel
.LCPI0_0:
.zero 4
.long 1 # 0x1
.long 2 # 0x2
.long 3 # 0x3
.long 4 # 0x4
@minjang
minjang / matmul_kernel.ttmir
Last active November 20, 2024 08:03
TTMIR for matmul_kernel (03-matrix-multiplication-cpu.py)
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 1], warpsPerCTA = [1, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [1, 1], warpsPerCTA = [1, 1], order = [1, 0]}>
#loc = loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0)
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cpu", "triton_gpu.threads-per-warp" = 1 : i32} {
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg1: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg2: !tt.ptr<f32> {tt.divisibility = 8192 : i32} loc("/data/users/minjang/triton-oss/triton-cpu/python/tutorials/03-matrix-multiplication-cpu.py":166:0), %arg3: i32 {tt.
@minjang
minjang / matmul_kernel.llir
Last active November 20, 2024 07:56
LLVM IR for matmul_kernel (03-matrix-multiplication-cpu.py) from TTMIR
; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"
; Function Attrs: nofree norecurse nosync nounwind memory(argmem: readwrite)
define void @matmul_kernel(ptr addrspace(1) nocapture readonly %0, ptr addrspace(1) nocapture readonly %1, ptr addrspace(1) nocapture writeonly %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8, i32 %9, i32 %10, i32 %11, i32 %12, i32 %13, i32 %14, i32 %15, i32 %16) local_unnamed_addr #0 !dbg !3 {
%18 = add i32 %3, 15, !dbg !6
%19 = sdiv i32 %18, 16, !dbg !10
%20 = add i32 %4, 15, !dbg !11
from typing import Optional, Union
import os
import numpy as np
import torch
import triton
import triton.language as tl
import math
from triton.runtime.jit import TensorWrapper, reinterpret
from numpy.random import RandomState
@minjang
minjang / patch.diff
Created May 23, 2024 01:10
A quick patch to make triton-cpu runnable for github.com/ienkovich/triton-cpu/tree/ienkovich/change-cast-test-size
diff --git a/include/triton/Conversion/CMakeLists.txt b/include/triton/Conversion/CMakeLists.txt
index ae31ac93..691104f3 100644
--- a/include/triton/Conversion/CMakeLists.txt
+++ b/include/triton/Conversion/CMakeLists.txt
@@ -1,4 +1,4 @@
-add_subdirectory(TritonCPUToLLVM)
+# add_subdirectory(TritonCPUToLLVM)
add_subdirectory(TritonGPUToLLVM)
-add_subdirectory(TritonToTritonCPU)
+# add_subdirectory(TritonToTritonCPU)
@minjang
minjang / dlmalloc.c
Created December 10, 2016 19:43
Version 2.8.6 Wed Aug 29 06:57:58 201
/*
This is a version (aka dlmalloc) of malloc/free/realloc written by
Doug Lea and released to the public domain, as explained at
http://creativecommons.org/publicdomain/zero/1.0/ Send questions,
comments, complaints, performance data, etc to [email protected]
* Version 2.8.6 Wed Aug 29 06:57:58 2012 Doug Lea
Note: There may be an updated version of this malloc obtainable at
ftp://gee.cs.oswego.edu/pub/misc/malloc.c
Check before installing!
@minjang
minjang / swap-mem2reg-inline-instcombine.ll
Created November 29, 2016 05:36
Fully optimized code
define i32 @_Z4testv() #0 {
entry:
%a = call i32 @get()
%b = call i32 @get()
%b = call i32 @process(i32 %a, i32 %b)
ret i32 %b
}
%a = call i32 @get()
%b = call i32 @get()
%xor = xor i32 %b, %a
%xor1 = xor i32 %a, %xor
; => %xor1 = %a ^ %xor
; => %xor1 = %a ^ (%b ^ %a) ; a ^ (b ^ a) = b ^ 0
; => %xor1 = %b ^ 0 ; b ^ 0 = b
; => %xor1 = %b ; 이후 %xor1 사용처를 모두 %b로 바꿈
; => %xor1 삭제
@minjang
minjang / swap-mem2reg-inline-mem2reg.ll
Created November 29, 2016 03:17
After mem2reg, inline, and mem2reg optimizations
define i32 @_Z4testv() #0 {
entry:
%call = call i32 @_Z3getv() ; a = get();
%call1 = call i32 @_Z3getv() ; b = get();
; temp_swap(a, b)는 사라짐
%xor.i = xor i32 %call1, %call ; xor_swap(a, b)가 xor_swap(b, a)로 바뀜
%xor1.i = xor i32 %call, %xor.i
%xor2.i = xor i32 %xor.i, %xor1.i
%call2 = call i32 @_Z7processii(i32 %xor2.i, i32 %xor1.i)
ret i32 %call2