Skip to content

Instantly share code, notes, and snippets.

@komakai
Created December 13, 2024 03:43
Show Gist options
  • Save komakai/f0046a57a904c7052dab50a95b206589 to your computer and use it in GitHub Desktop.
Save komakai/f0046a57a904c7052dab50a95b206589 to your computer and use it in GitHub Desktop.
Metal compute shader swift
kernel void add_arrays(device const float* inA [[buffer(0)]],
device const float* inB [[buffer(1)]],
device float* result [[buffer(2)]],
uint index [[thread_position_in_grid]])
{
result[index] = inA[index] + inB[index];
}
let bufferLength = 1024 * 512
let bufferSize = MemoryLayout<Float>.stride * bufferLength
let device = MTLCreateSystemDefaultDevice()
let defaultLibrary = device?.makeDefaultLibrary()
let addFunction = defaultLibrary?.makeFunction(name: "add_arrays")
let addFunctionPSO = try? device?.makeComputePipelineState(function: addFunction!)
let commandQueue = device?.makeCommandQueue()
let bufferDataA = (0..<bufferLength).map { _ in randomFloat() }
let bufferDataB = (0..<bufferLength).map { _ in randomFloat() }
let bufferA = device?.makeBuffer(bytes: bufferDataA, length: bufferSize, options: .storageModeShared)
let bufferB = device?.makeBuffer(bytes: bufferDataB, length: bufferSize, options: .storageModeShared)
let bufferResult = device?.makeBuffer(length: bufferSize, options: .storageModeShared)
let commandBuffer = commandQueue?.makeCommandBuffer()
let computeEncoder = commandBuffer?.makeComputeCommandEncoder()
computeEncoder?.setComputePipelineState(addFunctionPSO!)
computeEncoder?.setBuffer(bufferA, offset: 0, index: 0)
computeEncoder?.setBuffer(bufferB, offset: 0, index: 1)
computeEncoder?.setBuffer(bufferResult, offset: 0, index: 2)
let maxTotalThreadsPerThreadgroup = addFunctionPSO!.maxTotalThreadsPerThreadgroup
let threadGroupCount = MTLSizeMake(min(maxTotalThreadsPerThreadgroup, bufferLength), 1, 1)
let threadGroups = MTLSizeMake(bufferLength / threadGroupCount.width, 1, 1)
computeEncoder?.dispatchThreadgroups(threadGroups, threadsPerThreadgroup: threadGroupCount)
computeEncoder?.endEncoding()
commandBuffer?.commit()
commandBuffer?.waitUntilCompleted()
let floatBufferPointer = UnsafeBufferPointer(start: bufferResult?.contents().assumingMemoryBound(to: Float.self), count: bufferLength)
let result = [Float](floatBufferPointer)
var error = false
for index in 0..<bufferLength {
if (abs(bufferDataA[index] + bufferDataB[index] - result[index]) > Float.ulpOfOne) {
print("Error")
error = true
}
}
if (!error) {
print("Ok")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment