Created
May 31, 2024 13:47
-
-
Save malfet/25bcbce305e7425acf8616c9d5517652 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Metal | |
import MetalPerformanceShadersGraph | |
func calculateExpMetal(device: MTLDevice, ibuf: MTLBuffer, obuf: MTLBuffer, nelem: Int, fastMathEnabled: Bool = false) { | |
let shader_source = """ | |
#include <metal_stdlib> | |
using namespace metal; | |
kernel void do_exp(constant float *input [[buffer(0)]], | |
device float *output [[buffer(1)]], | |
uint thread_index [[thread_position_in_grid]]) { | |
output[thread_index] = exp(input[thread_index]); | |
} | |
""" | |
let options = MTLCompileOptions() | |
options.languageVersion = .version3_1 | |
options.fastMathEnabled = fastMathEnabled | |
let library = try! device.makeLibrary(source:shader_source, options:options) | |
guard let mfunc = library.makeFunction(name: "do_exp") else { fatalError("Can't find function") } | |
guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") } | |
guard let cmdBuffer = queue.makeCommandBuffer() else { fatalError("Can't make command buffer") } | |
guard let computeEncoder = cmdBuffer.makeComputeCommandEncoder() else { fatalError("Can't make compute encoder") } | |
computeEncoder.setComputePipelineState(try! device.makeComputePipelineState(function: mfunc)) | |
computeEncoder.setBuffer(ibuf, offset:0, index: 0) | |
computeEncoder.setBuffer(obuf, offset:0, index: 1) | |
computeEncoder.dispatchThreads(MTLSizeMake(nelem, 1, 1), threadsPerThreadgroup:MTLSizeMake(nelem, 1, 1)) | |
computeEncoder.endEncoding() | |
cmdBuffer.commit() | |
cmdBuffer.waitUntilCompleted() | |
} | |
func calculateExpMPS(device: MTLDevice, ibuf: MTLBuffer, obuf: MTLBuffer, nelem: Int) { | |
let graph = MPSGraph() | |
let inputPlaceholder = graph.placeholder(shape: [nelem as NSNumber], dataType: .float32, name: nil) | |
let expNode = graph.exponent(with: inputPlaceholder, name: nil) | |
let mpsInputBuffer = MPSGraphTensorData(ibuf, shape: [nelem as NSNumber], dataType: .float32) | |
let mpsOutputBuffer = MPSGraphTensorData(obuf, shape: [nelem as NSNumber], dataType: .float32) | |
guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") } | |
graph.run(with: queue, feeds: [inputPlaceholder: mpsInputBuffer], targetOperations: nil, resultsDictionary: [expNode: mpsOutputBuffer]) | |
} | |
guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") } | |
print("Using device \(device.name)") | |
let nelem = 256; | |
guard let ibuf = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") } | |
let ibuf_data = ibuf.contents().assumingMemoryBound(to: Float.self) | |
for i in 0..<nelem { | |
ibuf_data[i] = log(Float(i)*0.1 + 0.1) | |
} | |
guard let obuf_fast = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") } | |
guard let obuf_prec = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") } | |
guard let obuf_mps = device.makeBuffer(length:nelem * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") } | |
calculateExpMPS(device: device, ibuf: ibuf, obuf: obuf_mps, nelem: nelem) | |
calculateExpMetal(device: device, ibuf: ibuf, obuf: obuf_fast, nelem: nelem, fastMathEnabled: true) | |
calculateExpMetal(device: device, ibuf: ibuf, obuf: obuf_prec, nelem: nelem, fastMathEnabled: false) | |
let obuf_fast_data = obuf_fast.contents().assumingMemoryBound(to: Float.self) | |
let obuf_prec_data = obuf_prec.contents().assumingMemoryBound(to: Float.self) | |
let obuf_mps_data = obuf_mps.contents().assumingMemoryBound(to: Float.self) | |
for i in 0..<100 { | |
let cpu_exp = exp(ibuf_data[i]) | |
let fast_prec_diff = obuf_fast_data[i] - obuf_prec_data[i] | |
let mps_prec_diff = obuf_mps_data[i] - obuf_prec_data[i] | |
let prec_cpu_diff = obuf_prec_data[i] - cpu_exp | |
print("exp(\(ibuf_data[i])) = \(cpu_exp) cpu_prec_diff = \(prec_cpu_diff) fast vs prec diff = \(fast_prec_diff) mps diff = \(mps_prec_diff)") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment