Created
December 25, 2020 18:14
-
-
Save xrq-phys/c9d198dcd97647f73c0092733b77dec5 to your computer and use it in GitHub Desktop.
Very simple example of an ML Compute matrix multiplication. Naming convention somehow violated.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
import MLCompute | |
import PlaygroundSupport | |
let iPage = PlaygroundPage.current | |
iPage.needsIndefiniteExecution = true | |
/* | |
* Apple says MLCMatMulLayer does a ``batch matrix multiplication'' | |
* but didn't make clear its meaning. | |
* According to my try-and-error, it seems to mean that MLCMatMulLayer | |
* broadcasts GEMM operations along the first axis. | |
* (Meaning the 2nd and 3rd axes stores matrices to be multiplied.) | |
*/ | |
let tA = MLCTensor(shape: [1, 2, 2], dataType: .float32) | |
let tB = MLCTensor(shape: [1, 2, 2], dataType: .float32) | |
let tC = MLCTensor(shape: [1, 2, 2], dataType: .float32) | |
let bufA: [Float] = [1, 2, 3, 4] | |
let bufB: [Float] = [1, 2, 3, 4] | |
let bufC: [Float] = [1, 1, 1, 1] | |
let datA = MLCTensorData(immutableBytesNoCopy: UnsafeRawPointer(bufA), | |
length: bufA.count * MemoryLayout<Float>.size) | |
let datB = MLCTensorData(immutableBytesNoCopy: UnsafeRawPointer(bufB), | |
length: bufB.count * MemoryLayout<Float>.size) | |
let datC = MLCTensorData(immutableBytesNoCopy: UnsafeRawPointer(bufC), | |
length: bufC.count * MemoryLayout<Float>.size) | |
let iGraph = MLCGraph() | |
let tAB = iGraph.node(with: MLCMatMulLayer(descriptor: MLCMatMulDescriptor())!, | |
sources: [tA, tB]) | |
iGraph.node(with: MLCArithmeticLayer(operation: .add), sources: [tAB!, tC]) | |
let iPlan = MLCInferenceGraph(graphObjects: [iGraph]) | |
iPlan.addInputs(["A": tA, "B": tB, "C": tC]) | |
iPlan.compile(options: .debugLayers, device: MLCDevice()) | |
iPlan.execute(inputsData: ["A": datA, "B": datB, "C": datC], | |
batchSize: 0, | |
options: []) { (r, e, time) in | |
print("Error: \(String(describing: e))") | |
print("Result: \(String(describing: r))") | |
let bufO = UnsafeMutableRawPointer.allocate(byteCount: 4 * MemoryLayout<Float>.size, | |
alignment: MemoryLayout<Float>.alignment) | |
r!.copyDataFromDeviceMemory(toBytes: bufO, | |
length: 4 * MemoryLayout<Float>.size, | |
synchronizeWithDevice: false) | |
let outArray = bufO.bindMemory(to: Float.self, capacity: 4) | |
let outArrayDat = UnsafeBufferPointer(start: outArray, count: 4) | |
print(Array(outArrayDat)) | |
iPage.finishExecution() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment