Last active
February 21, 2025 12:59
-
-
Save uvolchyk/8ce59857f800ca5cc939ef55150823ac to your computer and use it in GitHub Desktop.
Source code for the article: https://uvolchyk.me/blog/bursting-fireworks-with-metal
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import MetalKit | |
struct ConstantBuffer<T> { | |
let length: Int | |
var data: UnsafeMutablePointer<T> | |
var position: Int | |
static var stepSize: Int { | |
MemoryLayout<T>.stride | |
} | |
init (_ buffer: MTLBuffer) { | |
let dataPtr = buffer.contents() | |
let floatPtr = dataPtr.bindMemory( | |
to: T.self, | |
capacity: buffer.length / Self.stepSize | |
) | |
self.init( | |
buffer: floatPtr, | |
elementsCount: buffer.length / Self.stepSize | |
) | |
} | |
init( | |
buffer: UnsafeMutablePointer<T>, | |
elementsCount: Int | |
) { | |
data = buffer | |
length = elementsCount | |
position = .zero | |
} | |
var availableSpace: UInt { | |
UInt(length - position) | |
} | |
func hasSpace(for count: UInt) -> Bool { | |
availableSpace >= count | |
} | |
} | |
extension ConstantBuffer { | |
mutating func write( | |
value: inout T, | |
instance: Int = 0 | |
) { | |
withUnsafeBytes(of: &value) { bytes in | |
data[position] = bytes.load(as: T.self) | |
} | |
position = position &+ 1 | |
} | |
mutating func append(_ value: T) { | |
guard hasSpace(for: 1) else { return } | |
appendRaw(value) | |
} | |
mutating func appendRaw(_ value: T) { | |
data[position] = value | |
position = position &+ 1 | |
} | |
} | |
extension ConstantBuffer where T: SIMDScalar { | |
mutating func append(_ vector: SIMD3<T>) { | |
append(vector.x) | |
append(vector.y) | |
append(vector.z) | |
} | |
mutating func append(_ vector: SIMD4<T>) { | |
append(vector.x) | |
append(vector.y) | |
append(vector.z) | |
append(vector.w) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import simd | |
struct FireworkScene { | |
struct Uniform { | |
let color: SIMD4<Float> | |
let transform: simd_float4x4 | |
let time: Float | |
} | |
var trailStartPoint: SIMD2<Float> | |
var trailEndPoint: SIMD2<Float> | |
var trailControlPoint: SIMD2<Float> | |
var ratio: Float | |
var segments: Int = 8 | |
var trailWidth: Float = 0.05 | |
var burstPalmTrails: Int = 16 | |
var burstPalmRadius: Float = 0.7 | |
} | |
extension FireworkScene { | |
func draw( | |
vertexBuffer: inout ConstantBuffer<Float>, | |
progressBuffer: inout ConstantBuffer<Float> | |
) { | |
drawLaunchTrail( | |
vertexBuffer: &vertexBuffer, | |
progressBuffer: &progressBuffer | |
) | |
drawBurstPalm( | |
vertexBuffer: &vertexBuffer, | |
progressBuffer: &progressBuffer | |
) | |
} | |
} | |
private extension FireworkScene { | |
func drawLaunchTrail( | |
vertexBuffer: inout ConstantBuffer<Float>, | |
progressBuffer: inout ConstantBuffer<Float> | |
) { | |
let tIncrement = 1.0 / Float(segments) | |
var previousPoints: [SIMD2<Float>] = [] | |
for i in 0...segments { | |
let t = Float(i) * tIncrement | |
let currentPoint = quadraticBezier( | |
trailStartPoint, | |
trailControlPoint, | |
trailEndPoint, | |
t | |
) | |
let tangent = quadraticBezierTangent( | |
trailStartPoint, | |
trailControlPoint, | |
trailEndPoint, | |
t | |
) | |
let normalizedTangent = normalize(tangent) | |
let normal = SIMD2<Float>( | |
-normalizedTangent.y, | |
normalizedTangent.x | |
) | |
let offset = -trailWidth / 2 | |
let p1 = currentPoint + normal * offset | |
let p2 = currentPoint - normal * offset | |
guard i > 0 else { | |
previousPoints = [p1, p2] | |
continue | |
} | |
let p_current = 0.5 * Float(i) / (Float(segments)) | |
let p_previous = 0.5 * Float(i - 1) / (Float(segments)) | |
appendVertex(previousPoints[0], to: &vertexBuffer) | |
appendVertex(p1, to: &vertexBuffer) | |
appendVertex(previousPoints[1], to: &vertexBuffer) | |
progressBuffer.appendRaw(p_previous) | |
progressBuffer.appendRaw(p_current) | |
progressBuffer.appendRaw(p_previous) | |
appendVertex(previousPoints[1], to: &vertexBuffer) | |
appendVertex(p1, to: &vertexBuffer) | |
appendVertex(p2, to: &vertexBuffer) | |
progressBuffer.appendRaw(p_previous) | |
progressBuffer.appendRaw(p_current) | |
progressBuffer.appendRaw(p_current) | |
previousPoints = [p1, p2] | |
} | |
} | |
private func drawBurstPalm( | |
vertexBuffer: inout ConstantBuffer<Float>, | |
progressBuffer: inout ConstantBuffer<Float> | |
) { | |
let tIncrement = 1.0 / Float(segments) | |
let gravity: Float = 0.8 | |
for i in 0..<burstPalmTrails { | |
let angle = (Float.pi * 2 * Float(i)) / Float(burstPalmTrails) | |
let initialVelocity = SIMD2<Float>( | |
cos(angle), | |
sin(angle) | |
) * burstPalmRadius | |
let trailStart = trailEndPoint | |
var previousPoints: [SIMD2<Float>] = [] | |
for j in 0...segments { | |
let t = Float(j) * tIncrement | |
let currentPoint = SIMD2<Float>( | |
trailStart.x + initialVelocity.x * t, | |
trailStart.y + initialVelocity.y * t - (gravity * t * t) / 2 | |
) | |
let tangent = SIMD2<Float>( | |
initialVelocity.x, | |
initialVelocity.y - gravity * t | |
) | |
let normalizedTangent = normalize(tangent) | |
let normal = SIMD2<Float>(-normalizedTangent.y, normalizedTangent.x) | |
let offset = -trailWidth / 2 | |
let p1 = currentPoint + normal * offset | |
let p2 = currentPoint - normal * offset | |
guard j > 0 else { | |
previousPoints = [p1, p2] | |
continue | |
} | |
let p_current = 0.5 * Float(j) / Float(segments) + 0.5 | |
let p_previous = 0.5 * Float(j - 1) / Float(segments) + 0.5 | |
appendVertex(previousPoints[0], to: &vertexBuffer) | |
appendVertex(p1, to: &vertexBuffer) | |
appendVertex(previousPoints[1], to: &vertexBuffer) | |
progressBuffer.appendRaw(p_previous) | |
progressBuffer.appendRaw(p_current) | |
progressBuffer.appendRaw(p_previous) | |
appendVertex(previousPoints[1], to: &vertexBuffer) | |
appendVertex(p1, to: &vertexBuffer) | |
appendVertex(p2, to: &vertexBuffer) | |
progressBuffer.appendRaw(p_previous) | |
progressBuffer.appendRaw(p_current) | |
progressBuffer.appendRaw(p_current) | |
previousPoints = [p1, p2] | |
} | |
} | |
} | |
} | |
private extension FireworkScene { | |
func quadraticBezier( | |
_ p0: SIMD2<Float>, | |
_ p1: SIMD2<Float>, | |
_ p2: SIMD2<Float>, | |
_ t: Float | |
) -> SIMD2<Float> { | |
let oneMinusT = 1 - t | |
return oneMinusT * oneMinusT * p0 + 2 * oneMinusT * t * p1 + t * t * p2 | |
} | |
func quadraticBezierTangent( | |
_ p0: SIMD2<Float>, | |
_ p1: SIMD2<Float>, | |
_ p2: SIMD2<Float>, | |
_ t: Float | |
) -> SIMD2<Float> { | |
let oneMinusT = 1 - t | |
return 2 * oneMinusT * (p1 - p0) + 2 * t * (p2 - p1) | |
} | |
func appendVertex( | |
_ point: SIMD2<Float>, | |
to buffer: inout ConstantBuffer<Float> | |
) { | |
buffer.appendRaw(point.x * ratio) | |
buffer.appendRaw(point.y) | |
buffer.appendRaw(0.0) | |
buffer.appendRaw(1.0) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import MetalKit | |
import MetalPerformanceShaders | |
import SwiftUI | |
#if os(macOS) | |
typealias ViewController = NSViewController | |
#else | |
typealias ViewController = UIViewController | |
#endif | |
final class FireworkViewController: ViewController { | |
private var device: (any MTLDevice)! | |
private var commandQueue: (any MTLCommandQueue)! | |
private var scenePipelineState: (any MTLRenderPipelineState)! | |
private var vertexBuffer: (any MTLBuffer)! | |
private var progressBuffer: (any MTLBuffer)! | |
private var uniformBuffer: (any MTLBuffer)! | |
private var sceneTexture: (any MTLTexture)! | |
private var glowTexture: (any MTLTexture)! | |
private lazy var canvasView = MTKView() | |
let duration: Float = 3.0 | |
private let initialTime = CACurrentMediaTime() | |
private var compositePipelineState: (any MTLRenderPipelineState)! | |
private var sceneBuffer: (any MTLBuffer)! | |
} | |
extension FireworkViewController { | |
override func loadView() { | |
view = canvasView | |
} | |
override func viewDidLoad() { | |
super.viewDidLoad() | |
device = MTLCreateSystemDefaultDevice() | |
commandQueue = device.makeCommandQueue() | |
canvasView.device = device | |
canvasView.delegate = self | |
buildPipelineStates() | |
buildBuffers() | |
} | |
#if os(macOS) | |
override func viewDidLayout() { | |
super.viewDidLayout() | |
buildResources(size: canvasView.bounds.size) | |
} | |
#else | |
override func viewDidLayoutSubviews() { | |
super.viewDidLayoutSubviews() | |
buildResources(size: canvasView.bounds.size) | |
} | |
#endif | |
} | |
extension FireworkViewController: MTKViewDelegate { | |
func draw(in view: MTKView) { | |
var vBufferWrapper = ConstantBuffer<Float>(vertexBuffer) | |
var pBufferWrapper = ConstantBuffer<Float>(progressBuffer) | |
var uBufferWrapper = ConstantBuffer<FireworkScene.Uniform>(uniformBuffer) | |
let firework = FireworkScene( | |
trailStartPoint: SIMD2<Float>(-0.4, -0.8), | |
trailEndPoint: SIMD2<Float>(0.1, 0.5), | |
trailControlPoint: SIMD2<Float>(-0.3, -0.1), | |
ratio: Float(view.bounds.height / view.bounds.width), | |
trailWidth: 0.03 | |
) | |
let elapsedTime = Float(CACurrentMediaTime() - initialTime).truncatingRemainder(dividingBy: duration) | |
let t1 = easeOutCubic(elapsedTime / duration) | |
var uniform1 = FireworkScene.Uniform( | |
color: SIMD4<Float>(208.0 / 255.0, 80.0 / 255.0, 111.0 / 255.0, 1.0), | |
transform: .init(translationX: 0.1, y: 0.0), | |
time: t1 | |
) | |
uBufferWrapper.write(value: &uniform1) | |
let t2 = easeOutCubic((elapsedTime - 0.2) / duration) | |
var uniform2 = FireworkScene.Uniform( | |
color: SIMD4<Float>(100.0 / 255.0, 167.0 / 255.0, 230.0 / 255.0, 1.0), | |
transform: .init(translationX: -0.1, y: 0.0), | |
time: t2 | |
) | |
uBufferWrapper.write(value: &uniform2) | |
let t3 = easeOutCubic((elapsedTime - 0.4) / duration) | |
var uniform3 = FireworkScene.Uniform( | |
color: SIMD4<Float>(90.0 / 255.0, 98.0 / 255.0, 198.0 / 255.0, 1.0), | |
transform: .init(translationX: 0.0, y: 0.2), | |
time: t3 | |
) | |
uBufferWrapper.write(value: &uniform3) | |
firework.draw( | |
vertexBuffer: &vBufferWrapper, | |
progressBuffer: &pBufferWrapper | |
) | |
guard | |
let commandBuffer = commandQueue.makeCommandBuffer(), | |
let drawable = view.currentDrawable | |
else { return } | |
let vertexCount = vBufferWrapper.position / 4 | |
if | |
let renderPassDescriptor = sceneRenderPassDescriptor(sceneTexture), | |
let renderEncoder = commandBuffer.makeRenderCommandEncoder( | |
descriptor: renderPassDescriptor | |
) | |
{ | |
renderEncoder.setRenderPipelineState(scenePipelineState) | |
renderEncoder.setVertexBuffer( | |
vertexBuffer, | |
offset: 0, | |
index: 0 | |
) | |
renderEncoder.setVertexBuffer( | |
progressBuffer, | |
offset: 0, | |
index: 1 | |
) | |
renderEncoder.setVertexBuffer( | |
uniformBuffer, | |
offset: 0, | |
index: 2 | |
) | |
renderEncoder.drawPrimitives( | |
type: .triangle, | |
vertexStart: 0, | |
vertexCount: vertexCount, | |
instanceCount: 3 | |
) | |
renderEncoder.endEncoding() | |
} | |
let kernel = MPSImageGaussianBlur( | |
device: device, | |
sigma: 40.0 | |
) | |
kernel.encode( | |
commandBuffer: commandBuffer, | |
sourceTexture: sceneTexture, | |
destinationTexture: glowTexture | |
) | |
if | |
let renderPassDescriptor = sceneRenderPassDescriptor(drawable.texture), | |
let renderEncoder = commandBuffer.makeRenderCommandEncoder(descriptor: renderPassDescriptor) | |
{ | |
renderEncoder.setRenderPipelineState(compositePipelineState) | |
renderEncoder.setVertexBuffer(sceneBuffer, offset: 0, index: 0) | |
renderEncoder.setFragmentTexture(sceneTexture, index: 0) | |
renderEncoder.setFragmentTexture(glowTexture, index: 1) | |
renderEncoder.drawPrimitives( | |
type: .triangleStrip, | |
vertexStart: 0, | |
vertexCount: 4 | |
) | |
renderEncoder.endEncoding() | |
} | |
commandBuffer.present(drawable) | |
commandBuffer.commit() | |
} | |
func mtkView( | |
_ view: MTKView, | |
drawableSizeWillChange size: CGSize | |
) { | |
buildResources(size: size) | |
} | |
} | |
private extension FireworkViewController { | |
func buildPipelineStates() { | |
guard | |
let library = device.makeDefaultLibrary() | |
else { | |
return | |
} | |
let scenePipelineDescriptor = MTLRenderPipelineDescriptor() | |
scenePipelineDescriptor.vertexFunction = library.makeFunction(name: "Firework::vertexScene") | |
scenePipelineDescriptor.fragmentFunction = library.makeFunction(name: "Firework::fragmentScene") | |
scenePipelineDescriptor.colorAttachments[0].pixelFormat = .bgra8Unorm | |
scenePipelineDescriptor.colorAttachments[0].isBlendingEnabled = true | |
scenePipelineDescriptor.colorAttachments[0].rgbBlendOperation = .add | |
scenePipelineDescriptor.colorAttachments[0].alphaBlendOperation = .add | |
scenePipelineDescriptor.colorAttachments[0].sourceRGBBlendFactor = .sourceAlpha | |
scenePipelineDescriptor.colorAttachments[0].sourceAlphaBlendFactor = .sourceAlpha | |
scenePipelineDescriptor.colorAttachments[0].destinationRGBBlendFactor = .oneMinusSourceAlpha | |
scenePipelineDescriptor.colorAttachments[0].destinationAlphaBlendFactor = .oneMinusSourceAlpha | |
do { | |
scenePipelineState = try device.makeRenderPipelineState(descriptor: scenePipelineDescriptor) | |
} catch { | |
fatalError("Failed to create scene pipeline state: \(error)") | |
} | |
let compositePipelineDescriptor = MTLRenderPipelineDescriptor() | |
compositePipelineDescriptor.vertexFunction = library.makeFunction( | |
name: "Firework::vertexComposition" | |
) | |
compositePipelineDescriptor.fragmentFunction = library.makeFunction( | |
name: "Firework::fragmentComposition" | |
) | |
compositePipelineDescriptor.colorAttachments[0].pixelFormat = .bgra8Unorm | |
do { | |
compositePipelineState = try device.makeRenderPipelineState( | |
descriptor: compositePipelineDescriptor | |
) | |
} catch { | |
fatalError("Failed to create composite pipeline state: \(error)") | |
} | |
} | |
func buildBuffers() { | |
vertexBuffer = device.makeBuffer( | |
length: 10 * 1000 * 1000 | |
) | |
progressBuffer = device.makeBuffer( | |
length: 10 * 1000 * 1000 | |
) | |
uniformBuffer = device.makeBuffer( | |
length: 10 * 1000 * 1000 | |
) | |
let quadVertices: [Float] = [ | |
-1.0, 1.0, 0.0, 0.0, | |
1.0, 1.0, 1.0, 0.0, | |
-1.0, -1.0, 0.0, 1.0, | |
1.0, -1.0, 1.0, 1.0, | |
] | |
sceneBuffer = device.makeBuffer( | |
bytes: quadVertices, | |
length: MemoryLayout<Float>.stride * quadVertices.count, | |
options: [] | |
) | |
} | |
func buildResources(size: CGSize) { | |
let width = Int(size.width) | |
let height = Int(size.height) | |
// Create texture descriptors | |
let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor( | |
pixelFormat: .bgra8Unorm, | |
width: width, | |
height: height, | |
mipmapped: false | |
) | |
textureDescriptor.usage = [.renderTarget, .shaderRead, .shaderWrite] | |
sceneTexture = device.makeTexture(descriptor: textureDescriptor) | |
glowTexture = device.makeTexture(descriptor: textureDescriptor) | |
} | |
func sceneRenderPassDescriptor(_ texture: any MTLTexture) -> MTLRenderPassDescriptor? { | |
let descriptor = MTLRenderPassDescriptor() | |
descriptor.colorAttachments[0].texture = texture | |
descriptor.colorAttachments[0].loadAction = .clear | |
descriptor.colorAttachments[0].storeAction = .store | |
descriptor.colorAttachments[0].clearColor = MTLClearColorMake(0, 0, 0, 1) | |
return descriptor | |
} | |
} | |
#if os(macOS) | |
struct FireworkView: NSViewControllerRepresentable { | |
func makeNSViewController(context: Context) -> some NSViewController { | |
FireworkViewController() | |
} | |
func updateNSViewController( | |
_ nsViewController: NSViewControllerType, | |
context: Context | |
) {} | |
} | |
#else | |
struct FireworkView: UIViewControllerRepresentable { | |
func makeUIViewController(context: Context) -> some UIViewController { | |
FireworkViewController() | |
} | |
func updateUIViewController( | |
_ uiViewController: UIViewControllerType, | |
context: Context | |
) {} | |
} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
drawn-multiple.mp4